Repository: o8oo8o/GoWebSSH
Branch: main
Commit: 2784fa78fb3e
Files: 997
Total size: 12.8 MB
Directory structure:
gitextract_q8k6q315/
├── .github/
│ └── workflows/
│ └── build-and-release.yml
├── .gitignore
├── Dockerfile
├── LICENSE
├── README.md
├── gossh/
│ ├── Makefile
│ ├── app/
│ │ ├── config/
│ │ │ └── config.go
│ │ ├── middleware/
│ │ │ ├── db_check.go
│ │ │ ├── jwt_auth.go
│ │ │ ├── net_filter.go
│ │ │ ├── perm_check.go
│ │ │ └── sys_init.go
│ │ ├── model/
│ │ │ ├── cmd_note.go
│ │ │ ├── datetime.go
│ │ │ ├── db_init.go
│ │ │ ├── login_audit.go
│ │ │ ├── net_filter.go
│ │ │ ├── policy_conf.go
│ │ │ ├── ssh_conf.go
│ │ │ ├── sshd_cert.go
│ │ │ ├── sshd_conf.go
│ │ │ ├── sshd_user.go
│ │ │ └── web_user.go
│ │ ├── service/
│ │ │ ├── cmd_note.go
│ │ │ ├── db_conn.go
│ │ │ ├── login_audit.go
│ │ │ ├── net_filter.go
│ │ │ ├── policy_conf.go
│ │ │ ├── service_init.go
│ │ │ ├── ssh_conf.go
│ │ │ ├── ssh_conn.go
│ │ │ ├── ssh_sftp.go
│ │ │ ├── ssh_status.go
│ │ │ ├── sshd_cert.go
│ │ │ ├── sshd_server.go
│ │ │ ├── sshd_user.go
│ │ │ ├── sys_init.go
│ │ │ └── web_user.go
│ │ └── utils/
│ │ ├── crypto.go
│ │ └── utils.go
│ ├── crypto/
│ │ ├── blowfish/
│ │ │ ├── block.go
│ │ │ ├── cipher.go
│ │ │ └── const.go
│ │ ├── chacha20/
│ │ │ ├── chacha_arm64.go
│ │ │ ├── chacha_arm64.s
│ │ │ ├── chacha_generic.go
│ │ │ ├── chacha_noasm.go
│ │ │ ├── chacha_ppc64le.go
│ │ │ ├── chacha_ppc64le.s
│ │ │ ├── chacha_s390x.go
│ │ │ ├── chacha_s390x.s
│ │ │ └── xor.go
│ │ ├── chacha20poly1305/
│ │ │ ├── chacha20poly1305.go
│ │ │ ├── chacha20poly1305_amd64.go
│ │ │ ├── chacha20poly1305_amd64.s
│ │ │ ├── chacha20poly1305_generic.go
│ │ │ ├── chacha20poly1305_noasm.go
│ │ │ └── xchacha20poly1305.go
│ │ ├── curve25519/
│ │ │ ├── curve25519.go
│ │ │ ├── curve25519_compat.go
│ │ │ ├── curve25519_go120.go
│ │ │ └── internal/
│ │ │ └── field/
│ │ │ ├── fe.go
│ │ │ ├── fe_amd64.go
│ │ │ ├── fe_amd64.s
│ │ │ ├── fe_amd64_noasm.go
│ │ │ ├── fe_arm64.go
│ │ │ ├── fe_arm64.s
│ │ │ ├── fe_arm64_noasm.go
│ │ │ ├── fe_generic.go
│ │ │ ├── sync.checkpoint
│ │ │ └── sync.sh
│ │ ├── internal/
│ │ │ ├── alias/
│ │ │ │ ├── alias.go
│ │ │ │ └── alias_purego.go
│ │ │ ├── poly1305/
│ │ │ │ ├── mac_noasm.go
│ │ │ │ ├── poly1305.go
│ │ │ │ ├── sum_amd64.go
│ │ │ │ ├── sum_amd64.s
│ │ │ │ ├── sum_generic.go
│ │ │ │ ├── sum_ppc64le.go
│ │ │ │ ├── sum_ppc64le.s
│ │ │ │ ├── sum_s390x.go
│ │ │ │ └── sum_s390x.s
│ │ │ └── testenv/
│ │ │ ├── exec.go
│ │ │ ├── testenv_notunix.go
│ │ │ └── testenv_unix.go
│ │ └── ssh/
│ │ ├── agent/
│ │ │ ├── client.go
│ │ │ ├── forward.go
│ │ │ ├── keyring.go
│ │ │ └── server.go
│ │ ├── buffer.go
│ │ ├── certs.go
│ │ ├── channel.go
│ │ ├── cipher.go
│ │ ├── client.go
│ │ ├── client_auth.go
│ │ ├── common.go
│ │ ├── connection.go
│ │ ├── doc.go
│ │ ├── handshake.go
│ │ ├── internal/
│ │ │ └── bcrypt_pbkdf/
│ │ │ └── bcrypt_pbkdf.go
│ │ ├── kex.go
│ │ ├── keys.go
│ │ ├── knownhosts/
│ │ │ └── knownhosts.go
│ │ ├── mac.go
│ │ ├── messages.go
│ │ ├── mux.go
│ │ ├── server.go
│ │ ├── session.go
│ │ ├── ssh_gss.go
│ │ ├── streamlocal.go
│ │ ├── tcpip.go
│ │ ├── terminal/
│ │ │ └── terminal.go
│ │ └── transport.go
│ ├── gin/
│ │ ├── auth.go
│ │ ├── binding/
│ │ │ ├── binding.go
│ │ │ ├── default_validator.go
│ │ │ ├── form.go
│ │ │ ├── form_mapping.go
│ │ │ ├── header.go
│ │ │ ├── json.go
│ │ │ ├── multipart_form_mapping.go
│ │ │ ├── query.go
│ │ │ ├── uri.go
│ │ │ └── xml.go
│ │ ├── context.go
│ │ ├── debug.go
│ │ ├── errors.go
│ │ ├── fs.go
│ │ ├── gin.go
│ │ ├── ginS/
│ │ │ ├── README.md
│ │ │ └── gins.go
│ │ ├── internal/
│ │ │ ├── bytesconv/
│ │ │ │ └── bytesconv.go
│ │ │ └── json/
│ │ │ └── json.go
│ │ ├── jwt/
│ │ │ ├── README.md
│ │ │ ├── claims.go
│ │ │ ├── cmd/
│ │ │ │ └── jwt/
│ │ │ │ ├── README.md
│ │ │ │ └── main.go
│ │ │ ├── ecdsa.go
│ │ │ ├── ecdsa_utils.go
│ │ │ ├── ed25519.go
│ │ │ ├── ed25519_utils.go
│ │ │ ├── errors.go
│ │ │ ├── errors_go1_20.go
│ │ │ ├── errors_go_other.go
│ │ │ ├── hmac.go
│ │ │ ├── map_claims.go
│ │ │ ├── none.go
│ │ │ ├── parser.go
│ │ │ ├── parser_option.go
│ │ │ ├── registered_claims.go
│ │ │ ├── request/
│ │ │ │ ├── doc.go
│ │ │ │ ├── extractor.go
│ │ │ │ ├── oauth2.go
│ │ │ │ └── request.go
│ │ │ ├── rsa.go
│ │ │ ├── rsa_pss.go
│ │ │ ├── rsa_utils.go
│ │ │ ├── signing_method.go
│ │ │ ├── token.go
│ │ │ ├── token_option.go
│ │ │ ├── types.go
│ │ │ └── validator.go
│ │ ├── logger.go
│ │ ├── mode.go
│ │ ├── path.go
│ │ ├── recovery.go
│ │ ├── render/
│ │ │ ├── data.go
│ │ │ ├── html.go
│ │ │ ├── json.go
│ │ │ ├── reader.go
│ │ │ ├── redirect.go
│ │ │ ├── render.go
│ │ │ ├── text.go
│ │ │ └── xml.go
│ │ ├── response_writer.go
│ │ ├── routergroup.go
│ │ ├── sessions/
│ │ │ ├── context/
│ │ │ │ └── context.go
│ │ │ ├── cookie/
│ │ │ │ └── cookie.go
│ │ │ ├── memstore/
│ │ │ │ ├── cache.go
│ │ │ │ ├── memstore.go
│ │ │ │ └── store.go
│ │ │ ├── securecookie/
│ │ │ │ └── securecookie.go
│ │ │ ├── session_options_go1.10.go
│ │ │ ├── session_options_go1.11.go
│ │ │ ├── sessions/
│ │ │ │ ├── cookie.go
│ │ │ │ ├── cookie_go111.go
│ │ │ │ ├── doc.go
│ │ │ │ ├── lex.go
│ │ │ │ ├── options.go
│ │ │ │ ├── options_go111.go
│ │ │ │ ├── sessions.go
│ │ │ │ └── store.go
│ │ │ └── sessions.go
│ │ ├── sse/
│ │ │ ├── sse-decoder.go
│ │ │ ├── sse-encoder.go
│ │ │ └── writer.go
│ │ ├── tree.go
│ │ ├── utils.go
│ │ ├── validator/
│ │ │ ├── baked_in.go
│ │ │ ├── cache.go
│ │ │ ├── country_codes.go
│ │ │ ├── currency_codes.go
│ │ │ ├── errors.go
│ │ │ ├── field_level.go
│ │ │ ├── options.go
│ │ │ ├── postcode_regexes.go
│ │ │ ├── regexes.go
│ │ │ ├── struct_level.go
│ │ │ ├── util.go
│ │ │ ├── validator.go
│ │ │ └── validator_instance.go
│ │ └── version.go
│ ├── go.mod
│ ├── gorm/
│ │ ├── association.go
│ │ ├── callbacks/
│ │ │ ├── associations.go
│ │ │ ├── callbacks.go
│ │ │ ├── callmethod.go
│ │ │ ├── create.go
│ │ │ ├── delete.go
│ │ │ ├── helper.go
│ │ │ ├── interfaces.go
│ │ │ ├── preload.go
│ │ │ ├── query.go
│ │ │ ├── raw.go
│ │ │ ├── row.go
│ │ │ ├── transaction.go
│ │ │ └── update.go
│ │ ├── callbacks.go
│ │ ├── chainable_api.go
│ │ ├── clause/
│ │ │ ├── clause.go
│ │ │ ├── delete.go
│ │ │ ├── expression.go
│ │ │ ├── from.go
│ │ │ ├── group_by.go
│ │ │ ├── insert.go
│ │ │ ├── joins.go
│ │ │ ├── limit.go
│ │ │ ├── locking.go
│ │ │ ├── on_conflict.go
│ │ │ ├── order_by.go
│ │ │ ├── returning.go
│ │ │ ├── select.go
│ │ │ ├── set.go
│ │ │ ├── update.go
│ │ │ ├── values.go
│ │ │ ├── where.go
│ │ │ └── with.go
│ │ ├── driver/
│ │ │ ├── mysql/
│ │ │ │ ├── error_translator.go
│ │ │ │ ├── migrator.go
│ │ │ │ └── mysql.go
│ │ │ └── pgsql/
│ │ │ ├── error_translator.go
│ │ │ ├── migrator.go
│ │ │ └── postgres.go
│ │ ├── errors.go
│ │ ├── finisher_api.go
│ │ ├── gorm.go
│ │ ├── inflection/
│ │ │ └── inflections.go
│ │ ├── interfaces.go
│ │ ├── logger/
│ │ │ ├── logger.go
│ │ │ └── sql.go
│ │ ├── migrator/
│ │ │ ├── column_type.go
│ │ │ ├── index.go
│ │ │ ├── migrator.go
│ │ │ └── table_type.go
│ │ ├── migrator.go
│ │ ├── model.go
│ │ ├── now/
│ │ │ ├── main.go
│ │ │ ├── now.go
│ │ │ └── time.go
│ │ ├── prepare_stmt.go
│ │ ├── scan.go
│ │ ├── schema/
│ │ │ ├── constraint.go
│ │ │ ├── field.go
│ │ │ ├── index.go
│ │ │ ├── interfaces.go
│ │ │ ├── naming.go
│ │ │ ├── pool.go
│ │ │ ├── relationship.go
│ │ │ ├── schema.go
│ │ │ ├── serializer.go
│ │ │ └── utils.go
│ │ ├── soft_delete.go
│ │ ├── statement.go
│ │ └── utils/
│ │ ├── tests/
│ │ │ ├── dummy_dialecter.go
│ │ │ ├── models.go
│ │ │ └── utils.go
│ │ └── utils.go
│ ├── main.go
│ ├── mysql/
│ │ ├── atomic_bool.go
│ │ ├── atomic_bool_go118.go
│ │ ├── auth.go
│ │ ├── buffer.go
│ │ ├── collations.go
│ │ ├── conncheck.go
│ │ ├── conncheck_dummy.go
│ │ ├── connection.go
│ │ ├── connector.go
│ │ ├── const.go
│ │ ├── driver.go
│ │ ├── dsn.go
│ │ ├── errors.go
│ │ ├── fields.go
│ │ ├── infile.go
│ │ ├── nulltime.go
│ │ ├── packets.go
│ │ ├── result.go
│ │ ├── rows.go
│ │ ├── statement.go
│ │ ├── transaction.go
│ │ └── utils.go
│ ├── pgsql/
│ │ ├── array.go
│ │ ├── buf.go
│ │ ├── conn.go
│ │ ├── conn_go115.go
│ │ ├── conn_go18.go
│ │ ├── connector.go
│ │ ├── copy.go
│ │ ├── encode.go
│ │ ├── error.go
│ │ ├── hstore/
│ │ │ └── hstore.go
│ │ ├── krb.go
│ │ ├── notice.go
│ │ ├── notify.go
│ │ ├── oid/
│ │ │ ├── doc.go
│ │ │ ├── gen.go
│ │ │ └── types.go
│ │ ├── rows.go
│ │ ├── scram/
│ │ │ └── scram.go
│ │ ├── ssl.go
│ │ ├── ssl_permissions.go
│ │ ├── ssl_windows.go
│ │ ├── url.go
│ │ ├── user_other.go
│ │ ├── user_posix.go
│ │ ├── user_windows.go
│ │ └── uuid.go
│ ├── pty/
│ │ ├── asm_solaris_amd64.s
│ │ ├── cmd_windows.go
│ │ ├── doc.go
│ │ ├── ioctl.go
│ │ ├── ioctl_bsd.go
│ │ ├── ioctl_inner.go
│ │ ├── ioctl_legacy.go
│ │ ├── ioctl_solaris.go
│ │ ├── ioctl_unsupported.go
│ │ ├── pty_darwin.go
│ │ ├── pty_dragonfly.go
│ │ ├── pty_freebsd.go
│ │ ├── pty_linux.go
│ │ ├── pty_netbsd.go
│ │ ├── pty_openbsd.go
│ │ ├── pty_solaris.go
│ │ ├── pty_unsupported.go
│ │ ├── pty_windows.go
│ │ ├── run.go
│ │ ├── run_unix.go
│ │ ├── run_windows.go
│ │ ├── types.go
│ │ ├── types_dragonfly.go
│ │ ├── types_freebsd.go
│ │ ├── types_netbsd.go
│ │ ├── types_openbsd.go
│ │ ├── winsize.go
│ │ ├── winsize_unix.go
│ │ ├── winsize_windows.go
│ │ ├── ztypes_386.go
│ │ ├── ztypes_amd64.go
│ │ ├── ztypes_arm.go
│ │ ├── ztypes_arm64.go
│ │ ├── ztypes_dragonfly_amd64.go
│ │ ├── ztypes_freebsd_386.go
│ │ ├── ztypes_freebsd_amd64.go
│ │ ├── ztypes_freebsd_arm.go
│ │ ├── ztypes_freebsd_arm64.go
│ │ ├── ztypes_freebsd_ppc64.go
│ │ ├── ztypes_freebsd_riscv64.go
│ │ ├── ztypes_loong64.go
│ │ ├── ztypes_mipsx.go
│ │ ├── ztypes_netbsd_32bit_int.go
│ │ ├── ztypes_openbsd_32bit_int.go
│ │ ├── ztypes_ppc.go
│ │ ├── ztypes_ppc64.go
│ │ ├── ztypes_ppc64le.go
│ │ ├── ztypes_riscvx.go
│ │ ├── ztypes_s390x.go
│ │ └── ztypes_sparcx.go
│ ├── readme.md
│ ├── sftp/
│ │ ├── allocator.go
│ │ ├── attrs.go
│ │ ├── attrs_stubs.go
│ │ ├── attrs_unix.go
│ │ ├── client.go
│ │ ├── conn.go
│ │ ├── debug.go
│ │ ├── fuzz.go
│ │ ├── internal/
│ │ │ ├── encoding/
│ │ │ │ └── ssh/
│ │ │ │ └── filexfer/
│ │ │ │ ├── attrs.go
│ │ │ │ ├── buffer.go
│ │ │ │ ├── extended_packets.go
│ │ │ │ ├── extensions.go
│ │ │ │ ├── filexfer.go
│ │ │ │ ├── fx.go
│ │ │ │ ├── fxp.go
│ │ │ │ ├── handle_packets.go
│ │ │ │ ├── init_packets.go
│ │ │ │ ├── open_packets.go
│ │ │ │ ├── openssh/
│ │ │ │ │ ├── fsync.go
│ │ │ │ │ ├── hardlink.go
│ │ │ │ │ ├── openssh.go
│ │ │ │ │ ├── posix-rename.go
│ │ │ │ │ └── statvfs.go
│ │ │ │ ├── packets.go
│ │ │ │ ├── path_packets.go
│ │ │ │ ├── permissions.go
│ │ │ │ └── response_packets.go
│ │ │ └── sftpfs/
│ │ │ ├── filesystem.go
│ │ │ └── walk.go
│ │ ├── ls_formatting.go
│ │ ├── ls_plan9.go
│ │ ├── ls_stub.go
│ │ ├── ls_unix.go
│ │ ├── match.go
│ │ ├── packet-manager.go
│ │ ├── packet-typing.go
│ │ ├── packet.go
│ │ ├── pool.go
│ │ ├── release.go
│ │ ├── request-attrs.go
│ │ ├── request-errors.go
│ │ ├── request-example.go
│ │ ├── request-interfaces.go
│ │ ├── request-plan9.go
│ │ ├── request-server.go
│ │ ├── request-unix.go
│ │ ├── request.go
│ │ ├── request_windows.go
│ │ ├── server.go
│ │ ├── server_plan9.go
│ │ ├── server_statvfs_darwin.go
│ │ ├── server_statvfs_impl.go
│ │ ├── server_statvfs_linux.go
│ │ ├── server_statvfs_plan9.go
│ │ ├── server_statvfs_stubs.go
│ │ ├── server_unix.go
│ │ ├── server_windows.go
│ │ ├── sftp.go
│ │ ├── stat_plan9.go
│ │ ├── stat_posix.go
│ │ ├── syscall_fixed.go
│ │ └── syscall_good.go
│ ├── sqlite/
│ │ ├── ddlmod.go
│ │ ├── errors.go
│ │ ├── migrator.go
│ │ └── sqlite.go
│ ├── sys/
│ │ ├── cpu/
│ │ │ ├── asm_aix_ppc64.s
│ │ │ ├── byteorder.go
│ │ │ ├── cpu.go
│ │ │ ├── cpu_aix.go
│ │ │ ├── cpu_arm.go
│ │ │ ├── cpu_arm64.go
│ │ │ ├── cpu_arm64.s
│ │ │ ├── cpu_gc_arm64.go
│ │ │ ├── cpu_gc_s390x.go
│ │ │ ├── cpu_gc_x86.go
│ │ │ ├── cpu_gccgo_arm64.go
│ │ │ ├── cpu_gccgo_s390x.go
│ │ │ ├── cpu_gccgo_x86.c
│ │ │ ├── cpu_gccgo_x86.go
│ │ │ ├── cpu_linux.go
│ │ │ ├── cpu_linux_arm.go
│ │ │ ├── cpu_linux_arm64.go
│ │ │ ├── cpu_linux_mips64x.go
│ │ │ ├── cpu_linux_noinit.go
│ │ │ ├── cpu_linux_ppc64x.go
│ │ │ ├── cpu_linux_s390x.go
│ │ │ ├── cpu_loong64.go
│ │ │ ├── cpu_mips64x.go
│ │ │ ├── cpu_mipsx.go
│ │ │ ├── cpu_netbsd_arm64.go
│ │ │ ├── cpu_openbsd_arm64.go
│ │ │ ├── cpu_openbsd_arm64.s
│ │ │ ├── cpu_other_arm.go
│ │ │ ├── cpu_other_arm64.go
│ │ │ ├── cpu_other_mips64x.go
│ │ │ ├── cpu_other_ppc64x.go
│ │ │ ├── cpu_other_riscv64.go
│ │ │ ├── cpu_ppc64x.go
│ │ │ ├── cpu_riscv64.go
│ │ │ ├── cpu_s390x.go
│ │ │ ├── cpu_s390x.s
│ │ │ ├── cpu_wasm.go
│ │ │ ├── cpu_x86.go
│ │ │ ├── cpu_x86.s
│ │ │ ├── cpu_zos.go
│ │ │ ├── cpu_zos_s390x.go
│ │ │ ├── endian_big.go
│ │ │ ├── endian_little.go
│ │ │ ├── hwcap_linux.go
│ │ │ ├── parse.go
│ │ │ ├── proc_cpuinfo_linux.go
│ │ │ ├── runtime_auxv.go
│ │ │ ├── runtime_auxv_go121.go
│ │ │ ├── syscall_aix_gccgo.go
│ │ │ └── syscall_aix_ppc64_gc.go
│ │ ├── execabs/
│ │ │ ├── execabs.go
│ │ │ ├── execabs_go118.go
│ │ │ └── execabs_go119.go
│ │ ├── plan9/
│ │ │ ├── asm.s
│ │ │ ├── asm_plan9_386.s
│ │ │ ├── asm_plan9_amd64.s
│ │ │ ├── asm_plan9_arm.s
│ │ │ ├── const_plan9.go
│ │ │ ├── dir_plan9.go
│ │ │ ├── env_plan9.go
│ │ │ ├── errors_plan9.go
│ │ │ ├── mkall.sh
│ │ │ ├── mkerrors.sh
│ │ │ ├── mksyscall.go
│ │ │ ├── mksysnum_plan9.sh
│ │ │ ├── pwd_go15_plan9.go
│ │ │ ├── pwd_plan9.go
│ │ │ ├── race.go
│ │ │ ├── race0.go
│ │ │ ├── str.go
│ │ │ ├── syscall.go
│ │ │ ├── syscall_plan9.go
│ │ │ ├── zsyscall_plan9_386.go
│ │ │ ├── zsyscall_plan9_amd64.go
│ │ │ ├── zsyscall_plan9_arm.go
│ │ │ └── zsysnum_plan9.go
│ │ ├── unix/
│ │ │ ├── README.md
│ │ │ ├── affinity_linux.go
│ │ │ ├── aliases.go
│ │ │ ├── asm_aix_ppc64.s
│ │ │ ├── asm_bsd_386.s
│ │ │ ├── asm_bsd_amd64.s
│ │ │ ├── asm_bsd_arm.s
│ │ │ ├── asm_bsd_arm64.s
│ │ │ ├── asm_bsd_ppc64.s
│ │ │ ├── asm_bsd_riscv64.s
│ │ │ ├── asm_linux_386.s
│ │ │ ├── asm_linux_amd64.s
│ │ │ ├── asm_linux_arm.s
│ │ │ ├── asm_linux_arm64.s
│ │ │ ├── asm_linux_loong64.s
│ │ │ ├── asm_linux_mips64x.s
│ │ │ ├── asm_linux_mipsx.s
│ │ │ ├── asm_linux_ppc64x.s
│ │ │ ├── asm_linux_riscv64.s
│ │ │ ├── asm_linux_s390x.s
│ │ │ ├── asm_openbsd_mips64.s
│ │ │ ├── asm_solaris_amd64.s
│ │ │ ├── asm_zos_s390x.s
│ │ │ ├── bluetooth_linux.go
│ │ │ ├── cap_freebsd.go
│ │ │ ├── constants.go
│ │ │ ├── dev_aix_ppc.go
│ │ │ ├── dev_aix_ppc64.go
│ │ │ ├── dev_darwin.go
│ │ │ ├── dev_dragonfly.go
│ │ │ ├── dev_freebsd.go
│ │ │ ├── dev_linux.go
│ │ │ ├── dev_netbsd.go
│ │ │ ├── dev_openbsd.go
│ │ │ ├── dev_zos.go
│ │ │ ├── dirent.go
│ │ │ ├── endian_big.go
│ │ │ ├── endian_little.go
│ │ │ ├── env_unix.go
│ │ │ ├── epoll_zos.go
│ │ │ ├── fcntl.go
│ │ │ ├── fcntl_darwin.go
│ │ │ ├── fcntl_linux_32bit.go
│ │ │ ├── fdset.go
│ │ │ ├── fstatfs_zos.go
│ │ │ ├── gccgo.go
│ │ │ ├── gccgo_c.c
│ │ │ ├── gccgo_linux_amd64.go
│ │ │ ├── ifreq_linux.go
│ │ │ ├── internal/
│ │ │ │ └── mkmerge/
│ │ │ │ └── mkmerge.go
│ │ │ ├── ioctl_linux.go
│ │ │ ├── ioctl_signed.go
│ │ │ ├── ioctl_unsigned.go
│ │ │ ├── ioctl_zos.go
│ │ │ ├── linux/
│ │ │ │ ├── Dockerfile
│ │ │ │ ├── mkall.go
│ │ │ │ ├── mksysnum.go
│ │ │ │ └── types.go
│ │ │ ├── mkall.sh
│ │ │ ├── mkasm.go
│ │ │ ├── mkerrors.sh
│ │ │ ├── mkpost.go
│ │ │ ├── mksyscall.go
│ │ │ ├── mksyscall_aix_ppc.go
│ │ │ ├── mksyscall_aix_ppc64.go
│ │ │ ├── mksyscall_solaris.go
│ │ │ ├── mksysctl_openbsd.go
│ │ │ ├── mksysnum.go
│ │ │ ├── mmap_nomremap.go
│ │ │ ├── mremap.go
│ │ │ ├── pagesize_unix.go
│ │ │ ├── pledge_openbsd.go
│ │ │ ├── ptrace_darwin.go
│ │ │ ├── ptrace_ios.go
│ │ │ ├── race.go
│ │ │ ├── race0.go
│ │ │ ├── readdirent_getdents.go
│ │ │ ├── readdirent_getdirentries.go
│ │ │ ├── sockcmsg_dragonfly.go
│ │ │ ├── sockcmsg_linux.go
│ │ │ ├── sockcmsg_unix.go
│ │ │ ├── sockcmsg_unix_other.go
│ │ │ ├── syscall.go
│ │ │ ├── syscall_aix.go
│ │ │ ├── syscall_aix_ppc.go
│ │ │ ├── syscall_aix_ppc64.go
│ │ │ ├── syscall_bsd.go
│ │ │ ├── syscall_darwin.go
│ │ │ ├── syscall_darwin_amd64.go
│ │ │ ├── syscall_darwin_arm64.go
│ │ │ ├── syscall_darwin_libSystem.go
│ │ │ ├── syscall_dragonfly.go
│ │ │ ├── syscall_dragonfly_amd64.go
│ │ │ ├── syscall_freebsd.go
│ │ │ ├── syscall_freebsd_386.go
│ │ │ ├── syscall_freebsd_amd64.go
│ │ │ ├── syscall_freebsd_arm.go
│ │ │ ├── syscall_freebsd_arm64.go
│ │ │ ├── syscall_freebsd_riscv64.go
│ │ │ ├── syscall_hurd.go
│ │ │ ├── syscall_hurd_386.go
│ │ │ ├── syscall_illumos.go
│ │ │ ├── syscall_linux.go
│ │ │ ├── syscall_linux_386.go
│ │ │ ├── syscall_linux_alarm.go
│ │ │ ├── syscall_linux_amd64.go
│ │ │ ├── syscall_linux_amd64_gc.go
│ │ │ ├── syscall_linux_arm.go
│ │ │ ├── syscall_linux_arm64.go
│ │ │ ├── syscall_linux_gc.go
│ │ │ ├── syscall_linux_gc_386.go
│ │ │ ├── syscall_linux_gc_arm.go
│ │ │ ├── syscall_linux_gccgo_386.go
│ │ │ ├── syscall_linux_gccgo_arm.go
│ │ │ ├── syscall_linux_loong64.go
│ │ │ ├── syscall_linux_mips64x.go
│ │ │ ├── syscall_linux_mipsx.go
│ │ │ ├── syscall_linux_ppc.go
│ │ │ ├── syscall_linux_ppc64x.go
│ │ │ ├── syscall_linux_riscv64.go
│ │ │ ├── syscall_linux_s390x.go
│ │ │ ├── syscall_linux_sparc64.go
│ │ │ ├── syscall_netbsd.go
│ │ │ ├── syscall_netbsd_386.go
│ │ │ ├── syscall_netbsd_amd64.go
│ │ │ ├── syscall_netbsd_arm.go
│ │ │ ├── syscall_netbsd_arm64.go
│ │ │ ├── syscall_openbsd.go
│ │ │ ├── syscall_openbsd_386.go
│ │ │ ├── syscall_openbsd_amd64.go
│ │ │ ├── syscall_openbsd_arm.go
│ │ │ ├── syscall_openbsd_arm64.go
│ │ │ ├── syscall_openbsd_libc.go
│ │ │ ├── syscall_openbsd_mips64.go
│ │ │ ├── syscall_openbsd_ppc64.go
│ │ │ ├── syscall_openbsd_riscv64.go
│ │ │ ├── syscall_solaris.go
│ │ │ ├── syscall_solaris_amd64.go
│ │ │ ├── syscall_unix.go
│ │ │ ├── syscall_unix_gc.go
│ │ │ ├── syscall_unix_gc_ppc64x.go
│ │ │ ├── syscall_zos_s390x.go
│ │ │ ├── sysvshm_linux.go
│ │ │ ├── sysvshm_unix.go
│ │ │ ├── sysvshm_unix_other.go
│ │ │ ├── timestruct.go
│ │ │ ├── types_aix.go
│ │ │ ├── types_darwin.go
│ │ │ ├── types_dragonfly.go
│ │ │ ├── types_freebsd.go
│ │ │ ├── types_netbsd.go
│ │ │ ├── types_openbsd.go
│ │ │ ├── types_solaris.go
│ │ │ ├── unveil_openbsd.go
│ │ │ ├── xattr_bsd.go
│ │ │ ├── zerrors_aix_ppc.go
│ │ │ ├── zerrors_aix_ppc64.go
│ │ │ ├── zerrors_darwin_amd64.go
│ │ │ ├── zerrors_darwin_arm64.go
│ │ │ ├── zerrors_dragonfly_amd64.go
│ │ │ ├── zerrors_freebsd_386.go
│ │ │ ├── zerrors_freebsd_amd64.go
│ │ │ ├── zerrors_freebsd_arm.go
│ │ │ ├── zerrors_freebsd_arm64.go
│ │ │ ├── zerrors_freebsd_riscv64.go
│ │ │ ├── zerrors_linux.go
│ │ │ ├── zerrors_linux_386.go
│ │ │ ├── zerrors_linux_amd64.go
│ │ │ ├── zerrors_linux_arm.go
│ │ │ ├── zerrors_linux_arm64.go
│ │ │ ├── zerrors_linux_loong64.go
│ │ │ ├── zerrors_linux_mips.go
│ │ │ ├── zerrors_linux_mips64.go
│ │ │ ├── zerrors_linux_mips64le.go
│ │ │ ├── zerrors_linux_mipsle.go
│ │ │ ├── zerrors_linux_ppc.go
│ │ │ ├── zerrors_linux_ppc64.go
│ │ │ ├── zerrors_linux_ppc64le.go
│ │ │ ├── zerrors_linux_riscv64.go
│ │ │ ├── zerrors_linux_s390x.go
│ │ │ ├── zerrors_linux_sparc64.go
│ │ │ ├── zerrors_netbsd_386.go
│ │ │ ├── zerrors_netbsd_amd64.go
│ │ │ ├── zerrors_netbsd_arm.go
│ │ │ ├── zerrors_netbsd_arm64.go
│ │ │ ├── zerrors_openbsd_386.go
│ │ │ ├── zerrors_openbsd_amd64.go
│ │ │ ├── zerrors_openbsd_arm.go
│ │ │ ├── zerrors_openbsd_arm64.go
│ │ │ ├── zerrors_openbsd_mips64.go
│ │ │ ├── zerrors_openbsd_ppc64.go
│ │ │ ├── zerrors_openbsd_riscv64.go
│ │ │ ├── zerrors_solaris_amd64.go
│ │ │ ├── zerrors_zos_s390x.go
│ │ │ ├── zptrace_armnn_linux.go
│ │ │ ├── zptrace_linux_arm64.go
│ │ │ ├── zptrace_mipsnn_linux.go
│ │ │ ├── zptrace_mipsnnle_linux.go
│ │ │ ├── zptrace_x86_linux.go
│ │ │ ├── zsyscall_aix_ppc.go
│ │ │ ├── zsyscall_aix_ppc64.go
│ │ │ ├── zsyscall_aix_ppc64_gc.go
│ │ │ ├── zsyscall_aix_ppc64_gccgo.go
│ │ │ ├── zsyscall_darwin_amd64.go
│ │ │ ├── zsyscall_darwin_amd64.s
│ │ │ ├── zsyscall_darwin_arm64.go
│ │ │ ├── zsyscall_darwin_arm64.s
│ │ │ ├── zsyscall_dragonfly_amd64.go
│ │ │ ├── zsyscall_freebsd_386.go
│ │ │ ├── zsyscall_freebsd_amd64.go
│ │ │ ├── zsyscall_freebsd_arm.go
│ │ │ ├── zsyscall_freebsd_arm64.go
│ │ │ ├── zsyscall_freebsd_riscv64.go
│ │ │ ├── zsyscall_illumos_amd64.go
│ │ │ ├── zsyscall_linux.go
│ │ │ ├── zsyscall_linux_386.go
│ │ │ ├── zsyscall_linux_amd64.go
│ │ │ ├── zsyscall_linux_arm.go
│ │ │ ├── zsyscall_linux_arm64.go
│ │ │ ├── zsyscall_linux_loong64.go
│ │ │ ├── zsyscall_linux_mips.go
│ │ │ ├── zsyscall_linux_mips64.go
│ │ │ ├── zsyscall_linux_mips64le.go
│ │ │ ├── zsyscall_linux_mipsle.go
│ │ │ ├── zsyscall_linux_ppc.go
│ │ │ ├── zsyscall_linux_ppc64.go
│ │ │ ├── zsyscall_linux_ppc64le.go
│ │ │ ├── zsyscall_linux_riscv64.go
│ │ │ ├── zsyscall_linux_s390x.go
│ │ │ ├── zsyscall_linux_sparc64.go
│ │ │ ├── zsyscall_netbsd_386.go
│ │ │ ├── zsyscall_netbsd_amd64.go
│ │ │ ├── zsyscall_netbsd_arm.go
│ │ │ ├── zsyscall_netbsd_arm64.go
│ │ │ ├── zsyscall_openbsd_386.go
│ │ │ ├── zsyscall_openbsd_386.s
│ │ │ ├── zsyscall_openbsd_amd64.go
│ │ │ ├── zsyscall_openbsd_amd64.s
│ │ │ ├── zsyscall_openbsd_arm.go
│ │ │ ├── zsyscall_openbsd_arm.s
│ │ │ ├── zsyscall_openbsd_arm64.go
│ │ │ ├── zsyscall_openbsd_arm64.s
│ │ │ ├── zsyscall_openbsd_mips64.go
│ │ │ ├── zsyscall_openbsd_mips64.s
│ │ │ ├── zsyscall_openbsd_ppc64.go
│ │ │ ├── zsyscall_openbsd_ppc64.s
│ │ │ ├── zsyscall_openbsd_riscv64.go
│ │ │ ├── zsyscall_openbsd_riscv64.s
│ │ │ ├── zsyscall_solaris_amd64.go
│ │ │ ├── zsyscall_zos_s390x.go
│ │ │ ├── zsysctl_openbsd_386.go
│ │ │ ├── zsysctl_openbsd_amd64.go
│ │ │ ├── zsysctl_openbsd_arm.go
│ │ │ ├── zsysctl_openbsd_arm64.go
│ │ │ ├── zsysctl_openbsd_mips64.go
│ │ │ ├── zsysctl_openbsd_ppc64.go
│ │ │ ├── zsysctl_openbsd_riscv64.go
│ │ │ ├── zsysnum_darwin_amd64.go
│ │ │ ├── zsysnum_darwin_arm64.go
│ │ │ ├── zsysnum_dragonfly_amd64.go
│ │ │ ├── zsysnum_freebsd_386.go
│ │ │ ├── zsysnum_freebsd_amd64.go
│ │ │ ├── zsysnum_freebsd_arm.go
│ │ │ ├── zsysnum_freebsd_arm64.go
│ │ │ ├── zsysnum_freebsd_riscv64.go
│ │ │ ├── zsysnum_linux_386.go
│ │ │ ├── zsysnum_linux_amd64.go
│ │ │ ├── zsysnum_linux_arm.go
│ │ │ ├── zsysnum_linux_arm64.go
│ │ │ ├── zsysnum_linux_loong64.go
│ │ │ ├── zsysnum_linux_mips.go
│ │ │ ├── zsysnum_linux_mips64.go
│ │ │ ├── zsysnum_linux_mips64le.go
│ │ │ ├── zsysnum_linux_mipsle.go
│ │ │ ├── zsysnum_linux_ppc.go
│ │ │ ├── zsysnum_linux_ppc64.go
│ │ │ ├── zsysnum_linux_ppc64le.go
│ │ │ ├── zsysnum_linux_riscv64.go
│ │ │ ├── zsysnum_linux_s390x.go
│ │ │ ├── zsysnum_linux_sparc64.go
│ │ │ ├── zsysnum_netbsd_386.go
│ │ │ ├── zsysnum_netbsd_amd64.go
│ │ │ ├── zsysnum_netbsd_arm.go
│ │ │ ├── zsysnum_netbsd_arm64.go
│ │ │ ├── zsysnum_openbsd_386.go
│ │ │ ├── zsysnum_openbsd_amd64.go
│ │ │ ├── zsysnum_openbsd_arm.go
│ │ │ ├── zsysnum_openbsd_arm64.go
│ │ │ ├── zsysnum_openbsd_mips64.go
│ │ │ ├── zsysnum_openbsd_ppc64.go
│ │ │ ├── zsysnum_openbsd_riscv64.go
│ │ │ ├── zsysnum_zos_s390x.go
│ │ │ ├── ztypes_aix_ppc.go
│ │ │ ├── ztypes_aix_ppc64.go
│ │ │ ├── ztypes_darwin_amd64.go
│ │ │ ├── ztypes_darwin_arm64.go
│ │ │ ├── ztypes_dragonfly_amd64.go
│ │ │ ├── ztypes_freebsd_386.go
│ │ │ ├── ztypes_freebsd_amd64.go
│ │ │ ├── ztypes_freebsd_arm.go
│ │ │ ├── ztypes_freebsd_arm64.go
│ │ │ ├── ztypes_freebsd_riscv64.go
│ │ │ ├── ztypes_linux.go
│ │ │ ├── ztypes_linux_386.go
│ │ │ ├── ztypes_linux_amd64.go
│ │ │ ├── ztypes_linux_arm.go
│ │ │ ├── ztypes_linux_arm64.go
│ │ │ ├── ztypes_linux_loong64.go
│ │ │ ├── ztypes_linux_mips.go
│ │ │ ├── ztypes_linux_mips64.go
│ │ │ ├── ztypes_linux_mips64le.go
│ │ │ ├── ztypes_linux_mipsle.go
│ │ │ ├── ztypes_linux_ppc.go
│ │ │ ├── ztypes_linux_ppc64.go
│ │ │ ├── ztypes_linux_ppc64le.go
│ │ │ ├── ztypes_linux_riscv64.go
│ │ │ ├── ztypes_linux_s390x.go
│ │ │ ├── ztypes_linux_sparc64.go
│ │ │ ├── ztypes_netbsd_386.go
│ │ │ ├── ztypes_netbsd_amd64.go
│ │ │ ├── ztypes_netbsd_arm.go
│ │ │ ├── ztypes_netbsd_arm64.go
│ │ │ ├── ztypes_openbsd_386.go
│ │ │ ├── ztypes_openbsd_amd64.go
│ │ │ ├── ztypes_openbsd_arm.go
│ │ │ ├── ztypes_openbsd_arm64.go
│ │ │ ├── ztypes_openbsd_mips64.go
│ │ │ ├── ztypes_openbsd_ppc64.go
│ │ │ ├── ztypes_openbsd_riscv64.go
│ │ │ ├── ztypes_solaris_amd64.go
│ │ │ └── ztypes_zos_s390x.go
│ │ ├── websocket/
│ │ │ ├── client.go
│ │ │ ├── dial.go
│ │ │ ├── hybi.go
│ │ │ ├── server.go
│ │ │ └── websocket.go
│ │ └── windows/
│ │ ├── aliases.go
│ │ ├── dll_windows.go
│ │ ├── empty.s
│ │ ├── env_windows.go
│ │ ├── eventlog.go
│ │ ├── exec_windows.go
│ │ ├── memory_windows.go
│ │ ├── mkerrors.bash
│ │ ├── mkknownfolderids.bash
│ │ ├── mksyscall.go
│ │ ├── mkwinsyscall/
│ │ │ └── mkwinsyscall.go
│ │ ├── race.go
│ │ ├── race0.go
│ │ ├── registry/
│ │ │ ├── key.go
│ │ │ ├── mksyscall.go
│ │ │ ├── syscall.go
│ │ │ ├── value.go
│ │ │ └── zsyscall_windows.go
│ │ ├── security_windows.go
│ │ ├── service.go
│ │ ├── setupapi_windows.go
│ │ ├── str.go
│ │ ├── svc/
│ │ │ ├── debug/
│ │ │ │ ├── log.go
│ │ │ │ └── service.go
│ │ │ ├── eventlog/
│ │ │ │ ├── install.go
│ │ │ │ └── log.go
│ │ │ ├── mgr/
│ │ │ │ ├── config.go
│ │ │ │ ├── mgr.go
│ │ │ │ ├── recovery.go
│ │ │ │ └── service.go
│ │ │ ├── security.go
│ │ │ └── service.go
│ │ ├── syscall.go
│ │ ├── syscall_windows.go
│ │ ├── types_windows.go
│ │ ├── types_windows_386.go
│ │ ├── types_windows_amd64.go
│ │ ├── types_windows_arm.go
│ │ ├── types_windows_arm64.go
│ │ ├── zerrors_windows.go
│ │ ├── zknownfolderids_windows.go
│ │ └── zsyscall_windows.go
│ ├── term/
│ │ ├── term.go
│ │ ├── term_plan9.go
│ │ ├── term_unix.go
│ │ ├── term_unix_bsd.go
│ │ ├── term_unix_other.go
│ │ ├── term_unsupported.go
│ │ ├── term_windows.go
│ │ └── terminal.go
│ ├── toml/
│ │ ├── README.md
│ │ ├── decode.go
│ │ ├── errors.go
│ │ ├── internal/
│ │ │ ├── characters/
│ │ │ │ ├── ascii.go
│ │ │ │ └── utf8.go
│ │ │ ├── cli/
│ │ │ │ └── cli.go
│ │ │ ├── danger/
│ │ │ │ ├── danger.go
│ │ │ │ └── typeid.go
│ │ │ ├── testsuite/
│ │ │ │ ├── add.go
│ │ │ │ ├── json.go
│ │ │ │ ├── parser.go
│ │ │ │ ├── rm.go
│ │ │ │ └── testsuite.go
│ │ │ └── tracker/
│ │ │ ├── key.go
│ │ │ ├── seen.go
│ │ │ └── tracker.go
│ │ ├── localtime.go
│ │ ├── marshaler.go
│ │ ├── strict.go
│ │ ├── types.go
│ │ ├── unmarshaler.go
│ │ └── unstable/
│ │ ├── ast.go
│ │ ├── builder.go
│ │ ├── doc.go
│ │ ├── kind.go
│ │ ├── parser.go
│ │ ├── scanner.go
│ │ └── unmarshaler.go
│ ├── webroot/
│ │ ├── assets/
│ │ │ ├── About-BpAaEtx1.js
│ │ │ ├── About-CLFeE-IU.css
│ │ │ ├── Manage-BTEK3ejV.css
│ │ │ ├── Manage-D_HsZWaT.js
│ │ │ ├── NotFound-Ts0Rt7lc.js
│ │ │ ├── SysInit-Lho0-Z3X.css
│ │ │ ├── SysInit-UcmoR0-x.js
│ │ │ ├── el-card-fwQOLwdi.css
│ │ │ ├── index-D6GWZ696.css
│ │ │ └── index-jH42E5tC.js
│ │ ├── index.html
│ │ └── readme.md
│ └── websocket/
│ ├── client.go
│ ├── compression.go
│ ├── conn.go
│ ├── doc.go
│ ├── join.go
│ ├── json.go
│ ├── mask.go
│ ├── mask_safe.go
│ ├── prepared.go
│ ├── proxy.go
│ ├── server.go
│ └── util.go
├── postman.json
├── sshd.dockerfile
└── webssh/
├── .gitignore
├── README.md
├── auto-imports.d.ts
├── components.d.ts
├── env.d.ts
├── index.html
├── package.json
├── src/
│ ├── App.vue
│ ├── components/
│ │ ├── Account.vue
│ │ ├── Connect.vue
│ │ ├── LoginAudit.vue
│ │ ├── NetFilter.vue
│ │ ├── SshdCert.vue
│ │ └── SshdUser.vue
│ ├── main.ts
│ ├── router/
│ │ └── index.ts
│ ├── stores/
│ │ └── store.ts
│ └── views/
│ ├── About.vue
│ ├── Empty.vue
│ ├── Home.vue
│ ├── Login.vue
│ ├── Manage.vue
│ ├── NotFound.vue
│ └── SysInit.vue
├── tsconfig.app.json
├── tsconfig.json
├── tsconfig.node.json
└── vite.config.ts
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/workflows/build-and-release.yml
================================================
name: Build & Release
on:
push:
tags:
- 'v*'
jobs:
build-release:
runs-on: ubuntu-latest
name: Build and Release for Multiple OS & Arch
permissions:
contents: write
strategy:
matrix:
include:
# Windows targets (64-bit only)
- os: windows
goos: windows
arch: amd64
suffix: .exe
- os: windows
goos: windows
arch: arm64
suffix: .exe
# Linux targets (64-bit architectures)
- os: linux
goos: linux
arch: amd64
suffix: ''
- os: linux
goos: linux
arch: arm64
suffix: ''
# macOS targets (64-bit only)
- os: macos
goos: darwin
arch: amd64
suffix: ''
- os: macos
goos: darwin
arch: arm64
suffix: ''
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: '1.24.5'
- name: Prepare workspace
run: |
echo "📂 Workspace structure:"
ls -la
echo "📂 gossh directory content:"
ls -la gossh
# 确保目录存在
if [ ! -d "gossh" ]; then
echo "❌ Error: gossh directory not found!"
exit 1
fi
# 进入gossh目录准备构建
cd gossh
# 更新依赖
go mod tidy
go mod download
- name: Build for ${{ matrix.os }} (${{ matrix.arch }})
run: |
cd gossh
# 静态编译
CGO_ENABLED=0 GOOS=${{ matrix.goos }} GOARCH=${{ matrix.arch }} \
go build -trimpath -ldflags="-s -w" -o GoWebSSH${{ matrix.suffix }}
# 验证构建结果
if [ -f "GoWebSSH${{ matrix.suffix }}" ]; then
echo "✅ Build successful! Output file: GoWebSSH${{ matrix.suffix }}"
else
echo "❌ Build failed! No output file found in: $(pwd)"
ls -la
exit 1
fi
# 创建压缩包(在仓库根目录)
zip -j ../GoWebSSH-${{ matrix.os }}-${{ matrix.arch }}.zip GoWebSSH${{ matrix.suffix }}
- name: Create Release and Upload Assets
uses: softprops/action-gh-release@v2
with:
name: ${{ github.ref_name }}
tag_name: ${{ github.ref_name }}
files: |
GoWebSSH-*.zip
================================================
FILE: .gitignore
================================================
# Logs
logs
*.log
npm-debug.log*
yarn-debug.log*
yarn-error.log*
lerna-debug.log*
# Diagnostic reports (https://nodejs.org/api/report.html)
report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json
# Runtime data
pids
*.pid
*.seed
*.pid.lock
# Directory for instrumented libs generated by jscoverage/JSCover
lib-cov
# Coverage directory used by tools like istanbul
coverage
*.lcov
# nyc test coverage
.nyc_output
# Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files)
.grunt
# Bower dependency directory (https://bower.io/)
bower_components
# node-waf configuration
.lock-wscript
# Compiled binary addons (https://nodejs.org/api/addons.html)
build/Release
# Dependency directories
node_modules/
jspm_packages/
# TypeScript v1 declaration files
typings/
# TypeScript cache
*.tsbuildinfo
# Optional npm cache directory
.npm
# Optional eslint cache
.eslintcache
# Microbundle cache
.rpt2_cache/
.rts2_cache_cjs/
.rts2_cache_es/
.rts2_cache_umd/
# Optional REPL history
.node_repl_history
# Output of 'npm pack'
*.tgz
# Yarn Integrity file
.yarn-integrity
# dotenv environment variables file
.env
.env.test
# parcel-bundler cache (https://parceljs.org/)
.cache
# Next.js build output
.next
# Nuxt.js build / generate output
.nuxt
dist
# Gatsby files
.cache/
# Comment in the public line in if your project uses Gatsby and *not* Next.js
# https://nextjs.org/blog/next-9-1#public-directory-support
# public
# vuepress build output
.vuepress/dist
# Serverless directories
.serverless/
# FuseBox cache
.fusebox/
# DynamoDB Local files
.dynamodb/
# TernJS port file
.tern-port
######################
# OS generated files #
######################
.DS_Store
.DS_Store?
._*
.Spotlight-V100
.Trashes
ehthumbs.db
Thumbs.db
######################
# ide #
######################
.idea
*.iml
/gossh/go.sum
/gossh/gossh
/gossh/GoWebSSH.db
================================================
FILE: Dockerfile
================================================
# 第一阶段:构建
FROM golang:1.23.4 AS builder
WORKDIR /mnt
COPY gossh ./gossh
RUN cd /mnt/gossh && CGO_ENABLED=0 go build -o /mnt/gowebssh
# 第二阶段:运行
FROM alpine:3.21.0
WORKDIR /root/
VOLUME /var/lib/webssh
EXPOSE 8899
COPY --from=builder /mnt/gowebssh /bin/gowebssh
CMD ["/bin/gowebssh", "-WorkDir", "/var/lib/webssh"]
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2021 o8oo8o
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
## GoWebSSH
### 项目介绍:
* **Web版ssh客户端 + (sshd,sftp)服务端实现**
### 概要:
* Golang 1.23 + (Vue3.5 + Vite6) 实现一个Web版单文件的(SSH+SSHD)
* 借助于Golang embed,打包以后只有一个文件,简单高效
* 使用及编译过程,超级简单,绝对保姆级
* 上一版主要本地运行,但是通过部分用户反馈,此项目定位改为服务器运行,所以此版本加入了很多企业场景中的功能
### 联系我:
* **QQ:774309635**
---
### Quick start(大象装进冰箱只需3步):
> 必须使用golang 1.21以上版本
* git clone https://github.com/o8oo8o/WebSSH.git
* cd WebSSH/gossh
* go build
* ./gossh
> 打开链接 http://127.0.0.1:8899/ 开始享用吧,第一次需要初始化
### Docker 方式:
* git clone https://github.com/o8oo8o/WebSSH.git
* cd WebSSH
* docker build -f Dockerfile -t gowebssh:v2 .
* docker run -d --name webssh -p 8899:8899 -v gowebssh:/var/lib/webssh gowebssh:v2
### 打赏我:
* **每一个开源项目的背后,都有一群默默付出、充满激情的开发者。他们用自己的业余时间,不断地优化代码、修复bug、撰写文档,只为让项目变得更好。如果您觉得我的项目对您有所帮助,如果您认可我的努力和付出,那么请考虑给予我一点小小的打赏,够买一瓶啤酒就行🍺,如果能同时打赏啤酒花生那更好🍺🥜,因为所有的代码都是喝完酒撸的。放上收款码的时候我是羞愧的,一个中年男人的最后的尊严和节操竟然没了😂,友情提示:打赏不退,怕被媳妇查到大额支出🥸,如果需要技术支持,需要收费哦**

### 运行环境依赖:
* 需要MySQL8+及PostgreSQL12.2+或者直接使用内置SQLite数据库
### SSHD服务器功能:
* 可以配置只监听本地端口
* 支持Web配置sshd服务器账号密码及公钥
* 支持密码认证,公钥认证,增强登录过程的安全性
* 通过SFTP或SCP用户可以安全地在本地和远程服务器之间传输文件
### Web客户端主要功能:
* 支持同时连接多个主机,支持重连、清屏功能
* 支持IPv4、IPv6
* 支持SSH证书登陆及证书密码
* 支持批量支持命令,当前终端及所有终端
* 支持命令收藏,方便重复执行命令,批量发送命令到所有会话
* 可以保存主机连接信息
* 支持直接通过Web上传下载文件
* 支持直接通过Web创建目录,删除文件及目录功能
* 支持手动输入路径
* 支持自定义终端字体大小、字体颜色、字体样式
* 支持自定义背景、光标颜色及光标样式
* 已保存的主机信息可直接编辑并连接
* 支持后台管理,强制断开连接
* 支持登陆日志审计,方便监控违规操作
* 支持访问控制,在公网场景中有效拦截非法访问
### 为什么这么简单:
* 为了方便您使用,把golang编译的依赖已经整理好了,clone就一起下载了
* 前端已经编译完成,并把编译完成的静态资源拷贝到gossh/webroot目录中
* 可执行文件内嵌静态资源,方便你随性所欲的移动可执行文件
* 因内置sshd服务器,在受限的网络环境依然能通过web访问
### 配置文件:
* 第一次运行会在用户home目录创建一个 .GoWebSSH 目录
* GoWebSSH.toml 可以配置server端口等信息
* cert.pem HTTPS服务器证书文件
* key.key HTTPS服务器私钥文件
### 注意:
* 当程序检测到cert.pem 和 key.key 文件,会使用https协议,否则使用http协议
* 用户只需把证书文件和私钥文件放到 .GoWebSSH 目录就可以了
```shell
openssl genpkey -algorithm RSA -out key.key -pkeyopt rsa_keygen_bits:2048
openssl req -new -x509 -key key.key -out cert.pem -days 365 -subj "/C=CN/ST=bj/L=bj/O=gowebssh/OU=gowebssh/CN=gowebssh.com"
```
### Systemd 方式启动:
```shell
cat > /etc/systemd/system/gowebssh.service << "END"
##################################
[Unit]
Description=GoWebSSH Daemon
After=network.target
Wants=network-online.target
[Service]
Type=simple
User=root
Environment=TERM=xterm
Environment=XDG_SESSION_TYPE=tty
Environment=HOME=/root
PrivateTmp=true
LimitNOFILE=65535
# 执行程序路径
ExecStart=/usr/local/gossh
# auto restart
StartLimitIntervalSec=0
Restart=always
RestartSec=1
[Install]
WantedBy=multi-user.target
##################################
END
systemctl daemon-reload
systemctl start gowebssh.service
systemctl enable gowebssh.service
```
---
### 演示截图:






================================================
FILE: gossh/Makefile
================================================
.PHONY: clean all
all: linux-amd64 linux-arm64 macos-amd64 macos-arm64 windows-amd64
linux-amd64:
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o WebSSH-linux-amd64
linux-arm64:
CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -o WebSSH-linux-arm64
macos-amd64:
CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -o WebSSH-macos-amd64
macos-arm64:
CGO_ENABLED=0 GOOS=darwin GOARCH=arm64 go build -o WebSSH-macos-arm64
windows-amd64:
CGO_ENABLED=0 GOOS=windows GOARCH=amd64 go build -o WebSSH-windows-amd64.exe
WebSSH:
go build -o WebSSH
clean:
rm -f WebSSH-linux-amd64 WebSSH-linux-arm64 WebSSH-macos-amd64 WebSSH-macos-arm64 WebSSH-windows-amd64.exe
================================================
FILE: gossh/app/config/config.go
================================================
package config
import (
"flag"
"gossh/app/utils"
"gossh/toml"
"log/slog"
"os"
"path"
"time"
)
type AppConfig struct {
WebBaseDir string `json:"web_base_dir" toml:"web_base_dir"`
AppName string `json:"app_name" toml:"app_name"`
DbType string `json:"db_type" toml:"db_type"`
DbDsn string `json:"db_dsn" toml:"db_dsn"`
IsInit bool `json:"is_init" toml:"is_init"`
JwtSecret string `json:"jwt_secret" toml:"jwt_secret"`
AesSecret string `json:"aes_secret" toml:"aes_secret"`
JwtExpire time.Duration `json:"jwt_expire" toml:"jwt_expire"`
StatusRefresh time.Duration `json:"status_refresh" toml:"status_refresh"`
ClientCheck time.Duration `json:"client_check" toml:"client_check"`
SessionSecret string `json:"session_secret" toml:"session_secret"`
Address string `json:"address" toml:"address"`
Port string `json:"port" toml:"port"`
CertFile string `json:"cert_file" toml:"cert_file"`
KeyFile string `json:"key_file" toml:"key_file"`
}
func (c *AppConfig) write() error {
data, err := toml.Marshal(c)
if err != nil {
slog.Error("序列化TOML配置文件错误:", "err_msg", err)
return err
}
err = os.WriteFile(confFileFullPath, data, os.FileMode(0777))
if err != nil {
slog.Error("写入默认配置文件错误:", "err_msg", err.Error())
return err
}
return nil
}
var DefaultConfig = AppConfig{
WebBaseDir: "",
AppName: "GoWebSHH",
DbType: "mysql",
DbDsn: "",
IsInit: false,
JwtSecret: utils.RandString(64),
AesSecret: utils.RandString(32),
SessionSecret: utils.RandString(64),
JwtExpire: time.Minute * 120,
StatusRefresh: time.Second * 3,
ClientCheck: time.Second * 15,
Address: "",
Port: "8899",
CertFile: path.Join(WorkDir, "cert.pem"),
KeyFile: path.Join(WorkDir, "key.key"),
}
var UserHomeDir, _ = os.UserHomeDir()
var projectName = "GoWebSSH"
var confFileName = "GoWebSSH.toml"
var confPath = string(os.PathSeparator) + "." + projectName + string(os.PathSeparator)
// WorkDir 程序默认工作目录,在用户的home目录下 .GoWebSSH 目录
var WorkDir = path.Join(UserHomeDir, confPath)
var confFileFullPath = path.Join(WorkDir, confFileName)
func InitConfig() {
defer func() {
if err := recover(); err != nil {
slog.Error("init config failed", "err_msg", err)
os.Exit(255)
}
}()
var configDir string
var webBaseDir string
flag.StringVar(&configDir, "ConfigDir", "", "ConfigDir")
flag.StringVar(&webBaseDir, "WebBaseDir", "", "WebBaseDir")
flag.Parse()
DefaultConfig.WebBaseDir = configDir
if configDir != "" {
WorkDir = path.Join(configDir, confPath)
confFileFullPath = path.Join(WorkDir, confFileName)
DefaultConfig.CertFile = path.Join(WorkDir, "cert.pem")
DefaultConfig.KeyFile = path.Join(WorkDir, "key.key")
}
slog.Info("use-config-file", "path", confFileFullPath)
info, err := os.Stat(WorkDir)
if os.IsNotExist(err) {
err := os.MkdirAll(WorkDir, os.FileMode(0755))
if err != nil {
slog.Error("创建工作目录失败")
return
}
}
info, err = os.Stat(WorkDir)
if !info.IsDir() {
slog.Error("有一个和工作目录同名的目录", "dir", WorkDir)
return
}
_, err = os.Stat(confFileFullPath)
if os.IsNotExist(err) {
if err := DefaultConfig.write(); err != nil {
return
}
}
data, err := os.ReadFile(confFileFullPath)
if err != nil {
slog.Error("读取配置文件错误:", "err_msg", err.Error())
return
}
err = toml.Unmarshal(data, &DefaultConfig)
if err != nil {
slog.Error("TOML解析配置文件错误:", "err_msg", err.Error())
return
}
// 修正webBaseDir
if DefaultConfig.WebBaseDir != webBaseDir {
DefaultConfig.WebBaseDir = webBaseDir
if err := DefaultConfig.write(); err != nil {
return
}
}
slog.Debug("DefaultConfig:", "data", DefaultConfig)
}
func RewriteConfig(conf AppConfig) error {
data, err := toml.Marshal(conf)
if err != nil {
slog.Error("序列化TOML配置文件错误:", "err_msg", err)
return err
}
// 覆盖全局变量
DefaultConfig = conf
err = os.Remove(confFileFullPath)
if err != nil {
slog.Error("删除旧配置文件错误:", "err_msg", err.Error())
return err
}
err = os.WriteFile(confFileFullPath, data, os.FileMode(0777))
if err != nil {
slog.Error("写入默认配置文件错误:", "err_msg", err.Error())
return err
}
return nil
}
================================================
FILE: gossh/app/middleware/db_check.go
================================================
package middleware
import (
"gossh/app/config"
"gossh/app/model"
"gossh/gin"
)
func DbCheck() gin.HandlerFunc {
return func(c *gin.Context) {
if !config.DefaultConfig.IsInit {
c.Next()
return
}
if model.Db != nil {
tx := model.Db.Exec("select 1=1")
if tx.Error == nil {
c.Next()
} else {
err := model.DbMigrate(config.DefaultConfig.DbType, config.DefaultConfig.DbDsn)
if err != nil {
c.Abort()
c.JSON(500, gin.H{"code": 500, "msg": "数据库连接错误:" + err.Error()})
return
}
}
}
c.Next()
}
}
================================================
FILE: gossh/app/middleware/jwt_auth.go
================================================
package middleware
import (
"errors"
"gossh/app/config"
"gossh/gin"
"gossh/gin/jwt"
"strings"
"time"
)
// JwtSecret 生成JWT签名的密钥
var JwtSecret = []byte(config.DefaultConfig.JwtSecret)
type JwtClaims struct {
// 用户Id
Id uint
// 标准Claims结构体,可设置8个标准字段
jwt.RegisteredClaims
}
// GenerateToken 登录成功后调用,传入SshUser结构体
func GenerateToken(id uint) (string, error) {
// 两个小时有效期
expirationTime := time.Now().Add(config.DefaultConfig.JwtExpire)
claims := &JwtClaims{
Id: id,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expirationTime),
Issuer: "go_web_ssh",
},
}
// 生成Token,指定签名算法和claims
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
// 签名
tokenString, err := token.SignedString(JwtSecret)
if err != nil {
return "", err
}
return tokenString, nil
}
func ParseToken(tokenString string) (*JwtClaims, error) {
claims := &JwtClaims{}
_, err := jwt.ParseWithClaims(tokenString, claims, func(t *jwt.Token) (any, error) {
return JwtSecret, nil
})
// 若token只是过期claims是有数据的,若token无法解析claims无数据
return claims, err
}
func RenewToken(claims *JwtClaims) (string, error) {
// 若token过期不超过10分钟则给它续签
if withinLimit(claims.ExpiresAt.Time.Unix(), 600) {
return GenerateToken(claims.Id)
}
return "", errors.New("登录已过期")
}
// 计算过期时间是否超过l
func withinLimit(s, l int64) bool {
return time.Now().Unix()-s < l
}
func JWTAuth() gin.HandlerFunc {
return func(c *gin.Context) {
auth := c.Request.Header.Get("Authorization")
if len(auth) == 0 {
// Websocket SSE 从请求参数取
auth = c.Query("Authorization")
}
if len(auth) == 0 {
// 无token直接拒绝
c.Abort()
c.JSON(401, gin.H{"code": 401, "msg": "请添加Authorization请求头"})
return
}
// 校验token
claims, err := ParseToken(auth)
if err != nil {
if strings.Contains(err.Error(), "token is expired") {
// 若过期,调用续签函数
newToken, _ := RenewToken(claims)
if newToken != "" {
// 续签成功返回头设置一个NewToken字段
c.Header("NewToken", newToken)
//c.Request.Header.Set("Authorization", newToken)
c.Set("uid", claims.Id)
c.Next()
return
}
}
// Token验证失败或续签失败直接拒绝请求
c.Abort()
c.JSON(401, gin.H{"code": 401, "msg": "未登录"})
return
}
c.Set("uid", claims.Id)
// token未过期继续执行其他中间件
c.Next()
}
}
================================================
FILE: gossh/app/middleware/net_filter.go
================================================
package middleware
import (
"fmt"
"gossh/app/config"
"gossh/app/model"
"gossh/gin"
"log/slog"
"net"
)
func check(ip net.IP) bool {
var policyConf model.PolicyConf
conf, err := policyConf.FindByID(1)
if err != nil {
slog.Error("get policyConf:", "err_msg", err.Error())
return false
}
// 白名单检查
if conf.NetPolicy == "Y" {
// slog.Info("netFilterPolicy", "value", "Y")
var filter model.NetFilter
list, err := filter.FindAllPolicy("Y")
if err != nil {
slog.Error("FindAllPolicy:", "err_msg", err.Error())
}
isOK := false
for _, item := range list {
_, ipNet, err := net.ParseCIDR(item.Cidr)
if err != nil {
slog.Error("net.ParseCIDR:", "err_msg", err.Error())
return false
}
if ipNet.Contains(ip) {
isOK = true
break
}
}
return isOK
}
// 黑名单检查
if conf.NetPolicy == "N" {
// slog.Info("netFilterPolicy", "value", "N")
var filter model.NetFilter
list, err := filter.FindAllPolicy("N")
if err != nil {
slog.Error("FindAllPolicy:", "err_msg", err.Error())
}
isOK := true
for _, item := range list {
_, ipNet, err := net.ParseCIDR(item.Cidr)
if err != nil {
slog.Error("net.ParseCIDR:", "err_msg", err.Error())
return false
}
if ipNet.Contains(ip) {
isOK = false
break
}
}
return isOK
}
return false
}
func NetFilter() gin.HandlerFunc {
return func(c *gin.Context) {
// 系统没有进行初始化,过滤功能不生效
if !config.DefaultConfig.IsInit {
c.Next()
return
}
ip := net.ParseIP(c.RemoteIP())
if !check(ip) {
c.JSON(403, gin.H{"err_msg": fmt.Sprintf("Deny %s", ip.String())})
c.Abort()
return
}
c.Next()
}
}
================================================
FILE: gossh/app/middleware/perm_check.go
================================================
package middleware
import (
"fmt"
"gossh/app/config"
"gossh/app/model"
"gossh/gin"
"slices"
"strings"
)
var checkRoute = []string{
// "GET:/app/*filepath",
// "GET:/web_base_dir",
// "HEAD:/app/*filepath",
// "GET:/api/ssh/conn",
// "PATCH:/api/ssh/conn",
// "POST:/api/ssh/exec",
// "POST:/api/ssh/disconnect",
// "POST:/api/ssh/create_session",
// "GET:/api/conn_conf",
// "GET:/api/conn_conf/:id",
// "POST:/api/conn_conf",
// "PUT:/api/conn_conf",
// "DELETE:/api/conn_conf/:id",
// "GET:/api/cmd_note",
// "GET:/api/cmd_note/:id",
// "POST:/api/cmd_note",
// "PUT:/api/cmd_note",
// "DELETE:/api/cmd_note/:id",
// "GET:/api/sftp/download",
// "POST:/api/sftp/create_dir",
// "POST:/api/sftp/list",
// "DELETE:/api/sftp/delete",
// "PUT:/api/sftp/upload",
// "PUT:/api/conn_manage/refresh_conn_time",
// "POST:/api/login",
// "PATCH:/api/user/pwd",
"GET:/api/sshd_cert",
"GET:/api/sshd_cert_text",
"GET:/api/sshd_cert/:id",
"GET:/api/sshd_user",
"GET:/api/sshd_user/:id",
"GET:/api/sys/is_init",
"GET:/api/sys/config",
"GET:/api/conn_manage/online_client",
"GET:/api/policy_conf",
"GET:/api/policy_conf/:id",
"GET:/api/net_filter",
"GET:/api/net_filter/:id",
"GET:/api/user",
"GET:/api/user/:id",
"POST:/api/sshd_user",
"POST:/api/sshd_cert",
"POST:/api/sys/db_conn_check",
"POST:/api/sys/init",
"POST:/api/sys/config",
"POST:/api/login_audit",
"POST:/api/policy_conf",
"POST:/api/net_filter",
"POST:/api/user",
"PUT:/api/sshd_user",
"PUT:/api/sshd_cert",
"PUT:/api/policy_conf",
"PUT:/api/net_filter",
"PUT:/api/user",
"DELETE:/api/sshd_user/:id",
"DELETE:/api/sshd_cert/:id",
"DELETE:/api/policy_conf/:id",
"DELETE:/api/net_filter/:id",
"DELETE:/api/user/:id",
"PATCH:/api/sshd_user/check_name_exists",
"PATCH:/api/sshd_cert/check_name_exists",
"PATCH:/api/user/check_name_exists",
}
func PremCheck(engine *gin.Engine) gin.HandlerFunc {
return func(c *gin.Context) {
/*
for _, route := range engine.Routes() {
path := route.Path
method := route.Method
fmt.Printf("Route==>: %s:%s\n", method, path)
}
*/
var tmp = c.FullPath()
var baseDir = strings.Trim(config.DefaultConfig.WebBaseDir, " ")
if baseDir != "" && strings.HasPrefix(tmp, baseDir) {
tmp = strings.TrimPrefix(tmp, baseDir)
}
full := fmt.Sprintf("%s:%s", c.Request.Method, tmp)
// slog.Info("PermCheck", "base", baseDir, "path", full)
if slices.Contains(checkRoute, full) {
var wu model.WebUser
u, err := wu.FindByID(c.GetUint("uid"))
if err != nil || u.IsAdmin == "N" {
c.JSON(200, gin.H{"code": 403, "msg": "权限检查失败,非管理员拒绝操作"})
c.Abort()
return
}
}
c.Next()
}
}
================================================
FILE: gossh/app/middleware/sys_init.go
================================================
package middleware
import (
"gossh/app/config"
"gossh/gin"
)
func SysInit() gin.HandlerFunc {
return func(c *gin.Context) {
if !config.DefaultConfig.IsInit {
// 需要进行系统初始化
c.Abort()
c.JSON(401, gin.H{"code": 401, "msg": "请对系统进行初始化"})
return
}
c.Next()
}
}
================================================
FILE: gossh/app/model/cmd_note.go
================================================
package model
type CmdNote struct {
ID uint `gorm:"column:id;primaryKey,autoIncrement" form:"id" json:"id"`
Uid uint `gorm:"column:uid;not null;default:0" form:"uid" json:"uid"`
CmdName string `gorm:"column:cmd_name;type:text" form:"cmd_name" binding:"required" json:"cmd_name"`
CmdData string `gorm:"column:cmd_data;type:text" form:"cmd_data" binding:"required" json:"cmd_data"`
CreatedAt DateTime `gorm:"created_at" json:"-"`
UpdatedAt DateTime `gorm:"updated_at" json:"-"`
}
func (c CmdNote) Create(cmd *CmdNote) error {
return Db.Create(cmd).Error
}
func (c CmdNote) FindByName(name string) (CmdNote, error) {
var cmd CmdNote
err := Db.First(&cmd, "name = ?", name).Error
return cmd, err
}
func (c CmdNote) FindByID(id uint, uid uint) (CmdNote, error) {
var cmd CmdNote
err := Db.First(&cmd, "id = ? AND uid = ?", id, uid).Error
return cmd, err
}
func (c CmdNote) FindAll(offset, limit int, uid uint) ([]CmdNote, error) {
var list []CmdNote
err := Db.Where("uid = ?", uid).Offset(offset).Limit(limit).Order("updated_at desc").Find(&list).Error
return list, err
}
func (c CmdNote) UpdateById(id, uid uint, cmd *CmdNote) error {
return Db.Model(&c).Where("id = ? AND uid = ?", id, uid).Updates(cmd).Error
}
func (c CmdNote) DeleteByID(id, uid uint) error {
return Db.Unscoped().Delete(&c, "id = ? AND uid = ?", id, uid).Error
}
================================================
FILE: gossh/app/model/datetime.go
================================================
package model
import (
"database/sql/driver"
"fmt"
"time"
)
const (
TimeFormat = "2006-01-02 15:04:05"
)
type DateTime time.Time
func NewDateTime(str string) (DateTime, error) {
now, err := time.ParseInLocation(TimeFormat, str, time.Local)
if err != nil {
return DateTime{}, fmt.Errorf("can not convert %v to date,must like format:yyyy-MM-dd HH:mm:ss,simple example : %v", str, TimeFormat)
}
t := DateTime(now)
return t, nil
}
func (t DateTime) MarshalJSON() ([]byte, error) {
return []byte(fmt.Sprintf(`"%s"`, time.Time(t).Format(TimeFormat))), nil
}
func (t *DateTime) UnmarshalJSON(b []byte) error {
now, err := time.ParseInLocation(`"`+TimeFormat+`"`, string(b), time.Local)
if err != nil {
return fmt.Errorf("can not convert %v to date,must like format:yyyy-MM-dd HH:mm:ss,simple example : %v", string(b), TimeFormat)
}
*t = DateTime(now)
return nil
}
func (t DateTime) Value() (driver.Value, error) {
var zeroTime time.Time
if time.Time(t).UnixNano() == zeroTime.UnixNano() {
return nil, nil
}
return time.Time(t), nil
}
func (t *DateTime) Scan(v any) error {
value, ok := v.(time.Time)
if ok {
*t = DateTime(value)
return nil
}
return fmt.Errorf("can not convert %v to timestamp", v)
}
func (t *DateTime) ToTime() time.Time {
return time.Time(*t)
}
func (t DateTime) String() string {
return time.Time(t).Format(TimeFormat)
}
================================================
FILE: gossh/app/model/db_init.go
================================================
package model
import (
"errors"
"fmt"
"gossh/app/config"
"gossh/gorm"
"gossh/gorm/driver/mysql"
"gossh/gorm/driver/pgsql"
"gossh/gorm/logger"
_ "gossh/mysql"
_ "gossh/pgsql"
"gossh/sqlite"
"log/slog"
"os"
"path"
"path/filepath"
)
var Db *gorm.DB
func InitDatabase() {
if !config.DefaultConfig.IsInit {
slog.Warn("系统未初始化,跳过DbMigrate")
return
}
err := DbMigrate(config.DefaultConfig.DbType, config.DefaultConfig.DbDsn)
if err != nil {
slog.Error("DbMigrate error", "err_msg", err.Error())
}
}
func GetSqliteDb(dsn string) (*gorm.DB, error) {
//loadInit()
dbPath := path.Join(config.WorkDir, dsn)
// 确保数据库目录存在
if err := os.MkdirAll(filepath.Dir(dbPath), 0755); err != nil {
return nil, fmt.Errorf("failed to create database directory: %v", err)
}
// 尝试打开数据库连接
db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{
Logger: logger.Default.LogMode(logger.Warn),
})
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %v", err)
}
// 验证连接
sqlDB, err := db.DB()
if err != nil {
return nil, fmt.Errorf("failed to get database instance: %v", err)
}
if err := sqlDB.Ping(); err != nil {
return nil, fmt.Errorf("failed to ping database: %v", err)
}
return db, nil
}
func DbMigrate(dbType, dsn string) error {
defer func() {
if err := recover(); err != nil {
slog.Error("DbMigrate error", "err_msg", err)
}
}()
if dbType == "pgsql" {
db, err := gorm.Open(pgsql.Open(dsn), &gorm.Config{
Logger: logger.Default.LogMode(logger.Warn),
})
if err != nil {
return err
}
err = db.Exec("select 1=1;").Error
if err != nil {
return err
}
Db = db
}
if dbType == "mysql" {
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
Logger: logger.Default.LogMode(logger.Warn),
})
if err != nil {
return err
}
err = db.Exec("select 1=1;").Error
if err != nil {
return err
}
Db = db
}
if dbType == "sqlite" {
db, err := GetSqliteDb(dsn)
if err != nil {
return err
}
Db = db
}
if Db == nil {
return errors.New("请检查数据库链接")
}
err := Db.AutoMigrate(
SshConf{}, WebUser{}, CmdNote{},
NetFilter{}, PolicyConf{}, LoginAudit{},
SshdConf{}, SshdUser{}, SshdCert{})
if err != nil {
slog.Error("AutoMigrate error:", "err_msg", err.Error())
return err
}
return nil
}
================================================
FILE: gossh/app/model/login_audit.go
================================================
package model
type LoginAudit struct {
ID uint `gorm:"column:id;primaryKey,autoIncrement" form:"id" json:"id"`
Name string `gorm:"column:name;not null;size:128" form:"name" binding:"required,min=1,max=128" json:"name"`
Pwd string `gorm:"column:pwd;size:128" form:"pwd" binding:"required,min=1,max=128" json:"pwd"`
ClientIp string `gorm:"column:client_ip;size:128" form:"client_ip" binding:"required,min=1,max=128" json:"client_ip"`
UserAgent string `gorm:"column:user_agent;size:512" form:"user_agent" binding:"required,min=0,max=512" json:"user_agent"`
ErrMsg string `gorm:"column:err_msg;size:64" form:"err_msg" binding:"required,min=1,max=64" json:"err_msg"`
IsSuccess string `gorm:"column:is_success;not null;size:64;default:'N'" form:"is_success" binding:"required,min=1,max=64,oneof=Y N" json:"is_success"`
OccurAt DateTime `gorm:"column:occur_at;not null" json:"occur_at" form:"occur_at" binding:"required"`
CreatedAt DateTime `gorm:"column:created_at" json:"-"`
UpdatedAt DateTime `gorm:"column:updated_at" json:"-"`
}
func (c LoginAudit) Create(audit *LoginAudit) error {
return Db.Create(audit).Error
}
func (c LoginAudit) Search(
isSuccess, name, clientIp string,
occurBegin, occurEnd DateTime, offset, limit int,
) ([]LoginAudit, int64, error) {
var list []LoginAudit
var db = Db
if isSuccess != "" {
db = db.Where("is_success = ?", isSuccess)
}
if name != "" {
db = db.Where("name like ?", "%"+name+"%")
}
if clientIp != "" {
db = db.Where("client_ip = ?", clientIp)
}
if occurBegin.String() != "0001-01-01 00:00:00" && occurEnd.String() != "0001-01-01 00:00:00" {
db = db.Where("occur_at between ? AND ?", occurBegin, occurEnd)
}
var count int64
err := db.Model(&LoginAudit{}).Count(&count).Error
if err != nil {
return list, count, err
}
return list, count, db.Debug().Order("occur_at desc").Offset(offset).Limit(limit).Find(&list).Error
}
func (c LoginAudit) FindByID(id uint) (LoginAudit, error) {
var audit LoginAudit
err := Db.First(&audit, "id = ?", id).Error
return audit, err
}
func (c LoginAudit) DeleteByID(id uint) error {
return Db.Unscoped().Delete(&c, "id = ? AND is_root = ?", id, "N").Error
}
================================================
FILE: gossh/app/model/net_filter.go
================================================
package model
import "time"
type NetFilter struct {
ID uint `gorm:"column:id;primaryKey,autoIncrement" form:"id" json:"id"`
Name string `gorm:"column:name;not null" form:"name" json:"name" binding:"required"`
Cidr string `gorm:"column:cidr;not null" form:"cidr" json:"cidr" binding:"required,cidr"`
NetPolicy string `gorm:"column:net_policy;not null;size:64;default:'Y'" form:"net_policy" binding:"required,min=1,max=64,oneof=Y N" json:"net_policy"`
PolicyNo uint `gorm:"column:policy_no;not null;" form:"policy_no" json:"policy_no" binding:"required,gte=1,lte=65535"`
ExpiryAt DateTime `gorm:"column:expiry_at;not null" json:"expiry_at" form:"expiry_at" binding:"required"`
CreatedAt DateTime `gorm:"column:created_at" json:"-"`
UpdatedAt DateTime `gorm:"column:updated_at" json:"-"`
}
func (c NetFilter) Create(filter *NetFilter) error {
return Db.Create(filter).Error
}
func (c NetFilter) FindByName(name string) (NetFilter, error) {
var filter NetFilter
err := Db.First(&filter, "name = ?", name).Error
return filter, err
}
func (c NetFilter) FindByID(id uint) (NetFilter, error) {
var filter NetFilter
err := Db.First(&filter, "id = ? ", id).Error
return filter, err
}
func (c NetFilter) FindAll(offset, limit int) ([]NetFilter, error) {
var list []NetFilter
err := Db.Offset(offset).Limit(limit).Order("policy_no asc, expiry_at, updated_at desc").Find(&list).Error
return list, err
}
func (c NetFilter) FindAllPolicy(policy string) ([]NetFilter, error) {
var list []NetFilter
err := Db.Where("net_policy = ? AND expiry_at > ?", policy, time.Now()).Order("policy_no asc, expiry_at, updated_at desc").Find(&list).Error
return list, err
}
func (c NetFilter) UpdateById(id uint, filter *NetFilter) error {
return Db.Model(&c).Where("id = ?", id).Updates(filter).Error
}
func (c NetFilter) DeleteByID(id uint) error {
return Db.Unscoped().Delete(&c, "id = ?", id).Error
}
================================================
FILE: gossh/app/model/policy_conf.go
================================================
package model
type PolicyConf struct {
ID uint `gorm:"column:id;primaryKey,autoIncrement" form:"id" json:"id"`
NetPolicy string `gorm:"column:net_policy;not null;size:64;default:'Y'" form:"net_policy" binding:"required,min=1,max=64,oneof=Y N" json:"net_policy"`
CreatedAt DateTime `gorm:"column:created_at" json:"-"`
UpdatedAt DateTime `gorm:"column:updated_at" json:"-"`
}
func (c PolicyConf) Create(conf *PolicyConf) error {
return Db.Create(conf).Error
}
func (c PolicyConf) FindByID(id uint) (PolicyConf, error) {
var conf PolicyConf
err := Db.First(&conf, "id = ? ", id).Error
return conf, err
}
func (c PolicyConf) FindAll(offset, limit int) ([]PolicyConf, error) {
var list []PolicyConf
err := Db.Offset(offset).Limit(limit).Order("updated_at desc").Find(&list).Error
return list, err
}
func (c PolicyConf) UpdateById(id uint, conf *PolicyConf) error {
return Db.Model(&c).Where("id = ?", id).Updates(conf).Error
}
func (c PolicyConf) DeleteByID(id uint) error {
return Db.Unscoped().Delete(&c, "id = ?", id).Error
}
================================================
FILE: gossh/app/model/ssh_conf.go
================================================
package model
import (
"errors"
"gossh/app/config"
"gossh/app/utils"
"gossh/gorm"
)
type SshConf struct {
ID uint `gorm:"column:id;primaryKey,autoIncrement" form:"id" json:"id"`
Uid uint `gorm:"column:uid;not null;default:0" form:"uid" json:"uid"`
Name string `gorm:"column:name;not null;size:64" form:"name" binding:"required,min=1,max=63" json:"name"`
Address string `gorm:"column:address;size:128" form:"address" binding:"required,min=1,max=128" json:"address"`
User string `gorm:"column:user;size:128" form:"user" binding:"required,min=1,max=128" json:"user"`
Pwd string `gorm:"column:pwd;not null;size:4096;default:''" form:"pwd" binding:"max=4096" json:"pwd"`
AuthType string `gorm:"column:auth_type;not null;size:32;default:'pwd'" form:"auth_type" binding:"required,min=1,max=32,oneof=pwd cert" json:"auth_type"`
NetType string `gorm:"column:net_type;not null;size:32;default:'tcp4'" form:"net_type" binding:"required,min=1,max=32,oneof=tcp4 tcp6" json:"net_type"`
CertData string `gorm:"column:cert_data;type:text" form:"cert_data" json:"cert_data"`
CertPwd string `gorm:"column:cert_pwd;not null;size:128;default:''" form:"cert_pwd" binding:"max=128" json:"cert_pwd"`
Port uint16 `gorm:"column:port;not null;default:22" form:"port" binding:"required,gte=1,lte=65535" json:"port"`
FontSize uint16 `gorm:"column:font_size;not null;default:14" form:"font_size" binding:"required,gte=8,lte=48" json:"font_size"`
Background string `gorm:"column:background;not null;size:128;default:'#000000'" form:"background" binding:"required,hexcolor" json:"background"`
Foreground string `gorm:"column:foreground;not null;size:128;default:'#FFFFFF'" form:"foreground" binding:"required,hexcolor" json:"foreground"`
CursorColor string `gorm:"column:cursor_color;not null;size:128;default:'#FFFFFF'" form:"cursor_color" binding:"required,hexcolor" json:"cursor_color"`
FontFamily string `gorm:"column:font_family;not null;size:128;default:'Courier'" form:"font_family" binding:"min=1,max=128" json:"font_family"`
CursorStyle string `gorm:"column:cursor_style;not null;size:128;default:'block'" form:"cursor_style" binding:"min=1,max=128" json:"cursor_style"`
Shell string `gorm:"column:shell;not null;size:64;default:'sh'" form:"shell" binding:"min=1,max=128" json:"shell"`
PtyType string `gorm:"column:pty_type;not null;size:64;default:'xterm-256color'" form:"pty_type" binding:"min=1,max=128" json:"pty_type"`
InitCmd string `gorm:"column:init_cmd;type:text" form:"init_cmd" json:"init_cmd"`
InitBanner string `gorm:"column:init_banner;type:text" form:"init_banner" json:"init_banner"`
CreatedAt DateTime `gorm:"column:created_at" json:"-"`
UpdatedAt DateTime `gorm:"column:updated_at" json:"-"`
}
func (c *SshConf) Create(conf *SshConf) error {
return Db.Create(conf).Error
}
func (c *SshConf) FindByID(id uint, uid uint) (SshConf, error) {
var conf SshConf
err := Db.First(&conf, "id = ? AND uid = ?", id, uid).Error
return conf, err
}
func (c *SshConf) FindAll(offset, limit int, uid uint) ([]SshConf, error) {
var list []SshConf
err := Db.Where("uid = ?", uid).Offset(offset).Limit(limit).Order("updated_at desc").Find(&list).Error
return list, err
}
func (_ *SshConf) UpdateById(id, uid uint, conf *SshConf) error {
return Db.Model(&conf).Where("id = ? AND uid = ?", id, uid).Select("*").Omit("id", "uid").Updates(conf).Error
}
func (c *SshConf) DeleteByID(id, uid uint) error {
return Db.Unscoped().Delete(&c, "id = ? AND uid = ?", id, uid).Error
}
// BeforeCreate Hook
func (c *SshConf) BeforeCreate(_ *gorm.DB) error {
_, err := c.sshConfEncrypt()
return err
}
// BeforeUpdate Hook
func (c *SshConf) BeforeUpdate(_ *gorm.DB) error {
_, err := c.sshConfEncrypt()
return err
}
// AfterFind Hook
func (c *SshConf) AfterFind(_ *gorm.DB) error {
var err error
c.Pwd, err = utils.DecryptString(c.Pwd, config.DefaultConfig.AesSecret)
if err != nil {
return errors.New("解密密码错误," + err.Error())
}
c.CertPwd, err = utils.DecryptString(c.CertPwd, config.DefaultConfig.AesSecret)
if err != nil {
return errors.New("解密证书密码错误,," + err.Error())
}
c.CertData, err = utils.DecryptString(c.CertData, config.DefaultConfig.AesSecret)
if err != nil {
return errors.New("解密密证书数据错误," + err.Error())
}
return nil
}
func (c *SshConf) sshConfEncrypt() (*SshConf, error) {
var err error
c.Pwd, err = utils.EncryptString(c.Pwd, config.DefaultConfig.AesSecret)
if err != nil {
return c, errors.New("加密密码错误," + err.Error())
}
c.CertPwd, err = utils.EncryptString(c.CertPwd, config.DefaultConfig.AesSecret)
if err != nil {
return c, errors.New("加密证书密码错误," + err.Error())
}
c.CertData, err = utils.EncryptString(c.CertData, config.DefaultConfig.AesSecret)
if err != nil {
return c, errors.New("加密证书数据错误," + err.Error())
}
return c, nil
}
================================================
FILE: gossh/app/model/sshd_cert.go
================================================
package model
import "time"
type SshdCert struct {
ID uint `gorm:"column:id;primaryKey,autoIncrement" form:"id" json:"id"`
DescInfo string `gorm:"column:desc_info;size:64" form:"desc_info" binding:"required,min=1,max=64" json:"desc_info"`
Name string `gorm:"column:name;uniqueIndex;not null;size:64" form:"name" binding:"required,min=1,max=63" json:"name"`
PubKey string `gorm:"column:pub_key;type:text;not null;" form:"pub_key" binding:"required,min=16" json:"pub_key"`
IsEnable string `gorm:"column:is_enable;not null;size:64;default:'Y'" form:"is_enable" binding:"required,min=1,max=64,oneof=Y N" json:"is_enable"`
ExpiryAt DateTime `gorm:"column:expiry_at;not null" json:"expiry_at" form:"expiry_at" binding:"required"`
CreatedAt DateTime `gorm:"column:created_at" json:"-"`
UpdatedAt DateTime `gorm:"column:updated_at" json:"-"`
}
func (c *SshdCert) Create(user *SshdCert) error {
return Db.Create(user).Error
}
func (c *SshdCert) FindByName(name string) (SshdCert, error) {
var user SshdCert
err := Db.Find(&user, "name = ?", name).Error
return user, err
}
func (c *SshdCert) FindByID(id uint) (SshdCert, error) {
var user SshdCert
err := Db.First(&user, "id = ?", id).Error
return user, err
}
func (c *SshdCert) FindAll(limit, offset int) ([]SshdCert, error) {
var list []SshdCert
err := Db.Limit(limit).Offset(offset).Find(&list).Error
return list, err
}
func (c *SshdCert) GetAuthorizedKeys() (string, error) {
var authorizedKeys = ""
var list []string
err := Db.Model(c).Select("pub_key").Where("is_enable = ? AND expiry_at > ?", "Y", time.Now()).Find(&list).Error
if err != nil {
return "", err
}
for _, v := range list {
authorizedKeys += v + "\n"
}
return authorizedKeys, err
}
func (c *SshdCert) UpdateById(id uint, user *SshdCert) error {
return Db.Model(&c).Where("id = ?", id).Select("*").Omit("id").Updates(user).Error
}
func (c *SshdCert) DeleteByID(id uint) error {
return Db.Unscoped().Delete(&c, "id = ?", id).Error
}
================================================
FILE: gossh/app/model/sshd_conf.go
================================================
package model
import (
"gossh/app/config"
"gossh/app/utils"
"gossh/gorm"
)
type SshdConf struct {
ID uint `gorm:"column:id;primaryKey,autoIncrement" form:"id" json:"id"`
Name string `gorm:"column:name;uniqueIndex;not null;size:64" form:"name" binding:"required,min=1,max=63" json:"name"`
Host string `gorm:"column:host;size:128" form:"host" binding:"required,min=1,max=128" json:"host"`
Port uint16 `gorm:"column:port;uniqueIndex;not null" form:"port" binding:"required,gte=1,lte=65535" json:"port"`
Shell string `gorm:"column:shell;not null;size:64;default:''" form:"shell" binding:"min=1,max=128" json:"shell"`
KeyFile string `gorm:"column:key_file;not null;type:text" form:"key_file" json:"key_file"`
KeySeed string `gorm:"column:key_seed;not null;size:4096;default:''" form:"key_seed" binding:"max=4096" json:"key_seed"`
KeepAlive uint16 `gorm:"column:keep_alive;not null;default:60" form:"keep_alive" binding:"required,gte=10,lte=360" json:"keep_alive"`
LoadEnv string `gorm:"column:load_env;not null;size:64;default:'Y'" form:"load_env" binding:"required,min=1,max=64,oneof=Y N" json:"load_env"`
AuthType string `gorm:"column:auth_type;not null;size:32;default:'all'" form:"auth_type" binding:"required,min=1,max=32,oneof=all pwd cert" json:"auth_type"`
ServerVersion string `gorm:"column:server_version;not null;size:64;default:'SSH-2.0-OpenSSH'" form:"name" binding:"required,min=10,max=63" json:"server_version"`
CreatedAt DateTime `gorm:"column:created_at" json:"-"`
UpdatedAt DateTime `gorm:"column:updated_at" json:"-"`
}
func (c *SshdConf) Create(conf *SshdConf) error {
return Db.Create(conf).Error
}
func (c *SshdConf) FindByName(name string) (SshdConf, error) {
var conf SshdConf
err := Db.Find(&conf, "name = ?", name).Error
return conf, err
}
func (c *SshdConf) FindByID(id uint) (SshdConf, error) {
var conf SshdConf
err := Db.First(&conf, "id = ?", id).Error
return conf, err
}
func (c *SshdConf) UpdateByName(name string, conf *SshdConf) error {
return Db.Model(&c).Where("name = ?", name).Select("*").Omit("id", "name").Updates(conf).Error
}
func (c *SshdConf) DeleteByName(name string) error {
return Db.Unscoped().Delete(&c, "name = ?", name).Error
}
// BeforeCreate Hook
func (c *SshdConf) BeforeCreate(_ *gorm.DB) error {
var err error
c.KeyFile, err = utils.EncryptString(c.KeyFile, config.DefaultConfig.AesSecret)
return err
}
// BeforeUpdate Hook
func (c *SshdConf) BeforeUpdate(_ *gorm.DB) error {
var err error
c.KeyFile, err = utils.EncryptString(c.KeyFile, config.DefaultConfig.AesSecret)
return err
}
// AfterFind Hook
func (c *SshdConf) AfterFind(_ *gorm.DB) error {
var err error
c.KeyFile, err = utils.DecryptString(c.KeyFile, config.DefaultConfig.AesSecret)
return err
}
================================================
FILE: gossh/app/model/sshd_user.go
================================================
package model
import (
"errors"
"fmt"
"gossh/app/config"
"gossh/app/utils"
"gossh/gorm"
"time"
)
type SshdUser struct {
ID uint `gorm:"column:id;primaryKey,autoIncrement" form:"id" json:"id"`
Name string `gorm:"column:name;uniqueIndex;not null;size:64" form:"name" binding:"required,min=1,max=63" json:"name"`
Pwd string `gorm:"column:pwd;size:64" form:"pwd" binding:"required,min=1,max=64" json:"pwd"`
DescInfo string `gorm:"column:desc_info;size:64" form:"desc_info" binding:"required,min=1,max=64" json:"desc_info"`
IsEnable string `gorm:"column:is_enable;not null;size:64;default:'Y'" form:"is_enable" binding:"required,min=1,max=64,oneof=Y N" json:"is_enable"`
WorkDir string `gorm:"column:work_dir;size:4096;default:''" form:"work_dir" binding:"max=4096" json:"work_dir"`
ExpiryAt DateTime `gorm:"column:expiry_at;not null" json:"expiry_at" form:"expiry_at" binding:"required"`
CreatedAt DateTime `gorm:"column:created_at" json:"-"`
UpdatedAt DateTime `gorm:"column:updated_at" json:"-"`
}
func (c *SshdUser) Create(user *SshdUser) error {
return Db.Create(user).Error
}
func (c *SshdUser) FindByNameAndPwd(name, pwd string) (SshdUser, error) {
var user SshdUser
err := Db.First(&user, "name = ? AND is_enable = ?", name, "Y").Error
if err != nil {
return SshdUser{}, err
}
if user.Pwd != pwd {
return SshdUser{}, errors.New("password error")
}
if user.ExpiryAt.ToTime().Unix() < time.Now().Unix() {
return user, errors.New(fmt.Sprintf("sshd user '%s' expired", name))
}
return user, err
}
func (c *SshdUser) FindByName(name string) (SshdUser, error) {
var user SshdUser
err := Db.Find(&user, "name = ?", name).Error
return user, err
}
func (c *SshdUser) FindByID(id uint) (SshdUser, error) {
var user SshdUser
err := Db.First(&user, "id = ?", id).Error
return user, err
}
func (c *SshdUser) FindAll(limit, offset int) ([]SshdUser, error) {
var list []SshdUser
err := Db.Limit(limit).Offset(offset).Find(&list).Error
return list, err
}
func (c *SshdUser) UpdateById(id uint, user *SshdUser) error {
return Db.Model(&c).Where("id = ?", id).Select("*").Omit("id").Updates(user).Error
}
func (c *SshdUser) DeleteByID(id uint) error {
return Db.Unscoped().Delete(&c, "id = ?", id).Error
}
// BeforeCreate Hook
func (c *SshdUser) BeforeCreate(_ *gorm.DB) error {
var err error
c.Pwd, err = utils.EncryptString(c.Pwd, config.DefaultConfig.AesSecret)
return err
}
// BeforeUpdate Hook
func (c *SshdUser) BeforeUpdate(_ *gorm.DB) error {
var err error
c.Pwd, err = utils.EncryptString(c.Pwd, config.DefaultConfig.AesSecret)
return err
}
// AfterFind Hook
func (c *SshdUser) AfterFind(_ *gorm.DB) error {
var err error
c.Pwd, err = utils.DecryptString(c.Pwd, config.DefaultConfig.AesSecret)
return err
}
================================================
FILE: gossh/app/model/web_user.go
================================================
package model
import (
"errors"
"gossh/app/config"
"gossh/app/utils"
"gossh/gorm"
)
type WebUser struct {
ID uint `gorm:"column:id;primaryKey,autoIncrement" form:"id" json:"id"`
Name string `gorm:"column:name;uniqueIndex;not null;size:64" form:"name" binding:"required,min=1,max=63" json:"name"`
Pwd string `gorm:"column:pwd;size:64" form:"pwd" binding:"required,min=1,max=64" json:"pwd"`
DescInfo string `gorm:"column:desc_info;size:64" form:"desc_info" binding:"required,min=1,max=64" json:"desc_info"`
IsAdmin string `gorm:"column:is_admin;not null;size:64;default:'N'" form:"is_admin" binding:"required,min=1,max=64,oneof=Y N" json:"is_admin"`
IsEnable string `gorm:"column:is_enable;not null;size:64;default:'Y'" form:"is_enable" binding:"required,min=1,max=64,oneof=Y N" json:"is_enable"`
IsRoot string `gorm:"column:is_root;not null;size:64;default:'N'" form:"is_root" json:"is_root"`
ExpiryAt DateTime `gorm:"column:expiry_at;not null" json:"expiry_at" form:"expiry_at" binding:"required"`
CreatedAt DateTime `gorm:"column:created_at" json:"-"`
UpdatedAt DateTime `gorm:"column:updated_at" json:"-"`
}
func (c *WebUser) Create(user *WebUser) error {
return Db.Create(user).Error
}
func (c *WebUser) FindByNameAndPwd(name, pwd string) (WebUser, error) {
var user WebUser
err := Db.First(&user, "name = ?", name).Error
if err != nil {
return WebUser{}, err
}
if user.Pwd != pwd {
return WebUser{}, errors.New("password error")
}
return user, err
}
func (c *WebUser) FindByName(name string) (WebUser, error) {
var user WebUser
err := Db.Find(&user, "name = ?", name).Error
return user, err
}
func (c *WebUser) FindByID(id uint) (WebUser, error) {
var user WebUser
err := Db.First(&user, "id = ?", id).Error
return user, err
}
func (c *WebUser) FindAll(limit, offset int) ([]WebUser, error) {
var list []WebUser
err := Db.Where("is_root = ?", "N").Limit(limit).Offset(offset).Find(&list).Error
return list, err
}
func (c *WebUser) UpdateById(id uint, user *WebUser) error {
return Db.Model(&c).Where("id = ? AND is_root = ?", id, "N").Select("*").Omit("id", "is_root").Updates(user).Error
}
func (c *WebUser) UpdatePassword(id uint, user *WebUser) error {
return Db.Model(&c).Where("id = ?", id).Select("*").Omit("id").Updates(user).Error
}
func (c *WebUser) DeleteByID(id uint) error {
return Db.Unscoped().Delete(&c, "id = ? AND is_root = ?", id, "N").Error
}
// BeforeCreate Hook
func (c *WebUser) BeforeCreate(_ *gorm.DB) error {
var err error
c.Pwd, err = utils.EncryptString(c.Pwd, config.DefaultConfig.AesSecret)
return err
}
// BeforeUpdate Hook
func (c *WebUser) BeforeUpdate(_ *gorm.DB) error {
var err error
c.Pwd, err = utils.EncryptString(c.Pwd, config.DefaultConfig.AesSecret)
return err
}
// AfterFind Hook
func (c *WebUser) AfterFind(_ *gorm.DB) error {
var err error
c.Pwd, err = utils.DecryptString(c.Pwd, config.DefaultConfig.AesSecret)
return err
}
================================================
FILE: gossh/app/service/cmd_note.go
================================================
package service
import (
"gossh/app/model"
"gossh/gin"
"strconv"
)
func CmdNoteCreate(c *gin.Context) {
var cmd model.CmdNote
if err := c.ShouldBind(&cmd); err != nil {
c.JSON(200, gin.H{"code": 1, "msg": err.Error()})
return
}
cmd.Uid = c.GetUint("uid")
err := cmd.Create(&cmd)
if err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
CmdNoteFindAll(c)
}
func CmdNoteFindByID(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
var cmd model.CmdNote
data, err := cmd.FindByID(uint(id), c.GetUint("uid"))
if err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
c.JSON(200, gin.H{"code": 0, "msg": "ok", "data": data})
}
func CmdNoteFindAll(c *gin.Context) {
limit, err := strconv.Atoi(c.DefaultQuery("limit", "10000"))
if err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
offset, err := strconv.Atoi(c.DefaultQuery("offset", "0"))
if err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
var cmd model.CmdNote
data, err := cmd.FindAll(offset, limit, c.GetUint("uid"))
if err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
c.JSON(200, gin.H{"code": 0, "msg": "ok", "data": data})
}
func CmdNoteUpdateById(c *gin.Context) {
var cmd model.CmdNote
if err := c.ShouldBind(&cmd); err != nil {
c.JSON(200, gin.H{"code": 1, "msg": err.Error()})
return
}
err := cmd.UpdateById(cmd.ID, c.GetUint("uid"), &cmd)
if err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
CmdNoteFindAll(c)
}
func CmdNoteDeleteById(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
var cmd model.CmdNote
err = cmd.DeleteByID(uint(id), c.GetUint("uid"))
if err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
CmdNoteFindAll(c)
}
================================================
FILE: gossh/app/service/db_conn.go
================================================
package service
import (
"gossh/app/config"
"gossh/app/model"
"gossh/gin"
"gossh/gorm"
"gossh/gorm/driver/mysql"
"gossh/gorm/driver/pgsql"
_ "gossh/mysql"
_ "gossh/pgsql"
"log/slog"
)
type DbConnConf struct {
DbDsn string `form:"db_dsn" binding:"required,min=1,max=65535" json:"db_dsn"`
DbType string `form:"db_type" binding:"required,oneof=sqlite pgsql mysql" json:"db_type"`
}
func DbConnCheck(c *gin.Context) {
if config.DefaultConfig.IsInit {
c.JSON(200, gin.H{"code": 1, "msg": "系统已经完成初始化配置"})
return
}
var dbConf DbConnConf
if err := c.ShouldBind(&dbConf); err != nil {
c.JSON(200, gin.H{"code": 1, "msg": err.Error()})
return
}
err := DbConnTestCheck(dbConf)
if err != nil {
c.JSON(200, gin.H{"code": 1, "msg": err.Error()})
return
}
c.JSON(200, gin.H{"code": 0, "msg": "连接成功"})
}
func DbConnTestCheck(dbConf DbConnConf) error {
slog.Info("DB link check", "db_type", dbConf.DbType, "db_dsn", dbConf.DbDsn)
if dbConf.DbType == "pgsql" {
Db, err := gorm.Open(pgsql.Open(dbConf.DbDsn), &gorm.Config{})
if err != nil {
return err
}
err = Db.Exec("select 1=1;").Error
if err != nil {
return err
}
}
if dbConf.DbType == "mysql" {
Db, err := gorm.Open(mysql.Open(dbConf.DbDsn), &gorm.Config{})
if err != nil {
return err
}
err = Db.Exec("select 1=1;").Error
if err != nil {
return err
}
}
if dbConf.DbType == "sqlite" {
Db, err := model.GetSqliteDb(dbConf.DbDsn)
if err != nil {
return err
}
err = Db.Exec("select 1=1;").Error
if err != nil {
return err
}
}
return nil
}
================================================
FILE: gossh/app/service/login_audit.go
================================================
package service
import (
"gossh/app/model"
"gossh/gin"
"strconv"
)
func AuditFindByID(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
var audit model.LoginAudit
data, err := audit.FindByID(uint(id))
if err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
c.JSON(200, gin.H{"code": 0, "msg": "ok", "data": data})
}
func LoginAuditSearch(c *gin.Context) {
type Param struct {
OccurBegin model.DateTime `json:"occur_begin" form:"occur_begin"`
OccurEnd model.DateTime `json:"occur_end" form:"occur_end"`
Offset int `form:"offset" json:"offset" binding:"min=0"`
Limit int `form:"limit" json:"limit" binding:"max=1000"`
Name string `form:"name" binding:"max=64" json:"name"`
ClientIp string `form:"client_ip" binding:"max=128" json:"client_ip"`
IsSuccess string `form:"is_success" binding:"max=1" json:"is_success"`
}
var p Param
if err := c.ShouldBind(&p); err != nil {
c.JSON(200, gin.H{"code": 1, "msg": err.Error()})
return
}
if p.Limit == 0 {
p.Limit = 100
}
var audit model.LoginAudit
data, count, err := audit.Search(p.IsSuccess, p.Name, p.ClientIp, p.OccurBegin, p.OccurEnd, p.Offset, p.Limit)
if err != nil {
c.JSON(200, gin.H{"code": 1, "msg": err.Error()})
return
}
c.JSON(200, gin.H{"code": 0, "msg": "ok", "data": data, "count": count})
}
func LoginAuditDeleteById(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
var audit model.LoginAudit
err = audit.DeleteByID(uint(id))
if err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
}
================================================
FILE: gossh/app/service/net_filter.go
================================================
package service
import (
"gossh/app/model"
"gossh/gin"
"strconv"
)
func NetFilterCreate(c *gin.Context) {
var netFilter model.NetFilter
if err := c.ShouldBind(&netFilter); err != nil {
c.JSON(200, gin.H{"code": 1, "msg": err.Error()})
return
}
err := netFilter.Create(&netFilter)
if err != nil {
c.JSON(200, gin.H{"code": 3, "msg": err.Error()})
return
}
NetFilterFindAll(c)
}
func NetFilterFindByID(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
var netFilter model.NetFilter
data, err := netFilter.FindByID(uint(id))
if err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
c.JSON(200, gin.H{"code": 0, "msg": "ok", "data": data})
}
func NetFilterFindAll(c *gin.Context) {
limit, err := strconv.Atoi(c.DefaultQuery("limit", "10000"))
if err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
offset, err := strconv.Atoi(c.DefaultQuery("offset", "0"))
if err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
var netFilter model.NetFilter
data, err := netFilter.FindAll(offset, limit)
if err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
c.JSON(200, gin.H{"code": 0, "msg": "ok", "data": data})
}
func NetFilterUpdateById(c *gin.Context) {
var netFilter model.NetFilter
if err := c.ShouldBind(&netFilter); err != nil {
c.JSON(200, gin.H{"code": 1, "msg": err.Error()})
return
}
err := netFilter.UpdateById(netFilter.ID, &netFilter)
if err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
NetFilterFindAll(c)
}
func NetFilterDeleteById(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
var netFilter model.NetFilter
err = netFilter.DeleteByID(uint(id))
if err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
NetFilterFindAll(c)
}
================================================
FILE: gossh/app/service/policy_conf.go
================================================
package service
import (
"gossh/app/model"
"gossh/gin"
"strconv"
)
func PolicyConfCreate(c *gin.Context) {
var conf model.PolicyConf
if err := c.ShouldBind(&conf); err != nil {
c.JSON(200, gin.H{"code": 1, "msg": err.Error()})
return
}
err := conf.Create(&conf)
if err != nil {
c.JSON(200, gin.H{"code": 3, "msg": err.Error()})
return
}
PolicyConfFindAll(c)
}
func PolicyConfFindByID(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
var conf model.PolicyConf
data, err := conf.FindByID(uint(id))
if err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
c.JSON(200, gin.H{"code": 0, "msg": "ok", "data": data})
}
func PolicyConfFindAll(c *gin.Context) {
limit, err := strconv.Atoi(c.DefaultQuery("limit", "10000"))
if err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
offset, err := strconv.Atoi(c.DefaultQuery("offset", "0"))
if err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
var conf model.PolicyConf
data, err := conf.FindAll(offset, limit)
if err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
c.JSON(200, gin.H{"code": 0, "msg": "ok", "data": data})
}
func PolicyConfUpdateById(c *gin.Context) {
var conf model.PolicyConf
if err := c.ShouldBind(&conf); err != nil {
c.JSON(200, gin.H{"code": 1, "msg": err.Error()})
return
}
err := conf.UpdateById(conf.ID, &conf)
if err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
data, err := conf.FindByID(conf.ID)
if err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
c.JSON(200, gin.H{"code": 0, "msg": "ok", "data": data})
}
func PolicyConfDeleteById(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
var conf model.PolicyConf
err = conf.DeleteByID(uint(id))
if err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
PolicyConfFindAll(c)
}
================================================
FILE: gossh/app/service/service_init.go
================================================
package service
import (
"gossh/app/config"
"io"
"log/slog"
"sync"
"time"
)
// OnlineClients 存储的客户端信息
var OnlineClients = sync.Map{}
func DeleteOnlineClient(sessionId string) {
defer func() {
if err := recover(); err != nil {
slog.Error("DeleteOnlineClient recover error:", "err_msg", err)
}
}()
cli, ok := OnlineClients.Load(sessionId)
if !ok || cli == nil {
slog.Info("OnlineClient sessionId not exist")
return
}
conn, ok := cli.(*SshConn)
if !ok || conn == nil {
slog.Error("DeleteOnlineClient type Asset error")
return
}
// 从map 中删除会话
defer OnlineClients.Delete(sessionId)
// 关闭 ssh 客户端
defer func() {
err := conn.sshClient.Close()
if err == io.EOF {
slog.Info("sshClient.Close EOF")
return
}
if err != nil {
slog.Error("DeleteOnlineClient.Close sshClient error:", "err_msg", err)
}
}()
// 关闭 sftp 客户端
defer func() {
err := conn.sftpClient.Close()
if err == io.EOF {
slog.Info("sftpClient.Close EOF")
return
}
if err != nil {
slog.Error("DeleteOnlineClient.Close sftpClient error:", "err_msg", err)
}
}()
// 关闭 ssh 会话
defer func() {
err := conn.sshSession.Close()
if err == io.EOF {
slog.Info("sshSession.Close EOF")
return
}
if err != nil {
slog.Error("DeleteOnlineClient.Close sshSession error:", "err_msg", err)
}
}()
// 关闭 websocket
defer func() {
err := conn.ws.Close()
if err != nil {
slog.Error("DeleteOnlineClient.Close ws error:", "err_msg", err)
}
}()
}
// 清理不活跃的会话
func cleanNoActiveSession() {
defer func() {
if err := recover(); err != nil {
slog.Error("cleanNoActiveSession error:", "err_msg", err)
}
}()
OnlineClients.Range(func(key, value any) bool {
// 对键进行类型断言
if sessionId, ok := key.(string); ok {
// 对值进行类型断言
if conn, ok := value.(*SshConn); ok {
if conn.LastActiveTime.Add(time.Minute).Before(time.Now()) {
slog.Info("clean not active session:", "sid", sessionId)
DeleteOnlineClient(sessionId)
}
}
}
return true
})
}
func initApp() {
defer func() {
if err := recover(); err != nil {
slog.Error("service init error")
}
}()
if config.DefaultConfig.IsInit {
isStartSshd <- true
}
for {
cleanNoActiveSession()
time.Sleep(config.DefaultConfig.ClientCheck)
}
}
func InitSessionClean() {
go initApp()
}
================================================
FILE: gossh/app/service/ssh_conf.go
================================================
package service
import (
"gossh/app/model"
"gossh/gin"
"strconv"
)
func ConfCreate(c *gin.Context) {
var config model.SshConf
if err := c.ShouldBind(&config); err != nil {
c.JSON(200, gin.H{"code": 1, "msg": err.Error()})
return
}
config.Uid = c.GetUint("uid")
err := config.Create(&config)
if err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
ConfFindAll(c)
}
func ConfFindByID(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
var config model.SshConf
data, err := config.FindByID(uint(id), c.GetUint("uid"))
if err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
c.JSON(200, gin.H{"code": 0, "msg": "ok", "data": data})
}
func ConfFindAll(c *gin.Context) {
limit, err := strconv.Atoi(c.DefaultQuery("limit", "10000"))
if err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
offset, err := strconv.Atoi(c.DefaultQuery("offset", "0"))
if err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
var config model.SshConf
data, err := config.FindAll(offset, limit, c.GetUint("uid"))
if err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
c.JSON(200, gin.H{"code": 0, "msg": "ok", "data": data})
}
func ConfUpdateById(c *gin.Context) {
var config model.SshConf
if err := c.ShouldBind(&config); err != nil {
c.JSON(200, gin.H{"code": 1, "msg": err.Error()})
return
}
err := config.UpdateById(config.ID, c.GetUint("uid"), &config)
if err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
// c.JSON(200, gin.H{"code": 0, "msg": "ok"})
ConfFindAll(c)
}
func ConfDeleteById(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
var config model.SshConf
err = config.DeleteByID(uint(id), c.GetUint("uid"))
if err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
// c.JSON(200, gin.H{"code": 0, "msg": "ok"})
ConfFindAll(c)
}
================================================
FILE: gossh/app/service/ssh_conn.go
================================================
package service
import (
"encoding/json"
"fmt"
"gossh/app/model"
"gossh/app/utils"
"gossh/crypto/ssh"
"gossh/gin"
"gossh/sftp"
"gossh/websocket"
"io"
"log/slog"
"net/http"
"strconv"
"strings"
"time"
)
type SshConn struct {
*model.SshConf
//会话ID
SessionId string `json:"session_id"`
// 最后活跃时间,心跳
LastActiveTime time.Time `json:"last_active_time"`
// 创建连接的时间
StartTime time.Time `json:"start_time"`
// 客户端IP
ClientIP string `json:"client_ip"`
//ssh客户端
sshClient *ssh.Client
//sftp客户端
sftpClient *sftp.Client
//ssh会话
sshSession *ssh.Session
// websocket 连接
ws *websocket.Conn
}
// MarshalJSON 重写序列化方法
func (s *SshConn) MarshalJSON() ([]byte, error) {
type Alias SshConn
return json.Marshal(&struct {
Alias
LastActiveTime string `json:"last_active_time"`
StartTime string `json:"start_time"`
Pwd string `json:"pwd"`
CertData string `json:"cert_data"`
CertPwd string `json:"cert_pwd"`
CreatedAt uint ` json:"created_at"`
UpdatedAt uint ` json:"updated_at"`
DeletedAt uint ` json:"deleted_at"`
}{
Alias: (Alias)(*s),
LastActiveTime: s.LastActiveTime.Format("2006-01-02 15:04:05"),
StartTime: s.StartTime.Format("2006-01-02 15:04:05"),
Pwd: "",
CertData: "",
CertPwd: "",
CreatedAt: 0,
UpdatedAt: 0,
DeletedAt: 0,
})
}
// 连接主机
func (s *SshConn) connect(clientIp string) error {
defer func() {
if err := recover(); err != nil {
slog.Error("ssh connect error:", "err_msg", err)
}
}()
s.ClientIP = clientIp
config := ssh.ClientConfig{
User: s.User,
Auth: []ssh.AuthMethod{
ssh.Password(s.Pwd),
ssh.KeyboardInteractive(func(name, instruction string, questions []string, echos []bool) ([]string, error) {
answers := make([]string, len(questions))
for i := range answers {
answers[i] = s.Pwd
}
return answers, nil
}),
},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 30 * time.Second,
}
// 证书认证方式
if s.AuthType == "cert" {
privateKeyPassword := []byte(s.CertPwd)
privateKeyBytes := []byte(s.CertData)
if s.CertPwd != "" {
// 使用证书有证书密码登陆
signer, err := ssh.ParsePrivateKeyWithPassphrase(privateKeyBytes, privateKeyPassword)
if err != nil {
slog.Error("解析带密码私钥Key错误:", "err_msg", err.Error())
return err
}
config.Auth = []ssh.AuthMethod{
ssh.PublicKeys(signer),
}
} else {
// 使用证书空密码登陆
signer, err := ssh.ParsePrivateKey(privateKeyBytes)
if err != nil {
slog.Error("解析私钥Key错误:", "err_msg", err.Error())
return err
}
config.Auth = []ssh.AuthMethod{
ssh.PublicKeys(signer),
}
}
}
addr := fmt.Sprintf("%s:%d", s.Address, s.Port)
if s.NetType == "tcp6" {
addr = fmt.Sprintf("[%s]:%d", s.Address, s.Port)
}
sshClient, err := ssh.Dial(s.NetType, addr, &config)
if err != nil {
return err
}
s.sshClient = sshClient
//使用sshClient构建sftpClient
var sftpClient *sftp.Client
if sftpClient, err = sftp.NewClient(sshClient); err != nil {
slog.Error("create sftp sshClient error:", "err_msg", err)
}
s.sftpClient = sftpClient
sshSession, err := s.sshClient.NewSession()
if err != nil {
slog.Error("sshClient.NewSession error:", "err_msg", err.Error())
return err
}
s.sshSession = sshSession
return nil
}
// RunTerminal 运行一个终端
func (s *SshConn) RunTerminal(shell string, stdout, stderr io.Writer, stdin io.Reader, w, h int, ws *websocket.Conn) error {
defer func() {
DeleteOnlineClient(s.SessionId)
if err := recover(); err != nil {
slog.Error("RunTerminal error:", "err_msg", err)
}
}()
var err error
s.ws = ws
s.sshSession.Stdout = stdout
s.sshSession.Stderr = stderr
s.sshSession.Stdin = stdin
modes := ssh.TerminalModes{}
if err := s.sshSession.RequestPty(s.PtyType, h, w, modes); err != nil {
slog.Error("sshSession.RequestPty error:", "err_msg", err.Error())
ws.WriteMessage(websocket.BinaryMessage, []byte("sshSession.RequestPty error:"+err.Error()))
return err
}
if err = s.sshSession.Shell(); err != nil {
slog.Error("sshSession.Shell error:", "err_msg", err.Error())
return err
}
if err = s.sshSession.Wait(); err != nil {
if strings.Contains(err.Error(), "remote command exited without exit status or exit signal") {
slog.Info("sshSession.Wait remote command exited without exit status or exit signal")
return err
}
slog.Error("sshSession.Wait error:", "err_msg", err.Error())
return err
}
return nil
}
// ResizeWindow 调整终端大小
func (s *SshConn) ResizeWindow(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
slog.Error("ResizeWindow recover error:", "err_msg", err)
}
}()
w, err := strconv.Atoi(c.Query("w"))
if err != nil || (w < 40 || w > 8192) {
c.JSON(200, gin.H{
"code": 1,
"msg": fmt.Sprintf("connect error window width !!!")})
return
}
h, err := strconv.Atoi(c.Query("h"))
if err != nil || (h < 2 || h > 4096) {
c.JSON(200, gin.H{
"code": 1,
"msg": fmt.Sprintf("connect error window width !!!")})
return
}
sessionId := c.Query("session_id")
cli, ok := OnlineClients.Load(sessionId)
if !ok || cli == nil {
c.JSON(200, gin.H{"code": 1, "msg": "the client is disconnected"})
return
}
conn, ok := cli.(*SshConn)
if !ok || conn == nil {
DeleteOnlineClient(sessionId)
c.JSON(200, gin.H{"code": 1, "msg": "to type SshConn error"})
return
}
err = conn.sshSession.WindowChange(h, w)
if err != nil {
DeleteOnlineClient(sessionId)
slog.Error("sshSession.WindowChange error:", "err_msg", err.Error())
}
str := fmt.Sprintf("W:%d;H:%d\n", w, h)
c.JSON(200, gin.H{"code": 0, "data": str, "msg": "ok"})
return
}
func ResizeWindow(c *gin.Context) {
var sshObj = SshConn{}
sshObj.ResizeWindow(c)
}
func NewSshConn(c *gin.Context) {
// 设置 websocket upgrader
upgrader := websocket.Upgrader{
ReadBufferSize: 4096,
WriteBufferSize: 4096,
CheckOrigin: func(r *http.Request) bool {
return true
},
}
// 升级 HTTP 连接为 WebSocket 连接
ws, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
slog.Error("Failed to upgrade connection:", "err_msg", err)
return
}
defer ws.Close()
// 设置 ping handler
ws.SetPingHandler(func(appData string) error {
return ws.WriteControl(websocket.PongMessage, []byte{}, time.Now().Add(time.Second*10))
})
// 启动 ping 保活
go func() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := ws.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(time.Second*10)); err != nil {
slog.Error("ping error:", "err_msg", err)
return
}
}
}
}()
sessionId := c.Query("session_id")
defer DeleteOnlineClient(sessionId)
w, err := strconv.Atoi(c.Query("w"))
if err != nil || (w < 40 || w > 8192) {
ws.WriteMessage(websocket.BinaryMessage, []byte("connect error window width !!!"))
DeleteOnlineClient(sessionId)
return
}
h, err := strconv.Atoi(c.Query("h"))
if err != nil || (h < 2 || h > 4096) {
ws.WriteMessage(websocket.BinaryMessage, []byte("connect error window height !!!"))
DeleteOnlineClient(sessionId)
return
}
cli, ok := OnlineClients.Load(sessionId)
if !ok || cli == nil {
ws.WriteMessage(websocket.BinaryMessage, []byte("session_id not exists !!!"))
DeleteOnlineClient(sessionId)
return
}
conn, ok := cli.(*SshConn)
if !ok || conn == nil {
ws.WriteMessage(websocket.BinaryMessage, []byte("to ssh.Session error !!!"))
DeleteOnlineClient(sessionId)
return
}
// 创建一个适配器来实现 io.Reader 和 io.Writer 接口
wsReadWriter := &WebSocketReadWriter{ws: ws}
err = conn.RunTerminal(conn.Shell, wsReadWriter, wsReadWriter, wsReadWriter, w, h, ws)
if err != nil {
ws.WriteMessage(websocket.BinaryMessage, []byte("connect error:"+err.Error()))
DeleteOnlineClient(sessionId)
return
}
}
// WebSocketReadWriter 实现 io.Reader 和 io.Writer 接口
type WebSocketReadWriter struct {
ws *websocket.Conn
}
func (w *WebSocketReadWriter) Read(p []byte) (n int, err error) {
_, message, err := w.ws.ReadMessage()
if err != nil {
return 0, err
}
copy(p, message)
return len(message), nil
}
func (w *WebSocketReadWriter) Write(p []byte) (n int, err error) {
err = w.ws.WriteMessage(websocket.BinaryMessage, p)
if err != nil {
return 0, err
}
return len(p), nil
}
func CreateSessionId(c *gin.Context) {
var conn SshConn
if err := c.ShouldBind(&conn); err != nil {
c.JSON(200, gin.H{"code": 1, "msg": err.Error()})
return
}
// 如果客户提供,使用客户的,并从会话列表删除,防止重复存在
sessionId := c.Query("session_id")
//DeleteOnlineClient(sessionId)
if sessionId == "" {
sessionId = utils.RandString(15)
}
conn.SessionId = sessionId
conn.LastActiveTime = time.Now()
conn.StartTime = time.Now()
err := conn.connect(c.RemoteIP())
if err != nil {
c.JSON(200, gin.H{"code": 1, "msg": "CreateSessionId error:" + err.Error()})
return
}
OnlineClients.Store(sessionId, &conn)
c.JSON(200, gin.H{"code": 0, "data": sessionId, "msg": "ok"})
}
func Disconnect(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
c.JSON(200, gin.H{
"code": 1,
"msg": "delete connect error",
})
}
}()
sessionId := c.Query("session_id")
if sessionId == "" {
c.JSON(200, gin.H{
"code": 1,
"msg": "session not exists",
})
return
}
DeleteOnlineClient(sessionId)
c.JSON(200, gin.H{
"code": 0,
"msg": "delete connect success",
})
}
func ExecCommand(c *gin.Context) {
type Param struct {
SessionId string `form:"session_id" binding:"required,min=10" json:"session_id"`
Cmd string `form:"cmd" binding:"required,min=1" json:"cmd"`
}
var param Param
if err := c.ShouldBind(¶m); err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
cli, ok := OnlineClients.Load(param.SessionId)
if !ok || cli == nil {
c.JSON(200, gin.H{"code": 3, "msg": "session not exists"})
return
}
conn, ok := cli.(*SshConn)
if !ok || conn == nil {
c.JSON(200, gin.H{"code": 4, "msg": "conn not exists"})
return
}
//创建ssh-session
session, err := conn.sshClient.NewSession()
if err != nil {
c.JSON(200, gin.H{"code": 5, "msg": "create session error"})
return
}
defer func(session *ssh.Session) {
_ = session.Close()
}(session)
//执行命令
out, err := session.CombinedOutput(param.Cmd)
if err != nil {
c.JSON(200, gin.H{"code": 6, "msg": "exec cmd error", "data": string(out)})
return
}
c.JSON(200, gin.H{
"code": 0,
"msg": "ok",
"data": string(out),
})
}
================================================
FILE: gossh/app/service/ssh_sftp.go
================================================
package service
import (
"errors"
"fmt"
"gossh/gin"
"io"
"log/slog"
"net/http"
"net/url"
"path"
"strconv"
"strings"
)
func getSshConn(sessionId string) (*SshConn, error) {
cli, ok := OnlineClients.Load(sessionId)
if !ok {
slog.Error("加载ssh连接错误")
return nil, errors.New("加载ssh连接错误")
}
conn, ok := cli.(*SshConn)
if conn != nil && !ok {
slog.Error("断言ssh连接错误")
return nil, errors.New("断言ssh连接错误")
}
return conn, nil
}
// SftpList GET sftp 获取指定目录下文件信息
func SftpList(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
c.JSON(200, gin.H{"code": 4, "msg": "读取目录错误"})
return
}
}()
dirPath := c.PostForm("path")
conn, err := getSshConn(c.PostForm("session_id"))
if err != nil {
slog.Error(err.Error())
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
files, err := conn.sftpClient.ReadDir(dirPath)
if err != nil {
slog.Error("sftp客户端ReadDir错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 3, "msg": "sftp客户端读取目录错误"})
return
}
fileCount := 0
dirCount := 0
var fileList []any
for _, file := range files {
fileInfo := map[string]any{}
fileInfo["path"] = path.Join(dirPath, file.Name())
fileInfo["name"] = file.Name()
fileInfo["mode"] = file.Mode().String()
fileInfo["size"] = file.Size()
fileInfo["mod_time"] = file.ModTime().Format("2006-01-02 15:04:05")
if file.IsDir() {
fileInfo["type"] = "d"
dirCount += 1
} else {
fileInfo["type"] = "f"
fileCount += 1
}
fileList = append(fileList, fileInfo)
}
// 内部方法,处理路径信息
pathHandler := func(dirPath string) (paths []map[string]string) {
tmp := strings.Split(dirPath, "/")
var dirs []string
if strings.HasPrefix(dirPath, "/") {
dirs = append(dirs, "/")
}
for _, item := range tmp {
name := strings.TrimSpace(item)
if len(name) > 0 {
dirs = append(dirs, name)
}
}
for i, item := range dirs {
fullPath := path.Join(dirs[:i+1]...)
pathInfo := map[string]string{}
pathInfo["name"] = item
pathInfo["dir"] = fullPath
paths = append(paths, pathInfo)
}
return paths
}
data := map[string]any{
"files": fileList,
"file_count": fileCount,
"dir_count": dirCount,
"paths": pathHandler(dirPath),
"current_dir": dirPath,
}
c.JSON(200, gin.H{"code": 0, "data": data, "msg": "ok"})
}
// SftpDownLoad POST sftp 下载文件
func SftpDownLoad(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
c.JSON(200, gin.H{"code": 1, "msg": "下载错误"})
return
}
}()
fullPath, err := url.QueryUnescape(c.Query("path"))
if err != nil {
slog.Error("获取文件路径参数错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 2, "msg": "获取文件路径参数错误"})
return
}
conn, err := getSshConn(c.Query("session_id"))
if err != nil {
slog.Error(err.Error())
c.JSON(200, gin.H{"code": 3, "msg": err.Error()})
return
}
file, err := conn.sftpClient.Open(fullPath)
defer func() {
_ = file.Close()
}()
if err != nil {
slog.Error("sftpClient.Openc错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 4, "msg": "sftp打开文件错误"})
return
}
stat, err := file.Stat()
if err != nil {
slog.Error("file.Stat()错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 4, "msg": "读取文件信息错误"})
return
}
c.Writer.WriteHeader(http.StatusOK)
c.Header("Content-Disposition", "attachment; filename="+stat.Name())
c.Header("Content-Type", "application/octet-stream")
//c.Header("Content-Type", "application/x-download")
c.Header("Content-Length", fmt.Sprintf("%d", stat.Size()))
_, err = file.WriteTo(c.Writer)
if err != nil {
slog.Error("file.WriteTo错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 5, "msg": "下载文件错误"})
return
}
c.Writer.Flush()
}
// SftpUpload PUT sftp 上传文件
func SftpUpload(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
c.JSON(200, gin.H{"code": 4, "msg": "上传错误"})
return
}
}()
dstPath := c.PostForm("path")
//获取上传的文件组
form, err := c.MultipartForm()
if err != nil {
slog.Error("获取form数据错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 1, "msg": "获取form数据错误"})
return
}
files := form.File["files"]
// files := c.Request.MultipartForm.File["file"]
conn, err := getSshConn(c.PostForm("session_id"))
if err != nil {
slog.Error(err.Error())
c.JSON(200, gin.H{"code": 2, "msg": err.Error()})
return
}
var ret []string
for _, file := range files {
srcFile, err := file.Open()
if err != nil {
continue
}
fileName := file.Filename
dstFile, err := conn.sftpClient.Create(path.Join(dstPath, fileName))
if err != nil {
continue
}
_, err = io.Copy(dstFile, srcFile)
if err != nil {
continue
}
_ = srcFile.Close()
_ = dstFile.Close()
ret = append(ret, fileName)
}
msg := strconv.Itoa(len(ret)) + " 个文件上传成功"
c.JSON(200, gin.H{"code": 0, "msg": msg, "data": ret})
}
// SftpDelete DELETE sftp 删除文件或目录
func SftpDelete(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
c.JSON(200, gin.H{"code": 4, "msg": "删除错误"})
return
}
}()
type Body struct {
SessionId string `form:"session_id" binding:"required,min=1,max=128" json:"session_id"`
Path string `form:"path" binding:"required,min=1,max=1024" json:"path"`
}
var body Body
if err := c.ShouldBind(&body); err != nil {
slog.Error("绑定数据错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 1, "msg": "输入数据不合法"})
return
}
conn, err := getSshConn(body.SessionId)
if err != nil {
slog.Error(err.Error())
c.JSON(200, gin.H{"code": 1, "msg": err.Error()})
return
}
err = conn.sftpClient.RemoveAll(body.Path)
if err != nil {
slog.Error("sftpClient.Remove错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 2, "msg": "删除文件错误"})
return
}
c.JSON(200, gin.H{"code": 0, "msg": "删除成功"})
}
// SftpCreateDir sftp 创建目录
func SftpCreateDir(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
c.JSON(200, gin.H{"code": 4, "msg": "创建目录错误"})
return
}
}()
type Body struct {
SessionId string `form:"session_id" binding:"required,min=1,max=128" json:"session_id"`
Path string `form:"path" binding:"required,min=1,max=1024" json:"path"`
}
var body Body
if err := c.ShouldBind(&body); err != nil {
slog.Error("绑定数据错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 1, "msg": "输入数据不合法"})
return
}
conn, err := getSshConn(body.SessionId)
if err != nil {
slog.Error(err.Error())
c.JSON(200, gin.H{"code": 1, "msg": err.Error()})
return
}
err = conn.sftpClient.MkdirAll(body.Path)
if err != nil {
slog.Error("sftpClient.MkdirAll错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 2, "msg": "创建目录错误"})
return
}
c.JSON(200, gin.H{"code": 0, "msg": "创建目录成功"})
}
================================================
FILE: gossh/app/service/ssh_status.go
================================================
package service
import (
"gossh/app/config"
"gossh/gin"
"gossh/gin/sse"
"net/http"
"sort"
"time"
)
type SshConnById []SshConn
func (a SshConnById) Len() int { return len(a) }
func (a SshConnById) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a SshConnById) Less(i, j int) bool { return a[i].SessionId < a[j].SessionId }
func GetOnlineClient(c *gin.Context) {
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Connection", "keep-alive")
c.Header("Cache-Control", "no-cache")
c.Header("Content-Type", "text/event-stream")
for {
c.Writer.(http.Flusher).Flush()
time.Sleep(config.DefaultConfig.StatusRefresh)
select {
case <-c.Request.Context().Done():
return
default:
var data SshConnById
OnlineClients.Range(func(key, value any) bool {
conn, ok := value.(*SshConn)
if ok && conn != nil {
data = append(data, *conn)
}
return ok
})
sort.Sort(data)
c.Render(200, sse.Event{
Id: "200",
Event: "message",
Retry: 10000,
Data: map[string]any{
"code": 0,
"data": data,
"msg": "ok",
},
})
}
}
}
func RefreshConnTime(c *gin.Context) {
ids := c.PostFormArray("ids")
for _, key := range ids {
cli, ok := OnlineClients.Load(key)
if !ok {
continue
}
conn, ok := cli.(*SshConn)
if conn != nil && ok {
conn.LastActiveTime = time.Now()
}
}
c.JSON(200, gin.H{
"code": 0,
"data": ids,
"msg": "ok",
})
}
================================================
FILE: gossh/app/service/sshd_cert.go
================================================
package service
import (
"gossh/app/model"
"gossh/gin"
"log/slog"
"strconv"
)
func SshdCertCreate(c *gin.Context) {
var sshdCert model.SshdCert
if err := c.ShouldBind(&sshdCert); err != nil {
slog.Error("SshdCertCreate 绑定数据错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 1, "msg": "输入数据不合法"})
return
}
err := sshdCert.Create(&sshdCert)
if err != nil {
slog.Error("创建SSHD证书错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 3, "msg": "创建SSHD证书错误"})
return
}
SshdCertFindAll(c)
}
func CheckSshdCertNameExists(c *gin.Context) {
type Name struct {
Name string `form:"name" binding:"required,min=1,max=128" json:"name"`
}
var name Name
if err := c.ShouldBind(&name); err != nil {
slog.Error("绑定数据错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 1, "msg": "输入数据不合法"})
return
}
var sshdCert model.SshdCert
tmp, err := sshdCert.FindByName(name.Name)
if err != nil {
slog.Error("FindByName错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 3, "msg": "获取SSHD证书信息错误"})
return
}
if tmp.ID != 0 {
c.JSON(200, gin.H{"code": 4, "msg": "名称已经存存在"})
return
}
c.JSON(200, gin.H{"code": 0, "msg": "ok"})
return
}
func SshdCertFindByID(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
slog.Error("获取ID错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 1, "msg": "获取ID错误"})
return
}
var sshdCert model.SshdCert
data, err := sshdCert.FindByID(uint(id))
if err != nil {
slog.Error("FindByID错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 3, "msg": "获取SSHD证书信息错误"})
return
}
c.JSON(200, gin.H{"code": 0, "msg": "ok", "data": data})
}
func SshdCertFindAll(c *gin.Context) {
limit, err := strconv.Atoi(c.DefaultQuery("limit", "10000"))
if err != nil {
slog.Error("获取limit错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 1, "msg": "获取limit错误"})
return
}
offset, err := strconv.Atoi(c.DefaultQuery("offset", "0"))
if err != nil {
slog.Error("获取offset错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 2, "msg": "获取offset错误"})
return
}
var sshdCert model.SshdCert
data, err := sshdCert.FindAll(limit, offset)
if err != nil {
slog.Error("sshdCert.FindAll错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 4, "msg": "获取SSHD证书信息错误"})
return
}
c.JSON(200, gin.H{"code": 0, "msg": "ok", "data": data})
}
func SshdCertUpdateById(c *gin.Context) {
var sshdCert model.SshdCert
if err := c.ShouldBind(&sshdCert); err != nil {
slog.Error("获取ID错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 1, "msg": "获取ID错误"})
return
}
err := sshdCert.UpdateById(sshdCert.ID, &sshdCert)
if err != nil {
slog.Error("UpdateById错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 5, "msg": "更新SSHD证书错误"})
return
}
SshdCertFindAll(c)
}
func SshdCertDeleteById(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
slog.Error("获取ID错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 1, "msg": "获取ID错误"})
return
}
var sshdCert model.SshdCert
err = sshdCert.DeleteByID(uint(id))
if err != nil {
slog.Error("sshdCert.DeleteByID错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 5, "msg": "删除SSHD证书错误"})
return
}
SshdCertFindAll(c)
}
func GetSshdCertAuthorizedKeys(c *gin.Context) {
var sshdCert model.SshdCert
text, err := sshdCert.GetAuthorizedKeys()
if err != nil {
slog.Error("GetAuthorizedKeys failed", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 5, "msg": "获取SSHD证书错误", "err_msg": err})
return
}
c.JSON(200, gin.H{"code": 5, "msg": "ok", "data": text})
}
================================================
FILE: gossh/app/service/sshd_server.go
================================================
package service
import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/sha512"
"crypto/x509"
"encoding/base64"
"encoding/binary"
"encoding/pem"
"errors"
"fmt"
"gossh/app/model"
"gossh/crypto/ssh"
"gossh/pty"
"gossh/sftp"
"io"
"log"
"log/slog"
"net"
"os"
"os/exec"
"os/user"
"path/filepath"
"runtime"
"strings"
"sync"
"time"
)
type Server struct {
cli *model.SshdConf
config *ssh.ServerConfig
}
func NewServer(c *model.SshdConf) (*Server, error) {
s := &Server{cli: c}
sc, err := s.computeSSHConfig()
if err != nil {
return nil, err
}
s.config = sc
return s, nil
}
func (s *Server) Start() error {
listen, err := net.Listen("tcp", fmt.Sprintf("%s:%d", s.cli.Host, s.cli.Port))
if err != nil {
return fmt.Errorf("failed to listen on" + fmt.Sprintf("%s:%d", s.cli.Host, s.cli.Port))
}
slog.Info("Listening on", "host", s.cli.Host, "port", s.cli.Port)
for {
tcpConn, err := listen.Accept()
if err != nil {
slog.Error("Failed to accept incoming connection", "err", err)
continue
}
go s.handleConn(tcpConn)
}
}
func (s *Server) handleConn(tcpConn net.Conn) {
defer func() {
if err := recover(); err != nil {
slog.Error("recovered from panic in handleConn", "err", err)
}
}()
sshConn, chans, reqs, err := ssh.NewServerConn(tcpConn, s.config)
if err != nil {
if err != io.EOF {
slog.Info("Failed to handshake", "err", err)
}
return
}
slog.Info("New SSH connection from", "address", sshConn.RemoteAddr(), "client_version", sshConn.ClientVersion())
go ssh.DiscardRequests(reqs)
go s.handleChannels(chans)
}
func (s *Server) handleChannels(chans <-chan ssh.NewChannel) {
// Service the incoming Channel channel in go routine
for newChannel := range chans {
go s.handleChannel(newChannel)
}
}
func (s *Server) handleChannel(newChannel ssh.NewChannel) {
if t := newChannel.ChannelType(); t != "session" {
_ = newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t))
return
}
slog.Info("Channel request", "type", newChannel.ChannelType())
if d := newChannel.ExtraData(); len(d) > 0 {
slog.Info("Channel data", "data", d)
}
connection, requests, err := newChannel.Accept()
if err != nil {
slog.Error("Could not accept channel failed", "err", err)
return
}
log.Println("Channel accepted")
go s.handleRequests(connection, requests)
}
func (s *Server) handleRequests(connection ssh.Channel, requests <-chan *ssh.Request) {
defer func() {
if err := recover(); err != nil {
slog.Error("handleRequests recovered from panic", "err", err)
}
}()
// start keep alive loop
if ka := s.cli.KeepAlive; ka > 0 {
ticking := make(chan bool, 1)
interval := time.Duration(ka) * time.Second
go s.keepAlive(connection, interval, ticking)
defer close(ticking)
}
// prepare to handle client requests
env := os.Environ()
resizes := make(chan []byte, 10)
defer close(resizes)
for req := range requests {
switch req.Type {
case "subsystem":
if string(req.Payload[4:]) == "sftp" {
req.Reply(true, nil)
server, err := sftp.NewServer(connection, sftp.WithDebug(os.Stderr))
if err != nil {
slog.Error("create SFTP server error", "err", err)
connection.Close()
return
}
err = server.Serve()
if err == io.EOF {
server.Close()
connection.Close()
return
} else if err != nil {
slog.Error("SFTP server error", "err", err)
connection.Close()
return
} else {
connection.Close()
return
}
}
case "pty-req":
termLen := req.Payload[3]
resizes <- req.Payload[termLen+4:]
req.Reply(true, nil)
case "window-change":
resizes <- req.Payload
case "env":
e := struct{ Name, Value string }{}
_ = ssh.Unmarshal(req.Payload, &e)
kv := e.Name + "=" + e.Value
slog.Info("env", "data", kv)
if s.cli.LoadEnv == "Y" {
env = appendEnv(env, kv)
}
case "shell":
if len(req.Payload) > 0 {
slog.Info("shell command ignored", "payload", req.Payload)
}
slog.Info("sshd attachShell")
err := s.attachShell(connection, env, resizes)
if err != nil {
slog.Error("exec shell", "err", err)
}
req.Reply(err == nil, []byte{})
case "exec":
if req.WantReply {
req.Reply(true, nil)
}
command := string(req.Payload[4:])
slog.Info("exec command string", "cmd", command)
ret := ExecuteForChannel(command, connection)
slog.Info("exec return", "code", ret)
connection.Close()
return
default:
slog.Info("unknown request", "type", req.Type, "want_reply", req.WantReply, "payload", req.Payload)
}
}
}
func (s *Server) keepAlive(connection ssh.Channel, interval time.Duration, ticking <-chan bool) {
defer func() {
if err := recover(); err != nil {
slog.Error("keepAlive recovered from panic", "err", err)
}
}()
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
_, err := connection.SendRequest("ping", false, nil)
if err != nil {
slog.Error("failed to send keep alive ping", "err", err)
}
slog.Info("sent keep alive ping")
case <-ticking:
return
}
}
}
func (s *Server) attachShell(connection ssh.Channel, env []string, resizes <-chan []byte) error {
defer func() {
if err := recover(); err != nil {
fmt.Printf("recovered from panic in handleRequests: %s", err)
}
}()
shell := exec.Command(s.cli.Shell)
dir, err := os.UserHomeDir()
if err == nil {
shell.Dir = dir
}
shell.Env = env
slog.Info("Session env", "data", env)
closeFunc := func() {
err := connection.Close()
if err != nil {
slog.Error("close connect failed", "err_msg", err)
return
}
slog.Info("Session closed")
}
shellPty, err := pty.Start(shell)
if err != nil {
closeFunc()
return fmt.Errorf("could not start pty (%s)\n", err)
}
//dequeue resizes
go func() {
for payload := range resizes {
w, h := parseDims(payload)
SetWindowSize(shellPty, w, h)
}
}()
//pipe session to shell and visa-versa
var once sync.Once
go func() {
_, _ = io.Copy(connection, shellPty)
once.Do(closeFunc)
}()
go func() {
_, _ = io.Copy(shellPty, connection)
once.Do(closeFunc)
}()
slog.Info("shell attached")
go func() {
if shell.Process != nil {
if ps, err := shell.Process.Wait(); err != nil && ps != nil {
slog.Error("Failed to exit shell", "err", err)
}
// shellPty.Close()
}
slog.Info("Shell terminated and Session closed")
}()
return nil
}
func (s *Server) loadCertAuthFile() (map[string]string, error) {
var tmp model.SshdCert
authorizedKeys, err := tmp.GetAuthorizedKeys()
if err != nil {
slog.Error("GetAuthorizedKeys failed", "err_msg", err.Error())
return nil, err
}
keys, err := parseKeys([]byte(authorizedKeys))
if err != nil {
slog.Error("parseKeys failed", "err_msg", err.Error())
return nil, err
}
return keys, nil
}
func (s *Server) computeSSHConfig() (*ssh.ServerConfig, error) {
sc := &ssh.ServerConfig{}
if s.cli.Shell == "" {
if runtime.GOOS == "windows" {
s.cli.Shell = "powershell"
} else {
s.cli.Shell = "sh"
}
}
p, err := exec.LookPath(s.cli.Shell)
if err != nil {
return nil, fmt.Errorf("find shell failed: %s\n", s.cli.Shell)
}
s.cli.Shell = p
slog.Info("Session shell", "shell", s.cli.Shell)
//user provided key (can generate with 'ssh-keygen -t rsa')
pri, err := ssh.ParsePrivateKey([]byte(s.cli.KeyFile))
if err != nil {
return nil, fmt.Errorf("failed to parse private key\n")
}
sc.AddHostKey(pri)
slog.Info("RSA key", "fingerprint", fingerprint(pri.PublicKey()))
// 注册账号密码认证逻辑(回调函数)
sc.PasswordCallback = func(conn ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
var sshdUsers model.SshdUser
u, err := sshdUsers.FindByNameAndPwd(conn.User(), string(pass))
if err != nil {
slog.Error("Authentication failed", "username", conn.User(), "err_msg", err)
return nil, err
}
if conn.User() == u.Name && string(pass) == u.Pwd {
slog.Info("Authentication failed with password", "username", conn.User())
return nil, nil
}
slog.Info("Authentication failed", "username", conn.User(), "password", u.Pwd)
return nil, fmt.Errorf("auth denied\n")
}
// 注册证书认证逻辑(回调函数)
sc.PublicKeyCallback = func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
ks, err := s.loadCertAuthFile()
if err != nil {
return nil, err
}
slog.Info("Authentication enabled public keys", "size", len(ks))
return nil, s.matchKeys(key, ks)
}
// 注册BannerCallback(回调函数)
sc.BannerCallback = func(conn ssh.ConnMetadata) string {
// 写入自定义的横幅信息到客户端
return fmt.Sprintf("Welcome to the Go_SSH_Server\n=>user:(%s)\n=>remote addr:(%s)\n=>local addr:(%s)\n=>client ver:(%s)\n=>server ver:(%s)\n\n",
conn.User(), conn.RemoteAddr().String(), conn.LocalAddr(),
conn.ClientVersion(), conn.ServerVersion())
}
sc.ServerVersion = "SSH-2.0-OpenSSH"
cf, err := s.cli.FindByID(1)
if err == nil {
sc.ServerVersion = cf.ServerVersion
}
return sc, nil
}
func (s *Server) matchKeys(key ssh.PublicKey, keys map[string]string) error {
if cmt, exists := keys[string(key.Marshal())]; exists {
slog.Info("User authenticated with public key", "username", cmt, "public_key", fingerprint(key))
return nil
}
slog.Error("User authentication failed with public key", "err", fingerprint(key))
return fmt.Errorf("denied\n")
}
func appendEnv(env []string, kv string) []string {
p := strings.SplitN(kv, "=", 2)
k := p[0] + "="
for i, e := range env {
if strings.HasPrefix(e, k) {
env[i] = kv
return env
}
}
return append(env, kv)
}
// parseDims extracts terminal dimensions (width x height) from the provided buffer.
func parseDims(b []byte) (uint32, uint32) {
w := binary.BigEndian.Uint32(b)
h := binary.BigEndian.Uint32(b[4:])
return w, h
}
// SetWindowSize sets the size of the given pty.
func SetWindowSize(t pty.Pty, w, h uint32) {
ws := &pty.Winsize{Rows: uint16(h), Cols: uint16(w)}
_ = pty.Setsize(t, ws)
}
func generateKey(seed string) ([]byte, error) {
var r io.Reader
if seed == "" {
r = rand.Reader
} else {
r = newDetermRand([]byte(seed))
}
privateKey, err := rsa.GenerateKey(r, 2048)
if err != nil {
return nil, err
}
err = privateKey.Validate()
if err != nil {
return nil, err
}
b := x509.MarshalPKCS1PrivateKey(privateKey)
return pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: b}), nil
}
func parseKeys(b []byte) (map[string]string, error) {
lines := bytes.Split(b, []byte("\n"))
//parse each line
keys := map[string]string{}
for _, l := range lines {
if key, cmt, _, _, err := ssh.ParseAuthorizedKey(l); err == nil {
keys[string(key.Marshal())] = cmt
}
}
//ensure we got something
if len(keys) == 0 {
return nil, fmt.Errorf("no keys found\n")
}
return keys, nil
}
func fingerprint(k ssh.PublicKey) string {
bytesData := sha256.Sum256(k.Marshal())
b64 := base64.StdEncoding.EncodeToString(bytesData[:])
if strings.HasSuffix(b64, "=") {
b64 = strings.TrimSuffix(b64, "=") + "."
}
return "SHA256:" + b64
}
func newDetermRand(seed []byte) io.Reader {
const randSize = 2048
var out []byte
//strengthen seed
var next = seed
for i := 0; i < randSize; i++ {
next, out = hash(next)
}
return &determRand{
next: next,
out: out,
}
}
type determRand struct {
next, out []byte
}
func (d *determRand) Read(b []byte) (int, error) {
l := len(b)
//HACK: combat https://golang.org/src/crypto/rsa/rsa.go#L257
if l == 1 {
return 1, nil
}
n := 0
for n < l {
next, out := hash(d.next)
n += copy(b[n:], out)
d.next = next
}
return n, nil
}
func hash(input []byte) (next []byte, output []byte) {
nextout := sha512.Sum512(input)
return nextout[:sha512.Size/2], nextout[sha512.Size/2:]
}
// ExecuteForChannel Execute a process for the channel.
func ExecuteForChannel(shellCmd string, ch ssh.Channel) uint32 {
var err error
// Windows 下特殊处理命令
var proc *exec.Cmd
if runtime.GOOS == "windows" {
// 使用 Command 创建命令,并正确处理参数
proc = exec.Command("powershell", "-NoProfile", "-NonInteractive", "-Command",
"[Console]::OutputEncoding = [System.Text.Encoding]::UTF8; $OutputEncoding = [System.Text.Encoding]::UTF8; chcp 65001 | Out-Null; "+shellCmd)
} else {
proc = exec.Command("sh", "-c", shellCmd)
}
proc.Env = append(os.Environ(), "LANG=zh_CN.UTF-8", "LC_ALL=zh_CN.UTF-8")
exe, err := os.Executable()
if err != nil {
slog.Error("get os.Executable error", "err", err)
}
slog.Info("os.Executable() exe:", "path", exe)
dir := filepath.Dir(exe)
slog.Info("exec dir info:", "path", dir)
proc.Dir = dir
if userInfo, err := user.Current(); err == nil {
proc.Dir = userInfo.HomeDir
}
stdin, err := proc.StdinPipe()
if err != nil {
slog.Error("create pipe failed", "err", err)
return 1
}
go func() {
defer stdin.Close()
_, _ = io.Copy(stdin, ch)
}()
proc.Stdout = ch
proc.Stderr = ch.Stderr()
// 执行命令
err = proc.Start()
if err != nil {
slog.Error("start command failed", "err", err)
_, _ = ch.Write([]byte(fmt.Sprintf("start command failed: %v\n", err)))
return 1
}
// 等待命令完成
err = proc.Wait()
if err != nil {
var exitErr *exec.ExitError
if errors.As(err, &exitErr) {
slog.Error("wait command failed", "exit code", exitErr.ExitCode())
return uint32(exitErr.ExitCode())
}
slog.Error("exec command failed", "err", err)
_, _ = ch.Write([]byte(fmt.Sprintf("exec command failed: %v\n", err)))
return 1
}
slog.Info("exec command success")
return 0
}
func startSshServer() {
status := <-isStartSshd
if status {
slog.Info("start sshd server")
}
var sshdConf model.SshdConf
conf, err := sshdConf.FindByID(1)
if err != nil {
slog.Error("find sshd server config failed", "err_msg", err)
return
}
slog.Info("find sshd server config", "conf", conf)
s, err := NewServer(&conf)
if err != nil {
slog.Error("NewServer failed", "err_msg", err)
return
}
slog.Info("start sshd server")
err = s.Start()
if err != nil {
slog.Error("sshd server start failed", "err_msg", err)
return
}
}
var isStartSshd = make(chan bool)
func InitSshServer() {
go startSshServer()
slog.Info("exec sshd init func")
}
================================================
FILE: gossh/app/service/sshd_user.go
================================================
package service
import (
"gossh/app/model"
"gossh/gin"
"log/slog"
"strconv"
)
func SshdUserCreate(c *gin.Context) {
var sshdUser model.SshdUser
if err := c.ShouldBind(&sshdUser); err != nil {
slog.Error("UserCreate 绑定数据错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 1, "msg": "输入数据不合法"})
return
}
err := sshdUser.Create(&sshdUser)
if err != nil {
slog.Error("创建用户错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 3, "msg": "创建用户错误"})
return
}
SshdUserFindAll(c)
}
func CheckSshdUserNameExists(c *gin.Context) {
type Name struct {
Name string `form:"name" binding:"required,min=1,max=128" json:"name"`
}
var name Name
if err := c.ShouldBind(&name); err != nil {
slog.Error("绑定数据错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 1, "msg": "输入数据不合法"})
return
}
var sshdUser model.SshdUser
tmp, err := sshdUser.FindByName(name.Name)
if err != nil {
slog.Error("FindByName错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 3, "msg": "获取用户信息错误"})
return
}
if tmp.ID != 0 {
c.JSON(200, gin.H{"code": 4, "msg": "用户名已经存存在"})
return
}
c.JSON(200, gin.H{"code": 0, "msg": "ok"})
return
}
func SshdUserFindByID(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
slog.Error("获取ID错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 1, "msg": "获取ID错误"})
return
}
var sshdUser model.SshdUser
data, err := sshdUser.FindByID(uint(id))
if err != nil {
slog.Error("FindByID错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 3, "msg": "获取用户信息错误"})
return
}
c.JSON(200, gin.H{"code": 0, "msg": "ok", "data": data})
}
func SshdUserFindAll(c *gin.Context) {
limit, err := strconv.Atoi(c.DefaultQuery("limit", "10000"))
if err != nil {
slog.Error("获取limit错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 1, "msg": "获取limit错误"})
return
}
offset, err := strconv.Atoi(c.DefaultQuery("offset", "0"))
if err != nil {
slog.Error("获取offset错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 2, "msg": "获取offset错误"})
return
}
var sshdUser model.SshdUser
data, err := sshdUser.FindAll(limit, offset)
if err != nil {
slog.Error("user.FindAl错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 4, "msg": "获取用户信息错误"})
return
}
c.JSON(200, gin.H{"code": 0, "msg": "ok", "data": data})
}
func SshdUserUpdateById(c *gin.Context) {
var sshdUser model.SshdUser
if err := c.ShouldBind(&sshdUser); err != nil {
slog.Error("获取ID错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 1, "msg": "获取ID错误"})
return
}
err := sshdUser.UpdateById(sshdUser.ID, &sshdUser)
if err != nil {
slog.Error("UpdateById错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 5, "msg": "更新用户错误"})
return
}
SshdUserFindAll(c)
}
func SshdUserDeleteById(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
slog.Error("获取ID错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 1, "msg": "获取ID错误"})
return
}
var sshdUser model.SshdUser
err = sshdUser.DeleteByID(uint(id))
if err != nil {
slog.Error("user.DeleteByID错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 5, "msg": "删除用户错误"})
return
}
SshdUserFindAll(c)
}
================================================
FILE: gossh/app/service/sys_init.go
================================================
package service
import (
"gossh/app/config"
"gossh/app/model"
"gossh/gin"
"log/slog"
"os/exec"
"runtime"
)
func GetRunConf(c *gin.Context) {
c.JSON(200, gin.H{"code": 0, "msg": "ok", "data": config.DefaultConfig})
}
func SetRunConf(c *gin.Context) {
if config.DefaultConfig.IsInit {
c.JSON(200, gin.H{"code": 1, "msg": "已经初始化"})
return
}
var appConfig config.AppConfig
if err := c.ShouldBind(&appConfig); err != nil {
c.JSON(200, gin.H{"code": 1, "msg": err.Error()})
return
}
err := config.RewriteConfig(appConfig)
if err != nil {
c.JSON(200, gin.H{"code": 1, "msg": err.Error()})
return
}
c.JSON(200, gin.H{"code": 0, "msg": "ok", "data": config.DefaultConfig})
}
func GetIsInit(c *gin.Context) {
c.JSON(200, gin.H{
"code": 0, "msg": "ok", "data": map[string]any{
"is_init": config.DefaultConfig.IsInit,
},
})
}
type InitConfig struct {
DbConnConf
JwtSecret string `json:"jwt_secret" binding:"required,min=1,max=128"`
SessionSecret string `json:"session_secret" binding:"required,min=1,max=128"`
Username string `json:"username" binding:"required,min=1,max=63"`
Password string `json:"password" binding:"required,min=1,max=63"`
SshdHost string `json:"sshd_host" binding:"required,min=1,max=127"`
SshdPort uint16 `json:"sshd_port" binding:"required,gte=1,lte=65535"`
SshdUser string `json:"sshd_user" binding:"required,min=1,max=63"`
SshdPwd string `json:"sshd_pwd" binding:"required,min=1,max=63"`
}
func SysInit(c *gin.Context) {
var initConf InitConfig
if err := c.ShouldBind(&initConf); err != nil {
c.JSON(200, gin.H{"code": 1, "msg": err.Error()})
return
}
// 1.检查系统是否已经初始化
if config.DefaultConfig.IsInit {
c.JSON(200, gin.H{"code": 1, "msg": "系统已经初始化"})
return
}
// 2.数据库连接检查
dbConf := DbConnConf{
DbDsn: initConf.DbDsn,
DbType: initConf.DbType,
}
err := DbConnTestCheck(dbConf)
if err != nil {
c.JSON(200, gin.H{"code": 1, "msg": err.Error()})
return
}
// 3.数据库表迁移
err = model.DbMigrate(initConf.DbType, initConf.DbDsn)
if err != nil {
c.JSON(200, gin.H{"code": 1, "msg": err.Error(), "data": "执行数据库迁移错误"})
return
}
// root 用户过期时间
dateTime, _ := model.NewDateTime("2099-12-31 00:00:00")
// 4.创建初始化用户
var user = model.WebUser{
ID: 0,
Name: initConf.Username,
Pwd: initConf.Password,
DescInfo: "管理员",
IsAdmin: "Y",
IsEnable: "Y",
IsRoot: "Y",
ExpiryAt: dateTime,
}
err = user.Create(&user)
if err != nil {
c.JSON(200, gin.H{"code": 3, "msg": err.Error(), "data": "初始化创建账号错误"})
return
}
// 5.设置默认网络策略
var policyConf = model.PolicyConf{
NetPolicy: "N",
}
err = policyConf.Create(&policyConf)
if err != nil {
c.JSON(200, gin.H{"code": 3, "msg": err.Error(), "data": "初始化网络策略错误"})
return
}
privateKey, err := generateKey("webssh")
if err != nil {
slog.Error("failed to generate private key")
c.JSON(200, gin.H{"code": 3, "msg": err.Error(), "data": "初始化生成SSHD私钥错误"})
return
}
var shell = "sh"
if runtime.GOOS == "windows" {
shell = "powershell"
} else {
sh, err := exec.LookPath("bash")
if err == nil {
slog.Info("init find shell is bash\n")
shell = sh
}
}
// 6.初始化sshd服务端配置
var sshdConf = model.SshdConf{
ID: 1,
Name: "init",
Host: initConf.SshdHost,
Port: initConf.SshdPort,
Shell: shell,
KeyFile: string(privateKey),
KeySeed: "webssh",
KeepAlive: 60,
LoadEnv: "Y",
AuthType: "all",
ServerVersion: "SSH-2.0-OpenSSH",
}
err = sshdConf.Create(&sshdConf)
if err != nil {
c.JSON(200, gin.H{"code": 3, "msg": err.Error(), "data": "初始化SSHD服务端配置错误"})
return
}
// 7.初始化sshd服务端账号
var sshdUser = model.SshdUser{
ID: 1,
Name: initConf.SshdUser,
Pwd: initConf.SshdPwd,
DescInfo: "服务器初始化账号",
IsEnable: "Y",
WorkDir: "",
ExpiryAt: dateTime,
}
err = sshdUser.Create(&sshdUser)
if err != nil {
c.JSON(200, gin.H{"code": 3, "msg": err.Error(), "data": "初始化SSHD服务端账号错误"})
return
}
// 8.覆盖默认配置
var defConf = config.DefaultConfig
defConf.IsInit = true
defConf.DbType = initConf.DbType
defConf.DbDsn = initConf.DbDsn
defConf.JwtSecret = initConf.JwtSecret
defConf.SessionSecret = initConf.SessionSecret
err = config.RewriteConfig(defConf)
if err != nil {
c.JSON(200, gin.H{"code": 2, "msg": err.Error(), "data": "写入配置错误"})
return
}
// 9.创建默认连接
var sshConf = model.SshConf{
Uid: 1,
Name: "self_sshd",
Address: "127.0.0.1",
User: initConf.SshdUser,
Pwd: initConf.SshdPwd,
AuthType: "pwd",
NetType: "tcp4",
CertData: "",
CertPwd: "",
Port: initConf.SshdPort,
FontSize: 16,
Background: "#000000",
Foreground: "#FFFFFF",
CursorColor: "#FFFFFF",
FontFamily: "Courier",
CursorStyle: "block",
Shell: "sh",
PtyType: "xterm-256color",
InitCmd: "",
InitBanner: "# https://github.com/o8oo8o/WebSSH",
}
err = sshConf.Create(&sshConf)
if err != nil {
c.JSON(200, gin.H{"code": 3, "msg": err.Error(), "data": "初始化创建默认连接错误"})
return
}
isStartSshd <- true
c.JSON(200, gin.H{"code": 0, "msg": "系统初始化完成"})
}
================================================
FILE: gossh/app/service/web_user.go
================================================
package service
import (
"gossh/app/config"
"gossh/app/middleware"
"gossh/app/model"
"gossh/app/utils"
"gossh/gin"
"log/slog"
"net/http"
"strconv"
"time"
)
func UserCreate(c *gin.Context) {
var user model.WebUser
if err := c.ShouldBind(&user); err != nil {
slog.Error("UserCreate 绑定数据错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 1, "msg": "输入数据不合法"})
return
}
var err error
user.IsRoot = "N"
err = user.Create(&user)
if err != nil {
slog.Error("创建用户错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 3, "msg": "创建用户错误"})
return
}
UserFindAll(c)
}
func ModifyPasswd(c *gin.Context) {
type password struct {
Pwd string `form:"pwd" binding:"required,min=1,max=64" json:"pwd"`
}
var pwd password
if err := c.ShouldBind(&pwd); err != nil {
slog.Error("绑定数据错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 1, "msg": "输入数据不合法"})
return
}
uid := c.GetUint("uid")
var tmp model.WebUser
user, err := tmp.FindByID(uid)
if err != nil {
slog.Error("FindByID错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 2, "msg": "获取用户信息错误"})
return
}
user.Pwd = pwd.Pwd
err = user.UpdatePassword(uid, &user)
if err != nil {
slog.Error("UpdatePassword错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 3, "msg": "更新用户密码错误"})
return
}
c.JSON(200, gin.H{"code": 0, "msg": "更新密码成功"})
}
func CheckUserNameExists(c *gin.Context) {
type Name struct {
Name string `form:"name" binding:"required,min=1,max=128" json:"name"`
}
var name Name
if err := c.ShouldBind(&name); err != nil {
slog.Error("绑定数据错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 1, "msg": "输入数据不合法"})
return
}
var user model.WebUser
tmp, err := user.FindByName(name.Name)
if err != nil {
slog.Error("FindByName错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 3, "msg": "获取用户信息错误"})
return
}
if tmp.ID != 0 {
c.JSON(200, gin.H{"code": 4, "msg": "用户名已经存存在"})
return
}
c.JSON(200, gin.H{"code": 0, "msg": "ok"})
return
}
func UserFindByID(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
slog.Error("获取ID错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 1, "msg": "获取ID错误"})
return
}
var user model.WebUser
data, err := user.FindByID(uint(id))
if err != nil {
slog.Error("FindByID错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 3, "msg": "获取用户信息错误"})
return
}
c.JSON(200, gin.H{"code": 0, "msg": "ok", "data": data})
}
func UserFindAll(c *gin.Context) {
limit, err := strconv.Atoi(c.DefaultQuery("limit", "10000"))
if err != nil {
slog.Error("获取limit错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 1, "msg": "获取limit错误"})
return
}
offset, err := strconv.Atoi(c.DefaultQuery("offset", "0"))
if err != nil {
slog.Error("获取offset错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 2, "msg": "获取offset错误"})
return
}
var user model.WebUser
data, err := user.FindAll(limit, offset)
if err != nil {
slog.Error("user.FindAl错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 4, "msg": "获取用户信息错误"})
return
}
c.JSON(200, gin.H{"code": 0, "msg": "ok", "data": data})
}
func UserUpdateById(c *gin.Context) {
var user model.WebUser
if err := c.ShouldBind(&user); err != nil {
slog.Error("获取ID错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 1, "msg": "获取ID错误"})
return
}
tmpUser, err := user.FindByID(user.ID)
if err != nil {
slog.Error("FindByID错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 3, "msg": "获取用户信息错误"})
return
}
// 防止越权操作
if tmpUser.IsRoot == "N" {
user.IsRoot = "N"
}
if tmpUser.IsRoot == "Y" {
c.JSON(200, gin.H{"code": 4, "msg": "内置Root用户不能更新"})
return
}
err = user.UpdateById(user.ID, &user)
if err != nil {
slog.Error("UpdateById错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 5, "msg": "更新用户错误"})
return
}
UserFindAll(c)
}
func UserDeleteById(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
slog.Error("获取ID错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 1, "msg": "获取ID错误"})
return
}
var user model.WebUser
tmpUser, err := user.FindByID(uint(id))
if err != nil {
slog.Error("user.FindByID错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 3, "msg": "获取用户信息错误"})
return
}
// 防止越权操作
if tmpUser.IsRoot == "Y" {
c.JSON(200, gin.H{"code": 4, "msg": "内置Root用户不能删除"})
return
}
err = user.DeleteByID(uint(id))
if err != nil {
slog.Error("user.DeleteByID错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 5, "msg": "删除用户错误"})
return
}
UserFindAll(c)
}
func UserLogin(c *gin.Context) {
if !config.DefaultConfig.IsInit {
slog.Warn("system no init")
c.JSON(401, gin.H{"code": 401, "msg": "请对系统进行初始化"})
return
}
type Param struct {
Name string `form:"name" binding:"required,min=1,max=64" json:"name"`
Pwd string `form:"pwd" binding:"required,min=1,max=64" json:"pwd"`
}
var loginAudit model.LoginAudit
var param Param
audit := model.LoginAudit{
ClientIp: c.ClientIP(),
UserAgent: utils.TruncateString(c.Request.UserAgent(), 500),
ErrMsg: "请求参数错误",
IsSuccess: "N",
OccurAt: model.DateTime(time.Now()),
}
if err := c.ShouldBind(¶m); err != nil {
audit.Name = utils.TruncateString(param.Name, 60)
audit.Pwd = utils.TruncateString(param.Pwd, 60)
_ = loginAudit.Create(&audit)
slog.Error("绑定数据错误", "err_msg", err.Error())
c.JSON(200, gin.H{"code": 1, "msg": "输入数据不合法"})
return
}
audit.Name = utils.TruncateString(param.Name, 60)
audit.Pwd = utils.TruncateString(param.Pwd, 60)
var user model.WebUser
u, err := user.FindByNameAndPwd(param.Name, param.Pwd)
if err != nil {
audit.ErrMsg = "账号密码错误"
_ = loginAudit.Create(&audit)
slog.Error("账号密码错误", "err_msg", err.Error())
c.JSON(401, gin.H{"code": 2, "msg": "账号密码错误"})
return
}
if u.IsEnable == "N" {
audit.ErrMsg = "账号已禁用"
_ = loginAudit.Create(&audit)
c.JSON(401, gin.H{"code": 3, "msg": "账号已禁用"})
return
}
if u.ExpiryAt.ToTime().Unix() < time.Now().Unix() {
audit.ErrMsg = "账号已过期"
_ = loginAudit.Create(&audit)
c.JSON(401, gin.H{"code": 4, "msg": "账号已过期"})
return
}
tokenString, err := middleware.GenerateToken(u.ID)
if err != nil {
audit.ErrMsg = "生成Token错误"
_ = loginAudit.Create(&audit)
c.JSON(401, gin.H{"code": 5, "msg": err.Error()})
return
}
audit.Name = param.Name
audit.Pwd = "*"
audit.ErrMsg = "*"
audit.IsSuccess = "Y"
_ = loginAudit.Create(&audit)
c.JSON(http.StatusOK, gin.H{
"code": 0,
"token": tokenString,
"msg": "登录成功",
"is_root": u.IsRoot,
"is_admin": u.IsAdmin,
"user_name": u.Name,
"user_desc": u.DescInfo,
"user_expiry_at": u.ExpiryAt.String(),
})
}
================================================
FILE: gossh/app/utils/crypto.go
================================================
package utils
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"io"
"log/slog"
"strings"
)
func AesEncrypt(orig string, key string) (string, error) {
orig = strings.TrimSpace(orig)
defer func() {
if err := recover(); err != nil {
slog.Error("AES加密错误", "err_msg", err)
}
}()
if len(orig) == 0 {
return "", nil
}
if len(key) != 32 {
return "", errors.New("加密需要的key长度错误")
}
// 转成字节数组
origData := []byte(orig)
k := []byte(key)
// 分组秘钥
block, err := aes.NewCipher(k)
if err != nil {
return "", err
}
// 获取秘钥块的长度
blockSize := block.BlockSize()
//补码
PKCS7Padding := func(ciphertext []byte, blockSize int) []byte {
padding := blockSize - len(ciphertext)%blockSize
padText := bytes.Repeat([]byte{byte(padding)}, padding)
return append(ciphertext, padText...)
}
origData = PKCS7Padding(origData, blockSize)
// 加密模式
blockMode := cipher.NewCBCEncrypter(block, k[:blockSize])
// 创建数组
crypt := make([]byte, len(origData))
// 加密
blockMode.CryptBlocks(crypt, origData)
return base64.StdEncoding.EncodeToString(crypt), nil
}
func AesDecrypt(crypt string, key string) (string, error) {
crypt = strings.TrimSpace(crypt)
defer func() {
if err := recover(); err != nil {
slog.Error("AES解密错误", "err_msg", err)
}
}()
if len(crypt) == 0 {
return "", nil
}
if len(crypt) < 1 || len(key) != 32 {
return "", errors.New("解密需要的key长度错误")
}
// 转成字节数组
cryptByte, _ := base64.StdEncoding.DecodeString(crypt)
k := []byte(key)
// 分组秘钥
block, err := aes.NewCipher(k)
if err != nil {
return "", err
}
// 获取秘钥块的长度
blockSize := block.BlockSize()
// 加密模式
blockMode := cipher.NewCBCDecrypter(block, k[:blockSize])
// 创建数组
orig := make([]byte, len(cryptByte))
// 解密
blockMode.CryptBlocks(orig, cryptByte)
// 去补全码
PKCS7UnPadding := func(origData []byte) []byte {
length := len(origData)
unPadding := int(origData[length-1])
return origData[:(length - unPadding)]
}
orig = PKCS7UnPadding(orig)
return string(orig), nil
}
func EncryptString(plaintext, key string) (string, error) {
plaintext = strings.TrimSpace(plaintext)
defer func() {
if err := recover(); err != nil {
slog.Error("AES加密错误", "err_msg", err)
}
}()
if len(plaintext) == 0 {
return "", nil
}
if len(key) != 32 {
return "", errors.New("加密需要的key长度错误")
}
// 创建 cipher
block, err := aes.NewCipher([]byte(key))
if err != nil {
return "", err
}
// 创建 GCM 模式
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", err
}
// 创建 nonce
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", err
}
// 加密
ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil)
// 返回 base64 编码的结果
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
func DecryptString(encrypted, key string) (string, error) {
encrypted = strings.TrimSpace(encrypted)
defer func() {
if err := recover(); err != nil {
slog.Error("AES解密错误", "err_msg", err)
}
}()
if len(encrypted) == 0 {
return "", nil
}
if len(encrypted) < 1 || len(key) != 32 {
return "", errors.New("解密需要的key长度错误")
}
// 解码 base64
ciphertext, err := base64.StdEncoding.DecodeString(encrypted)
if err != nil {
return "", err
}
// 创建 cipher
block, err := aes.NewCipher([]byte(key))
if err != nil {
return "", err
}
// 创建 GCM 模式
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", err
}
// 提取 nonce
if len(ciphertext) < gcm.NonceSize() {
return "", fmt.Errorf("ciphertext too short")
}
nonce, ciphertext := ciphertext[:gcm.NonceSize()], ciphertext[gcm.NonceSize():]
// 解密
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return "", err
}
return string(plaintext), nil
}
================================================
FILE: gossh/app/utils/utils.go
================================================
package utils
import (
"math/rand"
"time"
"unicode/utf8"
)
// RandString 生成指定长度随机字符串
func RandString(length int) string {
str := "0123456789abcdefghijklmnopqrstuvwxyz"
data := []byte(str)
var result []byte
r := rand.New(rand.NewSource(time.Now().UnixNano()))
for i := 0; i < length; i++ {
result = append(result, data[r.Intn(len(data))])
}
return string(result)
}
func TruncateString(s string, length int) string {
if utf8.RuneCountInString(s) <= length {
return s
}
runes := []rune(s)
if length < 1 {
return ""
}
return string(runes[:length])
}
================================================
FILE: gossh/crypto/blowfish/block.go
================================================
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package blowfish
// getNextWord returns the next big-endian uint32 value from the byte slice
// at the given position in a circular manner, updating the position.
func getNextWord(b []byte, pos *int) uint32 {
var w uint32
j := *pos
for i := 0; i < 4; i++ {
w = w<<8 | uint32(b[j])
j++
if j >= len(b) {
j = 0
}
}
*pos = j
return w
}
// ExpandKey performs a key expansion on the given *Cipher. Specifically, it
// performs the Blowfish algorithm's key schedule which sets up the *Cipher's
// pi and substitution tables for calls to Encrypt. This is used, primarily,
// by the bcrypt package to reuse the Blowfish key schedule during its
// set up. It's unlikely that you need to use this directly.
func ExpandKey(key []byte, c *Cipher) {
j := 0
for i := 0; i < 18; i++ {
// Using inlined getNextWord for performance.
var d uint32
for k := 0; k < 4; k++ {
d = d<<8 | uint32(key[j])
j++
if j >= len(key) {
j = 0
}
}
c.p[i] ^= d
}
var l, r uint32
for i := 0; i < 18; i += 2 {
l, r = encryptBlock(l, r, c)
c.p[i], c.p[i+1] = l, r
}
for i := 0; i < 256; i += 2 {
l, r = encryptBlock(l, r, c)
c.s0[i], c.s0[i+1] = l, r
}
for i := 0; i < 256; i += 2 {
l, r = encryptBlock(l, r, c)
c.s1[i], c.s1[i+1] = l, r
}
for i := 0; i < 256; i += 2 {
l, r = encryptBlock(l, r, c)
c.s2[i], c.s2[i+1] = l, r
}
for i := 0; i < 256; i += 2 {
l, r = encryptBlock(l, r, c)
c.s3[i], c.s3[i+1] = l, r
}
}
// This is similar to ExpandKey, but folds the salt during the key
// schedule. While ExpandKey is essentially expandKeyWithSalt with an all-zero
// salt passed in, reusing ExpandKey turns out to be a place of inefficiency
// and specializing it here is useful.
func expandKeyWithSalt(key []byte, salt []byte, c *Cipher) {
j := 0
for i := 0; i < 18; i++ {
c.p[i] ^= getNextWord(key, &j)
}
j = 0
var l, r uint32
for i := 0; i < 18; i += 2 {
l ^= getNextWord(salt, &j)
r ^= getNextWord(salt, &j)
l, r = encryptBlock(l, r, c)
c.p[i], c.p[i+1] = l, r
}
for i := 0; i < 256; i += 2 {
l ^= getNextWord(salt, &j)
r ^= getNextWord(salt, &j)
l, r = encryptBlock(l, r, c)
c.s0[i], c.s0[i+1] = l, r
}
for i := 0; i < 256; i += 2 {
l ^= getNextWord(salt, &j)
r ^= getNextWord(salt, &j)
l, r = encryptBlock(l, r, c)
c.s1[i], c.s1[i+1] = l, r
}
for i := 0; i < 256; i += 2 {
l ^= getNextWord(salt, &j)
r ^= getNextWord(salt, &j)
l, r = encryptBlock(l, r, c)
c.s2[i], c.s2[i+1] = l, r
}
for i := 0; i < 256; i += 2 {
l ^= getNextWord(salt, &j)
r ^= getNextWord(salt, &j)
l, r = encryptBlock(l, r, c)
c.s3[i], c.s3[i+1] = l, r
}
}
func encryptBlock(l, r uint32, c *Cipher) (uint32, uint32) {
xl, xr := l, r
xl ^= c.p[0]
xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[1]
xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[2]
xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[3]
xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[4]
xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[5]
xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[6]
xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[7]
xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[8]
xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[9]
xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[10]
xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[11]
xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[12]
xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[13]
xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[14]
xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[15]
xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[16]
xr ^= c.p[17]
return xr, xl
}
func decryptBlock(l, r uint32, c *Cipher) (uint32, uint32) {
xl, xr := l, r
xl ^= c.p[17]
xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[16]
xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[15]
xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[14]
xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[13]
xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[12]
xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[11]
xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[10]
xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[9]
xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[8]
xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[7]
xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[6]
xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[5]
xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[4]
xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[3]
xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[2]
xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[1]
xr ^= c.p[0]
return xr, xl
}
================================================
FILE: gossh/crypto/blowfish/cipher.go
================================================
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package blowfish implements Bruce Schneier's Blowfish encryption algorithm.
//
// Blowfish is a legacy cipher and its short block size makes it vulnerable to
// birthday bound attacks (see https://sweet32.info). It should only be used
// where compatibility with legacy systems, not security, is the goal.
//
// Deprecated: any new system should use AES (from crypto/aes, if necessary in
// an AEAD mode like crypto/cipher.NewGCM) or XChaCha20-Poly1305 (from
// golang.org/x/crypto/chacha20poly1305).
package blowfish // import "golang.org/x/crypto/blowfish"
// The code is a port of Bruce Schneier's C implementation.
// See https://www.schneier.com/blowfish.html.
import "strconv"
// The Blowfish block size in bytes.
const BlockSize = 8
// A Cipher is an instance of Blowfish encryption using a particular key.
type Cipher struct {
p [18]uint32
s0, s1, s2, s3 [256]uint32
}
type KeySizeError int
func (k KeySizeError) Error() string {
return "crypto/blowfish: invalid key size " + strconv.Itoa(int(k))
}
// NewCipher creates and returns a Cipher.
// The key argument should be the Blowfish key, from 1 to 56 bytes.
func NewCipher(key []byte) (*Cipher, error) {
var result Cipher
if k := len(key); k < 1 || k > 56 {
return nil, KeySizeError(k)
}
initCipher(&result)
ExpandKey(key, &result)
return &result, nil
}
// NewSaltedCipher creates a returns a Cipher that folds a salt into its key
// schedule. For most purposes, NewCipher, instead of NewSaltedCipher, is
// sufficient and desirable. For bcrypt compatibility, the key can be over 56
// bytes.
func NewSaltedCipher(key, salt []byte) (*Cipher, error) {
if len(salt) == 0 {
return NewCipher(key)
}
var result Cipher
if k := len(key); k < 1 {
return nil, KeySizeError(k)
}
initCipher(&result)
expandKeyWithSalt(key, salt, &result)
return &result, nil
}
// BlockSize returns the Blowfish block size, 8 bytes.
// It is necessary to satisfy the Block interface in the
// package "crypto/cipher".
func (c *Cipher) BlockSize() int { return BlockSize }
// Encrypt encrypts the 8-byte buffer src using the key k
// and stores the result in dst.
// Note that for amounts of data larger than a block,
// it is not safe to just call Encrypt on successive blocks;
// instead, use an encryption mode like CBC (see crypto/cipher/cbc.go).
func (c *Cipher) Encrypt(dst, src []byte) {
l := uint32(src[0])<<24 | uint32(src[1])<<16 | uint32(src[2])<<8 | uint32(src[3])
r := uint32(src[4])<<24 | uint32(src[5])<<16 | uint32(src[6])<<8 | uint32(src[7])
l, r = encryptBlock(l, r, c)
dst[0], dst[1], dst[2], dst[3] = byte(l>>24), byte(l>>16), byte(l>>8), byte(l)
dst[4], dst[5], dst[6], dst[7] = byte(r>>24), byte(r>>16), byte(r>>8), byte(r)
}
// Decrypt decrypts the 8-byte buffer src using the key k
// and stores the result in dst.
func (c *Cipher) Decrypt(dst, src []byte) {
l := uint32(src[0])<<24 | uint32(src[1])<<16 | uint32(src[2])<<8 | uint32(src[3])
r := uint32(src[4])<<24 | uint32(src[5])<<16 | uint32(src[6])<<8 | uint32(src[7])
l, r = decryptBlock(l, r, c)
dst[0], dst[1], dst[2], dst[3] = byte(l>>24), byte(l>>16), byte(l>>8), byte(l)
dst[4], dst[5], dst[6], dst[7] = byte(r>>24), byte(r>>16), byte(r>>8), byte(r)
}
func initCipher(c *Cipher) {
copy(c.p[0:], p[0:])
copy(c.s0[0:], s0[0:])
copy(c.s1[0:], s1[0:])
copy(c.s2[0:], s2[0:])
copy(c.s3[0:], s3[0:])
}
================================================
FILE: gossh/crypto/blowfish/const.go
================================================
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// The startup permutation array and substitution boxes.
// They are the hexadecimal digits of PI; see:
// https://www.schneier.com/code/constants.txt.
package blowfish
var s0 = [256]uint32{
0xd1310ba6, 0x98dfb5ac, 0x2ffd72db, 0xd01adfb7, 0xb8e1afed, 0x6a267e96,
0xba7c9045, 0xf12c7f99, 0x24a19947, 0xb3916cf7, 0x0801f2e2, 0x858efc16,
0x636920d8, 0x71574e69, 0xa458fea3, 0xf4933d7e, 0x0d95748f, 0x728eb658,
0x718bcd58, 0x82154aee, 0x7b54a41d, 0xc25a59b5, 0x9c30d539, 0x2af26013,
0xc5d1b023, 0x286085f0, 0xca417918, 0xb8db38ef, 0x8e79dcb0, 0x603a180e,
0x6c9e0e8b, 0xb01e8a3e, 0xd71577c1, 0xbd314b27, 0x78af2fda, 0x55605c60,
0xe65525f3, 0xaa55ab94, 0x57489862, 0x63e81440, 0x55ca396a, 0x2aab10b6,
0xb4cc5c34, 0x1141e8ce, 0xa15486af, 0x7c72e993, 0xb3ee1411, 0x636fbc2a,
0x2ba9c55d, 0x741831f6, 0xce5c3e16, 0x9b87931e, 0xafd6ba33, 0x6c24cf5c,
0x7a325381, 0x28958677, 0x3b8f4898, 0x6b4bb9af, 0xc4bfe81b, 0x66282193,
0x61d809cc, 0xfb21a991, 0x487cac60, 0x5dec8032, 0xef845d5d, 0xe98575b1,
0xdc262302, 0xeb651b88, 0x23893e81, 0xd396acc5, 0x0f6d6ff3, 0x83f44239,
0x2e0b4482, 0xa4842004, 0x69c8f04a, 0x9e1f9b5e, 0x21c66842, 0xf6e96c9a,
0x670c9c61, 0xabd388f0, 0x6a51a0d2, 0xd8542f68, 0x960fa728, 0xab5133a3,
0x6eef0b6c, 0x137a3be4, 0xba3bf050, 0x7efb2a98, 0xa1f1651d, 0x39af0176,
0x66ca593e, 0x82430e88, 0x8cee8619, 0x456f9fb4, 0x7d84a5c3, 0x3b8b5ebe,
0xe06f75d8, 0x85c12073, 0x401a449f, 0x56c16aa6, 0x4ed3aa62, 0x363f7706,
0x1bfedf72, 0x429b023d, 0x37d0d724, 0xd00a1248, 0xdb0fead3, 0x49f1c09b,
0x075372c9, 0x80991b7b, 0x25d479d8, 0xf6e8def7, 0xe3fe501a, 0xb6794c3b,
0x976ce0bd, 0x04c006ba, 0xc1a94fb6, 0x409f60c4, 0x5e5c9ec2, 0x196a2463,
0x68fb6faf, 0x3e6c53b5, 0x1339b2eb, 0x3b52ec6f, 0x6dfc511f, 0x9b30952c,
0xcc814544, 0xaf5ebd09, 0xbee3d004, 0xde334afd, 0x660f2807, 0x192e4bb3,
0xc0cba857, 0x45c8740f, 0xd20b5f39, 0xb9d3fbdb, 0x5579c0bd, 0x1a60320a,
0xd6a100c6, 0x402c7279, 0x679f25fe, 0xfb1fa3cc, 0x8ea5e9f8, 0xdb3222f8,
0x3c7516df, 0xfd616b15, 0x2f501ec8, 0xad0552ab, 0x323db5fa, 0xfd238760,
0x53317b48, 0x3e00df82, 0x9e5c57bb, 0xca6f8ca0, 0x1a87562e, 0xdf1769db,
0xd542a8f6, 0x287effc3, 0xac6732c6, 0x8c4f5573, 0x695b27b0, 0xbbca58c8,
0xe1ffa35d, 0xb8f011a0, 0x10fa3d98, 0xfd2183b8, 0x4afcb56c, 0x2dd1d35b,
0x9a53e479, 0xb6f84565, 0xd28e49bc, 0x4bfb9790, 0xe1ddf2da, 0xa4cb7e33,
0x62fb1341, 0xcee4c6e8, 0xef20cada, 0x36774c01, 0xd07e9efe, 0x2bf11fb4,
0x95dbda4d, 0xae909198, 0xeaad8e71, 0x6b93d5a0, 0xd08ed1d0, 0xafc725e0,
0x8e3c5b2f, 0x8e7594b7, 0x8ff6e2fb, 0xf2122b64, 0x8888b812, 0x900df01c,
0x4fad5ea0, 0x688fc31c, 0xd1cff191, 0xb3a8c1ad, 0x2f2f2218, 0xbe0e1777,
0xea752dfe, 0x8b021fa1, 0xe5a0cc0f, 0xb56f74e8, 0x18acf3d6, 0xce89e299,
0xb4a84fe0, 0xfd13e0b7, 0x7cc43b81, 0xd2ada8d9, 0x165fa266, 0x80957705,
0x93cc7314, 0x211a1477, 0xe6ad2065, 0x77b5fa86, 0xc75442f5, 0xfb9d35cf,
0xebcdaf0c, 0x7b3e89a0, 0xd6411bd3, 0xae1e7e49, 0x00250e2d, 0x2071b35e,
0x226800bb, 0x57b8e0af, 0x2464369b, 0xf009b91e, 0x5563911d, 0x59dfa6aa,
0x78c14389, 0xd95a537f, 0x207d5ba2, 0x02e5b9c5, 0x83260376, 0x6295cfa9,
0x11c81968, 0x4e734a41, 0xb3472dca, 0x7b14a94a, 0x1b510052, 0x9a532915,
0xd60f573f, 0xbc9bc6e4, 0x2b60a476, 0x81e67400, 0x08ba6fb5, 0x571be91f,
0xf296ec6b, 0x2a0dd915, 0xb6636521, 0xe7b9f9b6, 0xff34052e, 0xc5855664,
0x53b02d5d, 0xa99f8fa1, 0x08ba4799, 0x6e85076a,
}
var s1 = [256]uint32{
0x4b7a70e9, 0xb5b32944, 0xdb75092e, 0xc4192623, 0xad6ea6b0, 0x49a7df7d,
0x9cee60b8, 0x8fedb266, 0xecaa8c71, 0x699a17ff, 0x5664526c, 0xc2b19ee1,
0x193602a5, 0x75094c29, 0xa0591340, 0xe4183a3e, 0x3f54989a, 0x5b429d65,
0x6b8fe4d6, 0x99f73fd6, 0xa1d29c07, 0xefe830f5, 0x4d2d38e6, 0xf0255dc1,
0x4cdd2086, 0x8470eb26, 0x6382e9c6, 0x021ecc5e, 0x09686b3f, 0x3ebaefc9,
0x3c971814, 0x6b6a70a1, 0x687f3584, 0x52a0e286, 0xb79c5305, 0xaa500737,
0x3e07841c, 0x7fdeae5c, 0x8e7d44ec, 0x5716f2b8, 0xb03ada37, 0xf0500c0d,
0xf01c1f04, 0x0200b3ff, 0xae0cf51a, 0x3cb574b2, 0x25837a58, 0xdc0921bd,
0xd19113f9, 0x7ca92ff6, 0x94324773, 0x22f54701, 0x3ae5e581, 0x37c2dadc,
0xc8b57634, 0x9af3dda7, 0xa9446146, 0x0fd0030e, 0xecc8c73e, 0xa4751e41,
0xe238cd99, 0x3bea0e2f, 0x3280bba1, 0x183eb331, 0x4e548b38, 0x4f6db908,
0x6f420d03, 0xf60a04bf, 0x2cb81290, 0x24977c79, 0x5679b072, 0xbcaf89af,
0xde9a771f, 0xd9930810, 0xb38bae12, 0xdccf3f2e, 0x5512721f, 0x2e6b7124,
0x501adde6, 0x9f84cd87, 0x7a584718, 0x7408da17, 0xbc9f9abc, 0xe94b7d8c,
0xec7aec3a, 0xdb851dfa, 0x63094366, 0xc464c3d2, 0xef1c1847, 0x3215d908,
0xdd433b37, 0x24c2ba16, 0x12a14d43, 0x2a65c451, 0x50940002, 0x133ae4dd,
0x71dff89e, 0x10314e55, 0x81ac77d6, 0x5f11199b, 0x043556f1, 0xd7a3c76b,
0x3c11183b, 0x5924a509, 0xf28fe6ed, 0x97f1fbfa, 0x9ebabf2c, 0x1e153c6e,
0x86e34570, 0xeae96fb1, 0x860e5e0a, 0x5a3e2ab3, 0x771fe71c, 0x4e3d06fa,
0x2965dcb9, 0x99e71d0f, 0x803e89d6, 0x5266c825, 0x2e4cc978, 0x9c10b36a,
0xc6150eba, 0x94e2ea78, 0xa5fc3c53, 0x1e0a2df4, 0xf2f74ea7, 0x361d2b3d,
0x1939260f, 0x19c27960, 0x5223a708, 0xf71312b6, 0xebadfe6e, 0xeac31f66,
0xe3bc4595, 0xa67bc883, 0xb17f37d1, 0x018cff28, 0xc332ddef, 0xbe6c5aa5,
0x65582185, 0x68ab9802, 0xeecea50f, 0xdb2f953b, 0x2aef7dad, 0x5b6e2f84,
0x1521b628, 0x29076170, 0xecdd4775, 0x619f1510, 0x13cca830, 0xeb61bd96,
0x0334fe1e, 0xaa0363cf, 0xb5735c90, 0x4c70a239, 0xd59e9e0b, 0xcbaade14,
0xeecc86bc, 0x60622ca7, 0x9cab5cab, 0xb2f3846e, 0x648b1eaf, 0x19bdf0ca,
0xa02369b9, 0x655abb50, 0x40685a32, 0x3c2ab4b3, 0x319ee9d5, 0xc021b8f7,
0x9b540b19, 0x875fa099, 0x95f7997e, 0x623d7da8, 0xf837889a, 0x97e32d77,
0x11ed935f, 0x16681281, 0x0e358829, 0xc7e61fd6, 0x96dedfa1, 0x7858ba99,
0x57f584a5, 0x1b227263, 0x9b83c3ff, 0x1ac24696, 0xcdb30aeb, 0x532e3054,
0x8fd948e4, 0x6dbc3128, 0x58ebf2ef, 0x34c6ffea, 0xfe28ed61, 0xee7c3c73,
0x5d4a14d9, 0xe864b7e3, 0x42105d14, 0x203e13e0, 0x45eee2b6, 0xa3aaabea,
0xdb6c4f15, 0xfacb4fd0, 0xc742f442, 0xef6abbb5, 0x654f3b1d, 0x41cd2105,
0xd81e799e, 0x86854dc7, 0xe44b476a, 0x3d816250, 0xcf62a1f2, 0x5b8d2646,
0xfc8883a0, 0xc1c7b6a3, 0x7f1524c3, 0x69cb7492, 0x47848a0b, 0x5692b285,
0x095bbf00, 0xad19489d, 0x1462b174, 0x23820e00, 0x58428d2a, 0x0c55f5ea,
0x1dadf43e, 0x233f7061, 0x3372f092, 0x8d937e41, 0xd65fecf1, 0x6c223bdb,
0x7cde3759, 0xcbee7460, 0x4085f2a7, 0xce77326e, 0xa6078084, 0x19f8509e,
0xe8efd855, 0x61d99735, 0xa969a7aa, 0xc50c06c2, 0x5a04abfc, 0x800bcadc,
0x9e447a2e, 0xc3453484, 0xfdd56705, 0x0e1e9ec9, 0xdb73dbd3, 0x105588cd,
0x675fda79, 0xe3674340, 0xc5c43465, 0x713e38d8, 0x3d28f89e, 0xf16dff20,
0x153e21e7, 0x8fb03d4a, 0xe6e39f2b, 0xdb83adf7,
}
var s2 = [256]uint32{
0xe93d5a68, 0x948140f7, 0xf64c261c, 0x94692934, 0x411520f7, 0x7602d4f7,
0xbcf46b2e, 0xd4a20068, 0xd4082471, 0x3320f46a, 0x43b7d4b7, 0x500061af,
0x1e39f62e, 0x97244546, 0x14214f74, 0xbf8b8840, 0x4d95fc1d, 0x96b591af,
0x70f4ddd3, 0x66a02f45, 0xbfbc09ec, 0x03bd9785, 0x7fac6dd0, 0x31cb8504,
0x96eb27b3, 0x55fd3941, 0xda2547e6, 0xabca0a9a, 0x28507825, 0x530429f4,
0x0a2c86da, 0xe9b66dfb, 0x68dc1462, 0xd7486900, 0x680ec0a4, 0x27a18dee,
0x4f3ffea2, 0xe887ad8c, 0xb58ce006, 0x7af4d6b6, 0xaace1e7c, 0xd3375fec,
0xce78a399, 0x406b2a42, 0x20fe9e35, 0xd9f385b9, 0xee39d7ab, 0x3b124e8b,
0x1dc9faf7, 0x4b6d1856, 0x26a36631, 0xeae397b2, 0x3a6efa74, 0xdd5b4332,
0x6841e7f7, 0xca7820fb, 0xfb0af54e, 0xd8feb397, 0x454056ac, 0xba489527,
0x55533a3a, 0x20838d87, 0xfe6ba9b7, 0xd096954b, 0x55a867bc, 0xa1159a58,
0xcca92963, 0x99e1db33, 0xa62a4a56, 0x3f3125f9, 0x5ef47e1c, 0x9029317c,
0xfdf8e802, 0x04272f70, 0x80bb155c, 0x05282ce3, 0x95c11548, 0xe4c66d22,
0x48c1133f, 0xc70f86dc, 0x07f9c9ee, 0x41041f0f, 0x404779a4, 0x5d886e17,
0x325f51eb, 0xd59bc0d1, 0xf2bcc18f, 0x41113564, 0x257b7834, 0x602a9c60,
0xdff8e8a3, 0x1f636c1b, 0x0e12b4c2, 0x02e1329e, 0xaf664fd1, 0xcad18115,
0x6b2395e0, 0x333e92e1, 0x3b240b62, 0xeebeb922, 0x85b2a20e, 0xe6ba0d99,
0xde720c8c, 0x2da2f728, 0xd0127845, 0x95b794fd, 0x647d0862, 0xe7ccf5f0,
0x5449a36f, 0x877d48fa, 0xc39dfd27, 0xf33e8d1e, 0x0a476341, 0x992eff74,
0x3a6f6eab, 0xf4f8fd37, 0xa812dc60, 0xa1ebddf8, 0x991be14c, 0xdb6e6b0d,
0xc67b5510, 0x6d672c37, 0x2765d43b, 0xdcd0e804, 0xf1290dc7, 0xcc00ffa3,
0xb5390f92, 0x690fed0b, 0x667b9ffb, 0xcedb7d9c, 0xa091cf0b, 0xd9155ea3,
0xbb132f88, 0x515bad24, 0x7b9479bf, 0x763bd6eb, 0x37392eb3, 0xcc115979,
0x8026e297, 0xf42e312d, 0x6842ada7, 0xc66a2b3b, 0x12754ccc, 0x782ef11c,
0x6a124237, 0xb79251e7, 0x06a1bbe6, 0x4bfb6350, 0x1a6b1018, 0x11caedfa,
0x3d25bdd8, 0xe2e1c3c9, 0x44421659, 0x0a121386, 0xd90cec6e, 0xd5abea2a,
0x64af674e, 0xda86a85f, 0xbebfe988, 0x64e4c3fe, 0x9dbc8057, 0xf0f7c086,
0x60787bf8, 0x6003604d, 0xd1fd8346, 0xf6381fb0, 0x7745ae04, 0xd736fccc,
0x83426b33, 0xf01eab71, 0xb0804187, 0x3c005e5f, 0x77a057be, 0xbde8ae24,
0x55464299, 0xbf582e61, 0x4e58f48f, 0xf2ddfda2, 0xf474ef38, 0x8789bdc2,
0x5366f9c3, 0xc8b38e74, 0xb475f255, 0x46fcd9b9, 0x7aeb2661, 0x8b1ddf84,
0x846a0e79, 0x915f95e2, 0x466e598e, 0x20b45770, 0x8cd55591, 0xc902de4c,
0xb90bace1, 0xbb8205d0, 0x11a86248, 0x7574a99e, 0xb77f19b6, 0xe0a9dc09,
0x662d09a1, 0xc4324633, 0xe85a1f02, 0x09f0be8c, 0x4a99a025, 0x1d6efe10,
0x1ab93d1d, 0x0ba5a4df, 0xa186f20f, 0x2868f169, 0xdcb7da83, 0x573906fe,
0xa1e2ce9b, 0x4fcd7f52, 0x50115e01, 0xa70683fa, 0xa002b5c4, 0x0de6d027,
0x9af88c27, 0x773f8641, 0xc3604c06, 0x61a806b5, 0xf0177a28, 0xc0f586e0,
0x006058aa, 0x30dc7d62, 0x11e69ed7, 0x2338ea63, 0x53c2dd94, 0xc2c21634,
0xbbcbee56, 0x90bcb6de, 0xebfc7da1, 0xce591d76, 0x6f05e409, 0x4b7c0188,
0x39720a3d, 0x7c927c24, 0x86e3725f, 0x724d9db9, 0x1ac15bb4, 0xd39eb8fc,
0xed545578, 0x08fca5b5, 0xd83d7cd3, 0x4dad0fc4, 0x1e50ef5e, 0xb161e6f8,
0xa28514d9, 0x6c51133c, 0x6fd5c7e7, 0x56e14ec4, 0x362abfce, 0xddc6c837,
0xd79a3234, 0x92638212, 0x670efa8e, 0x406000e0,
}
var s3 = [256]uint32{
0x3a39ce37, 0xd3faf5cf, 0xabc27737, 0x5ac52d1b, 0x5cb0679e, 0x4fa33742,
0xd3822740, 0x99bc9bbe, 0xd5118e9d, 0xbf0f7315, 0xd62d1c7e, 0xc700c47b,
0xb78c1b6b, 0x21a19045, 0xb26eb1be, 0x6a366eb4, 0x5748ab2f, 0xbc946e79,
0xc6a376d2, 0x6549c2c8, 0x530ff8ee, 0x468dde7d, 0xd5730a1d, 0x4cd04dc6,
0x2939bbdb, 0xa9ba4650, 0xac9526e8, 0xbe5ee304, 0xa1fad5f0, 0x6a2d519a,
0x63ef8ce2, 0x9a86ee22, 0xc089c2b8, 0x43242ef6, 0xa51e03aa, 0x9cf2d0a4,
0x83c061ba, 0x9be96a4d, 0x8fe51550, 0xba645bd6, 0x2826a2f9, 0xa73a3ae1,
0x4ba99586, 0xef5562e9, 0xc72fefd3, 0xf752f7da, 0x3f046f69, 0x77fa0a59,
0x80e4a915, 0x87b08601, 0x9b09e6ad, 0x3b3ee593, 0xe990fd5a, 0x9e34d797,
0x2cf0b7d9, 0x022b8b51, 0x96d5ac3a, 0x017da67d, 0xd1cf3ed6, 0x7c7d2d28,
0x1f9f25cf, 0xadf2b89b, 0x5ad6b472, 0x5a88f54c, 0xe029ac71, 0xe019a5e6,
0x47b0acfd, 0xed93fa9b, 0xe8d3c48d, 0x283b57cc, 0xf8d56629, 0x79132e28,
0x785f0191, 0xed756055, 0xf7960e44, 0xe3d35e8c, 0x15056dd4, 0x88f46dba,
0x03a16125, 0x0564f0bd, 0xc3eb9e15, 0x3c9057a2, 0x97271aec, 0xa93a072a,
0x1b3f6d9b, 0x1e6321f5, 0xf59c66fb, 0x26dcf319, 0x7533d928, 0xb155fdf5,
0x03563482, 0x8aba3cbb, 0x28517711, 0xc20ad9f8, 0xabcc5167, 0xccad925f,
0x4de81751, 0x3830dc8e, 0x379d5862, 0x9320f991, 0xea7a90c2, 0xfb3e7bce,
0x5121ce64, 0x774fbe32, 0xa8b6e37e, 0xc3293d46, 0x48de5369, 0x6413e680,
0xa2ae0810, 0xdd6db224, 0x69852dfd, 0x09072166, 0xb39a460a, 0x6445c0dd,
0x586cdecf, 0x1c20c8ae, 0x5bbef7dd, 0x1b588d40, 0xccd2017f, 0x6bb4e3bb,
0xdda26a7e, 0x3a59ff45, 0x3e350a44, 0xbcb4cdd5, 0x72eacea8, 0xfa6484bb,
0x8d6612ae, 0xbf3c6f47, 0xd29be463, 0x542f5d9e, 0xaec2771b, 0xf64e6370,
0x740e0d8d, 0xe75b1357, 0xf8721671, 0xaf537d5d, 0x4040cb08, 0x4eb4e2cc,
0x34d2466a, 0x0115af84, 0xe1b00428, 0x95983a1d, 0x06b89fb4, 0xce6ea048,
0x6f3f3b82, 0x3520ab82, 0x011a1d4b, 0x277227f8, 0x611560b1, 0xe7933fdc,
0xbb3a792b, 0x344525bd, 0xa08839e1, 0x51ce794b, 0x2f32c9b7, 0xa01fbac9,
0xe01cc87e, 0xbcc7d1f6, 0xcf0111c3, 0xa1e8aac7, 0x1a908749, 0xd44fbd9a,
0xd0dadecb, 0xd50ada38, 0x0339c32a, 0xc6913667, 0x8df9317c, 0xe0b12b4f,
0xf79e59b7, 0x43f5bb3a, 0xf2d519ff, 0x27d9459c, 0xbf97222c, 0x15e6fc2a,
0x0f91fc71, 0x9b941525, 0xfae59361, 0xceb69ceb, 0xc2a86459, 0x12baa8d1,
0xb6c1075e, 0xe3056a0c, 0x10d25065, 0xcb03a442, 0xe0ec6e0e, 0x1698db3b,
0x4c98a0be, 0x3278e964, 0x9f1f9532, 0xe0d392df, 0xd3a0342b, 0x8971f21e,
0x1b0a7441, 0x4ba3348c, 0xc5be7120, 0xc37632d8, 0xdf359f8d, 0x9b992f2e,
0xe60b6f47, 0x0fe3f11d, 0xe54cda54, 0x1edad891, 0xce6279cf, 0xcd3e7e6f,
0x1618b166, 0xfd2c1d05, 0x848fd2c5, 0xf6fb2299, 0xf523f357, 0xa6327623,
0x93a83531, 0x56cccd02, 0xacf08162, 0x5a75ebb5, 0x6e163697, 0x88d273cc,
0xde966292, 0x81b949d0, 0x4c50901b, 0x71c65614, 0xe6c6c7bd, 0x327a140a,
0x45e1d006, 0xc3f27b9a, 0xc9aa53fd, 0x62a80f00, 0xbb25bfe2, 0x35bdd2f6,
0x71126905, 0xb2040222, 0xb6cbcf7c, 0xcd769c2b, 0x53113ec0, 0x1640e3d3,
0x38abbd60, 0x2547adf0, 0xba38209c, 0xf746ce76, 0x77afa1c5, 0x20756060,
0x85cbfe4e, 0x8ae88dd8, 0x7aaaf9b0, 0x4cf9aa7e, 0x1948c25c, 0x02fb8a8c,
0x01c36ae4, 0xd6ebe1f9, 0x90d4f869, 0xa65cdea0, 0x3f09252d, 0xc208e69f,
0xb74e6132, 0xce77e25b, 0x578fdfe3, 0x3ac372e6,
}
var p = [18]uint32{
0x243f6a88, 0x85a308d3, 0x13198a2e, 0x03707344, 0xa4093822, 0x299f31d0,
0x082efa98, 0xec4e6c89, 0x452821e6, 0x38d01377, 0xbe5466cf, 0x34e90c6c,
0xc0ac29b7, 0xc97c50dd, 0x3f84d5b5, 0xb5470917, 0x9216d5d9, 0x8979fb1b,
}
================================================
FILE: gossh/crypto/chacha20/chacha_arm64.go
================================================
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gc && !purego
package chacha20
const bufSize = 256
//go:noescape
func xorKeyStreamVX(dst, src []byte, key *[8]uint32, nonce *[3]uint32, counter *uint32)
func (c *Cipher) xorKeyStreamBlocks(dst, src []byte) {
xorKeyStreamVX(dst, src, &c.key, &c.nonce, &c.counter)
}
================================================
FILE: gossh/crypto/chacha20/chacha_arm64.s
================================================
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gc && !purego
#include "textflag.h"
#define NUM_ROUNDS 10
// func xorKeyStreamVX(dst, src []byte, key *[8]uint32, nonce *[3]uint32, counter *uint32)
TEXT ·xorKeyStreamVX(SB), NOSPLIT, $0
MOVD dst+0(FP), R1
MOVD src+24(FP), R2
MOVD src_len+32(FP), R3
MOVD key+48(FP), R4
MOVD nonce+56(FP), R6
MOVD counter+64(FP), R7
MOVD $·constants(SB), R10
MOVD $·incRotMatrix(SB), R11
MOVW (R7), R20
AND $~255, R3, R13
ADD R2, R13, R12 // R12 for block end
AND $255, R3, R13
loop:
MOVD $NUM_ROUNDS, R21
VLD1 (R11), [V30.S4, V31.S4]
// load contants
// VLD4R (R10), [V0.S4, V1.S4, V2.S4, V3.S4]
WORD $0x4D60E940
// load keys
// VLD4R 16(R4), [V4.S4, V5.S4, V6.S4, V7.S4]
WORD $0x4DFFE884
// VLD4R 16(R4), [V8.S4, V9.S4, V10.S4, V11.S4]
WORD $0x4DFFE888
SUB $32, R4
// load counter + nonce
// VLD1R (R7), [V12.S4]
WORD $0x4D40C8EC
// VLD3R (R6), [V13.S4, V14.S4, V15.S4]
WORD $0x4D40E8CD
// update counter
VADD V30.S4, V12.S4, V12.S4
chacha:
// V0..V3 += V4..V7
// V12..V15 <<<= ((V12..V15 XOR V0..V3), 16)
VADD V0.S4, V4.S4, V0.S4
VADD V1.S4, V5.S4, V1.S4
VADD V2.S4, V6.S4, V2.S4
VADD V3.S4, V7.S4, V3.S4
VEOR V12.B16, V0.B16, V12.B16
VEOR V13.B16, V1.B16, V13.B16
VEOR V14.B16, V2.B16, V14.B16
VEOR V15.B16, V3.B16, V15.B16
VREV32 V12.H8, V12.H8
VREV32 V13.H8, V13.H8
VREV32 V14.H8, V14.H8
VREV32 V15.H8, V15.H8
// V8..V11 += V12..V15
// V4..V7 <<<= ((V4..V7 XOR V8..V11), 12)
VADD V8.S4, V12.S4, V8.S4
VADD V9.S4, V13.S4, V9.S4
VADD V10.S4, V14.S4, V10.S4
VADD V11.S4, V15.S4, V11.S4
VEOR V8.B16, V4.B16, V16.B16
VEOR V9.B16, V5.B16, V17.B16
VEOR V10.B16, V6.B16, V18.B16
VEOR V11.B16, V7.B16, V19.B16
VSHL $12, V16.S4, V4.S4
VSHL $12, V17.S4, V5.S4
VSHL $12, V18.S4, V6.S4
VSHL $12, V19.S4, V7.S4
VSRI $20, V16.S4, V4.S4
VSRI $20, V17.S4, V5.S4
VSRI $20, V18.S4, V6.S4
VSRI $20, V19.S4, V7.S4
// V0..V3 += V4..V7
// V12..V15 <<<= ((V12..V15 XOR V0..V3), 8)
VADD V0.S4, V4.S4, V0.S4
VADD V1.S4, V5.S4, V1.S4
VADD V2.S4, V6.S4, V2.S4
VADD V3.S4, V7.S4, V3.S4
VEOR V12.B16, V0.B16, V12.B16
VEOR V13.B16, V1.B16, V13.B16
VEOR V14.B16, V2.B16, V14.B16
VEOR V15.B16, V3.B16, V15.B16
VTBL V31.B16, [V12.B16], V12.B16
VTBL V31.B16, [V13.B16], V13.B16
VTBL V31.B16, [V14.B16], V14.B16
VTBL V31.B16, [V15.B16], V15.B16
// V8..V11 += V12..V15
// V4..V7 <<<= ((V4..V7 XOR V8..V11), 7)
VADD V12.S4, V8.S4, V8.S4
VADD V13.S4, V9.S4, V9.S4
VADD V14.S4, V10.S4, V10.S4
VADD V15.S4, V11.S4, V11.S4
VEOR V8.B16, V4.B16, V16.B16
VEOR V9.B16, V5.B16, V17.B16
VEOR V10.B16, V6.B16, V18.B16
VEOR V11.B16, V7.B16, V19.B16
VSHL $7, V16.S4, V4.S4
VSHL $7, V17.S4, V5.S4
VSHL $7, V18.S4, V6.S4
VSHL $7, V19.S4, V7.S4
VSRI $25, V16.S4, V4.S4
VSRI $25, V17.S4, V5.S4
VSRI $25, V18.S4, V6.S4
VSRI $25, V19.S4, V7.S4
// V0..V3 += V5..V7, V4
// V15,V12-V14 <<<= ((V15,V12-V14 XOR V0..V3), 16)
VADD V0.S4, V5.S4, V0.S4
VADD V1.S4, V6.S4, V1.S4
VADD V2.S4, V7.S4, V2.S4
VADD V3.S4, V4.S4, V3.S4
VEOR V15.B16, V0.B16, V15.B16
VEOR V12.B16, V1.B16, V12.B16
VEOR V13.B16, V2.B16, V13.B16
VEOR V14.B16, V3.B16, V14.B16
VREV32 V12.H8, V12.H8
VREV32 V13.H8, V13.H8
VREV32 V14.H8, V14.H8
VREV32 V15.H8, V15.H8
// V10 += V15; V5 <<<= ((V10 XOR V5), 12)
// ...
VADD V15.S4, V10.S4, V10.S4
VADD V12.S4, V11.S4, V11.S4
VADD V13.S4, V8.S4, V8.S4
VADD V14.S4, V9.S4, V9.S4
VEOR V10.B16, V5.B16, V16.B16
VEOR V11.B16, V6.B16, V17.B16
VEOR V8.B16, V7.B16, V18.B16
VEOR V9.B16, V4.B16, V19.B16
VSHL $12, V16.S4, V5.S4
VSHL $12, V17.S4, V6.S4
VSHL $12, V18.S4, V7.S4
VSHL $12, V19.S4, V4.S4
VSRI $20, V16.S4, V5.S4
VSRI $20, V17.S4, V6.S4
VSRI $20, V18.S4, V7.S4
VSRI $20, V19.S4, V4.S4
// V0 += V5; V15 <<<= ((V0 XOR V15), 8)
// ...
VADD V5.S4, V0.S4, V0.S4
VADD V6.S4, V1.S4, V1.S4
VADD V7.S4, V2.S4, V2.S4
VADD V4.S4, V3.S4, V3.S4
VEOR V0.B16, V15.B16, V15.B16
VEOR V1.B16, V12.B16, V12.B16
VEOR V2.B16, V13.B16, V13.B16
VEOR V3.B16, V14.B16, V14.B16
VTBL V31.B16, [V12.B16], V12.B16
VTBL V31.B16, [V13.B16], V13.B16
VTBL V31.B16, [V14.B16], V14.B16
VTBL V31.B16, [V15.B16], V15.B16
// V10 += V15; V5 <<<= ((V10 XOR V5), 7)
// ...
VADD V15.S4, V10.S4, V10.S4
VADD V12.S4, V11.S4, V11.S4
VADD V13.S4, V8.S4, V8.S4
VADD V14.S4, V9.S4, V9.S4
VEOR V10.B16, V5.B16, V16.B16
VEOR V11.B16, V6.B16, V17.B16
VEOR V8.B16, V7.B16, V18.B16
VEOR V9.B16, V4.B16, V19.B16
VSHL $7, V16.S4, V5.S4
VSHL $7, V17.S4, V6.S4
VSHL $7, V18.S4, V7.S4
VSHL $7, V19.S4, V4.S4
VSRI $25, V16.S4, V5.S4
VSRI $25, V17.S4, V6.S4
VSRI $25, V18.S4, V7.S4
VSRI $25, V19.S4, V4.S4
SUB $1, R21
CBNZ R21, chacha
// VLD4R (R10), [V16.S4, V17.S4, V18.S4, V19.S4]
WORD $0x4D60E950
// VLD4R 16(R4), [V20.S4, V21.S4, V22.S4, V23.S4]
WORD $0x4DFFE894
VADD V30.S4, V12.S4, V12.S4
VADD V16.S4, V0.S4, V0.S4
VADD V17.S4, V1.S4, V1.S4
VADD V18.S4, V2.S4, V2.S4
VADD V19.S4, V3.S4, V3.S4
// VLD4R 16(R4), [V24.S4, V25.S4, V26.S4, V27.S4]
WORD $0x4DFFE898
// restore R4
SUB $32, R4
// load counter + nonce
// VLD1R (R7), [V28.S4]
WORD $0x4D40C8FC
// VLD3R (R6), [V29.S4, V30.S4, V31.S4]
WORD $0x4D40E8DD
VADD V20.S4, V4.S4, V4.S4
VADD V21.S4, V5.S4, V5.S4
VADD V22.S4, V6.S4, V6.S4
VADD V23.S4, V7.S4, V7.S4
VADD V24.S4, V8.S4, V8.S4
VADD V25.S4, V9.S4, V9.S4
VADD V26.S4, V10.S4, V10.S4
VADD V27.S4, V11.S4, V11.S4
VADD V28.S4, V12.S4, V12.S4
VADD V29.S4, V13.S4, V13.S4
VADD V30.S4, V14.S4, V14.S4
VADD V31.S4, V15.S4, V15.S4
VZIP1 V1.S4, V0.S4, V16.S4
VZIP2 V1.S4, V0.S4, V17.S4
VZIP1 V3.S4, V2.S4, V18.S4
VZIP2 V3.S4, V2.S4, V19.S4
VZIP1 V5.S4, V4.S4, V20.S4
VZIP2 V5.S4, V4.S4, V21.S4
VZIP1 V7.S4, V6.S4, V22.S4
VZIP2 V7.S4, V6.S4, V23.S4
VZIP1 V9.S4, V8.S4, V24.S4
VZIP2 V9.S4, V8.S4, V25.S4
VZIP1 V11.S4, V10.S4, V26.S4
VZIP2 V11.S4, V10.S4, V27.S4
VZIP1 V13.S4, V12.S4, V28.S4
VZIP2 V13.S4, V12.S4, V29.S4
VZIP1 V15.S4, V14.S4, V30.S4
VZIP2 V15.S4, V14.S4, V31.S4
VZIP1 V18.D2, V16.D2, V0.D2
VZIP2 V18.D2, V16.D2, V4.D2
VZIP1 V19.D2, V17.D2, V8.D2
VZIP2 V19.D2, V17.D2, V12.D2
VLD1.P 64(R2), [V16.B16, V17.B16, V18.B16, V19.B16]
VZIP1 V22.D2, V20.D2, V1.D2
VZIP2 V22.D2, V20.D2, V5.D2
VZIP1 V23.D2, V21.D2, V9.D2
VZIP2 V23.D2, V21.D2, V13.D2
VLD1.P 64(R2), [V20.B16, V21.B16, V22.B16, V23.B16]
VZIP1 V26.D2, V24.D2, V2.D2
VZIP2 V26.D2, V24.D2, V6.D2
VZIP1 V27.D2, V25.D2, V10.D2
VZIP2 V27.D2, V25.D2, V14.D2
VLD1.P 64(R2), [V24.B16, V25.B16, V26.B16, V27.B16]
VZIP1 V30.D2, V28.D2, V3.D2
VZIP2 V30.D2, V28.D2, V7.D2
VZIP1 V31.D2, V29.D2, V11.D2
VZIP2 V31.D2, V29.D2, V15.D2
VLD1.P 64(R2), [V28.B16, V29.B16, V30.B16, V31.B16]
VEOR V0.B16, V16.B16, V16.B16
VEOR V1.B16, V17.B16, V17.B16
VEOR V2.B16, V18.B16, V18.B16
VEOR V3.B16, V19.B16, V19.B16
VST1.P [V16.B16, V17.B16, V18.B16, V19.B16], 64(R1)
VEOR V4.B16, V20.B16, V20.B16
VEOR V5.B16, V21.B16, V21.B16
VEOR V6.B16, V22.B16, V22.B16
VEOR V7.B16, V23.B16, V23.B16
VST1.P [V20.B16, V21.B16, V22.B16, V23.B16], 64(R1)
VEOR V8.B16, V24.B16, V24.B16
VEOR V9.B16, V25.B16, V25.B16
VEOR V10.B16, V26.B16, V26.B16
VEOR V11.B16, V27.B16, V27.B16
VST1.P [V24.B16, V25.B16, V26.B16, V27.B16], 64(R1)
VEOR V12.B16, V28.B16, V28.B16
VEOR V13.B16, V29.B16, V29.B16
VEOR V14.B16, V30.B16, V30.B16
VEOR V15.B16, V31.B16, V31.B16
VST1.P [V28.B16, V29.B16, V30.B16, V31.B16], 64(R1)
ADD $4, R20
MOVW R20, (R7) // update counter
CMP R2, R12
BGT loop
RET
DATA ·constants+0x00(SB)/4, $0x61707865
DATA ·constants+0x04(SB)/4, $0x3320646e
DATA ·constants+0x08(SB)/4, $0x79622d32
DATA ·constants+0x0c(SB)/4, $0x6b206574
GLOBL ·constants(SB), NOPTR|RODATA, $32
DATA ·incRotMatrix+0x00(SB)/4, $0x00000000
DATA ·incRotMatrix+0x04(SB)/4, $0x00000001
DATA ·incRotMatrix+0x08(SB)/4, $0x00000002
DATA ·incRotMatrix+0x0c(SB)/4, $0x00000003
DATA ·incRotMatrix+0x10(SB)/4, $0x02010003
DATA ·incRotMatrix+0x14(SB)/4, $0x06050407
DATA ·incRotMatrix+0x18(SB)/4, $0x0A09080B
DATA ·incRotMatrix+0x1c(SB)/4, $0x0E0D0C0F
GLOBL ·incRotMatrix(SB), NOPTR|RODATA, $32
================================================
FILE: gossh/crypto/chacha20/chacha_generic.go
================================================
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package chacha20 implements the ChaCha20 and XChaCha20 encryption algorithms
// as specified in RFC 8439 and draft-irtf-cfrg-xchacha-01.
package chacha20
import (
"crypto/cipher"
"encoding/binary"
"errors"
"math/bits"
"gossh/crypto/internal/alias"
)
const (
// KeySize is the size of the key used by this cipher, in bytes.
KeySize = 32
// NonceSize is the size of the nonce used with the standard variant of this
// cipher, in bytes.
//
// Note that this is too short to be safely generated at random if the same
// key is reused more than 2³² times.
NonceSize = 12
// NonceSizeX is the size of the nonce used with the XChaCha20 variant of
// this cipher, in bytes.
NonceSizeX = 24
)
// Cipher is a stateful instance of ChaCha20 or XChaCha20 using a particular key
// and nonce. A *Cipher implements the cipher.Stream interface.
type Cipher struct {
// The ChaCha20 state is 16 words: 4 constant, 8 of key, 1 of counter
// (incremented after each block), and 3 of nonce.
key [8]uint32
counter uint32
nonce [3]uint32
// The last len bytes of buf are leftover key stream bytes from the previous
// XORKeyStream invocation. The size of buf depends on how many blocks are
// computed at a time by xorKeyStreamBlocks.
buf [bufSize]byte
len int
// overflow is set when the counter overflowed, no more blocks can be
// generated, and the next XORKeyStream call should panic.
overflow bool
// The counter-independent results of the first round are cached after they
// are computed the first time.
precompDone bool
p1, p5, p9, p13 uint32
p2, p6, p10, p14 uint32
p3, p7, p11, p15 uint32
}
var _ cipher.Stream = (*Cipher)(nil)
// NewUnauthenticatedCipher creates a new ChaCha20 stream cipher with the given
// 32 bytes key and a 12 or 24 bytes nonce. If a nonce of 24 bytes is provided,
// the XChaCha20 construction will be used. It returns an error if key or nonce
// have any other length.
//
// Note that ChaCha20, like all stream ciphers, is not authenticated and allows
// attackers to silently tamper with the plaintext. For this reason, it is more
// appropriate as a building block than as a standalone encryption mechanism.
// Instead, consider using package golang.org/x/crypto/chacha20poly1305.
func NewUnauthenticatedCipher(key, nonce []byte) (*Cipher, error) {
// This function is split into a wrapper so that the Cipher allocation will
// be inlined, and depending on how the caller uses the return value, won't
// escape to the heap.
c := &Cipher{}
return newUnauthenticatedCipher(c, key, nonce)
}
func newUnauthenticatedCipher(c *Cipher, key, nonce []byte) (*Cipher, error) {
if len(key) != KeySize {
return nil, errors.New("chacha20: wrong key size")
}
if len(nonce) == NonceSizeX {
// XChaCha20 uses the ChaCha20 core to mix 16 bytes of the nonce into a
// derived key, allowing it to operate on a nonce of 24 bytes. See
// draft-irtf-cfrg-xchacha-01, Section 2.3.
key, _ = HChaCha20(key, nonce[0:16])
cNonce := make([]byte, NonceSize)
copy(cNonce[4:12], nonce[16:24])
nonce = cNonce
} else if len(nonce) != NonceSize {
return nil, errors.New("chacha20: wrong nonce size")
}
key, nonce = key[:KeySize], nonce[:NonceSize] // bounds check elimination hint
c.key = [8]uint32{
binary.LittleEndian.Uint32(key[0:4]),
binary.LittleEndian.Uint32(key[4:8]),
binary.LittleEndian.Uint32(key[8:12]),
binary.LittleEndian.Uint32(key[12:16]),
binary.LittleEndian.Uint32(key[16:20]),
binary.LittleEndian.Uint32(key[20:24]),
binary.LittleEndian.Uint32(key[24:28]),
binary.LittleEndian.Uint32(key[28:32]),
}
c.nonce = [3]uint32{
binary.LittleEndian.Uint32(nonce[0:4]),
binary.LittleEndian.Uint32(nonce[4:8]),
binary.LittleEndian.Uint32(nonce[8:12]),
}
return c, nil
}
// The constant first 4 words of the ChaCha20 state.
const (
j0 uint32 = 0x61707865 // expa
j1 uint32 = 0x3320646e // nd 3
j2 uint32 = 0x79622d32 // 2-by
j3 uint32 = 0x6b206574 // te k
)
const blockSize = 64
// quarterRound is the core of ChaCha20. It shuffles the bits of 4 state words.
// It's executed 4 times for each of the 20 ChaCha20 rounds, operating on all 16
// words each round, in columnar or diagonal groups of 4 at a time.
func quarterRound(a, b, c, d uint32) (uint32, uint32, uint32, uint32) {
a += b
d ^= a
d = bits.RotateLeft32(d, 16)
c += d
b ^= c
b = bits.RotateLeft32(b, 12)
a += b
d ^= a
d = bits.RotateLeft32(d, 8)
c += d
b ^= c
b = bits.RotateLeft32(b, 7)
return a, b, c, d
}
// SetCounter sets the Cipher counter. The next invocation of XORKeyStream will
// behave as if (64 * counter) bytes had been encrypted so far.
//
// To prevent accidental counter reuse, SetCounter panics if counter is less
// than the current value.
//
// Note that the execution time of XORKeyStream is not independent of the
// counter value.
func (s *Cipher) SetCounter(counter uint32) {
// Internally, s may buffer multiple blocks, which complicates this
// implementation slightly. When checking whether the counter has rolled
// back, we must use both s.counter and s.len to determine how many blocks
// we have already output.
outputCounter := s.counter - uint32(s.len)/blockSize
if s.overflow || counter < outputCounter {
panic("chacha20: SetCounter attempted to rollback counter")
}
// In the general case, we set the new counter value and reset s.len to 0,
// causing the next call to XORKeyStream to refill the buffer. However, if
// we're advancing within the existing buffer, we can save work by simply
// setting s.len.
if counter < s.counter {
s.len = int(s.counter-counter) * blockSize
} else {
s.counter = counter
s.len = 0
}
}
// XORKeyStream XORs each byte in the given slice with a byte from the
// cipher's key stream. Dst and src must overlap entirely or not at all.
//
// If len(dst) < len(src), XORKeyStream will panic. It is acceptable
// to pass a dst bigger than src, and in that case, XORKeyStream will
// only update dst[:len(src)] and will not touch the rest of dst.
//
// Multiple calls to XORKeyStream behave as if the concatenation of
// the src buffers was passed in a single run. That is, Cipher
// maintains state and does not reset at each XORKeyStream call.
func (s *Cipher) XORKeyStream(dst, src []byte) {
if len(src) == 0 {
return
}
if len(dst) < len(src) {
panic("chacha20: output smaller than input")
}
dst = dst[:len(src)]
if alias.InexactOverlap(dst, src) {
panic("chacha20: invalid buffer overlap")
}
// First, drain any remaining key stream from a previous XORKeyStream.
if s.len != 0 {
keyStream := s.buf[bufSize-s.len:]
if len(src) < len(keyStream) {
keyStream = keyStream[:len(src)]
}
_ = src[len(keyStream)-1] // bounds check elimination hint
for i, b := range keyStream {
dst[i] = src[i] ^ b
}
s.len -= len(keyStream)
dst, src = dst[len(keyStream):], src[len(keyStream):]
}
if len(src) == 0 {
return
}
// If we'd need to let the counter overflow and keep generating output,
// panic immediately. If instead we'd only reach the last block, remember
// not to generate any more output after the buffer is drained.
numBlocks := (uint64(len(src)) + blockSize - 1) / blockSize
if s.overflow || uint64(s.counter)+numBlocks > 1<<32 {
panic("chacha20: counter overflow")
} else if uint64(s.counter)+numBlocks == 1<<32 {
s.overflow = true
}
// xorKeyStreamBlocks implementations expect input lengths that are a
// multiple of bufSize. Platform-specific ones process multiple blocks at a
// time, so have bufSizes that are a multiple of blockSize.
full := len(src) - len(src)%bufSize
if full > 0 {
s.xorKeyStreamBlocks(dst[:full], src[:full])
}
dst, src = dst[full:], src[full:]
// If using a multi-block xorKeyStreamBlocks would overflow, use the generic
// one that does one block at a time.
const blocksPerBuf = bufSize / blockSize
if uint64(s.counter)+blocksPerBuf > 1<<32 {
s.buf = [bufSize]byte{}
numBlocks := (len(src) + blockSize - 1) / blockSize
buf := s.buf[bufSize-numBlocks*blockSize:]
copy(buf, src)
s.xorKeyStreamBlocksGeneric(buf, buf)
s.len = len(buf) - copy(dst, buf)
return
}
// If we have a partial (multi-)block, pad it for xorKeyStreamBlocks, and
// keep the leftover keystream for the next XORKeyStream invocation.
if len(src) > 0 {
s.buf = [bufSize]byte{}
copy(s.buf[:], src)
s.xorKeyStreamBlocks(s.buf[:], s.buf[:])
s.len = bufSize - copy(dst, s.buf[:])
}
}
func (s *Cipher) xorKeyStreamBlocksGeneric(dst, src []byte) {
if len(dst) != len(src) || len(dst)%blockSize != 0 {
panic("chacha20: internal error: wrong dst and/or src length")
}
// To generate each block of key stream, the initial cipher state
// (represented below) is passed through 20 rounds of shuffling,
// alternatively applying quarterRounds by columns (like 1, 5, 9, 13)
// or by diagonals (like 1, 6, 11, 12).
//
// 0:cccccccc 1:cccccccc 2:cccccccc 3:cccccccc
// 4:kkkkkkkk 5:kkkkkkkk 6:kkkkkkkk 7:kkkkkkkk
// 8:kkkkkkkk 9:kkkkkkkk 10:kkkkkkkk 11:kkkkkkkk
// 12:bbbbbbbb 13:nnnnnnnn 14:nnnnnnnn 15:nnnnnnnn
//
// c=constant k=key b=blockcount n=nonce
var (
c0, c1, c2, c3 = j0, j1, j2, j3
c4, c5, c6, c7 = s.key[0], s.key[1], s.key[2], s.key[3]
c8, c9, c10, c11 = s.key[4], s.key[5], s.key[6], s.key[7]
_, c13, c14, c15 = s.counter, s.nonce[0], s.nonce[1], s.nonce[2]
)
// Three quarters of the first round don't depend on the counter, so we can
// calculate them here, and reuse them for multiple blocks in the loop, and
// for future XORKeyStream invocations.
if !s.precompDone {
s.p1, s.p5, s.p9, s.p13 = quarterRound(c1, c5, c9, c13)
s.p2, s.p6, s.p10, s.p14 = quarterRound(c2, c6, c10, c14)
s.p3, s.p7, s.p11, s.p15 = quarterRound(c3, c7, c11, c15)
s.precompDone = true
}
// A condition of len(src) > 0 would be sufficient, but this also
// acts as a bounds check elimination hint.
for len(src) >= 64 && len(dst) >= 64 {
// The remainder of the first column round.
fcr0, fcr4, fcr8, fcr12 := quarterRound(c0, c4, c8, s.counter)
// The second diagonal round.
x0, x5, x10, x15 := quarterRound(fcr0, s.p5, s.p10, s.p15)
x1, x6, x11, x12 := quarterRound(s.p1, s.p6, s.p11, fcr12)
x2, x7, x8, x13 := quarterRound(s.p2, s.p7, fcr8, s.p13)
x3, x4, x9, x14 := quarterRound(s.p3, fcr4, s.p9, s.p14)
// The remaining 18 rounds.
for i := 0; i < 9; i++ {
// Column round.
x0, x4, x8, x12 = quarterRound(x0, x4, x8, x12)
x1, x5, x9, x13 = quarterRound(x1, x5, x9, x13)
x2, x6, x10, x14 = quarterRound(x2, x6, x10, x14)
x3, x7, x11, x15 = quarterRound(x3, x7, x11, x15)
// Diagonal round.
x0, x5, x10, x15 = quarterRound(x0, x5, x10, x15)
x1, x6, x11, x12 = quarterRound(x1, x6, x11, x12)
x2, x7, x8, x13 = quarterRound(x2, x7, x8, x13)
x3, x4, x9, x14 = quarterRound(x3, x4, x9, x14)
}
// Add back the initial state to generate the key stream, then
// XOR the key stream with the source and write out the result.
addXor(dst[0:4], src[0:4], x0, c0)
addXor(dst[4:8], src[4:8], x1, c1)
addXor(dst[8:12], src[8:12], x2, c2)
addXor(dst[12:16], src[12:16], x3, c3)
addXor(dst[16:20], src[16:20], x4, c4)
addXor(dst[20:24], src[20:24], x5, c5)
addXor(dst[24:28], src[24:28], x6, c6)
addXor(dst[28:32], src[28:32], x7, c7)
addXor(dst[32:36], src[32:36], x8, c8)
addXor(dst[36:40], src[36:40], x9, c9)
addXor(dst[40:44], src[40:44], x10, c10)
addXor(dst[44:48], src[44:48], x11, c11)
addXor(dst[48:52], src[48:52], x12, s.counter)
addXor(dst[52:56], src[52:56], x13, c13)
addXor(dst[56:60], src[56:60], x14, c14)
addXor(dst[60:64], src[60:64], x15, c15)
s.counter += 1
src, dst = src[blockSize:], dst[blockSize:]
}
}
// HChaCha20 uses the ChaCha20 core to generate a derived key from a 32 bytes
// key and a 16 bytes nonce. It returns an error if key or nonce have any other
// length. It is used as part of the XChaCha20 construction.
func HChaCha20(key, nonce []byte) ([]byte, error) {
// This function is split into a wrapper so that the slice allocation will
// be inlined, and depending on how the caller uses the return value, won't
// escape to the heap.
out := make([]byte, 32)
return hChaCha20(out, key, nonce)
}
func hChaCha20(out, key, nonce []byte) ([]byte, error) {
if len(key) != KeySize {
return nil, errors.New("chacha20: wrong HChaCha20 key size")
}
if len(nonce) != 16 {
return nil, errors.New("chacha20: wrong HChaCha20 nonce size")
}
x0, x1, x2, x3 := j0, j1, j2, j3
x4 := binary.LittleEndian.Uint32(key[0:4])
x5 := binary.LittleEndian.Uint32(key[4:8])
x6 := binary.LittleEndian.Uint32(key[8:12])
x7 := binary.LittleEndian.Uint32(key[12:16])
x8 := binary.LittleEndian.Uint32(key[16:20])
x9 := binary.LittleEndian.Uint32(key[20:24])
x10 := binary.LittleEndian.Uint32(key[24:28])
x11 := binary.LittleEndian.Uint32(key[28:32])
x12 := binary.LittleEndian.Uint32(nonce[0:4])
x13 := binary.LittleEndian.Uint32(nonce[4:8])
x14 := binary.LittleEndian.Uint32(nonce[8:12])
x15 := binary.LittleEndian.Uint32(nonce[12:16])
for i := 0; i < 10; i++ {
// Diagonal round.
x0, x4, x8, x12 = quarterRound(x0, x4, x8, x12)
x1, x5, x9, x13 = quarterRound(x1, x5, x9, x13)
x2, x6, x10, x14 = quarterRound(x2, x6, x10, x14)
x3, x7, x11, x15 = quarterRound(x3, x7, x11, x15)
// Column round.
x0, x5, x10, x15 = quarterRound(x0, x5, x10, x15)
x1, x6, x11, x12 = quarterRound(x1, x6, x11, x12)
x2, x7, x8, x13 = quarterRound(x2, x7, x8, x13)
x3, x4, x9, x14 = quarterRound(x3, x4, x9, x14)
}
_ = out[31] // bounds check elimination hint
binary.LittleEndian.PutUint32(out[0:4], x0)
binary.LittleEndian.PutUint32(out[4:8], x1)
binary.LittleEndian.PutUint32(out[8:12], x2)
binary.LittleEndian.PutUint32(out[12:16], x3)
binary.LittleEndian.PutUint32(out[16:20], x12)
binary.LittleEndian.PutUint32(out[20:24], x13)
binary.LittleEndian.PutUint32(out[24:28], x14)
binary.LittleEndian.PutUint32(out[28:32], x15)
return out, nil
}
================================================
FILE: gossh/crypto/chacha20/chacha_noasm.go
================================================
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build (!arm64 && !s390x && !ppc64le) || !gc || purego
package chacha20
const bufSize = blockSize
func (s *Cipher) xorKeyStreamBlocks(dst, src []byte) {
s.xorKeyStreamBlocksGeneric(dst, src)
}
================================================
FILE: gossh/crypto/chacha20/chacha_ppc64le.go
================================================
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gc && !purego
package chacha20
const bufSize = 256
//go:noescape
func chaCha20_ctr32_vsx(out, inp *byte, len int, key *[8]uint32, counter *uint32)
func (c *Cipher) xorKeyStreamBlocks(dst, src []byte) {
chaCha20_ctr32_vsx(&dst[0], &src[0], len(src), &c.key, &c.counter)
}
================================================
FILE: gossh/crypto/chacha20/chacha_ppc64le.s
================================================
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Based on CRYPTOGAMS code with the following comment:
// # ====================================================================
// # Written by Andy Polyakov for the OpenSSL
// # project. The module is, however, dual licensed under OpenSSL and
// # CRYPTOGAMS licenses depending on where you obtain it. For further
// # details see http://www.openssl.org/~appro/cryptogams/.
// # ====================================================================
// Code for the perl script that generates the ppc64 assembler
// can be found in the cryptogams repository at the link below. It is based on
// the original from openssl.
// https://github.com/dot-asm/cryptogams/commit/a60f5b50ed908e91
// The differences in this and the original implementation are
// due to the calling conventions and initialization of constants.
//go:build gc && !purego
#include "textflag.h"
#define OUT R3
#define INP R4
#define LEN R5
#define KEY R6
#define CNT R7
#define TMP R15
#define CONSTBASE R16
#define BLOCKS R17
DATA consts<>+0x00(SB)/8, $0x3320646e61707865
DATA consts<>+0x08(SB)/8, $0x6b20657479622d32
DATA consts<>+0x10(SB)/8, $0x0000000000000001
DATA consts<>+0x18(SB)/8, $0x0000000000000000
DATA consts<>+0x20(SB)/8, $0x0000000000000004
DATA consts<>+0x28(SB)/8, $0x0000000000000000
DATA consts<>+0x30(SB)/8, $0x0a0b08090e0f0c0d
DATA consts<>+0x38(SB)/8, $0x0203000106070405
DATA consts<>+0x40(SB)/8, $0x090a0b080d0e0f0c
DATA consts<>+0x48(SB)/8, $0x0102030005060704
DATA consts<>+0x50(SB)/8, $0x6170786561707865
DATA consts<>+0x58(SB)/8, $0x6170786561707865
DATA consts<>+0x60(SB)/8, $0x3320646e3320646e
DATA consts<>+0x68(SB)/8, $0x3320646e3320646e
DATA consts<>+0x70(SB)/8, $0x79622d3279622d32
DATA consts<>+0x78(SB)/8, $0x79622d3279622d32
DATA consts<>+0x80(SB)/8, $0x6b2065746b206574
DATA consts<>+0x88(SB)/8, $0x6b2065746b206574
DATA consts<>+0x90(SB)/8, $0x0000000100000000
DATA consts<>+0x98(SB)/8, $0x0000000300000002
GLOBL consts<>(SB), RODATA, $0xa0
//func chaCha20_ctr32_vsx(out, inp *byte, len int, key *[8]uint32, counter *uint32)
TEXT ·chaCha20_ctr32_vsx(SB),NOSPLIT,$64-40
MOVD out+0(FP), OUT
MOVD inp+8(FP), INP
MOVD len+16(FP), LEN
MOVD key+24(FP), KEY
MOVD counter+32(FP), CNT
// Addressing for constants
MOVD $consts<>+0x00(SB), CONSTBASE
MOVD $16, R8
MOVD $32, R9
MOVD $48, R10
MOVD $64, R11
SRD $6, LEN, BLOCKS
// V16
LXVW4X (CONSTBASE)(R0), VS48
ADD $80,CONSTBASE
// Load key into V17,V18
LXVW4X (KEY)(R0), VS49
LXVW4X (KEY)(R8), VS50
// Load CNT, NONCE into V19
LXVW4X (CNT)(R0), VS51
// Clear V27
VXOR V27, V27, V27
// V28
LXVW4X (CONSTBASE)(R11), VS60
// splat slot from V19 -> V26
VSPLTW $0, V19, V26
VSLDOI $4, V19, V27, V19
VSLDOI $12, V27, V19, V19
VADDUWM V26, V28, V26
MOVD $10, R14
MOVD R14, CTR
loop_outer_vsx:
// V0, V1, V2, V3
LXVW4X (R0)(CONSTBASE), VS32
LXVW4X (R8)(CONSTBASE), VS33
LXVW4X (R9)(CONSTBASE), VS34
LXVW4X (R10)(CONSTBASE), VS35
// splat values from V17, V18 into V4-V11
VSPLTW $0, V17, V4
VSPLTW $1, V17, V5
VSPLTW $2, V17, V6
VSPLTW $3, V17, V7
VSPLTW $0, V18, V8
VSPLTW $1, V18, V9
VSPLTW $2, V18, V10
VSPLTW $3, V18, V11
// VOR
VOR V26, V26, V12
// splat values from V19 -> V13, V14, V15
VSPLTW $1, V19, V13
VSPLTW $2, V19, V14
VSPLTW $3, V19, V15
// splat const values
VSPLTISW $-16, V27
VSPLTISW $12, V28
VSPLTISW $8, V29
VSPLTISW $7, V30
loop_vsx:
VADDUWM V0, V4, V0
VADDUWM V1, V5, V1
VADDUWM V2, V6, V2
VADDUWM V3, V7, V3
VXOR V12, V0, V12
VXOR V13, V1, V13
VXOR V14, V2, V14
VXOR V15, V3, V15
VRLW V12, V27, V12
VRLW V13, V27, V13
VRLW V14, V27, V14
VRLW V15, V27, V15
VADDUWM V8, V12, V8
VADDUWM V9, V13, V9
VADDUWM V10, V14, V10
VADDUWM V11, V15, V11
VXOR V4, V8, V4
VXOR V5, V9, V5
VXOR V6, V10, V6
VXOR V7, V11, V7
VRLW V4, V28, V4
VRLW V5, V28, V5
VRLW V6, V28, V6
VRLW V7, V28, V7
VADDUWM V0, V4, V0
VADDUWM V1, V5, V1
VADDUWM V2, V6, V2
VADDUWM V3, V7, V3
VXOR V12, V0, V12
VXOR V13, V1, V13
VXOR V14, V2, V14
VXOR V15, V3, V15
VRLW V12, V29, V12
VRLW V13, V29, V13
VRLW V14, V29, V14
VRLW V15, V29, V15
VADDUWM V8, V12, V8
VADDUWM V9, V13, V9
VADDUWM V10, V14, V10
VADDUWM V11, V15, V11
VXOR V4, V8, V4
VXOR V5, V9, V5
VXOR V6, V10, V6
VXOR V7, V11, V7
VRLW V4, V30, V4
VRLW V5, V30, V5
VRLW V6, V30, V6
VRLW V7, V30, V7
VADDUWM V0, V5, V0
VADDUWM V1, V6, V1
VADDUWM V2, V7, V2
VADDUWM V3, V4, V3
VXOR V15, V0, V15
VXOR V12, V1, V12
VXOR V13, V2, V13
VXOR V14, V3, V14
VRLW V15, V27, V15
VRLW V12, V27, V12
VRLW V13, V27, V13
VRLW V14, V27, V14
VADDUWM V10, V15, V10
VADDUWM V11, V12, V11
VADDUWM V8, V13, V8
VADDUWM V9, V14, V9
VXOR V5, V10, V5
VXOR V6, V11, V6
VXOR V7, V8, V7
VXOR V4, V9, V4
VRLW V5, V28, V5
VRLW V6, V28, V6
VRLW V7, V28, V7
VRLW V4, V28, V4
VADDUWM V0, V5, V0
VADDUWM V1, V6, V1
VADDUWM V2, V7, V2
VADDUWM V3, V4, V3
VXOR V15, V0, V15
VXOR V12, V1, V12
VXOR V13, V2, V13
VXOR V14, V3, V14
VRLW V15, V29, V15
VRLW V12, V29, V12
VRLW V13, V29, V13
VRLW V14, V29, V14
VADDUWM V10, V15, V10
VADDUWM V11, V12, V11
VADDUWM V8, V13, V8
VADDUWM V9, V14, V9
VXOR V5, V10, V5
VXOR V6, V11, V6
VXOR V7, V8, V7
VXOR V4, V9, V4
VRLW V5, V30, V5
VRLW V6, V30, V6
VRLW V7, V30, V7
VRLW V4, V30, V4
BC 16, LT, loop_vsx
VADDUWM V12, V26, V12
WORD $0x13600F8C // VMRGEW V0, V1, V27
WORD $0x13821F8C // VMRGEW V2, V3, V28
WORD $0x10000E8C // VMRGOW V0, V1, V0
WORD $0x10421E8C // VMRGOW V2, V3, V2
WORD $0x13A42F8C // VMRGEW V4, V5, V29
WORD $0x13C63F8C // VMRGEW V6, V7, V30
XXPERMDI VS32, VS34, $0, VS33
XXPERMDI VS32, VS34, $3, VS35
XXPERMDI VS59, VS60, $0, VS32
XXPERMDI VS59, VS60, $3, VS34
WORD $0x10842E8C // VMRGOW V4, V5, V4
WORD $0x10C63E8C // VMRGOW V6, V7, V6
WORD $0x13684F8C // VMRGEW V8, V9, V27
WORD $0x138A5F8C // VMRGEW V10, V11, V28
XXPERMDI VS36, VS38, $0, VS37
XXPERMDI VS36, VS38, $3, VS39
XXPERMDI VS61, VS62, $0, VS36
XXPERMDI VS61, VS62, $3, VS38
WORD $0x11084E8C // VMRGOW V8, V9, V8
WORD $0x114A5E8C // VMRGOW V10, V11, V10
WORD $0x13AC6F8C // VMRGEW V12, V13, V29
WORD $0x13CE7F8C // VMRGEW V14, V15, V30
XXPERMDI VS40, VS42, $0, VS41
XXPERMDI VS40, VS42, $3, VS43
XXPERMDI VS59, VS60, $0, VS40
XXPERMDI VS59, VS60, $3, VS42
WORD $0x118C6E8C // VMRGOW V12, V13, V12
WORD $0x11CE7E8C // VMRGOW V14, V15, V14
VSPLTISW $4, V27
VADDUWM V26, V27, V26
XXPERMDI VS44, VS46, $0, VS45
XXPERMDI VS44, VS46, $3, VS47
XXPERMDI VS61, VS62, $0, VS44
XXPERMDI VS61, VS62, $3, VS46
VADDUWM V0, V16, V0
VADDUWM V4, V17, V4
VADDUWM V8, V18, V8
VADDUWM V12, V19, V12
CMPU LEN, $64
BLT tail_vsx
// Bottom of loop
LXVW4X (INP)(R0), VS59
LXVW4X (INP)(R8), VS60
LXVW4X (INP)(R9), VS61
LXVW4X (INP)(R10), VS62
VXOR V27, V0, V27
VXOR V28, V4, V28
VXOR V29, V8, V29
VXOR V30, V12, V30
STXVW4X VS59, (OUT)(R0)
STXVW4X VS60, (OUT)(R8)
ADD $64, INP
STXVW4X VS61, (OUT)(R9)
ADD $-64, LEN
STXVW4X VS62, (OUT)(R10)
ADD $64, OUT
BEQ done_vsx
VADDUWM V1, V16, V0
VADDUWM V5, V17, V4
VADDUWM V9, V18, V8
VADDUWM V13, V19, V12
CMPU LEN, $64
BLT tail_vsx
LXVW4X (INP)(R0), VS59
LXVW4X (INP)(R8), VS60
LXVW4X (INP)(R9), VS61
LXVW4X (INP)(R10), VS62
VXOR V27, V0, V27
VXOR V28, V4, V28
VXOR V29, V8, V29
VXOR V30, V12, V30
STXVW4X VS59, (OUT)(R0)
STXVW4X VS60, (OUT)(R8)
ADD $64, INP
STXVW4X VS61, (OUT)(R9)
ADD $-64, LEN
STXVW4X VS62, (OUT)(V10)
ADD $64, OUT
BEQ done_vsx
VADDUWM V2, V16, V0
VADDUWM V6, V17, V4
VADDUWM V10, V18, V8
VADDUWM V14, V19, V12
CMPU LEN, $64
BLT tail_vsx
LXVW4X (INP)(R0), VS59
LXVW4X (INP)(R8), VS60
LXVW4X (INP)(R9), VS61
LXVW4X (INP)(R10), VS62
VXOR V27, V0, V27
VXOR V28, V4, V28
VXOR V29, V8, V29
VXOR V30, V12, V30
STXVW4X VS59, (OUT)(R0)
STXVW4X VS60, (OUT)(R8)
ADD $64, INP
STXVW4X VS61, (OUT)(R9)
ADD $-64, LEN
STXVW4X VS62, (OUT)(R10)
ADD $64, OUT
BEQ done_vsx
VADDUWM V3, V16, V0
VADDUWM V7, V17, V4
VADDUWM V11, V18, V8
VADDUWM V15, V19, V12
CMPU LEN, $64
BLT tail_vsx
LXVW4X (INP)(R0), VS59
LXVW4X (INP)(R8), VS60
LXVW4X (INP)(R9), VS61
LXVW4X (INP)(R10), VS62
VXOR V27, V0, V27
VXOR V28, V4, V28
VXOR V29, V8, V29
VXOR V30, V12, V30
STXVW4X VS59, (OUT)(R0)
STXVW4X VS60, (OUT)(R8)
ADD $64, INP
STXVW4X VS61, (OUT)(R9)
ADD $-64, LEN
STXVW4X VS62, (OUT)(R10)
ADD $64, OUT
MOVD $10, R14
MOVD R14, CTR
BNE loop_outer_vsx
done_vsx:
// Increment counter by number of 64 byte blocks
MOVD (CNT), R14
ADD BLOCKS, R14
MOVD R14, (CNT)
RET
tail_vsx:
ADD $32, R1, R11
MOVD LEN, CTR
// Save values on stack to copy from
STXVW4X VS32, (R11)(R0)
STXVW4X VS36, (R11)(R8)
STXVW4X VS40, (R11)(R9)
STXVW4X VS44, (R11)(R10)
ADD $-1, R11, R12
ADD $-1, INP
ADD $-1, OUT
looptail_vsx:
// Copying the result to OUT
// in bytes.
MOVBZU 1(R12), KEY
MOVBZU 1(INP), TMP
XOR KEY, TMP, KEY
MOVBU KEY, 1(OUT)
BC 16, LT, looptail_vsx
// Clear the stack values
STXVW4X VS48, (R11)(R0)
STXVW4X VS48, (R11)(R8)
STXVW4X VS48, (R11)(R9)
STXVW4X VS48, (R11)(R10)
BR done_vsx
================================================
FILE: gossh/crypto/chacha20/chacha_s390x.go
================================================
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gc && !purego
package chacha20
import "gossh/sys/cpu"
var haveAsm = cpu.S390X.HasVX
const bufSize = 256
// xorKeyStreamVX is an assembly implementation of XORKeyStream. It must only
// be called when the vector facility is available. Implementation in asm_s390x.s.
//
//go:noescape
func xorKeyStreamVX(dst, src []byte, key *[8]uint32, nonce *[3]uint32, counter *uint32)
func (c *Cipher) xorKeyStreamBlocks(dst, src []byte) {
if cpu.S390X.HasVX {
xorKeyStreamVX(dst, src, &c.key, &c.nonce, &c.counter)
} else {
c.xorKeyStreamBlocksGeneric(dst, src)
}
}
================================================
FILE: gossh/crypto/chacha20/chacha_s390x.s
================================================
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gc && !purego
#include "go_asm.h"
#include "textflag.h"
// This is an implementation of the ChaCha20 encryption algorithm as
// specified in RFC 7539. It uses vector instructions to compute
// 4 keystream blocks in parallel (256 bytes) which are then XORed
// with the bytes in the input slice.
GLOBL ·constants<>(SB), RODATA|NOPTR, $32
// BSWAP: swap bytes in each 4-byte element
DATA ·constants<>+0x00(SB)/4, $0x03020100
DATA ·constants<>+0x04(SB)/4, $0x07060504
DATA ·constants<>+0x08(SB)/4, $0x0b0a0908
DATA ·constants<>+0x0c(SB)/4, $0x0f0e0d0c
// J0: [j0, j1, j2, j3]
DATA ·constants<>+0x10(SB)/4, $0x61707865
DATA ·constants<>+0x14(SB)/4, $0x3320646e
DATA ·constants<>+0x18(SB)/4, $0x79622d32
DATA ·constants<>+0x1c(SB)/4, $0x6b206574
#define BSWAP V5
#define J0 V6
#define KEY0 V7
#define KEY1 V8
#define NONCE V9
#define CTR V10
#define M0 V11
#define M1 V12
#define M2 V13
#define M3 V14
#define INC V15
#define X0 V16
#define X1 V17
#define X2 V18
#define X3 V19
#define X4 V20
#define X5 V21
#define X6 V22
#define X7 V23
#define X8 V24
#define X9 V25
#define X10 V26
#define X11 V27
#define X12 V28
#define X13 V29
#define X14 V30
#define X15 V31
#define NUM_ROUNDS 20
#define ROUND4(a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3, d0, d1, d2, d3) \
VAF a1, a0, a0 \
VAF b1, b0, b0 \
VAF c1, c0, c0 \
VAF d1, d0, d0 \
VX a0, a2, a2 \
VX b0, b2, b2 \
VX c0, c2, c2 \
VX d0, d2, d2 \
VERLLF $16, a2, a2 \
VERLLF $16, b2, b2 \
VERLLF $16, c2, c2 \
VERLLF $16, d2, d2 \
VAF a2, a3, a3 \
VAF b2, b3, b3 \
VAF c2, c3, c3 \
VAF d2, d3, d3 \
VX a3, a1, a1 \
VX b3, b1, b1 \
VX c3, c1, c1 \
VX d3, d1, d1 \
VERLLF $12, a1, a1 \
VERLLF $12, b1, b1 \
VERLLF $12, c1, c1 \
VERLLF $12, d1, d1 \
VAF a1, a0, a0 \
VAF b1, b0, b0 \
VAF c1, c0, c0 \
VAF d1, d0, d0 \
VX a0, a2, a2 \
VX b0, b2, b2 \
VX c0, c2, c2 \
VX d0, d2, d2 \
VERLLF $8, a2, a2 \
VERLLF $8, b2, b2 \
VERLLF $8, c2, c2 \
VERLLF $8, d2, d2 \
VAF a2, a3, a3 \
VAF b2, b3, b3 \
VAF c2, c3, c3 \
VAF d2, d3, d3 \
VX a3, a1, a1 \
VX b3, b1, b1 \
VX c3, c1, c1 \
VX d3, d1, d1 \
VERLLF $7, a1, a1 \
VERLLF $7, b1, b1 \
VERLLF $7, c1, c1 \
VERLLF $7, d1, d1
#define PERMUTE(mask, v0, v1, v2, v3) \
VPERM v0, v0, mask, v0 \
VPERM v1, v1, mask, v1 \
VPERM v2, v2, mask, v2 \
VPERM v3, v3, mask, v3
#define ADDV(x, v0, v1, v2, v3) \
VAF x, v0, v0 \
VAF x, v1, v1 \
VAF x, v2, v2 \
VAF x, v3, v3
#define XORV(off, dst, src, v0, v1, v2, v3) \
VLM off(src), M0, M3 \
PERMUTE(BSWAP, v0, v1, v2, v3) \
VX v0, M0, M0 \
VX v1, M1, M1 \
VX v2, M2, M2 \
VX v3, M3, M3 \
VSTM M0, M3, off(dst)
#define SHUFFLE(a, b, c, d, t, u, v, w) \
VMRHF a, c, t \ // t = {a[0], c[0], a[1], c[1]}
VMRHF b, d, u \ // u = {b[0], d[0], b[1], d[1]}
VMRLF a, c, v \ // v = {a[2], c[2], a[3], c[3]}
VMRLF b, d, w \ // w = {b[2], d[2], b[3], d[3]}
VMRHF t, u, a \ // a = {a[0], b[0], c[0], d[0]}
VMRLF t, u, b \ // b = {a[1], b[1], c[1], d[1]}
VMRHF v, w, c \ // c = {a[2], b[2], c[2], d[2]}
VMRLF v, w, d // d = {a[3], b[3], c[3], d[3]}
// func xorKeyStreamVX(dst, src []byte, key *[8]uint32, nonce *[3]uint32, counter *uint32)
TEXT ·xorKeyStreamVX(SB), NOSPLIT, $0
MOVD $·constants<>(SB), R1
MOVD dst+0(FP), R2 // R2=&dst[0]
LMG src+24(FP), R3, R4 // R3=&src[0] R4=len(src)
MOVD key+48(FP), R5 // R5=key
MOVD nonce+56(FP), R6 // R6=nonce
MOVD counter+64(FP), R7 // R7=counter
// load BSWAP and J0
VLM (R1), BSWAP, J0
// setup
MOVD $95, R0
VLM (R5), KEY0, KEY1
VLL R0, (R6), NONCE
VZERO M0
VLEIB $7, $32, M0
VSRLB M0, NONCE, NONCE
// initialize counter values
VLREPF (R7), CTR
VZERO INC
VLEIF $1, $1, INC
VLEIF $2, $2, INC
VLEIF $3, $3, INC
VAF INC, CTR, CTR
VREPIF $4, INC
chacha:
VREPF $0, J0, X0
VREPF $1, J0, X1
VREPF $2, J0, X2
VREPF $3, J0, X3
VREPF $0, KEY0, X4
VREPF $1, KEY0, X5
VREPF $2, KEY0, X6
VREPF $3, KEY0, X7
VREPF $0, KEY1, X8
VREPF $1, KEY1, X9
VREPF $2, KEY1, X10
VREPF $3, KEY1, X11
VLR CTR, X12
VREPF $1, NONCE, X13
VREPF $2, NONCE, X14
VREPF $3, NONCE, X15
MOVD $(NUM_ROUNDS/2), R1
loop:
ROUND4(X0, X4, X12, X8, X1, X5, X13, X9, X2, X6, X14, X10, X3, X7, X15, X11)
ROUND4(X0, X5, X15, X10, X1, X6, X12, X11, X2, X7, X13, X8, X3, X4, X14, X9)
ADD $-1, R1
BNE loop
// decrement length
ADD $-256, R4
// rearrange vectors
SHUFFLE(X0, X1, X2, X3, M0, M1, M2, M3)
ADDV(J0, X0, X1, X2, X3)
SHUFFLE(X4, X5, X6, X7, M0, M1, M2, M3)
ADDV(KEY0, X4, X5, X6, X7)
SHUFFLE(X8, X9, X10, X11, M0, M1, M2, M3)
ADDV(KEY1, X8, X9, X10, X11)
VAF CTR, X12, X12
SHUFFLE(X12, X13, X14, X15, M0, M1, M2, M3)
ADDV(NONCE, X12, X13, X14, X15)
// increment counters
VAF INC, CTR, CTR
// xor keystream with plaintext
XORV(0*64, R2, R3, X0, X4, X8, X12)
XORV(1*64, R2, R3, X1, X5, X9, X13)
XORV(2*64, R2, R3, X2, X6, X10, X14)
XORV(3*64, R2, R3, X3, X7, X11, X15)
// increment pointers
MOVD $256(R2), R2
MOVD $256(R3), R3
CMPBNE R4, $0, chacha
VSTEF $0, CTR, (R7)
RET
================================================
FILE: gossh/crypto/chacha20/xor.go
================================================
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found src the LICENSE file.
package chacha20
import "runtime"
// Platforms that have fast unaligned 32-bit little endian accesses.
const unaligned = runtime.GOARCH == "386" ||
runtime.GOARCH == "amd64" ||
runtime.GOARCH == "arm64" ||
runtime.GOARCH == "ppc64le" ||
runtime.GOARCH == "s390x"
// addXor reads a little endian uint32 from src, XORs it with (a + b) and
// places the result in little endian byte order in dst.
func addXor(dst, src []byte, a, b uint32) {
_, _ = src[3], dst[3] // bounds check elimination hint
if unaligned {
// The compiler should optimize this code into
// 32-bit unaligned little endian loads and stores.
// TODO: delete once the compiler does a reliably
// good job with the generic code below.
// See issue #25111 for more details.
v := uint32(src[0])
v |= uint32(src[1]) << 8
v |= uint32(src[2]) << 16
v |= uint32(src[3]) << 24
v ^= a + b
dst[0] = byte(v)
dst[1] = byte(v >> 8)
dst[2] = byte(v >> 16)
dst[3] = byte(v >> 24)
} else {
a += b
dst[0] = src[0] ^ byte(a)
dst[1] = src[1] ^ byte(a>>8)
dst[2] = src[2] ^ byte(a>>16)
dst[3] = src[3] ^ byte(a>>24)
}
}
================================================
FILE: gossh/crypto/chacha20poly1305/chacha20poly1305.go
================================================
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package chacha20poly1305 implements the ChaCha20-Poly1305 AEAD and its
// extended nonce variant XChaCha20-Poly1305, as specified in RFC 8439 and
// draft-irtf-cfrg-xchacha-01.
package chacha20poly1305 // import "golang.org/x/crypto/chacha20poly1305"
import (
"crypto/cipher"
"errors"
)
const (
// KeySize is the size of the key used by this AEAD, in bytes.
KeySize = 32
// NonceSize is the size of the nonce used with the standard variant of this
// AEAD, in bytes.
//
// Note that this is too short to be safely generated at random if the same
// key is reused more than 2³² times.
NonceSize = 12
// NonceSizeX is the size of the nonce used with the XChaCha20-Poly1305
// variant of this AEAD, in bytes.
NonceSizeX = 24
// Overhead is the size of the Poly1305 authentication tag, and the
// difference between a ciphertext length and its plaintext.
Overhead = 16
)
type chacha20poly1305 struct {
key [KeySize]byte
}
// New returns a ChaCha20-Poly1305 AEAD that uses the given 256-bit key.
func New(key []byte) (cipher.AEAD, error) {
if len(key) != KeySize {
return nil, errors.New("chacha20poly1305: bad key length")
}
ret := new(chacha20poly1305)
copy(ret.key[:], key)
return ret, nil
}
func (c *chacha20poly1305) NonceSize() int {
return NonceSize
}
func (c *chacha20poly1305) Overhead() int {
return Overhead
}
func (c *chacha20poly1305) Seal(dst, nonce, plaintext, additionalData []byte) []byte {
if len(nonce) != NonceSize {
panic("chacha20poly1305: bad nonce length passed to Seal")
}
if uint64(len(plaintext)) > (1<<38)-64 {
panic("chacha20poly1305: plaintext too large")
}
return c.seal(dst, nonce, plaintext, additionalData)
}
var errOpen = errors.New("chacha20poly1305: message authentication failed")
func (c *chacha20poly1305) Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) {
if len(nonce) != NonceSize {
panic("chacha20poly1305: bad nonce length passed to Open")
}
if len(ciphertext) < 16 {
return nil, errOpen
}
if uint64(len(ciphertext)) > (1<<38)-48 {
panic("chacha20poly1305: ciphertext too large")
}
return c.open(dst, nonce, ciphertext, additionalData)
}
// sliceForAppend takes a slice and a requested number of bytes. It returns a
// slice with the contents of the given slice followed by that many bytes and a
// second slice that aliases into it and contains only the extra bytes. If the
// original slice has sufficient capacity then no allocation is performed.
func sliceForAppend(in []byte, n int) (head, tail []byte) {
if total := len(in) + n; cap(in) >= total {
head = in[:total]
} else {
head = make([]byte, total)
copy(head, in)
}
tail = head[len(in):]
return
}
================================================
FILE: gossh/crypto/chacha20poly1305/chacha20poly1305_amd64.go
================================================
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gc && !purego
package chacha20poly1305
import (
"encoding/binary"
"gossh/crypto/internal/alias"
"gossh/sys/cpu"
)
//go:noescape
func chacha20Poly1305Open(dst []byte, key []uint32, src, ad []byte) bool
//go:noescape
func chacha20Poly1305Seal(dst []byte, key []uint32, src, ad []byte)
var (
useAVX2 = cpu.X86.HasAVX2 && cpu.X86.HasBMI2
)
// setupState writes a ChaCha20 input matrix to state. See
// https://tools.ietf.org/html/rfc7539#section-2.3.
func setupState(state *[16]uint32, key *[32]byte, nonce []byte) {
state[0] = 0x61707865
state[1] = 0x3320646e
state[2] = 0x79622d32
state[3] = 0x6b206574
state[4] = binary.LittleEndian.Uint32(key[0:4])
state[5] = binary.LittleEndian.Uint32(key[4:8])
state[6] = binary.LittleEndian.Uint32(key[8:12])
state[7] = binary.LittleEndian.Uint32(key[12:16])
state[8] = binary.LittleEndian.Uint32(key[16:20])
state[9] = binary.LittleEndian.Uint32(key[20:24])
state[10] = binary.LittleEndian.Uint32(key[24:28])
state[11] = binary.LittleEndian.Uint32(key[28:32])
state[12] = 0
state[13] = binary.LittleEndian.Uint32(nonce[0:4])
state[14] = binary.LittleEndian.Uint32(nonce[4:8])
state[15] = binary.LittleEndian.Uint32(nonce[8:12])
}
func (c *chacha20poly1305) seal(dst, nonce, plaintext, additionalData []byte) []byte {
if !cpu.X86.HasSSSE3 {
return c.sealGeneric(dst, nonce, plaintext, additionalData)
}
var state [16]uint32
setupState(&state, &c.key, nonce)
ret, out := sliceForAppend(dst, len(plaintext)+16)
if alias.InexactOverlap(out, plaintext) {
panic("chacha20poly1305: invalid buffer overlap")
}
chacha20Poly1305Seal(out[:], state[:], plaintext, additionalData)
return ret
}
func (c *chacha20poly1305) open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) {
if !cpu.X86.HasSSSE3 {
return c.openGeneric(dst, nonce, ciphertext, additionalData)
}
var state [16]uint32
setupState(&state, &c.key, nonce)
ciphertext = ciphertext[:len(ciphertext)-16]
ret, out := sliceForAppend(dst, len(ciphertext))
if alias.InexactOverlap(out, ciphertext) {
panic("chacha20poly1305: invalid buffer overlap")
}
if !chacha20Poly1305Open(out, state[:], ciphertext, additionalData) {
for i := range out {
out[i] = 0
}
return nil, errOpen
}
return ret, nil
}
================================================
FILE: gossh/crypto/chacha20poly1305/chacha20poly1305_amd64.s
================================================
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file was originally from https://golang.org/cl/24717 by Vlad Krasnov of CloudFlare.
//go:build gc && !purego
#include "textflag.h"
// General register allocation
#define oup DI
#define inp SI
#define inl BX
#define adp CX // free to reuse, after we hash the additional data
#define keyp R8 // free to reuse, when we copy the key to stack
#define itr2 R9 // general iterator
#define itr1 CX // general iterator
#define acc0 R10
#define acc1 R11
#define acc2 R12
#define t0 R13
#define t1 R14
#define t2 R15
#define t3 R8
// Register and stack allocation for the SSE code
#define rStore (0*16)(BP)
#define sStore (1*16)(BP)
#define state1Store (2*16)(BP)
#define state2Store (3*16)(BP)
#define tmpStore (4*16)(BP)
#define ctr0Store (5*16)(BP)
#define ctr1Store (6*16)(BP)
#define ctr2Store (7*16)(BP)
#define ctr3Store (8*16)(BP)
#define A0 X0
#define A1 X1
#define A2 X2
#define B0 X3
#define B1 X4
#define B2 X5
#define C0 X6
#define C1 X7
#define C2 X8
#define D0 X9
#define D1 X10
#define D2 X11
#define T0 X12
#define T1 X13
#define T2 X14
#define T3 X15
#define A3 T0
#define B3 T1
#define C3 T2
#define D3 T3
// Register and stack allocation for the AVX2 code
#define rsStoreAVX2 (0*32)(BP)
#define state1StoreAVX2 (1*32)(BP)
#define state2StoreAVX2 (2*32)(BP)
#define ctr0StoreAVX2 (3*32)(BP)
#define ctr1StoreAVX2 (4*32)(BP)
#define ctr2StoreAVX2 (5*32)(BP)
#define ctr3StoreAVX2 (6*32)(BP)
#define tmpStoreAVX2 (7*32)(BP) // 256 bytes on stack
#define AA0 Y0
#define AA1 Y5
#define AA2 Y6
#define AA3 Y7
#define BB0 Y14
#define BB1 Y9
#define BB2 Y10
#define BB3 Y11
#define CC0 Y12
#define CC1 Y13
#define CC2 Y8
#define CC3 Y15
#define DD0 Y4
#define DD1 Y1
#define DD2 Y2
#define DD3 Y3
#define TT0 DD3
#define TT1 AA3
#define TT2 BB3
#define TT3 CC3
// ChaCha20 constants
DATA ·chacha20Constants<>+0x00(SB)/4, $0x61707865
DATA ·chacha20Constants<>+0x04(SB)/4, $0x3320646e
DATA ·chacha20Constants<>+0x08(SB)/4, $0x79622d32
DATA ·chacha20Constants<>+0x0c(SB)/4, $0x6b206574
DATA ·chacha20Constants<>+0x10(SB)/4, $0x61707865
DATA ·chacha20Constants<>+0x14(SB)/4, $0x3320646e
DATA ·chacha20Constants<>+0x18(SB)/4, $0x79622d32
DATA ·chacha20Constants<>+0x1c(SB)/4, $0x6b206574
// <<< 16 with PSHUFB
DATA ·rol16<>+0x00(SB)/8, $0x0504070601000302
DATA ·rol16<>+0x08(SB)/8, $0x0D0C0F0E09080B0A
DATA ·rol16<>+0x10(SB)/8, $0x0504070601000302
DATA ·rol16<>+0x18(SB)/8, $0x0D0C0F0E09080B0A
// <<< 8 with PSHUFB
DATA ·rol8<>+0x00(SB)/8, $0x0605040702010003
DATA ·rol8<>+0x08(SB)/8, $0x0E0D0C0F0A09080B
DATA ·rol8<>+0x10(SB)/8, $0x0605040702010003
DATA ·rol8<>+0x18(SB)/8, $0x0E0D0C0F0A09080B
DATA ·avx2InitMask<>+0x00(SB)/8, $0x0
DATA ·avx2InitMask<>+0x08(SB)/8, $0x0
DATA ·avx2InitMask<>+0x10(SB)/8, $0x1
DATA ·avx2InitMask<>+0x18(SB)/8, $0x0
DATA ·avx2IncMask<>+0x00(SB)/8, $0x2
DATA ·avx2IncMask<>+0x08(SB)/8, $0x0
DATA ·avx2IncMask<>+0x10(SB)/8, $0x2
DATA ·avx2IncMask<>+0x18(SB)/8, $0x0
// Poly1305 key clamp
DATA ·polyClampMask<>+0x00(SB)/8, $0x0FFFFFFC0FFFFFFF
DATA ·polyClampMask<>+0x08(SB)/8, $0x0FFFFFFC0FFFFFFC
DATA ·polyClampMask<>+0x10(SB)/8, $0xFFFFFFFFFFFFFFFF
DATA ·polyClampMask<>+0x18(SB)/8, $0xFFFFFFFFFFFFFFFF
DATA ·sseIncMask<>+0x00(SB)/8, $0x1
DATA ·sseIncMask<>+0x08(SB)/8, $0x0
// To load/store the last < 16 bytes in a buffer
DATA ·andMask<>+0x00(SB)/8, $0x00000000000000ff
DATA ·andMask<>+0x08(SB)/8, $0x0000000000000000
DATA ·andMask<>+0x10(SB)/8, $0x000000000000ffff
DATA ·andMask<>+0x18(SB)/8, $0x0000000000000000
DATA ·andMask<>+0x20(SB)/8, $0x0000000000ffffff
DATA ·andMask<>+0x28(SB)/8, $0x0000000000000000
DATA ·andMask<>+0x30(SB)/8, $0x00000000ffffffff
DATA ·andMask<>+0x38(SB)/8, $0x0000000000000000
DATA ·andMask<>+0x40(SB)/8, $0x000000ffffffffff
DATA ·andMask<>+0x48(SB)/8, $0x0000000000000000
DATA ·andMask<>+0x50(SB)/8, $0x0000ffffffffffff
DATA ·andMask<>+0x58(SB)/8, $0x0000000000000000
DATA ·andMask<>+0x60(SB)/8, $0x00ffffffffffffff
DATA ·andMask<>+0x68(SB)/8, $0x0000000000000000
DATA ·andMask<>+0x70(SB)/8, $0xffffffffffffffff
DATA ·andMask<>+0x78(SB)/8, $0x0000000000000000
DATA ·andMask<>+0x80(SB)/8, $0xffffffffffffffff
DATA ·andMask<>+0x88(SB)/8, $0x00000000000000ff
DATA ·andMask<>+0x90(SB)/8, $0xffffffffffffffff
DATA ·andMask<>+0x98(SB)/8, $0x000000000000ffff
DATA ·andMask<>+0xa0(SB)/8, $0xffffffffffffffff
DATA ·andMask<>+0xa8(SB)/8, $0x0000000000ffffff
DATA ·andMask<>+0xb0(SB)/8, $0xffffffffffffffff
DATA ·andMask<>+0xb8(SB)/8, $0x00000000ffffffff
DATA ·andMask<>+0xc0(SB)/8, $0xffffffffffffffff
DATA ·andMask<>+0xc8(SB)/8, $0x000000ffffffffff
DATA ·andMask<>+0xd0(SB)/8, $0xffffffffffffffff
DATA ·andMask<>+0xd8(SB)/8, $0x0000ffffffffffff
DATA ·andMask<>+0xe0(SB)/8, $0xffffffffffffffff
DATA ·andMask<>+0xe8(SB)/8, $0x00ffffffffffffff
GLOBL ·chacha20Constants<>(SB), (NOPTR+RODATA), $32
GLOBL ·rol16<>(SB), (NOPTR+RODATA), $32
GLOBL ·rol8<>(SB), (NOPTR+RODATA), $32
GLOBL ·sseIncMask<>(SB), (NOPTR+RODATA), $16
GLOBL ·avx2IncMask<>(SB), (NOPTR+RODATA), $32
GLOBL ·avx2InitMask<>(SB), (NOPTR+RODATA), $32
GLOBL ·polyClampMask<>(SB), (NOPTR+RODATA), $32
GLOBL ·andMask<>(SB), (NOPTR+RODATA), $240
// No PALIGNR in Go ASM yet (but VPALIGNR is present).
#define shiftB0Left BYTE $0x66; BYTE $0x0f; BYTE $0x3a; BYTE $0x0f; BYTE $0xdb; BYTE $0x04 // PALIGNR $4, X3, X3
#define shiftB1Left BYTE $0x66; BYTE $0x0f; BYTE $0x3a; BYTE $0x0f; BYTE $0xe4; BYTE $0x04 // PALIGNR $4, X4, X4
#define shiftB2Left BYTE $0x66; BYTE $0x0f; BYTE $0x3a; BYTE $0x0f; BYTE $0xed; BYTE $0x04 // PALIGNR $4, X5, X5
#define shiftB3Left BYTE $0x66; BYTE $0x45; BYTE $0x0f; BYTE $0x3a; BYTE $0x0f; BYTE $0xed; BYTE $0x04 // PALIGNR $4, X13, X13
#define shiftC0Left BYTE $0x66; BYTE $0x0f; BYTE $0x3a; BYTE $0x0f; BYTE $0xf6; BYTE $0x08 // PALIGNR $8, X6, X6
#define shiftC1Left BYTE $0x66; BYTE $0x0f; BYTE $0x3a; BYTE $0x0f; BYTE $0xff; BYTE $0x08 // PALIGNR $8, X7, X7
#define shiftC2Left BYTE $0x66; BYTE $0x45; BYTE $0x0f; BYTE $0x3a; BYTE $0x0f; BYTE $0xc0; BYTE $0x08 // PALIGNR $8, X8, X8
#define shiftC3Left BYTE $0x66; BYTE $0x45; BYTE $0x0f; BYTE $0x3a; BYTE $0x0f; BYTE $0xf6; BYTE $0x08 // PALIGNR $8, X14, X14
#define shiftD0Left BYTE $0x66; BYTE $0x45; BYTE $0x0f; BYTE $0x3a; BYTE $0x0f; BYTE $0xc9; BYTE $0x0c // PALIGNR $12, X9, X9
#define shiftD1Left BYTE $0x66; BYTE $0x45; BYTE $0x0f; BYTE $0x3a; BYTE $0x0f; BYTE $0xd2; BYTE $0x0c // PALIGNR $12, X10, X10
#define shiftD2Left BYTE $0x66; BYTE $0x45; BYTE $0x0f; BYTE $0x3a; BYTE $0x0f; BYTE $0xdb; BYTE $0x0c // PALIGNR $12, X11, X11
#define shiftD3Left BYTE $0x66; BYTE $0x45; BYTE $0x0f; BYTE $0x3a; BYTE $0x0f; BYTE $0xff; BYTE $0x0c // PALIGNR $12, X15, X15
#define shiftB0Right BYTE $0x66; BYTE $0x0f; BYTE $0x3a; BYTE $0x0f; BYTE $0xdb; BYTE $0x0c // PALIGNR $12, X3, X3
#define shiftB1Right BYTE $0x66; BYTE $0x0f; BYTE $0x3a; BYTE $0x0f; BYTE $0xe4; BYTE $0x0c // PALIGNR $12, X4, X4
#define shiftB2Right BYTE $0x66; BYTE $0x0f; BYTE $0x3a; BYTE $0x0f; BYTE $0xed; BYTE $0x0c // PALIGNR $12, X5, X5
#define shiftB3Right BYTE $0x66; BYTE $0x45; BYTE $0x0f; BYTE $0x3a; BYTE $0x0f; BYTE $0xed; BYTE $0x0c // PALIGNR $12, X13, X13
#define shiftC0Right shiftC0Left
#define shiftC1Right shiftC1Left
#define shiftC2Right shiftC2Left
#define shiftC3Right shiftC3Left
#define shiftD0Right BYTE $0x66; BYTE $0x45; BYTE $0x0f; BYTE $0x3a; BYTE $0x0f; BYTE $0xc9; BYTE $0x04 // PALIGNR $4, X9, X9
#define shiftD1Right BYTE $0x66; BYTE $0x45; BYTE $0x0f; BYTE $0x3a; BYTE $0x0f; BYTE $0xd2; BYTE $0x04 // PALIGNR $4, X10, X10
#define shiftD2Right BYTE $0x66; BYTE $0x45; BYTE $0x0f; BYTE $0x3a; BYTE $0x0f; BYTE $0xdb; BYTE $0x04 // PALIGNR $4, X11, X11
#define shiftD3Right BYTE $0x66; BYTE $0x45; BYTE $0x0f; BYTE $0x3a; BYTE $0x0f; BYTE $0xff; BYTE $0x04 // PALIGNR $4, X15, X15
// Some macros
// ROL rotates the uint32s in register R left by N bits, using temporary T.
#define ROL(N, R, T) \
MOVO R, T; PSLLL $(N), T; PSRLL $(32-(N)), R; PXOR T, R
// ROL16 rotates the uint32s in register R left by 16, using temporary T if needed.
#ifdef GOAMD64_v2
#define ROL16(R, T) PSHUFB ·rol16<>(SB), R
#else
#define ROL16(R, T) ROL(16, R, T)
#endif
// ROL8 rotates the uint32s in register R left by 8, using temporary T if needed.
#ifdef GOAMD64_v2
#define ROL8(R, T) PSHUFB ·rol8<>(SB), R
#else
#define ROL8(R, T) ROL(8, R, T)
#endif
#define chachaQR(A, B, C, D, T) \
PADDD B, A; PXOR A, D; ROL16(D, T) \
PADDD D, C; PXOR C, B; MOVO B, T; PSLLL $12, T; PSRLL $20, B; PXOR T, B \
PADDD B, A; PXOR A, D; ROL8(D, T) \
PADDD D, C; PXOR C, B; MOVO B, T; PSLLL $7, T; PSRLL $25, B; PXOR T, B
#define chachaQR_AVX2(A, B, C, D, T) \
VPADDD B, A, A; VPXOR A, D, D; VPSHUFB ·rol16<>(SB), D, D \
VPADDD D, C, C; VPXOR C, B, B; VPSLLD $12, B, T; VPSRLD $20, B, B; VPXOR T, B, B \
VPADDD B, A, A; VPXOR A, D, D; VPSHUFB ·rol8<>(SB), D, D \
VPADDD D, C, C; VPXOR C, B, B; VPSLLD $7, B, T; VPSRLD $25, B, B; VPXOR T, B, B
#define polyAdd(S) ADDQ S, acc0; ADCQ 8+S, acc1; ADCQ $1, acc2
#define polyMulStage1 MOVQ (0*8)(BP), AX; MOVQ AX, t2; MULQ acc0; MOVQ AX, t0; MOVQ DX, t1; MOVQ (0*8)(BP), AX; MULQ acc1; IMULQ acc2, t2; ADDQ AX, t1; ADCQ DX, t2
#define polyMulStage2 MOVQ (1*8)(BP), AX; MOVQ AX, t3; MULQ acc0; ADDQ AX, t1; ADCQ $0, DX; MOVQ DX, acc0; MOVQ (1*8)(BP), AX; MULQ acc1; ADDQ AX, t2; ADCQ $0, DX
#define polyMulStage3 IMULQ acc2, t3; ADDQ acc0, t2; ADCQ DX, t3
#define polyMulReduceStage MOVQ t0, acc0; MOVQ t1, acc1; MOVQ t2, acc2; ANDQ $3, acc2; MOVQ t2, t0; ANDQ $-4, t0; MOVQ t3, t1; SHRQ $2, t3, t2; SHRQ $2, t3; ADDQ t0, acc0; ADCQ t1, acc1; ADCQ $0, acc2; ADDQ t2, acc0; ADCQ t3, acc1; ADCQ $0, acc2
#define polyMulStage1_AVX2 MOVQ (0*8)(BP), DX; MOVQ DX, t2; MULXQ acc0, t0, t1; IMULQ acc2, t2; MULXQ acc1, AX, DX; ADDQ AX, t1; ADCQ DX, t2
#define polyMulStage2_AVX2 MOVQ (1*8)(BP), DX; MULXQ acc0, acc0, AX; ADDQ acc0, t1; MULXQ acc1, acc1, t3; ADCQ acc1, t2; ADCQ $0, t3
#define polyMulStage3_AVX2 IMULQ acc2, DX; ADDQ AX, t2; ADCQ DX, t3
#define polyMul polyMulStage1; polyMulStage2; polyMulStage3; polyMulReduceStage
#define polyMulAVX2 polyMulStage1_AVX2; polyMulStage2_AVX2; polyMulStage3_AVX2; polyMulReduceStage
// ----------------------------------------------------------------------------
TEXT polyHashADInternal<>(SB), NOSPLIT, $0
// adp points to beginning of additional data
// itr2 holds ad length
XORQ acc0, acc0
XORQ acc1, acc1
XORQ acc2, acc2
CMPQ itr2, $13
JNE hashADLoop
openFastTLSAD:
// Special treatment for the TLS case of 13 bytes
MOVQ (adp), acc0
MOVQ 5(adp), acc1
SHRQ $24, acc1
MOVQ $1, acc2
polyMul
RET
hashADLoop:
// Hash in 16 byte chunks
CMPQ itr2, $16
JB hashADTail
polyAdd(0(adp))
LEAQ (1*16)(adp), adp
SUBQ $16, itr2
polyMul
JMP hashADLoop
hashADTail:
CMPQ itr2, $0
JE hashADDone
// Hash last < 16 byte tail
XORQ t0, t0
XORQ t1, t1
XORQ t2, t2
ADDQ itr2, adp
hashADTailLoop:
SHLQ $8, t0, t1
SHLQ $8, t0
MOVB -1(adp), t2
XORQ t2, t0
DECQ adp
DECQ itr2
JNE hashADTailLoop
hashADTailFinish:
ADDQ t0, acc0; ADCQ t1, acc1; ADCQ $1, acc2
polyMul
// Finished AD
hashADDone:
RET
// ----------------------------------------------------------------------------
// func chacha20Poly1305Open(dst, key, src, ad []byte) bool
TEXT ·chacha20Poly1305Open(SB), 0, $288-97
// For aligned stack access
MOVQ SP, BP
ADDQ $32, BP
ANDQ $-32, BP
MOVQ dst+0(FP), oup
MOVQ key+24(FP), keyp
MOVQ src+48(FP), inp
MOVQ src_len+56(FP), inl
MOVQ ad+72(FP), adp
// Check for AVX2 support
CMPB ·useAVX2(SB), $1
JE chacha20Poly1305Open_AVX2
// Special optimization, for very short buffers
CMPQ inl, $128
JBE openSSE128 // About 16% faster
// For long buffers, prepare the poly key first
MOVOU ·chacha20Constants<>(SB), A0
MOVOU (1*16)(keyp), B0
MOVOU (2*16)(keyp), C0
MOVOU (3*16)(keyp), D0
MOVO D0, T1
// Store state on stack for future use
MOVO B0, state1Store
MOVO C0, state2Store
MOVO D0, ctr3Store
MOVQ $10, itr2
openSSEPreparePolyKey:
chachaQR(A0, B0, C0, D0, T0)
shiftB0Left; shiftC0Left; shiftD0Left
chachaQR(A0, B0, C0, D0, T0)
shiftB0Right; shiftC0Right; shiftD0Right
DECQ itr2
JNE openSSEPreparePolyKey
// A0|B0 hold the Poly1305 32-byte key, C0,D0 can be discarded
PADDL ·chacha20Constants<>(SB), A0; PADDL state1Store, B0
// Clamp and store the key
PAND ·polyClampMask<>(SB), A0
MOVO A0, rStore; MOVO B0, sStore
// Hash AAD
MOVQ ad_len+80(FP), itr2
CALL polyHashADInternal<>(SB)
openSSEMainLoop:
CMPQ inl, $256
JB openSSEMainLoopDone
// Load state, increment counter blocks
MOVO ·chacha20Constants<>(SB), A0; MOVO state1Store, B0; MOVO state2Store, C0; MOVO ctr3Store, D0; PADDL ·sseIncMask<>(SB), D0
MOVO A0, A1; MOVO B0, B1; MOVO C0, C1; MOVO D0, D1; PADDL ·sseIncMask<>(SB), D1
MOVO A1, A2; MOVO B1, B2; MOVO C1, C2; MOVO D1, D2; PADDL ·sseIncMask<>(SB), D2
MOVO A2, A3; MOVO B2, B3; MOVO C2, C3; MOVO D2, D3; PADDL ·sseIncMask<>(SB), D3
// Store counters
MOVO D0, ctr0Store; MOVO D1, ctr1Store; MOVO D2, ctr2Store; MOVO D3, ctr3Store
// There are 10 ChaCha20 iterations of 2QR each, so for 6 iterations we hash 2 blocks, and for the remaining 4 only 1 block - for a total of 16
MOVQ $4, itr1
MOVQ inp, itr2
openSSEInternalLoop:
MOVO C3, tmpStore
chachaQR(A0, B0, C0, D0, C3); chachaQR(A1, B1, C1, D1, C3); chachaQR(A2, B2, C2, D2, C3)
MOVO tmpStore, C3
MOVO C1, tmpStore
chachaQR(A3, B3, C3, D3, C1)
MOVO tmpStore, C1
polyAdd(0(itr2))
shiftB0Left; shiftB1Left; shiftB2Left; shiftB3Left
shiftC0Left; shiftC1Left; shiftC2Left; shiftC3Left
shiftD0Left; shiftD1Left; shiftD2Left; shiftD3Left
polyMulStage1
polyMulStage2
LEAQ (2*8)(itr2), itr2
MOVO C3, tmpStore
chachaQR(A0, B0, C0, D0, C3); chachaQR(A1, B1, C1, D1, C3); chachaQR(A2, B2, C2, D2, C3)
MOVO tmpStore, C3
MOVO C1, tmpStore
polyMulStage3
chachaQR(A3, B3, C3, D3, C1)
MOVO tmpStore, C1
polyMulReduceStage
shiftB0Right; shiftB1Right; shiftB2Right; shiftB3Right
shiftC0Right; shiftC1Right; shiftC2Right; shiftC3Right
shiftD0Right; shiftD1Right; shiftD2Right; shiftD3Right
DECQ itr1
JGE openSSEInternalLoop
polyAdd(0(itr2))
polyMul
LEAQ (2*8)(itr2), itr2
CMPQ itr1, $-6
JG openSSEInternalLoop
// Add in the state
PADDD ·chacha20Constants<>(SB), A0; PADDD ·chacha20Constants<>(SB), A1; PADDD ·chacha20Constants<>(SB), A2; PADDD ·chacha20Constants<>(SB), A3
PADDD state1Store, B0; PADDD state1Store, B1; PADDD state1Store, B2; PADDD state1Store, B3
PADDD state2Store, C0; PADDD state2Store, C1; PADDD state2Store, C2; PADDD state2Store, C3
PADDD ctr0Store, D0; PADDD ctr1Store, D1; PADDD ctr2Store, D2; PADDD ctr3Store, D3
// Load - xor - store
MOVO D3, tmpStore
MOVOU (0*16)(inp), D3; PXOR D3, A0; MOVOU A0, (0*16)(oup)
MOVOU (1*16)(inp), D3; PXOR D3, B0; MOVOU B0, (1*16)(oup)
MOVOU (2*16)(inp), D3; PXOR D3, C0; MOVOU C0, (2*16)(oup)
MOVOU (3*16)(inp), D3; PXOR D3, D0; MOVOU D0, (3*16)(oup)
MOVOU (4*16)(inp), D0; PXOR D0, A1; MOVOU A1, (4*16)(oup)
MOVOU (5*16)(inp), D0; PXOR D0, B1; MOVOU B1, (5*16)(oup)
MOVOU (6*16)(inp), D0; PXOR D0, C1; MOVOU C1, (6*16)(oup)
MOVOU (7*16)(inp), D0; PXOR D0, D1; MOVOU D1, (7*16)(oup)
MOVOU (8*16)(inp), D0; PXOR D0, A2; MOVOU A2, (8*16)(oup)
MOVOU (9*16)(inp), D0; PXOR D0, B2; MOVOU B2, (9*16)(oup)
MOVOU (10*16)(inp), D0; PXOR D0, C2; MOVOU C2, (10*16)(oup)
MOVOU (11*16)(inp), D0; PXOR D0, D2; MOVOU D2, (11*16)(oup)
MOVOU (12*16)(inp), D0; PXOR D0, A3; MOVOU A3, (12*16)(oup)
MOVOU (13*16)(inp), D0; PXOR D0, B3; MOVOU B3, (13*16)(oup)
MOVOU (14*16)(inp), D0; PXOR D0, C3; MOVOU C3, (14*16)(oup)
MOVOU (15*16)(inp), D0; PXOR tmpStore, D0; MOVOU D0, (15*16)(oup)
LEAQ 256(inp), inp
LEAQ 256(oup), oup
SUBQ $256, inl
JMP openSSEMainLoop
openSSEMainLoopDone:
// Handle the various tail sizes efficiently
TESTQ inl, inl
JE openSSEFinalize
CMPQ inl, $64
JBE openSSETail64
CMPQ inl, $128
JBE openSSETail128
CMPQ inl, $192
JBE openSSETail192
JMP openSSETail256
openSSEFinalize:
// Hash in the PT, AAD lengths
ADDQ ad_len+80(FP), acc0; ADCQ src_len+56(FP), acc1; ADCQ $1, acc2
polyMul
// Final reduce
MOVQ acc0, t0
MOVQ acc1, t1
MOVQ acc2, t2
SUBQ $-5, acc0
SBBQ $-1, acc1
SBBQ $3, acc2
CMOVQCS t0, acc0
CMOVQCS t1, acc1
CMOVQCS t2, acc2
// Add in the "s" part of the key
ADDQ 0+sStore, acc0
ADCQ 8+sStore, acc1
// Finally, constant time compare to the tag at the end of the message
XORQ AX, AX
MOVQ $1, DX
XORQ (0*8)(inp), acc0
XORQ (1*8)(inp), acc1
ORQ acc1, acc0
CMOVQEQ DX, AX
// Return true iff tags are equal
MOVB AX, ret+96(FP)
RET
// ----------------------------------------------------------------------------
// Special optimization for buffers smaller than 129 bytes
openSSE128:
// For up to 128 bytes of ciphertext and 64 bytes for the poly key, we require to process three blocks
MOVOU ·chacha20Constants<>(SB), A0; MOVOU (1*16)(keyp), B0; MOVOU (2*16)(keyp), C0; MOVOU (3*16)(keyp), D0
MOVO A0, A1; MOVO B0, B1; MOVO C0, C1; MOVO D0, D1; PADDL ·sseIncMask<>(SB), D1
MOVO A1, A2; MOVO B1, B2; MOVO C1, C2; MOVO D1, D2; PADDL ·sseIncMask<>(SB), D2
MOVO B0, T1; MOVO C0, T2; MOVO D1, T3
MOVQ $10, itr2
openSSE128InnerCipherLoop:
chachaQR(A0, B0, C0, D0, T0); chachaQR(A1, B1, C1, D1, T0); chachaQR(A2, B2, C2, D2, T0)
shiftB0Left; shiftB1Left; shiftB2Left
shiftC0Left; shiftC1Left; shiftC2Left
shiftD0Left; shiftD1Left; shiftD2Left
chachaQR(A0, B0, C0, D0, T0); chachaQR(A1, B1, C1, D1, T0); chachaQR(A2, B2, C2, D2, T0)
shiftB0Right; shiftB1Right; shiftB2Right
shiftC0Right; shiftC1Right; shiftC2Right
shiftD0Right; shiftD1Right; shiftD2Right
DECQ itr2
JNE openSSE128InnerCipherLoop
// A0|B0 hold the Poly1305 32-byte key, C0,D0 can be discarded
PADDL ·chacha20Constants<>(SB), A0; PADDL ·chacha20Constants<>(SB), A1; PADDL ·chacha20Constants<>(SB), A2
PADDL T1, B0; PADDL T1, B1; PADDL T1, B2
PADDL T2, C1; PADDL T2, C2
PADDL T3, D1; PADDL ·sseIncMask<>(SB), T3; PADDL T3, D2
// Clamp and store the key
PAND ·polyClampMask<>(SB), A0
MOVOU A0, rStore; MOVOU B0, sStore
// Hash
MOVQ ad_len+80(FP), itr2
CALL polyHashADInternal<>(SB)
openSSE128Open:
CMPQ inl, $16
JB openSSETail16
SUBQ $16, inl
// Load for hashing
polyAdd(0(inp))
// Load for decryption
MOVOU (inp), T0; PXOR T0, A1; MOVOU A1, (oup)
LEAQ (1*16)(inp), inp
LEAQ (1*16)(oup), oup
polyMul
// Shift the stream "left"
MOVO B1, A1
MOVO C1, B1
MOVO D1, C1
MOVO A2, D1
MOVO B2, A2
MOVO C2, B2
MOVO D2, C2
JMP openSSE128Open
openSSETail16:
TESTQ inl, inl
JE openSSEFinalize
// We can safely load the CT from the end, because it is padded with the MAC
MOVQ inl, itr2
SHLQ $4, itr2
LEAQ ·andMask<>(SB), t0
MOVOU (inp), T0
ADDQ inl, inp
PAND -16(t0)(itr2*1), T0
MOVO T0, 0+tmpStore
MOVQ T0, t0
MOVQ 8+tmpStore, t1
PXOR A1, T0
// We can only store one byte at a time, since plaintext can be shorter than 16 bytes
openSSETail16Store:
MOVQ T0, t3
MOVB t3, (oup)
PSRLDQ $1, T0
INCQ oup
DECQ inl
JNE openSSETail16Store
ADDQ t0, acc0; ADCQ t1, acc1; ADCQ $1, acc2
polyMul
JMP openSSEFinalize
// ----------------------------------------------------------------------------
// Special optimization for the last 64 bytes of ciphertext
openSSETail64:
// Need to decrypt up to 64 bytes - prepare single block
MOVO ·chacha20Constants<>(SB), A0; MOVO state1Store, B0; MOVO state2Store, C0; MOVO ctr3Store, D0; PADDL ·sseIncMask<>(SB), D0; MOVO D0, ctr0Store
XORQ itr2, itr2
MOVQ inl, itr1
CMPQ itr1, $16
JB openSSETail64LoopB
openSSETail64LoopA:
// Perform ChaCha rounds, while hashing the remaining input
polyAdd(0(inp)(itr2*1))
polyMul
SUBQ $16, itr1
openSSETail64LoopB:
ADDQ $16, itr2
chachaQR(A0, B0, C0, D0, T0)
shiftB0Left; shiftC0Left; shiftD0Left
chachaQR(A0, B0, C0, D0, T0)
shiftB0Right; shiftC0Right; shiftD0Right
CMPQ itr1, $16
JAE openSSETail64LoopA
CMPQ itr2, $160
JNE openSSETail64LoopB
PADDL ·chacha20Constants<>(SB), A0; PADDL state1Store, B0; PADDL state2Store, C0; PADDL ctr0Store, D0
openSSETail64DecLoop:
CMPQ inl, $16
JB openSSETail64DecLoopDone
SUBQ $16, inl
MOVOU (inp), T0
PXOR T0, A0
MOVOU A0, (oup)
LEAQ 16(inp), inp
LEAQ 16(oup), oup
MOVO B0, A0
MOVO C0, B0
MOVO D0, C0
JMP openSSETail64DecLoop
openSSETail64DecLoopDone:
MOVO A0, A1
JMP openSSETail16
// ----------------------------------------------------------------------------
// Special optimization for the last 128 bytes of ciphertext
openSSETail128:
// Need to decrypt up to 128 bytes - prepare two blocks
MOVO ·chacha20Constants<>(SB), A1; MOVO state1Store, B1; MOVO state2Store, C1; MOVO ctr3Store, D1; PADDL ·sseIncMask<>(SB), D1; MOVO D1, ctr0Store
MOVO A1, A0; MOVO B1, B0; MOVO C1, C0; MOVO D1, D0; PADDL ·sseIncMask<>(SB), D0; MOVO D0, ctr1Store
XORQ itr2, itr2
MOVQ inl, itr1
ANDQ $-16, itr1
openSSETail128LoopA:
// Perform ChaCha rounds, while hashing the remaining input
polyAdd(0(inp)(itr2*1))
polyMul
openSSETail128LoopB:
ADDQ $16, itr2
chachaQR(A0, B0, C0, D0, T0); chachaQR(A1, B1, C1, D1, T0)
shiftB0Left; shiftC0Left; shiftD0Left
shiftB1Left; shiftC1Left; shiftD1Left
chachaQR(A0, B0, C0, D0, T0); chachaQR(A1, B1, C1, D1, T0)
shiftB0Right; shiftC0Right; shiftD0Right
shiftB1Right; shiftC1Right; shiftD1Right
CMPQ itr2, itr1
JB openSSETail128LoopA
CMPQ itr2, $160
JNE openSSETail128LoopB
PADDL ·chacha20Constants<>(SB), A0; PADDL ·chacha20Constants<>(SB), A1
PADDL state1Store, B0; PADDL state1Store, B1
PADDL state2Store, C0; PADDL state2Store, C1
PADDL ctr1Store, D0; PADDL ctr0Store, D1
MOVOU (0*16)(inp), T0; MOVOU (1*16)(inp), T1; MOVOU (2*16)(inp), T2; MOVOU (3*16)(inp), T3
PXOR T0, A1; PXOR T1, B1; PXOR T2, C1; PXOR T3, D1
MOVOU A1, (0*16)(oup); MOVOU B1, (1*16)(oup); MOVOU C1, (2*16)(oup); MOVOU D1, (3*16)(oup)
SUBQ $64, inl
LEAQ 64(inp), inp
LEAQ 64(oup), oup
JMP openSSETail64DecLoop
// ----------------------------------------------------------------------------
// Special optimization for the last 192 bytes of ciphertext
openSSETail192:
// Need to decrypt up to 192 bytes - prepare three blocks
MOVO ·chacha20Constants<>(SB), A2; MOVO state1Store, B2; MOVO state2Store, C2; MOVO ctr3Store, D2; PADDL ·sseIncMask<>(SB), D2; MOVO D2, ctr0Store
MOVO A2, A1; MOVO B2, B1; MOVO C2, C1; MOVO D2, D1; PADDL ·sseIncMask<>(SB), D1; MOVO D1, ctr1Store
MOVO A1, A0; MOVO B1, B0; MOVO C1, C0; MOVO D1, D0; PADDL ·sseIncMask<>(SB), D0; MOVO D0, ctr2Store
MOVQ inl, itr1
MOVQ $160, itr2
CMPQ itr1, $160
CMOVQGT itr2, itr1
ANDQ $-16, itr1
XORQ itr2, itr2
openSSLTail192LoopA:
// Perform ChaCha rounds, while hashing the remaining input
polyAdd(0(inp)(itr2*1))
polyMul
openSSLTail192LoopB:
ADDQ $16, itr2
chachaQR(A0, B0, C0, D0, T0); chachaQR(A1, B1, C1, D1, T0); chachaQR(A2, B2, C2, D2, T0)
shiftB0Left; shiftC0Left; shiftD0Left
shiftB1Left; shiftC1Left; shiftD1Left
shiftB2Left; shiftC2Left; shiftD2Left
chachaQR(A0, B0, C0, D0, T0); chachaQR(A1, B1, C1, D1, T0); chachaQR(A2, B2, C2, D2, T0)
shiftB0Right; shiftC0Right; shiftD0Right
shiftB1Right; shiftC1Right; shiftD1Right
shiftB2Right; shiftC2Right; shiftD2Right
CMPQ itr2, itr1
JB openSSLTail192LoopA
CMPQ itr2, $160
JNE openSSLTail192LoopB
CMPQ inl, $176
JB openSSLTail192Store
polyAdd(160(inp))
polyMul
CMPQ inl, $192
JB openSSLTail192Store
polyAdd(176(inp))
polyMul
openSSLTail192Store:
PADDL ·chacha20Constants<>(SB), A0; PADDL ·chacha20Constants<>(SB), A1; PADDL ·chacha20Constants<>(SB), A2
PADDL state1Store, B0; PADDL state1Store, B1; PADDL state1Store, B2
PADDL state2Store, C0; PADDL state2Store, C1; PADDL state2Store, C2
PADDL ctr2Store, D0; PADDL ctr1Store, D1; PADDL ctr0Store, D2
MOVOU (0*16)(inp), T0; MOVOU (1*16)(inp), T1; MOVOU (2*16)(inp), T2; MOVOU (3*16)(inp), T3
PXOR T0, A2; PXOR T1, B2; PXOR T2, C2; PXOR T3, D2
MOVOU A2, (0*16)(oup); MOVOU B2, (1*16)(oup); MOVOU C2, (2*16)(oup); MOVOU D2, (3*16)(oup)
MOVOU (4*16)(inp), T0; MOVOU (5*16)(inp), T1; MOVOU (6*16)(inp), T2; MOVOU (7*16)(inp), T3
PXOR T0, A1; PXOR T1, B1; PXOR T2, C1; PXOR T3, D1
MOVOU A1, (4*16)(oup); MOVOU B1, (5*16)(oup); MOVOU C1, (6*16)(oup); MOVOU D1, (7*16)(oup)
SUBQ $128, inl
LEAQ 128(inp), inp
LEAQ 128(oup), oup
JMP openSSETail64DecLoop
// ----------------------------------------------------------------------------
// Special optimization for the last 256 bytes of ciphertext
openSSETail256:
// Need to decrypt up to 256 bytes - prepare four blocks
MOVO ·chacha20Constants<>(SB), A0; MOVO state1Store, B0; MOVO state2Store, C0; MOVO ctr3Store, D0; PADDL ·sseIncMask<>(SB), D0
MOVO A0, A1; MOVO B0, B1; MOVO C0, C1; MOVO D0, D1; PADDL ·sseIncMask<>(SB), D1
MOVO A1, A2; MOVO B1, B2; MOVO C1, C2; MOVO D1, D2; PADDL ·sseIncMask<>(SB), D2
MOVO A2, A3; MOVO B2, B3; MOVO C2, C3; MOVO D2, D3; PADDL ·sseIncMask<>(SB), D3
// Store counters
MOVO D0, ctr0Store; MOVO D1, ctr1Store; MOVO D2, ctr2Store; MOVO D3, ctr3Store
XORQ itr2, itr2
openSSETail256Loop:
// This loop inteleaves 8 ChaCha quarter rounds with 1 poly multiplication
polyAdd(0(inp)(itr2*1))
MOVO C3, tmpStore
chachaQR(A0, B0, C0, D0, C3); chachaQR(A1, B1, C1, D1, C3); chachaQR(A2, B2, C2, D2, C3)
MOVO tmpStore, C3
MOVO C1, tmpStore
chachaQR(A3, B3, C3, D3, C1)
MOVO tmpStore, C1
shiftB0Left; shiftB1Left; shiftB2Left; shiftB3Left
shiftC0Left; shiftC1Left; shiftC2Left; shiftC3Left
shiftD0Left; shiftD1Left; shiftD2Left; shiftD3Left
polyMulStage1
polyMulStage2
MOVO C3, tmpStore
chachaQR(A0, B0, C0, D0, C3); chachaQR(A1, B1, C1, D1, C3); chachaQR(A2, B2, C2, D2, C3)
MOVO tmpStore, C3
MOVO C1, tmpStore
chachaQR(A3, B3, C3, D3, C1)
MOVO tmpStore, C1
polyMulStage3
polyMulReduceStage
shiftB0Right; shiftB1Right; shiftB2Right; shiftB3Right
shiftC0Right; shiftC1Right; shiftC2Right; shiftC3Right
shiftD0Right; shiftD1Right; shiftD2Right; shiftD3Right
ADDQ $2*8, itr2
CMPQ itr2, $160
JB openSSETail256Loop
MOVQ inl, itr1
ANDQ $-16, itr1
openSSETail256HashLoop:
polyAdd(0(inp)(itr2*1))
polyMul
ADDQ $2*8, itr2
CMPQ itr2, itr1
JB openSSETail256HashLoop
// Add in the state
PADDD ·chacha20Constants<>(SB), A0; PADDD ·chacha20Constants<>(SB), A1; PADDD ·chacha20Constants<>(SB), A2; PADDD ·chacha20Constants<>(SB), A3
PADDD state1Store, B0; PADDD state1Store, B1; PADDD state1Store, B2; PADDD state1Store, B3
PADDD state2Store, C0; PADDD state2Store, C1; PADDD state2Store, C2; PADDD state2Store, C3
PADDD ctr0Store, D0; PADDD ctr1Store, D1; PADDD ctr2Store, D2; PADDD ctr3Store, D3
MOVO D3, tmpStore
// Load - xor - store
MOVOU (0*16)(inp), D3; PXOR D3, A0
MOVOU (1*16)(inp), D3; PXOR D3, B0
MOVOU (2*16)(inp), D3; PXOR D3, C0
MOVOU (3*16)(inp), D3; PXOR D3, D0
MOVOU A0, (0*16)(oup)
MOVOU B0, (1*16)(oup)
MOVOU C0, (2*16)(oup)
MOVOU D0, (3*16)(oup)
MOVOU (4*16)(inp), A0; MOVOU (5*16)(inp), B0; MOVOU (6*16)(inp), C0; MOVOU (7*16)(inp), D0
PXOR A0, A1; PXOR B0, B1; PXOR C0, C1; PXOR D0, D1
MOVOU A1, (4*16)(oup); MOVOU B1, (5*16)(oup); MOVOU C1, (6*16)(oup); MOVOU D1, (7*16)(oup)
MOVOU (8*16)(inp), A0; MOVOU (9*16)(inp), B0; MOVOU (10*16)(inp), C0; MOVOU (11*16)(inp), D0
PXOR A0, A2; PXOR B0, B2; PXOR C0, C2; PXOR D0, D2
MOVOU A2, (8*16)(oup); MOVOU B2, (9*16)(oup); MOVOU C2, (10*16)(oup); MOVOU D2, (11*16)(oup)
LEAQ 192(inp), inp
LEAQ 192(oup), oup
SUBQ $192, inl
MOVO A3, A0
MOVO B3, B0
MOVO C3, C0
MOVO tmpStore, D0
JMP openSSETail64DecLoop
// ----------------------------------------------------------------------------
// ------------------------- AVX2 Code ----------------------------------------
chacha20Poly1305Open_AVX2:
VZEROUPPER
VMOVDQU ·chacha20Constants<>(SB), AA0
BYTE $0xc4; BYTE $0x42; BYTE $0x7d; BYTE $0x5a; BYTE $0x70; BYTE $0x10 // broadcasti128 16(r8), ymm14
BYTE $0xc4; BYTE $0x42; BYTE $0x7d; BYTE $0x5a; BYTE $0x60; BYTE $0x20 // broadcasti128 32(r8), ymm12
BYTE $0xc4; BYTE $0xc2; BYTE $0x7d; BYTE $0x5a; BYTE $0x60; BYTE $0x30 // broadcasti128 48(r8), ymm4
VPADDD ·avx2InitMask<>(SB), DD0, DD0
// Special optimization, for very short buffers
CMPQ inl, $192
JBE openAVX2192
CMPQ inl, $320
JBE openAVX2320
// For the general key prepare the key first - as a byproduct we have 64 bytes of cipher stream
VMOVDQA BB0, state1StoreAVX2
VMOVDQA CC0, state2StoreAVX2
VMOVDQA DD0, ctr3StoreAVX2
MOVQ $10, itr2
openAVX2PreparePolyKey:
chachaQR_AVX2(AA0, BB0, CC0, DD0, TT0)
VPALIGNR $4, BB0, BB0, BB0; VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $12, DD0, DD0, DD0
chachaQR_AVX2(AA0, BB0, CC0, DD0, TT0)
VPALIGNR $12, BB0, BB0, BB0; VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $4, DD0, DD0, DD0
DECQ itr2
JNE openAVX2PreparePolyKey
VPADDD ·chacha20Constants<>(SB), AA0, AA0
VPADDD state1StoreAVX2, BB0, BB0
VPADDD state2StoreAVX2, CC0, CC0
VPADDD ctr3StoreAVX2, DD0, DD0
VPERM2I128 $0x02, AA0, BB0, TT0
// Clamp and store poly key
VPAND ·polyClampMask<>(SB), TT0, TT0
VMOVDQA TT0, rsStoreAVX2
// Stream for the first 64 bytes
VPERM2I128 $0x13, AA0, BB0, AA0
VPERM2I128 $0x13, CC0, DD0, BB0
// Hash AD + first 64 bytes
MOVQ ad_len+80(FP), itr2
CALL polyHashADInternal<>(SB)
XORQ itr1, itr1
openAVX2InitialHash64:
polyAdd(0(inp)(itr1*1))
polyMulAVX2
ADDQ $16, itr1
CMPQ itr1, $64
JNE openAVX2InitialHash64
// Decrypt the first 64 bytes
VPXOR (0*32)(inp), AA0, AA0
VPXOR (1*32)(inp), BB0, BB0
VMOVDQU AA0, (0*32)(oup)
VMOVDQU BB0, (1*32)(oup)
LEAQ (2*32)(inp), inp
LEAQ (2*32)(oup), oup
SUBQ $64, inl
openAVX2MainLoop:
CMPQ inl, $512
JB openAVX2MainLoopDone
// Load state, increment counter blocks, store the incremented counters
VMOVDQU ·chacha20Constants<>(SB), AA0; VMOVDQA AA0, AA1; VMOVDQA AA0, AA2; VMOVDQA AA0, AA3
VMOVDQA state1StoreAVX2, BB0; VMOVDQA BB0, BB1; VMOVDQA BB0, BB2; VMOVDQA BB0, BB3
VMOVDQA state2StoreAVX2, CC0; VMOVDQA CC0, CC1; VMOVDQA CC0, CC2; VMOVDQA CC0, CC3
VMOVDQA ctr3StoreAVX2, DD0; VPADDD ·avx2IncMask<>(SB), DD0, DD0; VPADDD ·avx2IncMask<>(SB), DD0, DD1; VPADDD ·avx2IncMask<>(SB), DD1, DD2; VPADDD ·avx2IncMask<>(SB), DD2, DD3
VMOVDQA DD0, ctr0StoreAVX2; VMOVDQA DD1, ctr1StoreAVX2; VMOVDQA DD2, ctr2StoreAVX2; VMOVDQA DD3, ctr3StoreAVX2
XORQ itr1, itr1
openAVX2InternalLoop:
// Lets just say this spaghetti loop interleaves 2 quarter rounds with 3 poly multiplications
// Effectively per 512 bytes of stream we hash 480 bytes of ciphertext
polyAdd(0*8(inp)(itr1*1))
VPADDD BB0, AA0, AA0; VPADDD BB1, AA1, AA1; VPADDD BB2, AA2, AA2; VPADDD BB3, AA3, AA3
polyMulStage1_AVX2
VPXOR AA0, DD0, DD0; VPXOR AA1, DD1, DD1; VPXOR AA2, DD2, DD2; VPXOR AA3, DD3, DD3
VPSHUFB ·rol16<>(SB), DD0, DD0; VPSHUFB ·rol16<>(SB), DD1, DD1; VPSHUFB ·rol16<>(SB), DD2, DD2; VPSHUFB ·rol16<>(SB), DD3, DD3
polyMulStage2_AVX2
VPADDD DD0, CC0, CC0; VPADDD DD1, CC1, CC1; VPADDD DD2, CC2, CC2; VPADDD DD3, CC3, CC3
VPXOR CC0, BB0, BB0; VPXOR CC1, BB1, BB1; VPXOR CC2, BB2, BB2; VPXOR CC3, BB3, BB3
polyMulStage3_AVX2
VMOVDQA CC3, tmpStoreAVX2
VPSLLD $12, BB0, CC3; VPSRLD $20, BB0, BB0; VPXOR CC3, BB0, BB0
VPSLLD $12, BB1, CC3; VPSRLD $20, BB1, BB1; VPXOR CC3, BB1, BB1
VPSLLD $12, BB2, CC3; VPSRLD $20, BB2, BB2; VPXOR CC3, BB2, BB2
VPSLLD $12, BB3, CC3; VPSRLD $20, BB3, BB3; VPXOR CC3, BB3, BB3
VMOVDQA tmpStoreAVX2, CC3
polyMulReduceStage
VPADDD BB0, AA0, AA0; VPADDD BB1, AA1, AA1; VPADDD BB2, AA2, AA2; VPADDD BB3, AA3, AA3
VPXOR AA0, DD0, DD0; VPXOR AA1, DD1, DD1; VPXOR AA2, DD2, DD2; VPXOR AA3, DD3, DD3
VPSHUFB ·rol8<>(SB), DD0, DD0; VPSHUFB ·rol8<>(SB), DD1, DD1; VPSHUFB ·rol8<>(SB), DD2, DD2; VPSHUFB ·rol8<>(SB), DD3, DD3
polyAdd(2*8(inp)(itr1*1))
VPADDD DD0, CC0, CC0; VPADDD DD1, CC1, CC1; VPADDD DD2, CC2, CC2; VPADDD DD3, CC3, CC3
polyMulStage1_AVX2
VPXOR CC0, BB0, BB0; VPXOR CC1, BB1, BB1; VPXOR CC2, BB2, BB2; VPXOR CC3, BB3, BB3
VMOVDQA CC3, tmpStoreAVX2
VPSLLD $7, BB0, CC3; VPSRLD $25, BB0, BB0; VPXOR CC3, BB0, BB0
VPSLLD $7, BB1, CC3; VPSRLD $25, BB1, BB1; VPXOR CC3, BB1, BB1
VPSLLD $7, BB2, CC3; VPSRLD $25, BB2, BB2; VPXOR CC3, BB2, BB2
VPSLLD $7, BB3, CC3; VPSRLD $25, BB3, BB3; VPXOR CC3, BB3, BB3
VMOVDQA tmpStoreAVX2, CC3
polyMulStage2_AVX2
VPALIGNR $4, BB0, BB0, BB0; VPALIGNR $4, BB1, BB1, BB1; VPALIGNR $4, BB2, BB2, BB2; VPALIGNR $4, BB3, BB3, BB3
VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1; VPALIGNR $8, CC2, CC2, CC2; VPALIGNR $8, CC3, CC3, CC3
VPALIGNR $12, DD0, DD0, DD0; VPALIGNR $12, DD1, DD1, DD1; VPALIGNR $12, DD2, DD2, DD2; VPALIGNR $12, DD3, DD3, DD3
VPADDD BB0, AA0, AA0; VPADDD BB1, AA1, AA1; VPADDD BB2, AA2, AA2; VPADDD BB3, AA3, AA3
polyMulStage3_AVX2
VPXOR AA0, DD0, DD0; VPXOR AA1, DD1, DD1; VPXOR AA2, DD2, DD2; VPXOR AA3, DD3, DD3
VPSHUFB ·rol16<>(SB), DD0, DD0; VPSHUFB ·rol16<>(SB), DD1, DD1; VPSHUFB ·rol16<>(SB), DD2, DD2; VPSHUFB ·rol16<>(SB), DD3, DD3
polyMulReduceStage
VPADDD DD0, CC0, CC0; VPADDD DD1, CC1, CC1; VPADDD DD2, CC2, CC2; VPADDD DD3, CC3, CC3
VPXOR CC0, BB0, BB0; VPXOR CC1, BB1, BB1; VPXOR CC2, BB2, BB2; VPXOR CC3, BB3, BB3
polyAdd(4*8(inp)(itr1*1))
LEAQ (6*8)(itr1), itr1
VMOVDQA CC3, tmpStoreAVX2
VPSLLD $12, BB0, CC3; VPSRLD $20, BB0, BB0; VPXOR CC3, BB0, BB0
VPSLLD $12, BB1, CC3; VPSRLD $20, BB1, BB1; VPXOR CC3, BB1, BB1
VPSLLD $12, BB2, CC3; VPSRLD $20, BB2, BB2; VPXOR CC3, BB2, BB2
VPSLLD $12, BB3, CC3; VPSRLD $20, BB3, BB3; VPXOR CC3, BB3, BB3
VMOVDQA tmpStoreAVX2, CC3
polyMulStage1_AVX2
VPADDD BB0, AA0, AA0; VPADDD BB1, AA1, AA1; VPADDD BB2, AA2, AA2; VPADDD BB3, AA3, AA3
VPXOR AA0, DD0, DD0; VPXOR AA1, DD1, DD1; VPXOR AA2, DD2, DD2; VPXOR AA3, DD3, DD3
polyMulStage2_AVX2
VPSHUFB ·rol8<>(SB), DD0, DD0; VPSHUFB ·rol8<>(SB), DD1, DD1; VPSHUFB ·rol8<>(SB), DD2, DD2; VPSHUFB ·rol8<>(SB), DD3, DD3
VPADDD DD0, CC0, CC0; VPADDD DD1, CC1, CC1; VPADDD DD2, CC2, CC2; VPADDD DD3, CC3, CC3
polyMulStage3_AVX2
VPXOR CC0, BB0, BB0; VPXOR CC1, BB1, BB1; VPXOR CC2, BB2, BB2; VPXOR CC3, BB3, BB3
VMOVDQA CC3, tmpStoreAVX2
VPSLLD $7, BB0, CC3; VPSRLD $25, BB0, BB0; VPXOR CC3, BB0, BB0
VPSLLD $7, BB1, CC3; VPSRLD $25, BB1, BB1; VPXOR CC3, BB1, BB1
VPSLLD $7, BB2, CC3; VPSRLD $25, BB2, BB2; VPXOR CC3, BB2, BB2
VPSLLD $7, BB3, CC3; VPSRLD $25, BB3, BB3; VPXOR CC3, BB3, BB3
VMOVDQA tmpStoreAVX2, CC3
polyMulReduceStage
VPALIGNR $12, BB0, BB0, BB0; VPALIGNR $12, BB1, BB1, BB1; VPALIGNR $12, BB2, BB2, BB2; VPALIGNR $12, BB3, BB3, BB3
VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1; VPALIGNR $8, CC2, CC2, CC2; VPALIGNR $8, CC3, CC3, CC3
VPALIGNR $4, DD0, DD0, DD0; VPALIGNR $4, DD1, DD1, DD1; VPALIGNR $4, DD2, DD2, DD2; VPALIGNR $4, DD3, DD3, DD3
CMPQ itr1, $480
JNE openAVX2InternalLoop
VPADDD ·chacha20Constants<>(SB), AA0, AA0; VPADDD ·chacha20Constants<>(SB), AA1, AA1; VPADDD ·chacha20Constants<>(SB), AA2, AA2; VPADDD ·chacha20Constants<>(SB), AA3, AA3
VPADDD state1StoreAVX2, BB0, BB0; VPADDD state1StoreAVX2, BB1, BB1; VPADDD state1StoreAVX2, BB2, BB2; VPADDD state1StoreAVX2, BB3, BB3
VPADDD state2StoreAVX2, CC0, CC0; VPADDD state2StoreAVX2, CC1, CC1; VPADDD state2StoreAVX2, CC2, CC2; VPADDD state2StoreAVX2, CC3, CC3
VPADDD ctr0StoreAVX2, DD0, DD0; VPADDD ctr1StoreAVX2, DD1, DD1; VPADDD ctr2StoreAVX2, DD2, DD2; VPADDD ctr3StoreAVX2, DD3, DD3
VMOVDQA CC3, tmpStoreAVX2
// We only hashed 480 of the 512 bytes available - hash the remaining 32 here
polyAdd(480(inp))
polyMulAVX2
VPERM2I128 $0x02, AA0, BB0, CC3; VPERM2I128 $0x13, AA0, BB0, BB0; VPERM2I128 $0x02, CC0, DD0, AA0; VPERM2I128 $0x13, CC0, DD0, CC0
VPXOR (0*32)(inp), CC3, CC3; VPXOR (1*32)(inp), AA0, AA0; VPXOR (2*32)(inp), BB0, BB0; VPXOR (3*32)(inp), CC0, CC0
VMOVDQU CC3, (0*32)(oup); VMOVDQU AA0, (1*32)(oup); VMOVDQU BB0, (2*32)(oup); VMOVDQU CC0, (3*32)(oup)
VPERM2I128 $0x02, AA1, BB1, AA0; VPERM2I128 $0x02, CC1, DD1, BB0; VPERM2I128 $0x13, AA1, BB1, CC0; VPERM2I128 $0x13, CC1, DD1, DD0
VPXOR (4*32)(inp), AA0, AA0; VPXOR (5*32)(inp), BB0, BB0; VPXOR (6*32)(inp), CC0, CC0; VPXOR (7*32)(inp), DD0, DD0
VMOVDQU AA0, (4*32)(oup); VMOVDQU BB0, (5*32)(oup); VMOVDQU CC0, (6*32)(oup); VMOVDQU DD0, (7*32)(oup)
// and here
polyAdd(496(inp))
polyMulAVX2
VPERM2I128 $0x02, AA2, BB2, AA0; VPERM2I128 $0x02, CC2, DD2, BB0; VPERM2I128 $0x13, AA2, BB2, CC0; VPERM2I128 $0x13, CC2, DD2, DD0
VPXOR (8*32)(inp), AA0, AA0; VPXOR (9*32)(inp), BB0, BB0; VPXOR (10*32)(inp), CC0, CC0; VPXOR (11*32)(inp), DD0, DD0
VMOVDQU AA0, (8*32)(oup); VMOVDQU BB0, (9*32)(oup); VMOVDQU CC0, (10*32)(oup); VMOVDQU DD0, (11*32)(oup)
VPERM2I128 $0x02, AA3, BB3, AA0; VPERM2I128 $0x02, tmpStoreAVX2, DD3, BB0; VPERM2I128 $0x13, AA3, BB3, CC0; VPERM2I128 $0x13, tmpStoreAVX2, DD3, DD0
VPXOR (12*32)(inp), AA0, AA0; VPXOR (13*32)(inp), BB0, BB0; VPXOR (14*32)(inp), CC0, CC0; VPXOR (15*32)(inp), DD0, DD0
VMOVDQU AA0, (12*32)(oup); VMOVDQU BB0, (13*32)(oup); VMOVDQU CC0, (14*32)(oup); VMOVDQU DD0, (15*32)(oup)
LEAQ (32*16)(inp), inp
LEAQ (32*16)(oup), oup
SUBQ $(32*16), inl
JMP openAVX2MainLoop
openAVX2MainLoopDone:
// Handle the various tail sizes efficiently
TESTQ inl, inl
JE openSSEFinalize
CMPQ inl, $128
JBE openAVX2Tail128
CMPQ inl, $256
JBE openAVX2Tail256
CMPQ inl, $384
JBE openAVX2Tail384
JMP openAVX2Tail512
// ----------------------------------------------------------------------------
// Special optimization for buffers smaller than 193 bytes
openAVX2192:
// For up to 192 bytes of ciphertext and 64 bytes for the poly key, we process four blocks
VMOVDQA AA0, AA1
VMOVDQA BB0, BB1
VMOVDQA CC0, CC1
VPADDD ·avx2IncMask<>(SB), DD0, DD1
VMOVDQA AA0, AA2
VMOVDQA BB0, BB2
VMOVDQA CC0, CC2
VMOVDQA DD0, DD2
VMOVDQA DD1, TT3
MOVQ $10, itr2
openAVX2192InnerCipherLoop:
chachaQR_AVX2(AA0, BB0, CC0, DD0, TT0); chachaQR_AVX2(AA1, BB1, CC1, DD1, TT0)
VPALIGNR $4, BB0, BB0, BB0; VPALIGNR $4, BB1, BB1, BB1
VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1
VPALIGNR $12, DD0, DD0, DD0; VPALIGNR $12, DD1, DD1, DD1
chachaQR_AVX2(AA0, BB0, CC0, DD0, TT0); chachaQR_AVX2(AA1, BB1, CC1, DD1, TT0)
VPALIGNR $12, BB0, BB0, BB0; VPALIGNR $12, BB1, BB1, BB1
VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1
VPALIGNR $4, DD0, DD0, DD0; VPALIGNR $4, DD1, DD1, DD1
DECQ itr2
JNE openAVX2192InnerCipherLoop
VPADDD AA2, AA0, AA0; VPADDD AA2, AA1, AA1
VPADDD BB2, BB0, BB0; VPADDD BB2, BB1, BB1
VPADDD CC2, CC0, CC0; VPADDD CC2, CC1, CC1
VPADDD DD2, DD0, DD0; VPADDD TT3, DD1, DD1
VPERM2I128 $0x02, AA0, BB0, TT0
// Clamp and store poly key
VPAND ·polyClampMask<>(SB), TT0, TT0
VMOVDQA TT0, rsStoreAVX2
// Stream for up to 192 bytes
VPERM2I128 $0x13, AA0, BB0, AA0
VPERM2I128 $0x13, CC0, DD0, BB0
VPERM2I128 $0x02, AA1, BB1, CC0
VPERM2I128 $0x02, CC1, DD1, DD0
VPERM2I128 $0x13, AA1, BB1, AA1
VPERM2I128 $0x13, CC1, DD1, BB1
openAVX2ShortOpen:
// Hash
MOVQ ad_len+80(FP), itr2
CALL polyHashADInternal<>(SB)
openAVX2ShortOpenLoop:
CMPQ inl, $32
JB openAVX2ShortTail32
SUBQ $32, inl
// Load for hashing
polyAdd(0*8(inp))
polyMulAVX2
polyAdd(2*8(inp))
polyMulAVX2
// Load for decryption
VPXOR (inp), AA0, AA0
VMOVDQU AA0, (oup)
LEAQ (1*32)(inp), inp
LEAQ (1*32)(oup), oup
// Shift stream left
VMOVDQA BB0, AA0
VMOVDQA CC0, BB0
VMOVDQA DD0, CC0
VMOVDQA AA1, DD0
VMOVDQA BB1, AA1
VMOVDQA CC1, BB1
VMOVDQA DD1, CC1
VMOVDQA AA2, DD1
VMOVDQA BB2, AA2
JMP openAVX2ShortOpenLoop
openAVX2ShortTail32:
CMPQ inl, $16
VMOVDQA A0, A1
JB openAVX2ShortDone
SUBQ $16, inl
// Load for hashing
polyAdd(0*8(inp))
polyMulAVX2
// Load for decryption
VPXOR (inp), A0, T0
VMOVDQU T0, (oup)
LEAQ (1*16)(inp), inp
LEAQ (1*16)(oup), oup
VPERM2I128 $0x11, AA0, AA0, AA0
VMOVDQA A0, A1
openAVX2ShortDone:
VZEROUPPER
JMP openSSETail16
// ----------------------------------------------------------------------------
// Special optimization for buffers smaller than 321 bytes
openAVX2320:
// For up to 320 bytes of ciphertext and 64 bytes for the poly key, we process six blocks
VMOVDQA AA0, AA1; VMOVDQA BB0, BB1; VMOVDQA CC0, CC1; VPADDD ·avx2IncMask<>(SB), DD0, DD1
VMOVDQA AA0, AA2; VMOVDQA BB0, BB2; VMOVDQA CC0, CC2; VPADDD ·avx2IncMask<>(SB), DD1, DD2
VMOVDQA BB0, TT1; VMOVDQA CC0, TT2; VMOVDQA DD0, TT3
MOVQ $10, itr2
openAVX2320InnerCipherLoop:
chachaQR_AVX2(AA0, BB0, CC0, DD0, TT0); chachaQR_AVX2(AA1, BB1, CC1, DD1, TT0); chachaQR_AVX2(AA2, BB2, CC2, DD2, TT0)
VPALIGNR $4, BB0, BB0, BB0; VPALIGNR $4, BB1, BB1, BB1; VPALIGNR $4, BB2, BB2, BB2
VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1; VPALIGNR $8, CC2, CC2, CC2
VPALIGNR $12, DD0, DD0, DD0; VPALIGNR $12, DD1, DD1, DD1; VPALIGNR $12, DD2, DD2, DD2
chachaQR_AVX2(AA0, BB0, CC0, DD0, TT0); chachaQR_AVX2(AA1, BB1, CC1, DD1, TT0); chachaQR_AVX2(AA2, BB2, CC2, DD2, TT0)
VPALIGNR $12, BB0, BB0, BB0; VPALIGNR $12, BB1, BB1, BB1; VPALIGNR $12, BB2, BB2, BB2
VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1; VPALIGNR $8, CC2, CC2, CC2
VPALIGNR $4, DD0, DD0, DD0; VPALIGNR $4, DD1, DD1, DD1; VPALIGNR $4, DD2, DD2, DD2
DECQ itr2
JNE openAVX2320InnerCipherLoop
VMOVDQA ·chacha20Constants<>(SB), TT0
VPADDD TT0, AA0, AA0; VPADDD TT0, AA1, AA1; VPADDD TT0, AA2, AA2
VPADDD TT1, BB0, BB0; VPADDD TT1, BB1, BB1; VPADDD TT1, BB2, BB2
VPADDD TT2, CC0, CC0; VPADDD TT2, CC1, CC1; VPADDD TT2, CC2, CC2
VMOVDQA ·avx2IncMask<>(SB), TT0
VPADDD TT3, DD0, DD0; VPADDD TT0, TT3, TT3
VPADDD TT3, DD1, DD1; VPADDD TT0, TT3, TT3
VPADDD TT3, DD2, DD2
// Clamp and store poly key
VPERM2I128 $0x02, AA0, BB0, TT0
VPAND ·polyClampMask<>(SB), TT0, TT0
VMOVDQA TT0, rsStoreAVX2
// Stream for up to 320 bytes
VPERM2I128 $0x13, AA0, BB0, AA0
VPERM2I128 $0x13, CC0, DD0, BB0
VPERM2I128 $0x02, AA1, BB1, CC0
VPERM2I128 $0x02, CC1, DD1, DD0
VPERM2I128 $0x13, AA1, BB1, AA1
VPERM2I128 $0x13, CC1, DD1, BB1
VPERM2I128 $0x02, AA2, BB2, CC1
VPERM2I128 $0x02, CC2, DD2, DD1
VPERM2I128 $0x13, AA2, BB2, AA2
VPERM2I128 $0x13, CC2, DD2, BB2
JMP openAVX2ShortOpen
// ----------------------------------------------------------------------------
// Special optimization for the last 128 bytes of ciphertext
openAVX2Tail128:
// Need to decrypt up to 128 bytes - prepare two blocks
VMOVDQA ·chacha20Constants<>(SB), AA1
VMOVDQA state1StoreAVX2, BB1
VMOVDQA state2StoreAVX2, CC1
VMOVDQA ctr3StoreAVX2, DD1
VPADDD ·avx2IncMask<>(SB), DD1, DD1
VMOVDQA DD1, DD0
XORQ itr2, itr2
MOVQ inl, itr1
ANDQ $-16, itr1
TESTQ itr1, itr1
JE openAVX2Tail128LoopB
openAVX2Tail128LoopA:
// Perform ChaCha rounds, while hashing the remaining input
polyAdd(0(inp)(itr2*1))
polyMulAVX2
openAVX2Tail128LoopB:
ADDQ $16, itr2
chachaQR_AVX2(AA1, BB1, CC1, DD1, TT0)
VPALIGNR $4, BB1, BB1, BB1
VPALIGNR $8, CC1, CC1, CC1
VPALIGNR $12, DD1, DD1, DD1
chachaQR_AVX2(AA1, BB1, CC1, DD1, TT0)
VPALIGNR $12, BB1, BB1, BB1
VPALIGNR $8, CC1, CC1, CC1
VPALIGNR $4, DD1, DD1, DD1
CMPQ itr2, itr1
JB openAVX2Tail128LoopA
CMPQ itr2, $160
JNE openAVX2Tail128LoopB
VPADDD ·chacha20Constants<>(SB), AA1, AA1
VPADDD state1StoreAVX2, BB1, BB1
VPADDD state2StoreAVX2, CC1, CC1
VPADDD DD0, DD1, DD1
VPERM2I128 $0x02, AA1, BB1, AA0; VPERM2I128 $0x02, CC1, DD1, BB0; VPERM2I128 $0x13, AA1, BB1, CC0; VPERM2I128 $0x13, CC1, DD1, DD0
openAVX2TailLoop:
CMPQ inl, $32
JB openAVX2Tail
SUBQ $32, inl
// Load for decryption
VPXOR (inp), AA0, AA0
VMOVDQU AA0, (oup)
LEAQ (1*32)(inp), inp
LEAQ (1*32)(oup), oup
VMOVDQA BB0, AA0
VMOVDQA CC0, BB0
VMOVDQA DD0, CC0
JMP openAVX2TailLoop
openAVX2Tail:
CMPQ inl, $16
VMOVDQA A0, A1
JB openAVX2TailDone
SUBQ $16, inl
// Load for decryption
VPXOR (inp), A0, T0
VMOVDQU T0, (oup)
LEAQ (1*16)(inp), inp
LEAQ (1*16)(oup), oup
VPERM2I128 $0x11, AA0, AA0, AA0
VMOVDQA A0, A1
openAVX2TailDone:
VZEROUPPER
JMP openSSETail16
// ----------------------------------------------------------------------------
// Special optimization for the last 256 bytes of ciphertext
openAVX2Tail256:
// Need to decrypt up to 256 bytes - prepare four blocks
VMOVDQA ·chacha20Constants<>(SB), AA0; VMOVDQA AA0, AA1
VMOVDQA state1StoreAVX2, BB0; VMOVDQA BB0, BB1
VMOVDQA state2StoreAVX2, CC0; VMOVDQA CC0, CC1
VMOVDQA ctr3StoreAVX2, DD0
VPADDD ·avx2IncMask<>(SB), DD0, DD0
VPADDD ·avx2IncMask<>(SB), DD0, DD1
VMOVDQA DD0, TT1
VMOVDQA DD1, TT2
// Compute the number of iterations that will hash data
MOVQ inl, tmpStoreAVX2
MOVQ inl, itr1
SUBQ $128, itr1
SHRQ $4, itr1
MOVQ $10, itr2
CMPQ itr1, $10
CMOVQGT itr2, itr1
MOVQ inp, inl
XORQ itr2, itr2
openAVX2Tail256LoopA:
polyAdd(0(inl))
polyMulAVX2
LEAQ 16(inl), inl
// Perform ChaCha rounds, while hashing the remaining input
openAVX2Tail256LoopB:
chachaQR_AVX2(AA0, BB0, CC0, DD0, TT0); chachaQR_AVX2(AA1, BB1, CC1, DD1, TT0)
VPALIGNR $4, BB0, BB0, BB0; VPALIGNR $4, BB1, BB1, BB1
VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1
VPALIGNR $12, DD0, DD0, DD0; VPALIGNR $12, DD1, DD1, DD1
INCQ itr2
chachaQR_AVX2(AA0, BB0, CC0, DD0, TT0); chachaQR_AVX2(AA1, BB1, CC1, DD1, TT0)
VPALIGNR $12, BB0, BB0, BB0; VPALIGNR $12, BB1, BB1, BB1
VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1
VPALIGNR $4, DD0, DD0, DD0; VPALIGNR $4, DD1, DD1, DD1
CMPQ itr2, itr1
JB openAVX2Tail256LoopA
CMPQ itr2, $10
JNE openAVX2Tail256LoopB
MOVQ inl, itr2
SUBQ inp, inl
MOVQ inl, itr1
MOVQ tmpStoreAVX2, inl
// Hash the remainder of data (if any)
openAVX2Tail256Hash:
ADDQ $16, itr1
CMPQ itr1, inl
JGT openAVX2Tail256HashEnd
polyAdd (0(itr2))
polyMulAVX2
LEAQ 16(itr2), itr2
JMP openAVX2Tail256Hash
// Store 128 bytes safely, then go to store loop
openAVX2Tail256HashEnd:
VPADDD ·chacha20Constants<>(SB), AA0, AA0; VPADDD ·chacha20Constants<>(SB), AA1, AA1
VPADDD state1StoreAVX2, BB0, BB0; VPADDD state1StoreAVX2, BB1, BB1
VPADDD state2StoreAVX2, CC0, CC0; VPADDD state2StoreAVX2, CC1, CC1
VPADDD TT1, DD0, DD0; VPADDD TT2, DD1, DD1
VPERM2I128 $0x02, AA0, BB0, AA2; VPERM2I128 $0x02, CC0, DD0, BB2; VPERM2I128 $0x13, AA0, BB0, CC2; VPERM2I128 $0x13, CC0, DD0, DD2
VPERM2I128 $0x02, AA1, BB1, AA0; VPERM2I128 $0x02, CC1, DD1, BB0; VPERM2I128 $0x13, AA1, BB1, CC0; VPERM2I128 $0x13, CC1, DD1, DD0
VPXOR (0*32)(inp), AA2, AA2; VPXOR (1*32)(inp), BB2, BB2; VPXOR (2*32)(inp), CC2, CC2; VPXOR (3*32)(inp), DD2, DD2
VMOVDQU AA2, (0*32)(oup); VMOVDQU BB2, (1*32)(oup); VMOVDQU CC2, (2*32)(oup); VMOVDQU DD2, (3*32)(oup)
LEAQ (4*32)(inp), inp
LEAQ (4*32)(oup), oup
SUBQ $4*32, inl
JMP openAVX2TailLoop
// ----------------------------------------------------------------------------
// Special optimization for the last 384 bytes of ciphertext
openAVX2Tail384:
// Need to decrypt up to 384 bytes - prepare six blocks
VMOVDQA ·chacha20Constants<>(SB), AA0; VMOVDQA AA0, AA1; VMOVDQA AA0, AA2
VMOVDQA state1StoreAVX2, BB0; VMOVDQA BB0, BB1; VMOVDQA BB0, BB2
VMOVDQA state2StoreAVX2, CC0; VMOVDQA CC0, CC1; VMOVDQA CC0, CC2
VMOVDQA ctr3StoreAVX2, DD0
VPADDD ·avx2IncMask<>(SB), DD0, DD0
VPADDD ·avx2IncMask<>(SB), DD0, DD1
VPADDD ·avx2IncMask<>(SB), DD1, DD2
VMOVDQA DD0, ctr0StoreAVX2
VMOVDQA DD1, ctr1StoreAVX2
VMOVDQA DD2, ctr2StoreAVX2
// Compute the number of iterations that will hash two blocks of data
MOVQ inl, tmpStoreAVX2
MOVQ inl, itr1
SUBQ $256, itr1
SHRQ $4, itr1
ADDQ $6, itr1
MOVQ $10, itr2
CMPQ itr1, $10
CMOVQGT itr2, itr1
MOVQ inp, inl
XORQ itr2, itr2
// Perform ChaCha rounds, while hashing the remaining input
openAVX2Tail384LoopB:
polyAdd(0(inl))
polyMulAVX2
LEAQ 16(inl), inl
openAVX2Tail384LoopA:
chachaQR_AVX2(AA0, BB0, CC0, DD0, TT0); chachaQR_AVX2(AA1, BB1, CC1, DD1, TT0); chachaQR_AVX2(AA2, BB2, CC2, DD2, TT0)
VPALIGNR $4, BB0, BB0, BB0; VPALIGNR $4, BB1, BB1, BB1; VPALIGNR $4, BB2, BB2, BB2
VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1; VPALIGNR $8, CC2, CC2, CC2
VPALIGNR $12, DD0, DD0, DD0; VPALIGNR $12, DD1, DD1, DD1; VPALIGNR $12, DD2, DD2, DD2
polyAdd(0(inl))
polyMulAVX2
LEAQ 16(inl), inl
INCQ itr2
chachaQR_AVX2(AA0, BB0, CC0, DD0, TT0); chachaQR_AVX2(AA1, BB1, CC1, DD1, TT0); chachaQR_AVX2(AA2, BB2, CC2, DD2, TT0)
VPALIGNR $12, BB0, BB0, BB0; VPALIGNR $12, BB1, BB1, BB1; VPALIGNR $12, BB2, BB2, BB2
VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1; VPALIGNR $8, CC2, CC2, CC2
VPALIGNR $4, DD0, DD0, DD0; VPALIGNR $4, DD1, DD1, DD1; VPALIGNR $4, DD2, DD2, DD2
CMPQ itr2, itr1
JB openAVX2Tail384LoopB
CMPQ itr2, $10
JNE openAVX2Tail384LoopA
MOVQ inl, itr2
SUBQ inp, inl
MOVQ inl, itr1
MOVQ tmpStoreAVX2, inl
openAVX2Tail384Hash:
ADDQ $16, itr1
CMPQ itr1, inl
JGT openAVX2Tail384HashEnd
polyAdd(0(itr2))
polyMulAVX2
LEAQ 16(itr2), itr2
JMP openAVX2Tail384Hash
// Store 256 bytes safely, then go to store loop
openAVX2Tail384HashEnd:
VPADDD ·chacha20Constants<>(SB), AA0, AA0; VPADDD ·chacha20Constants<>(SB), AA1, AA1; VPADDD ·chacha20Constants<>(SB), AA2, AA2
VPADDD state1StoreAVX2, BB0, BB0; VPADDD state1StoreAVX2, BB1, BB1; VPADDD state1StoreAVX2, BB2, BB2
VPADDD state2StoreAVX2, CC0, CC0; VPADDD state2StoreAVX2, CC1, CC1; VPADDD state2StoreAVX2, CC2, CC2
VPADDD ctr0StoreAVX2, DD0, DD0; VPADDD ctr1StoreAVX2, DD1, DD1; VPADDD ctr2StoreAVX2, DD2, DD2
VPERM2I128 $0x02, AA0, BB0, TT0; VPERM2I128 $0x02, CC0, DD0, TT1; VPERM2I128 $0x13, AA0, BB0, TT2; VPERM2I128 $0x13, CC0, DD0, TT3
VPXOR (0*32)(inp), TT0, TT0; VPXOR (1*32)(inp), TT1, TT1; VPXOR (2*32)(inp), TT2, TT2; VPXOR (3*32)(inp), TT3, TT3
VMOVDQU TT0, (0*32)(oup); VMOVDQU TT1, (1*32)(oup); VMOVDQU TT2, (2*32)(oup); VMOVDQU TT3, (3*32)(oup)
VPERM2I128 $0x02, AA1, BB1, TT0; VPERM2I128 $0x02, CC1, DD1, TT1; VPERM2I128 $0x13, AA1, BB1, TT2; VPERM2I128 $0x13, CC1, DD1, TT3
VPXOR (4*32)(inp), TT0, TT0; VPXOR (5*32)(inp), TT1, TT1; VPXOR (6*32)(inp), TT2, TT2; VPXOR (7*32)(inp), TT3, TT3
VMOVDQU TT0, (4*32)(oup); VMOVDQU TT1, (5*32)(oup); VMOVDQU TT2, (6*32)(oup); VMOVDQU TT3, (7*32)(oup)
VPERM2I128 $0x02, AA2, BB2, AA0; VPERM2I128 $0x02, CC2, DD2, BB0; VPERM2I128 $0x13, AA2, BB2, CC0; VPERM2I128 $0x13, CC2, DD2, DD0
LEAQ (8*32)(inp), inp
LEAQ (8*32)(oup), oup
SUBQ $8*32, inl
JMP openAVX2TailLoop
// ----------------------------------------------------------------------------
// Special optimization for the last 512 bytes of ciphertext
openAVX2Tail512:
VMOVDQU ·chacha20Constants<>(SB), AA0; VMOVDQA AA0, AA1; VMOVDQA AA0, AA2; VMOVDQA AA0, AA3
VMOVDQA state1StoreAVX2, BB0; VMOVDQA BB0, BB1; VMOVDQA BB0, BB2; VMOVDQA BB0, BB3
VMOVDQA state2StoreAVX2, CC0; VMOVDQA CC0, CC1; VMOVDQA CC0, CC2; VMOVDQA CC0, CC3
VMOVDQA ctr3StoreAVX2, DD0; VPADDD ·avx2IncMask<>(SB), DD0, DD0; VPADDD ·avx2IncMask<>(SB), DD0, DD1; VPADDD ·avx2IncMask<>(SB), DD1, DD2; VPADDD ·avx2IncMask<>(SB), DD2, DD3
VMOVDQA DD0, ctr0StoreAVX2; VMOVDQA DD1, ctr1StoreAVX2; VMOVDQA DD2, ctr2StoreAVX2; VMOVDQA DD3, ctr3StoreAVX2
XORQ itr1, itr1
MOVQ inp, itr2
openAVX2Tail512LoopB:
polyAdd(0(itr2))
polyMulAVX2
LEAQ (2*8)(itr2), itr2
openAVX2Tail512LoopA:
VPADDD BB0, AA0, AA0; VPADDD BB1, AA1, AA1; VPADDD BB2, AA2, AA2; VPADDD BB3, AA3, AA3
VPXOR AA0, DD0, DD0; VPXOR AA1, DD1, DD1; VPXOR AA2, DD2, DD2; VPXOR AA3, DD3, DD3
VPSHUFB ·rol16<>(SB), DD0, DD0; VPSHUFB ·rol16<>(SB), DD1, DD1; VPSHUFB ·rol16<>(SB), DD2, DD2; VPSHUFB ·rol16<>(SB), DD3, DD3
VPADDD DD0, CC0, CC0; VPADDD DD1, CC1, CC1; VPADDD DD2, CC2, CC2; VPADDD DD3, CC3, CC3
VPXOR CC0, BB0, BB0; VPXOR CC1, BB1, BB1; VPXOR CC2, BB2, BB2; VPXOR CC3, BB3, BB3
VMOVDQA CC3, tmpStoreAVX2
VPSLLD $12, BB0, CC3; VPSRLD $20, BB0, BB0; VPXOR CC3, BB0, BB0
VPSLLD $12, BB1, CC3; VPSRLD $20, BB1, BB1; VPXOR CC3, BB1, BB1
VPSLLD $12, BB2, CC3; VPSRLD $20, BB2, BB2; VPXOR CC3, BB2, BB2
VPSLLD $12, BB3, CC3; VPSRLD $20, BB3, BB3; VPXOR CC3, BB3, BB3
VMOVDQA tmpStoreAVX2, CC3
polyAdd(0*8(itr2))
polyMulAVX2
VPADDD BB0, AA0, AA0; VPADDD BB1, AA1, AA1; VPADDD BB2, AA2, AA2; VPADDD BB3, AA3, AA3
VPXOR AA0, DD0, DD0; VPXOR AA1, DD1, DD1; VPXOR AA2, DD2, DD2; VPXOR AA3, DD3, DD3
VPSHUFB ·rol8<>(SB), DD0, DD0; VPSHUFB ·rol8<>(SB), DD1, DD1; VPSHUFB ·rol8<>(SB), DD2, DD2; VPSHUFB ·rol8<>(SB), DD3, DD3
VPADDD DD0, CC0, CC0; VPADDD DD1, CC1, CC1; VPADDD DD2, CC2, CC2; VPADDD DD3, CC3, CC3
VPXOR CC0, BB0, BB0; VPXOR CC1, BB1, BB1; VPXOR CC2, BB2, BB2; VPXOR CC3, BB3, BB3
VMOVDQA CC3, tmpStoreAVX2
VPSLLD $7, BB0, CC3; VPSRLD $25, BB0, BB0; VPXOR CC3, BB0, BB0
VPSLLD $7, BB1, CC3; VPSRLD $25, BB1, BB1; VPXOR CC3, BB1, BB1
VPSLLD $7, BB2, CC3; VPSRLD $25, BB2, BB2; VPXOR CC3, BB2, BB2
VPSLLD $7, BB3, CC3; VPSRLD $25, BB3, BB3; VPXOR CC3, BB3, BB3
VMOVDQA tmpStoreAVX2, CC3
VPALIGNR $4, BB0, BB0, BB0; VPALIGNR $4, BB1, BB1, BB1; VPALIGNR $4, BB2, BB2, BB2; VPALIGNR $4, BB3, BB3, BB3
VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1; VPALIGNR $8, CC2, CC2, CC2; VPALIGNR $8, CC3, CC3, CC3
VPALIGNR $12, DD0, DD0, DD0; VPALIGNR $12, DD1, DD1, DD1; VPALIGNR $12, DD2, DD2, DD2; VPALIGNR $12, DD3, DD3, DD3
VPADDD BB0, AA0, AA0; VPADDD BB1, AA1, AA1; VPADDD BB2, AA2, AA2; VPADDD BB3, AA3, AA3
VPXOR AA0, DD0, DD0; VPXOR AA1, DD1, DD1; VPXOR AA2, DD2, DD2; VPXOR AA3, DD3, DD3
VPSHUFB ·rol16<>(SB), DD0, DD0; VPSHUFB ·rol16<>(SB), DD1, DD1; VPSHUFB ·rol16<>(SB), DD2, DD2; VPSHUFB ·rol16<>(SB), DD3, DD3
VPADDD DD0, CC0, CC0; VPADDD DD1, CC1, CC1; VPADDD DD2, CC2, CC2; VPADDD DD3, CC3, CC3
VPXOR CC0, BB0, BB0; VPXOR CC1, BB1, BB1; VPXOR CC2, BB2, BB2; VPXOR CC3, BB3, BB3
polyAdd(2*8(itr2))
polyMulAVX2
LEAQ (4*8)(itr2), itr2
VMOVDQA CC3, tmpStoreAVX2
VPSLLD $12, BB0, CC3; VPSRLD $20, BB0, BB0; VPXOR CC3, BB0, BB0
VPSLLD $12, BB1, CC3; VPSRLD $20, BB1, BB1; VPXOR CC3, BB1, BB1
VPSLLD $12, BB2, CC3; VPSRLD $20, BB2, BB2; VPXOR CC3, BB2, BB2
VPSLLD $12, BB3, CC3; VPSRLD $20, BB3, BB3; VPXOR CC3, BB3, BB3
VMOVDQA tmpStoreAVX2, CC3
VPADDD BB0, AA0, AA0; VPADDD BB1, AA1, AA1; VPADDD BB2, AA2, AA2; VPADDD BB3, AA3, AA3
VPXOR AA0, DD0, DD0; VPXOR AA1, DD1, DD1; VPXOR AA2, DD2, DD2; VPXOR AA3, DD3, DD3
VPSHUFB ·rol8<>(SB), DD0, DD0; VPSHUFB ·rol8<>(SB), DD1, DD1; VPSHUFB ·rol8<>(SB), DD2, DD2; VPSHUFB ·rol8<>(SB), DD3, DD3
VPADDD DD0, CC0, CC0; VPADDD DD1, CC1, CC1; VPADDD DD2, CC2, CC2; VPADDD DD3, CC3, CC3
VPXOR CC0, BB0, BB0; VPXOR CC1, BB1, BB1; VPXOR CC2, BB2, BB2; VPXOR CC3, BB3, BB3
VMOVDQA CC3, tmpStoreAVX2
VPSLLD $7, BB0, CC3; VPSRLD $25, BB0, BB0; VPXOR CC3, BB0, BB0
VPSLLD $7, BB1, CC3; VPSRLD $25, BB1, BB1; VPXOR CC3, BB1, BB1
VPSLLD $7, BB2, CC3; VPSRLD $25, BB2, BB2; VPXOR CC3, BB2, BB2
VPSLLD $7, BB3, CC3; VPSRLD $25, BB3, BB3; VPXOR CC3, BB3, BB3
VMOVDQA tmpStoreAVX2, CC3
VPALIGNR $12, BB0, BB0, BB0; VPALIGNR $12, BB1, BB1, BB1; VPALIGNR $12, BB2, BB2, BB2; VPALIGNR $12, BB3, BB3, BB3
VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1; VPALIGNR $8, CC2, CC2, CC2; VPALIGNR $8, CC3, CC3, CC3
VPALIGNR $4, DD0, DD0, DD0; VPALIGNR $4, DD1, DD1, DD1; VPALIGNR $4, DD2, DD2, DD2; VPALIGNR $4, DD3, DD3, DD3
INCQ itr1
CMPQ itr1, $4
JLT openAVX2Tail512LoopB
CMPQ itr1, $10
JNE openAVX2Tail512LoopA
MOVQ inl, itr1
SUBQ $384, itr1
ANDQ $-16, itr1
openAVX2Tail512HashLoop:
TESTQ itr1, itr1
JE openAVX2Tail512HashEnd
polyAdd(0(itr2))
polyMulAVX2
LEAQ 16(itr2), itr2
SUBQ $16, itr1
JMP openAVX2Tail512HashLoop
openAVX2Tail512HashEnd:
VPADDD ·chacha20Constants<>(SB), AA0, AA0; VPADDD ·chacha20Constants<>(SB), AA1, AA1; VPADDD ·chacha20Constants<>(SB), AA2, AA2; VPADDD ·chacha20Constants<>(SB), AA3, AA3
VPADDD state1StoreAVX2, BB0, BB0; VPADDD state1StoreAVX2, BB1, BB1; VPADDD state1StoreAVX2, BB2, BB2; VPADDD state1StoreAVX2, BB3, BB3
VPADDD state2StoreAVX2, CC0, CC0; VPADDD state2StoreAVX2, CC1, CC1; VPADDD state2StoreAVX2, CC2, CC2; VPADDD state2StoreAVX2, CC3, CC3
VPADDD ctr0StoreAVX2, DD0, DD0; VPADDD ctr1StoreAVX2, DD1, DD1; VPADDD ctr2StoreAVX2, DD2, DD2; VPADDD ctr3StoreAVX2, DD3, DD3
VMOVDQA CC3, tmpStoreAVX2
VPERM2I128 $0x02, AA0, BB0, CC3; VPERM2I128 $0x13, AA0, BB0, BB0; VPERM2I128 $0x02, CC0, DD0, AA0; VPERM2I128 $0x13, CC0, DD0, CC0
VPXOR (0*32)(inp), CC3, CC3; VPXOR (1*32)(inp), AA0, AA0; VPXOR (2*32)(inp), BB0, BB0; VPXOR (3*32)(inp), CC0, CC0
VMOVDQU CC3, (0*32)(oup); VMOVDQU AA0, (1*32)(oup); VMOVDQU BB0, (2*32)(oup); VMOVDQU CC0, (3*32)(oup)
VPERM2I128 $0x02, AA1, BB1, AA0; VPERM2I128 $0x02, CC1, DD1, BB0; VPERM2I128 $0x13, AA1, BB1, CC0; VPERM2I128 $0x13, CC1, DD1, DD0
VPXOR (4*32)(inp), AA0, AA0; VPXOR (5*32)(inp), BB0, BB0; VPXOR (6*32)(inp), CC0, CC0; VPXOR (7*32)(inp), DD0, DD0
VMOVDQU AA0, (4*32)(oup); VMOVDQU BB0, (5*32)(oup); VMOVDQU CC0, (6*32)(oup); VMOVDQU DD0, (7*32)(oup)
VPERM2I128 $0x02, AA2, BB2, AA0; VPERM2I128 $0x02, CC2, DD2, BB0; VPERM2I128 $0x13, AA2, BB2, CC0; VPERM2I128 $0x13, CC2, DD2, DD0
VPXOR (8*32)(inp), AA0, AA0; VPXOR (9*32)(inp), BB0, BB0; VPXOR (10*32)(inp), CC0, CC0; VPXOR (11*32)(inp), DD0, DD0
VMOVDQU AA0, (8*32)(oup); VMOVDQU BB0, (9*32)(oup); VMOVDQU CC0, (10*32)(oup); VMOVDQU DD0, (11*32)(oup)
VPERM2I128 $0x02, AA3, BB3, AA0; VPERM2I128 $0x02, tmpStoreAVX2, DD3, BB0; VPERM2I128 $0x13, AA3, BB3, CC0; VPERM2I128 $0x13, tmpStoreAVX2, DD3, DD0
LEAQ (12*32)(inp), inp
LEAQ (12*32)(oup), oup
SUBQ $12*32, inl
JMP openAVX2TailLoop
// ----------------------------------------------------------------------------
// ----------------------------------------------------------------------------
// func chacha20Poly1305Seal(dst, key, src, ad []byte)
TEXT ·chacha20Poly1305Seal(SB), 0, $288-96
// For aligned stack access
MOVQ SP, BP
ADDQ $32, BP
ANDQ $-32, BP
MOVQ dst+0(FP), oup
MOVQ key+24(FP), keyp
MOVQ src+48(FP), inp
MOVQ src_len+56(FP), inl
MOVQ ad+72(FP), adp
CMPB ·useAVX2(SB), $1
JE chacha20Poly1305Seal_AVX2
// Special optimization, for very short buffers
CMPQ inl, $128
JBE sealSSE128 // About 15% faster
// In the seal case - prepare the poly key + 3 blocks of stream in the first iteration
MOVOU ·chacha20Constants<>(SB), A0
MOVOU (1*16)(keyp), B0
MOVOU (2*16)(keyp), C0
MOVOU (3*16)(keyp), D0
// Store state on stack for future use
MOVO B0, state1Store
MOVO C0, state2Store
// Load state, increment counter blocks
MOVO A0, A1; MOVO B0, B1; MOVO C0, C1; MOVO D0, D1; PADDL ·sseIncMask<>(SB), D1
MOVO A1, A2; MOVO B1, B2; MOVO C1, C2; MOVO D1, D2; PADDL ·sseIncMask<>(SB), D2
MOVO A2, A3; MOVO B2, B3; MOVO C2, C3; MOVO D2, D3; PADDL ·sseIncMask<>(SB), D3
// Store counters
MOVO D0, ctr0Store; MOVO D1, ctr1Store; MOVO D2, ctr2Store; MOVO D3, ctr3Store
MOVQ $10, itr2
sealSSEIntroLoop:
MOVO C3, tmpStore
chachaQR(A0, B0, C0, D0, C3); chachaQR(A1, B1, C1, D1, C3); chachaQR(A2, B2, C2, D2, C3)
MOVO tmpStore, C3
MOVO C1, tmpStore
chachaQR(A3, B3, C3, D3, C1)
MOVO tmpStore, C1
shiftB0Left; shiftB1Left; shiftB2Left; shiftB3Left
shiftC0Left; shiftC1Left; shiftC2Left; shiftC3Left
shiftD0Left; shiftD1Left; shiftD2Left; shiftD3Left
MOVO C3, tmpStore
chachaQR(A0, B0, C0, D0, C3); chachaQR(A1, B1, C1, D1, C3); chachaQR(A2, B2, C2, D2, C3)
MOVO tmpStore, C3
MOVO C1, tmpStore
chachaQR(A3, B3, C3, D3, C1)
MOVO tmpStore, C1
shiftB0Right; shiftB1Right; shiftB2Right; shiftB3Right
shiftC0Right; shiftC1Right; shiftC2Right; shiftC3Right
shiftD0Right; shiftD1Right; shiftD2Right; shiftD3Right
DECQ itr2
JNE sealSSEIntroLoop
// Add in the state
PADDD ·chacha20Constants<>(SB), A0; PADDD ·chacha20Constants<>(SB), A1; PADDD ·chacha20Constants<>(SB), A2; PADDD ·chacha20Constants<>(SB), A3
PADDD state1Store, B0; PADDD state1Store, B1; PADDD state1Store, B2; PADDD state1Store, B3
PADDD state2Store, C1; PADDD state2Store, C2; PADDD state2Store, C3
PADDD ctr1Store, D1; PADDD ctr2Store, D2; PADDD ctr3Store, D3
// Clamp and store the key
PAND ·polyClampMask<>(SB), A0
MOVO A0, rStore
MOVO B0, sStore
// Hash AAD
MOVQ ad_len+80(FP), itr2
CALL polyHashADInternal<>(SB)
MOVOU (0*16)(inp), A0; MOVOU (1*16)(inp), B0; MOVOU (2*16)(inp), C0; MOVOU (3*16)(inp), D0
PXOR A0, A1; PXOR B0, B1; PXOR C0, C1; PXOR D0, D1
MOVOU A1, (0*16)(oup); MOVOU B1, (1*16)(oup); MOVOU C1, (2*16)(oup); MOVOU D1, (3*16)(oup)
MOVOU (4*16)(inp), A0; MOVOU (5*16)(inp), B0; MOVOU (6*16)(inp), C0; MOVOU (7*16)(inp), D0
PXOR A0, A2; PXOR B0, B2; PXOR C0, C2; PXOR D0, D2
MOVOU A2, (4*16)(oup); MOVOU B2, (5*16)(oup); MOVOU C2, (6*16)(oup); MOVOU D2, (7*16)(oup)
MOVQ $128, itr1
SUBQ $128, inl
LEAQ 128(inp), inp
MOVO A3, A1; MOVO B3, B1; MOVO C3, C1; MOVO D3, D1
CMPQ inl, $64
JBE sealSSE128SealHash
MOVOU (0*16)(inp), A0; MOVOU (1*16)(inp), B0; MOVOU (2*16)(inp), C0; MOVOU (3*16)(inp), D0
PXOR A0, A3; PXOR B0, B3; PXOR C0, C3; PXOR D0, D3
MOVOU A3, (8*16)(oup); MOVOU B3, (9*16)(oup); MOVOU C3, (10*16)(oup); MOVOU D3, (11*16)(oup)
ADDQ $64, itr1
SUBQ $64, inl
LEAQ 64(inp), inp
MOVQ $2, itr1
MOVQ $8, itr2
CMPQ inl, $64
JBE sealSSETail64
CMPQ inl, $128
JBE sealSSETail128
CMPQ inl, $192
JBE sealSSETail192
sealSSEMainLoop:
// Load state, increment counter blocks
MOVO ·chacha20Constants<>(SB), A0; MOVO state1Store, B0; MOVO state2Store, C0; MOVO ctr3Store, D0; PADDL ·sseIncMask<>(SB), D0
MOVO A0, A1; MOVO B0, B1; MOVO C0, C1; MOVO D0, D1; PADDL ·sseIncMask<>(SB), D1
MOVO A1, A2; MOVO B1, B2; MOVO C1, C2; MOVO D1, D2; PADDL ·sseIncMask<>(SB), D2
MOVO A2, A3; MOVO B2, B3; MOVO C2, C3; MOVO D2, D3; PADDL ·sseIncMask<>(SB), D3
// Store counters
MOVO D0, ctr0Store; MOVO D1, ctr1Store; MOVO D2, ctr2Store; MOVO D3, ctr3Store
sealSSEInnerLoop:
MOVO C3, tmpStore
chachaQR(A0, B0, C0, D0, C3); chachaQR(A1, B1, C1, D1, C3); chachaQR(A2, B2, C2, D2, C3)
MOVO tmpStore, C3
MOVO C1, tmpStore
chachaQR(A3, B3, C3, D3, C1)
MOVO tmpStore, C1
polyAdd(0(oup))
shiftB0Left; shiftB1Left; shiftB2Left; shiftB3Left
shiftC0Left; shiftC1Left; shiftC2Left; shiftC3Left
shiftD0Left; shiftD1Left; shiftD2Left; shiftD3Left
polyMulStage1
polyMulStage2
LEAQ (2*8)(oup), oup
MOVO C3, tmpStore
chachaQR(A0, B0, C0, D0, C3); chachaQR(A1, B1, C1, D1, C3); chachaQR(A2, B2, C2, D2, C3)
MOVO tmpStore, C3
MOVO C1, tmpStore
polyMulStage3
chachaQR(A3, B3, C3, D3, C1)
MOVO tmpStore, C1
polyMulReduceStage
shiftB0Right; shiftB1Right; shiftB2Right; shiftB3Right
shiftC0Right; shiftC1Right; shiftC2Right; shiftC3Right
shiftD0Right; shiftD1Right; shiftD2Right; shiftD3Right
DECQ itr2
JGE sealSSEInnerLoop
polyAdd(0(oup))
polyMul
LEAQ (2*8)(oup), oup
DECQ itr1
JG sealSSEInnerLoop
// Add in the state
PADDD ·chacha20Constants<>(SB), A0; PADDD ·chacha20Constants<>(SB), A1; PADDD ·chacha20Constants<>(SB), A2; PADDD ·chacha20Constants<>(SB), A3
PADDD state1Store, B0; PADDD state1Store, B1; PADDD state1Store, B2; PADDD state1Store, B3
PADDD state2Store, C0; PADDD state2Store, C1; PADDD state2Store, C2; PADDD state2Store, C3
PADDD ctr0Store, D0; PADDD ctr1Store, D1; PADDD ctr2Store, D2; PADDD ctr3Store, D3
MOVO D3, tmpStore
// Load - xor - store
MOVOU (0*16)(inp), D3; PXOR D3, A0
MOVOU (1*16)(inp), D3; PXOR D3, B0
MOVOU (2*16)(inp), D3; PXOR D3, C0
MOVOU (3*16)(inp), D3; PXOR D3, D0
MOVOU A0, (0*16)(oup)
MOVOU B0, (1*16)(oup)
MOVOU C0, (2*16)(oup)
MOVOU D0, (3*16)(oup)
MOVO tmpStore, D3
MOVOU (4*16)(inp), A0; MOVOU (5*16)(inp), B0; MOVOU (6*16)(inp), C0; MOVOU (7*16)(inp), D0
PXOR A0, A1; PXOR B0, B1; PXOR C0, C1; PXOR D0, D1
MOVOU A1, (4*16)(oup); MOVOU B1, (5*16)(oup); MOVOU C1, (6*16)(oup); MOVOU D1, (7*16)(oup)
MOVOU (8*16)(inp), A0; MOVOU (9*16)(inp), B0; MOVOU (10*16)(inp), C0; MOVOU (11*16)(inp), D0
PXOR A0, A2; PXOR B0, B2; PXOR C0, C2; PXOR D0, D2
MOVOU A2, (8*16)(oup); MOVOU B2, (9*16)(oup); MOVOU C2, (10*16)(oup); MOVOU D2, (11*16)(oup)
ADDQ $192, inp
MOVQ $192, itr1
SUBQ $192, inl
MOVO A3, A1
MOVO B3, B1
MOVO C3, C1
MOVO D3, D1
CMPQ inl, $64
JBE sealSSE128SealHash
MOVOU (0*16)(inp), A0; MOVOU (1*16)(inp), B0; MOVOU (2*16)(inp), C0; MOVOU (3*16)(inp), D0
PXOR A0, A3; PXOR B0, B3; PXOR C0, C3; PXOR D0, D3
MOVOU A3, (12*16)(oup); MOVOU B3, (13*16)(oup); MOVOU C3, (14*16)(oup); MOVOU D3, (15*16)(oup)
LEAQ 64(inp), inp
SUBQ $64, inl
MOVQ $6, itr1
MOVQ $4, itr2
CMPQ inl, $192
JG sealSSEMainLoop
MOVQ inl, itr1
TESTQ inl, inl
JE sealSSE128SealHash
MOVQ $6, itr1
CMPQ inl, $64
JBE sealSSETail64
CMPQ inl, $128
JBE sealSSETail128
JMP sealSSETail192
// ----------------------------------------------------------------------------
// Special optimization for the last 64 bytes of plaintext
sealSSETail64:
// Need to encrypt up to 64 bytes - prepare single block, hash 192 or 256 bytes
MOVO ·chacha20Constants<>(SB), A1
MOVO state1Store, B1
MOVO state2Store, C1
MOVO ctr3Store, D1
PADDL ·sseIncMask<>(SB), D1
MOVO D1, ctr0Store
sealSSETail64LoopA:
// Perform ChaCha rounds, while hashing the previously encrypted ciphertext
polyAdd(0(oup))
polyMul
LEAQ 16(oup), oup
sealSSETail64LoopB:
chachaQR(A1, B1, C1, D1, T1)
shiftB1Left; shiftC1Left; shiftD1Left
chachaQR(A1, B1, C1, D1, T1)
shiftB1Right; shiftC1Right; shiftD1Right
polyAdd(0(oup))
polyMul
LEAQ 16(oup), oup
DECQ itr1
JG sealSSETail64LoopA
DECQ itr2
JGE sealSSETail64LoopB
PADDL ·chacha20Constants<>(SB), A1
PADDL state1Store, B1
PADDL state2Store, C1
PADDL ctr0Store, D1
JMP sealSSE128Seal
// ----------------------------------------------------------------------------
// Special optimization for the last 128 bytes of plaintext
sealSSETail128:
// Need to encrypt up to 128 bytes - prepare two blocks, hash 192 or 256 bytes
MOVO ·chacha20Constants<>(SB), A0; MOVO state1Store, B0; MOVO state2Store, C0; MOVO ctr3Store, D0; PADDL ·sseIncMask<>(SB), D0; MOVO D0, ctr0Store
MOVO A0, A1; MOVO B0, B1; MOVO C0, C1; MOVO D0, D1; PADDL ·sseIncMask<>(SB), D1; MOVO D1, ctr1Store
sealSSETail128LoopA:
// Perform ChaCha rounds, while hashing the previously encrypted ciphertext
polyAdd(0(oup))
polyMul
LEAQ 16(oup), oup
sealSSETail128LoopB:
chachaQR(A0, B0, C0, D0, T0); chachaQR(A1, B1, C1, D1, T0)
shiftB0Left; shiftC0Left; shiftD0Left
shiftB1Left; shiftC1Left; shiftD1Left
polyAdd(0(oup))
polyMul
LEAQ 16(oup), oup
chachaQR(A0, B0, C0, D0, T0); chachaQR(A1, B1, C1, D1, T0)
shiftB0Right; shiftC0Right; shiftD0Right
shiftB1Right; shiftC1Right; shiftD1Right
DECQ itr1
JG sealSSETail128LoopA
DECQ itr2
JGE sealSSETail128LoopB
PADDL ·chacha20Constants<>(SB), A0; PADDL ·chacha20Constants<>(SB), A1
PADDL state1Store, B0; PADDL state1Store, B1
PADDL state2Store, C0; PADDL state2Store, C1
PADDL ctr0Store, D0; PADDL ctr1Store, D1
MOVOU (0*16)(inp), T0; MOVOU (1*16)(inp), T1; MOVOU (2*16)(inp), T2; MOVOU (3*16)(inp), T3
PXOR T0, A0; PXOR T1, B0; PXOR T2, C0; PXOR T3, D0
MOVOU A0, (0*16)(oup); MOVOU B0, (1*16)(oup); MOVOU C0, (2*16)(oup); MOVOU D0, (3*16)(oup)
MOVQ $64, itr1
LEAQ 64(inp), inp
SUBQ $64, inl
JMP sealSSE128SealHash
// ----------------------------------------------------------------------------
// Special optimization for the last 192 bytes of plaintext
sealSSETail192:
// Need to encrypt up to 192 bytes - prepare three blocks, hash 192 or 256 bytes
MOVO ·chacha20Constants<>(SB), A0; MOVO state1Store, B0; MOVO state2Store, C0; MOVO ctr3Store, D0; PADDL ·sseIncMask<>(SB), D0; MOVO D0, ctr0Store
MOVO A0, A1; MOVO B0, B1; MOVO C0, C1; MOVO D0, D1; PADDL ·sseIncMask<>(SB), D1; MOVO D1, ctr1Store
MOVO A1, A2; MOVO B1, B2; MOVO C1, C2; MOVO D1, D2; PADDL ·sseIncMask<>(SB), D2; MOVO D2, ctr2Store
sealSSETail192LoopA:
// Perform ChaCha rounds, while hashing the previously encrypted ciphertext
polyAdd(0(oup))
polyMul
LEAQ 16(oup), oup
sealSSETail192LoopB:
chachaQR(A0, B0, C0, D0, T0); chachaQR(A1, B1, C1, D1, T0); chachaQR(A2, B2, C2, D2, T0)
shiftB0Left; shiftC0Left; shiftD0Left
shiftB1Left; shiftC1Left; shiftD1Left
shiftB2Left; shiftC2Left; shiftD2Left
polyAdd(0(oup))
polyMul
LEAQ 16(oup), oup
chachaQR(A0, B0, C0, D0, T0); chachaQR(A1, B1, C1, D1, T0); chachaQR(A2, B2, C2, D2, T0)
shiftB0Right; shiftC0Right; shiftD0Right
shiftB1Right; shiftC1Right; shiftD1Right
shiftB2Right; shiftC2Right; shiftD2Right
DECQ itr1
JG sealSSETail192LoopA
DECQ itr2
JGE sealSSETail192LoopB
PADDL ·chacha20Constants<>(SB), A0; PADDL ·chacha20Constants<>(SB), A1; PADDL ·chacha20Constants<>(SB), A2
PADDL state1Store, B0; PADDL state1Store, B1; PADDL state1Store, B2
PADDL state2Store, C0; PADDL state2Store, C1; PADDL state2Store, C2
PADDL ctr0Store, D0; PADDL ctr1Store, D1; PADDL ctr2Store, D2
MOVOU (0*16)(inp), T0; MOVOU (1*16)(inp), T1; MOVOU (2*16)(inp), T2; MOVOU (3*16)(inp), T3
PXOR T0, A0; PXOR T1, B0; PXOR T2, C0; PXOR T3, D0
MOVOU A0, (0*16)(oup); MOVOU B0, (1*16)(oup); MOVOU C0, (2*16)(oup); MOVOU D0, (3*16)(oup)
MOVOU (4*16)(inp), T0; MOVOU (5*16)(inp), T1; MOVOU (6*16)(inp), T2; MOVOU (7*16)(inp), T3
PXOR T0, A1; PXOR T1, B1; PXOR T2, C1; PXOR T3, D1
MOVOU A1, (4*16)(oup); MOVOU B1, (5*16)(oup); MOVOU C1, (6*16)(oup); MOVOU D1, (7*16)(oup)
MOVO A2, A1
MOVO B2, B1
MOVO C2, C1
MOVO D2, D1
MOVQ $128, itr1
LEAQ 128(inp), inp
SUBQ $128, inl
JMP sealSSE128SealHash
// ----------------------------------------------------------------------------
// Special seal optimization for buffers smaller than 129 bytes
sealSSE128:
// For up to 128 bytes of ciphertext and 64 bytes for the poly key, we require to process three blocks
MOVOU ·chacha20Constants<>(SB), A0; MOVOU (1*16)(keyp), B0; MOVOU (2*16)(keyp), C0; MOVOU (3*16)(keyp), D0
MOVO A0, A1; MOVO B0, B1; MOVO C0, C1; MOVO D0, D1; PADDL ·sseIncMask<>(SB), D1
MOVO A1, A2; MOVO B1, B2; MOVO C1, C2; MOVO D1, D2; PADDL ·sseIncMask<>(SB), D2
MOVO B0, T1; MOVO C0, T2; MOVO D1, T3
MOVQ $10, itr2
sealSSE128InnerCipherLoop:
chachaQR(A0, B0, C0, D0, T0); chachaQR(A1, B1, C1, D1, T0); chachaQR(A2, B2, C2, D2, T0)
shiftB0Left; shiftB1Left; shiftB2Left
shiftC0Left; shiftC1Left; shiftC2Left
shiftD0Left; shiftD1Left; shiftD2Left
chachaQR(A0, B0, C0, D0, T0); chachaQR(A1, B1, C1, D1, T0); chachaQR(A2, B2, C2, D2, T0)
shiftB0Right; shiftB1Right; shiftB2Right
shiftC0Right; shiftC1Right; shiftC2Right
shiftD0Right; shiftD1Right; shiftD2Right
DECQ itr2
JNE sealSSE128InnerCipherLoop
// A0|B0 hold the Poly1305 32-byte key, C0,D0 can be discarded
PADDL ·chacha20Constants<>(SB), A0; PADDL ·chacha20Constants<>(SB), A1; PADDL ·chacha20Constants<>(SB), A2
PADDL T1, B0; PADDL T1, B1; PADDL T1, B2
PADDL T2, C1; PADDL T2, C2
PADDL T3, D1; PADDL ·sseIncMask<>(SB), T3; PADDL T3, D2
PAND ·polyClampMask<>(SB), A0
MOVOU A0, rStore
MOVOU B0, sStore
// Hash
MOVQ ad_len+80(FP), itr2
CALL polyHashADInternal<>(SB)
XORQ itr1, itr1
sealSSE128SealHash:
// itr1 holds the number of bytes encrypted but not yet hashed
CMPQ itr1, $16
JB sealSSE128Seal
polyAdd(0(oup))
polyMul
SUBQ $16, itr1
ADDQ $16, oup
JMP sealSSE128SealHash
sealSSE128Seal:
CMPQ inl, $16
JB sealSSETail
SUBQ $16, inl
// Load for decryption
MOVOU (inp), T0
PXOR T0, A1
MOVOU A1, (oup)
LEAQ (1*16)(inp), inp
LEAQ (1*16)(oup), oup
// Extract for hashing
MOVQ A1, t0
PSRLDQ $8, A1
MOVQ A1, t1
ADDQ t0, acc0; ADCQ t1, acc1; ADCQ $1, acc2
polyMul
// Shift the stream "left"
MOVO B1, A1
MOVO C1, B1
MOVO D1, C1
MOVO A2, D1
MOVO B2, A2
MOVO C2, B2
MOVO D2, C2
JMP sealSSE128Seal
sealSSETail:
TESTQ inl, inl
JE sealSSEFinalize
// We can only load the PT one byte at a time to avoid read after end of buffer
MOVQ inl, itr2
SHLQ $4, itr2
LEAQ ·andMask<>(SB), t0
MOVQ inl, itr1
LEAQ -1(inp)(inl*1), inp
XORQ t2, t2
XORQ t3, t3
XORQ AX, AX
sealSSETailLoadLoop:
SHLQ $8, t2, t3
SHLQ $8, t2
MOVB (inp), AX
XORQ AX, t2
LEAQ -1(inp), inp
DECQ itr1
JNE sealSSETailLoadLoop
MOVQ t2, 0+tmpStore
MOVQ t3, 8+tmpStore
PXOR 0+tmpStore, A1
MOVOU A1, (oup)
MOVOU -16(t0)(itr2*1), T0
PAND T0, A1
MOVQ A1, t0
PSRLDQ $8, A1
MOVQ A1, t1
ADDQ t0, acc0; ADCQ t1, acc1; ADCQ $1, acc2
polyMul
ADDQ inl, oup
sealSSEFinalize:
// Hash in the buffer lengths
ADDQ ad_len+80(FP), acc0
ADCQ src_len+56(FP), acc1
ADCQ $1, acc2
polyMul
// Final reduce
MOVQ acc0, t0
MOVQ acc1, t1
MOVQ acc2, t2
SUBQ $-5, acc0
SBBQ $-1, acc1
SBBQ $3, acc2
CMOVQCS t0, acc0
CMOVQCS t1, acc1
CMOVQCS t2, acc2
// Add in the "s" part of the key
ADDQ 0+sStore, acc0
ADCQ 8+sStore, acc1
// Finally store the tag at the end of the message
MOVQ acc0, (0*8)(oup)
MOVQ acc1, (1*8)(oup)
RET
// ----------------------------------------------------------------------------
// ------------------------- AVX2 Code ----------------------------------------
chacha20Poly1305Seal_AVX2:
VZEROUPPER
VMOVDQU ·chacha20Constants<>(SB), AA0
BYTE $0xc4; BYTE $0x42; BYTE $0x7d; BYTE $0x5a; BYTE $0x70; BYTE $0x10 // broadcasti128 16(r8), ymm14
BYTE $0xc4; BYTE $0x42; BYTE $0x7d; BYTE $0x5a; BYTE $0x60; BYTE $0x20 // broadcasti128 32(r8), ymm12
BYTE $0xc4; BYTE $0xc2; BYTE $0x7d; BYTE $0x5a; BYTE $0x60; BYTE $0x30 // broadcasti128 48(r8), ymm4
VPADDD ·avx2InitMask<>(SB), DD0, DD0
// Special optimizations, for very short buffers
CMPQ inl, $192
JBE seal192AVX2 // 33% faster
CMPQ inl, $320
JBE seal320AVX2 // 17% faster
// For the general key prepare the key first - as a byproduct we have 64 bytes of cipher stream
VMOVDQA AA0, AA1; VMOVDQA AA0, AA2; VMOVDQA AA0, AA3
VMOVDQA BB0, BB1; VMOVDQA BB0, BB2; VMOVDQA BB0, BB3; VMOVDQA BB0, state1StoreAVX2
VMOVDQA CC0, CC1; VMOVDQA CC0, CC2; VMOVDQA CC0, CC3; VMOVDQA CC0, state2StoreAVX2
VPADDD ·avx2IncMask<>(SB), DD0, DD1; VMOVDQA DD0, ctr0StoreAVX2
VPADDD ·avx2IncMask<>(SB), DD1, DD2; VMOVDQA DD1, ctr1StoreAVX2
VPADDD ·avx2IncMask<>(SB), DD2, DD3; VMOVDQA DD2, ctr2StoreAVX2
VMOVDQA DD3, ctr3StoreAVX2
MOVQ $10, itr2
sealAVX2IntroLoop:
VMOVDQA CC3, tmpStoreAVX2
chachaQR_AVX2(AA0, BB0, CC0, DD0, CC3); chachaQR_AVX2(AA1, BB1, CC1, DD1, CC3); chachaQR_AVX2(AA2, BB2, CC2, DD2, CC3)
VMOVDQA tmpStoreAVX2, CC3
VMOVDQA CC1, tmpStoreAVX2
chachaQR_AVX2(AA3, BB3, CC3, DD3, CC1)
VMOVDQA tmpStoreAVX2, CC1
VPALIGNR $4, BB0, BB0, BB0; VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $12, DD0, DD0, DD0
VPALIGNR $4, BB1, BB1, BB1; VPALIGNR $8, CC1, CC1, CC1; VPALIGNR $12, DD1, DD1, DD1
VPALIGNR $4, BB2, BB2, BB2; VPALIGNR $8, CC2, CC2, CC2; VPALIGNR $12, DD2, DD2, DD2
VPALIGNR $4, BB3, BB3, BB3; VPALIGNR $8, CC3, CC3, CC3; VPALIGNR $12, DD3, DD3, DD3
VMOVDQA CC3, tmpStoreAVX2
chachaQR_AVX2(AA0, BB0, CC0, DD0, CC3); chachaQR_AVX2(AA1, BB1, CC1, DD1, CC3); chachaQR_AVX2(AA2, BB2, CC2, DD2, CC3)
VMOVDQA tmpStoreAVX2, CC3
VMOVDQA CC1, tmpStoreAVX2
chachaQR_AVX2(AA3, BB3, CC3, DD3, CC1)
VMOVDQA tmpStoreAVX2, CC1
VPALIGNR $12, BB0, BB0, BB0; VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $4, DD0, DD0, DD0
VPALIGNR $12, BB1, BB1, BB1; VPALIGNR $8, CC1, CC1, CC1; VPALIGNR $4, DD1, DD1, DD1
VPALIGNR $12, BB2, BB2, BB2; VPALIGNR $8, CC2, CC2, CC2; VPALIGNR $4, DD2, DD2, DD2
VPALIGNR $12, BB3, BB3, BB3; VPALIGNR $8, CC3, CC3, CC3; VPALIGNR $4, DD3, DD3, DD3
DECQ itr2
JNE sealAVX2IntroLoop
VPADDD ·chacha20Constants<>(SB), AA0, AA0; VPADDD ·chacha20Constants<>(SB), AA1, AA1; VPADDD ·chacha20Constants<>(SB), AA2, AA2; VPADDD ·chacha20Constants<>(SB), AA3, AA3
VPADDD state1StoreAVX2, BB0, BB0; VPADDD state1StoreAVX2, BB1, BB1; VPADDD state1StoreAVX2, BB2, BB2; VPADDD state1StoreAVX2, BB3, BB3
VPADDD state2StoreAVX2, CC0, CC0; VPADDD state2StoreAVX2, CC1, CC1; VPADDD state2StoreAVX2, CC2, CC2; VPADDD state2StoreAVX2, CC3, CC3
VPADDD ctr0StoreAVX2, DD0, DD0; VPADDD ctr1StoreAVX2, DD1, DD1; VPADDD ctr2StoreAVX2, DD2, DD2; VPADDD ctr3StoreAVX2, DD3, DD3
VPERM2I128 $0x13, CC0, DD0, CC0 // Stream bytes 96 - 127
VPERM2I128 $0x02, AA0, BB0, DD0 // The Poly1305 key
VPERM2I128 $0x13, AA0, BB0, AA0 // Stream bytes 64 - 95
// Clamp and store poly key
VPAND ·polyClampMask<>(SB), DD0, DD0
VMOVDQA DD0, rsStoreAVX2
// Hash AD
MOVQ ad_len+80(FP), itr2
CALL polyHashADInternal<>(SB)
// Can store at least 320 bytes
VPXOR (0*32)(inp), AA0, AA0
VPXOR (1*32)(inp), CC0, CC0
VMOVDQU AA0, (0*32)(oup)
VMOVDQU CC0, (1*32)(oup)
VPERM2I128 $0x02, AA1, BB1, AA0; VPERM2I128 $0x02, CC1, DD1, BB0; VPERM2I128 $0x13, AA1, BB1, CC0; VPERM2I128 $0x13, CC1, DD1, DD0
VPXOR (2*32)(inp), AA0, AA0; VPXOR (3*32)(inp), BB0, BB0; VPXOR (4*32)(inp), CC0, CC0; VPXOR (5*32)(inp), DD0, DD0
VMOVDQU AA0, (2*32)(oup); VMOVDQU BB0, (3*32)(oup); VMOVDQU CC0, (4*32)(oup); VMOVDQU DD0, (5*32)(oup)
VPERM2I128 $0x02, AA2, BB2, AA0; VPERM2I128 $0x02, CC2, DD2, BB0; VPERM2I128 $0x13, AA2, BB2, CC0; VPERM2I128 $0x13, CC2, DD2, DD0
VPXOR (6*32)(inp), AA0, AA0; VPXOR (7*32)(inp), BB0, BB0; VPXOR (8*32)(inp), CC0, CC0; VPXOR (9*32)(inp), DD0, DD0
VMOVDQU AA0, (6*32)(oup); VMOVDQU BB0, (7*32)(oup); VMOVDQU CC0, (8*32)(oup); VMOVDQU DD0, (9*32)(oup)
MOVQ $320, itr1
SUBQ $320, inl
LEAQ 320(inp), inp
VPERM2I128 $0x02, AA3, BB3, AA0; VPERM2I128 $0x02, CC3, DD3, BB0; VPERM2I128 $0x13, AA3, BB3, CC0; VPERM2I128 $0x13, CC3, DD3, DD0
CMPQ inl, $128
JBE sealAVX2SealHash
VPXOR (0*32)(inp), AA0, AA0; VPXOR (1*32)(inp), BB0, BB0; VPXOR (2*32)(inp), CC0, CC0; VPXOR (3*32)(inp), DD0, DD0
VMOVDQU AA0, (10*32)(oup); VMOVDQU BB0, (11*32)(oup); VMOVDQU CC0, (12*32)(oup); VMOVDQU DD0, (13*32)(oup)
SUBQ $128, inl
LEAQ 128(inp), inp
MOVQ $8, itr1
MOVQ $2, itr2
CMPQ inl, $128
JBE sealAVX2Tail128
CMPQ inl, $256
JBE sealAVX2Tail256
CMPQ inl, $384
JBE sealAVX2Tail384
CMPQ inl, $512
JBE sealAVX2Tail512
// We have 448 bytes to hash, but main loop hashes 512 bytes at a time - perform some rounds, before the main loop
VMOVDQA ·chacha20Constants<>(SB), AA0; VMOVDQA AA0, AA1; VMOVDQA AA0, AA2; VMOVDQA AA0, AA3
VMOVDQA state1StoreAVX2, BB0; VMOVDQA BB0, BB1; VMOVDQA BB0, BB2; VMOVDQA BB0, BB3
VMOVDQA state2StoreAVX2, CC0; VMOVDQA CC0, CC1; VMOVDQA CC0, CC2; VMOVDQA CC0, CC3
VMOVDQA ctr3StoreAVX2, DD0
VPADDD ·avx2IncMask<>(SB), DD0, DD0; VPADDD ·avx2IncMask<>(SB), DD0, DD1; VPADDD ·avx2IncMask<>(SB), DD1, DD2; VPADDD ·avx2IncMask<>(SB), DD2, DD3
VMOVDQA DD0, ctr0StoreAVX2; VMOVDQA DD1, ctr1StoreAVX2; VMOVDQA DD2, ctr2StoreAVX2; VMOVDQA DD3, ctr3StoreAVX2
VMOVDQA CC3, tmpStoreAVX2
chachaQR_AVX2(AA0, BB0, CC0, DD0, CC3); chachaQR_AVX2(AA1, BB1, CC1, DD1, CC3); chachaQR_AVX2(AA2, BB2, CC2, DD2, CC3)
VMOVDQA tmpStoreAVX2, CC3
VMOVDQA CC1, tmpStoreAVX2
chachaQR_AVX2(AA3, BB3, CC3, DD3, CC1)
VMOVDQA tmpStoreAVX2, CC1
VPALIGNR $4, BB0, BB0, BB0; VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $12, DD0, DD0, DD0
VPALIGNR $4, BB1, BB1, BB1; VPALIGNR $8, CC1, CC1, CC1; VPALIGNR $12, DD1, DD1, DD1
VPALIGNR $4, BB2, BB2, BB2; VPALIGNR $8, CC2, CC2, CC2; VPALIGNR $12, DD2, DD2, DD2
VPALIGNR $4, BB3, BB3, BB3; VPALIGNR $8, CC3, CC3, CC3; VPALIGNR $12, DD3, DD3, DD3
VMOVDQA CC3, tmpStoreAVX2
chachaQR_AVX2(AA0, BB0, CC0, DD0, CC3); chachaQR_AVX2(AA1, BB1, CC1, DD1, CC3); chachaQR_AVX2(AA2, BB2, CC2, DD2, CC3)
VMOVDQA tmpStoreAVX2, CC3
VMOVDQA CC1, tmpStoreAVX2
chachaQR_AVX2(AA3, BB3, CC3, DD3, CC1)
VMOVDQA tmpStoreAVX2, CC1
VPALIGNR $12, BB0, BB0, BB0; VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $4, DD0, DD0, DD0
VPALIGNR $12, BB1, BB1, BB1; VPALIGNR $8, CC1, CC1, CC1; VPALIGNR $4, DD1, DD1, DD1
VPALIGNR $12, BB2, BB2, BB2; VPALIGNR $8, CC2, CC2, CC2; VPALIGNR $4, DD2, DD2, DD2
VPALIGNR $12, BB3, BB3, BB3; VPALIGNR $8, CC3, CC3, CC3; VPALIGNR $4, DD3, DD3, DD3
VPADDD BB0, AA0, AA0; VPADDD BB1, AA1, AA1; VPADDD BB2, AA2, AA2; VPADDD BB3, AA3, AA3
VPXOR AA0, DD0, DD0; VPXOR AA1, DD1, DD1; VPXOR AA2, DD2, DD2; VPXOR AA3, DD3, DD3
VPSHUFB ·rol16<>(SB), DD0, DD0; VPSHUFB ·rol16<>(SB), DD1, DD1; VPSHUFB ·rol16<>(SB), DD2, DD2; VPSHUFB ·rol16<>(SB), DD3, DD3
VPADDD DD0, CC0, CC0; VPADDD DD1, CC1, CC1; VPADDD DD2, CC2, CC2; VPADDD DD3, CC3, CC3
VPXOR CC0, BB0, BB0; VPXOR CC1, BB1, BB1; VPXOR CC2, BB2, BB2; VPXOR CC3, BB3, BB3
VMOVDQA CC3, tmpStoreAVX2
VPSLLD $12, BB0, CC3; VPSRLD $20, BB0, BB0; VPXOR CC3, BB0, BB0
VPSLLD $12, BB1, CC3; VPSRLD $20, BB1, BB1; VPXOR CC3, BB1, BB1
VPSLLD $12, BB2, CC3; VPSRLD $20, BB2, BB2; VPXOR CC3, BB2, BB2
VPSLLD $12, BB3, CC3; VPSRLD $20, BB3, BB3; VPXOR CC3, BB3, BB3
VMOVDQA tmpStoreAVX2, CC3
SUBQ $16, oup // Adjust the pointer
MOVQ $9, itr1
JMP sealAVX2InternalLoopStart
sealAVX2MainLoop:
// Load state, increment counter blocks, store the incremented counters
VMOVDQU ·chacha20Constants<>(SB), AA0; VMOVDQA AA0, AA1; VMOVDQA AA0, AA2; VMOVDQA AA0, AA3
VMOVDQA state1StoreAVX2, BB0; VMOVDQA BB0, BB1; VMOVDQA BB0, BB2; VMOVDQA BB0, BB3
VMOVDQA state2StoreAVX2, CC0; VMOVDQA CC0, CC1; VMOVDQA CC0, CC2; VMOVDQA CC0, CC3
VMOVDQA ctr3StoreAVX2, DD0; VPADDD ·avx2IncMask<>(SB), DD0, DD0; VPADDD ·avx2IncMask<>(SB), DD0, DD1; VPADDD ·avx2IncMask<>(SB), DD1, DD2; VPADDD ·avx2IncMask<>(SB), DD2, DD3
VMOVDQA DD0, ctr0StoreAVX2; VMOVDQA DD1, ctr1StoreAVX2; VMOVDQA DD2, ctr2StoreAVX2; VMOVDQA DD3, ctr3StoreAVX2
MOVQ $10, itr1
sealAVX2InternalLoop:
polyAdd(0*8(oup))
VPADDD BB0, AA0, AA0; VPADDD BB1, AA1, AA1; VPADDD BB2, AA2, AA2; VPADDD BB3, AA3, AA3
polyMulStage1_AVX2
VPXOR AA0, DD0, DD0; VPXOR AA1, DD1, DD1; VPXOR AA2, DD2, DD2; VPXOR AA3, DD3, DD3
VPSHUFB ·rol16<>(SB), DD0, DD0; VPSHUFB ·rol16<>(SB), DD1, DD1; VPSHUFB ·rol16<>(SB), DD2, DD2; VPSHUFB ·rol16<>(SB), DD3, DD3
polyMulStage2_AVX2
VPADDD DD0, CC0, CC0; VPADDD DD1, CC1, CC1; VPADDD DD2, CC2, CC2; VPADDD DD3, CC3, CC3
VPXOR CC0, BB0, BB0; VPXOR CC1, BB1, BB1; VPXOR CC2, BB2, BB2; VPXOR CC3, BB3, BB3
polyMulStage3_AVX2
VMOVDQA CC3, tmpStoreAVX2
VPSLLD $12, BB0, CC3; VPSRLD $20, BB0, BB0; VPXOR CC3, BB0, BB0
VPSLLD $12, BB1, CC3; VPSRLD $20, BB1, BB1; VPXOR CC3, BB1, BB1
VPSLLD $12, BB2, CC3; VPSRLD $20, BB2, BB2; VPXOR CC3, BB2, BB2
VPSLLD $12, BB3, CC3; VPSRLD $20, BB3, BB3; VPXOR CC3, BB3, BB3
VMOVDQA tmpStoreAVX2, CC3
polyMulReduceStage
sealAVX2InternalLoopStart:
VPADDD BB0, AA0, AA0; VPADDD BB1, AA1, AA1; VPADDD BB2, AA2, AA2; VPADDD BB3, AA3, AA3
VPXOR AA0, DD0, DD0; VPXOR AA1, DD1, DD1; VPXOR AA2, DD2, DD2; VPXOR AA3, DD3, DD3
VPSHUFB ·rol8<>(SB), DD0, DD0; VPSHUFB ·rol8<>(SB), DD1, DD1; VPSHUFB ·rol8<>(SB), DD2, DD2; VPSHUFB ·rol8<>(SB), DD3, DD3
polyAdd(2*8(oup))
VPADDD DD0, CC0, CC0; VPADDD DD1, CC1, CC1; VPADDD DD2, CC2, CC2; VPADDD DD3, CC3, CC3
polyMulStage1_AVX2
VPXOR CC0, BB0, BB0; VPXOR CC1, BB1, BB1; VPXOR CC2, BB2, BB2; VPXOR CC3, BB3, BB3
VMOVDQA CC3, tmpStoreAVX2
VPSLLD $7, BB0, CC3; VPSRLD $25, BB0, BB0; VPXOR CC3, BB0, BB0
VPSLLD $7, BB1, CC3; VPSRLD $25, BB1, BB1; VPXOR CC3, BB1, BB1
VPSLLD $7, BB2, CC3; VPSRLD $25, BB2, BB2; VPXOR CC3, BB2, BB2
VPSLLD $7, BB3, CC3; VPSRLD $25, BB3, BB3; VPXOR CC3, BB3, BB3
VMOVDQA tmpStoreAVX2, CC3
polyMulStage2_AVX2
VPALIGNR $4, BB0, BB0, BB0; VPALIGNR $4, BB1, BB1, BB1; VPALIGNR $4, BB2, BB2, BB2; VPALIGNR $4, BB3, BB3, BB3
VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1; VPALIGNR $8, CC2, CC2, CC2; VPALIGNR $8, CC3, CC3, CC3
VPALIGNR $12, DD0, DD0, DD0; VPALIGNR $12, DD1, DD1, DD1; VPALIGNR $12, DD2, DD2, DD2; VPALIGNR $12, DD3, DD3, DD3
VPADDD BB0, AA0, AA0; VPADDD BB1, AA1, AA1; VPADDD BB2, AA2, AA2; VPADDD BB3, AA3, AA3
polyMulStage3_AVX2
VPXOR AA0, DD0, DD0; VPXOR AA1, DD1, DD1; VPXOR AA2, DD2, DD2; VPXOR AA3, DD3, DD3
VPSHUFB ·rol16<>(SB), DD0, DD0; VPSHUFB ·rol16<>(SB), DD1, DD1; VPSHUFB ·rol16<>(SB), DD2, DD2; VPSHUFB ·rol16<>(SB), DD3, DD3
polyMulReduceStage
VPADDD DD0, CC0, CC0; VPADDD DD1, CC1, CC1; VPADDD DD2, CC2, CC2; VPADDD DD3, CC3, CC3
VPXOR CC0, BB0, BB0; VPXOR CC1, BB1, BB1; VPXOR CC2, BB2, BB2; VPXOR CC3, BB3, BB3
polyAdd(4*8(oup))
LEAQ (6*8)(oup), oup
VMOVDQA CC3, tmpStoreAVX2
VPSLLD $12, BB0, CC3; VPSRLD $20, BB0, BB0; VPXOR CC3, BB0, BB0
VPSLLD $12, BB1, CC3; VPSRLD $20, BB1, BB1; VPXOR CC3, BB1, BB1
VPSLLD $12, BB2, CC3; VPSRLD $20, BB2, BB2; VPXOR CC3, BB2, BB2
VPSLLD $12, BB3, CC3; VPSRLD $20, BB3, BB3; VPXOR CC3, BB3, BB3
VMOVDQA tmpStoreAVX2, CC3
polyMulStage1_AVX2
VPADDD BB0, AA0, AA0; VPADDD BB1, AA1, AA1; VPADDD BB2, AA2, AA2; VPADDD BB3, AA3, AA3
VPXOR AA0, DD0, DD0; VPXOR AA1, DD1, DD1; VPXOR AA2, DD2, DD2; VPXOR AA3, DD3, DD3
polyMulStage2_AVX2
VPSHUFB ·rol8<>(SB), DD0, DD0; VPSHUFB ·rol8<>(SB), DD1, DD1; VPSHUFB ·rol8<>(SB), DD2, DD2; VPSHUFB ·rol8<>(SB), DD3, DD3
VPADDD DD0, CC0, CC0; VPADDD DD1, CC1, CC1; VPADDD DD2, CC2, CC2; VPADDD DD3, CC3, CC3
polyMulStage3_AVX2
VPXOR CC0, BB0, BB0; VPXOR CC1, BB1, BB1; VPXOR CC2, BB2, BB2; VPXOR CC3, BB3, BB3
VMOVDQA CC3, tmpStoreAVX2
VPSLLD $7, BB0, CC3; VPSRLD $25, BB0, BB0; VPXOR CC3, BB0, BB0
VPSLLD $7, BB1, CC3; VPSRLD $25, BB1, BB1; VPXOR CC3, BB1, BB1
VPSLLD $7, BB2, CC3; VPSRLD $25, BB2, BB2; VPXOR CC3, BB2, BB2
VPSLLD $7, BB3, CC3; VPSRLD $25, BB3, BB3; VPXOR CC3, BB3, BB3
VMOVDQA tmpStoreAVX2, CC3
polyMulReduceStage
VPALIGNR $12, BB0, BB0, BB0; VPALIGNR $12, BB1, BB1, BB1; VPALIGNR $12, BB2, BB2, BB2; VPALIGNR $12, BB3, BB3, BB3
VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1; VPALIGNR $8, CC2, CC2, CC2; VPALIGNR $8, CC3, CC3, CC3
VPALIGNR $4, DD0, DD0, DD0; VPALIGNR $4, DD1, DD1, DD1; VPALIGNR $4, DD2, DD2, DD2; VPALIGNR $4, DD3, DD3, DD3
DECQ itr1
JNE sealAVX2InternalLoop
VPADDD ·chacha20Constants<>(SB), AA0, AA0; VPADDD ·chacha20Constants<>(SB), AA1, AA1; VPADDD ·chacha20Constants<>(SB), AA2, AA2; VPADDD ·chacha20Constants<>(SB), AA3, AA3
VPADDD state1StoreAVX2, BB0, BB0; VPADDD state1StoreAVX2, BB1, BB1; VPADDD state1StoreAVX2, BB2, BB2; VPADDD state1StoreAVX2, BB3, BB3
VPADDD state2StoreAVX2, CC0, CC0; VPADDD state2StoreAVX2, CC1, CC1; VPADDD state2StoreAVX2, CC2, CC2; VPADDD state2StoreAVX2, CC3, CC3
VPADDD ctr0StoreAVX2, DD0, DD0; VPADDD ctr1StoreAVX2, DD1, DD1; VPADDD ctr2StoreAVX2, DD2, DD2; VPADDD ctr3StoreAVX2, DD3, DD3
VMOVDQA CC3, tmpStoreAVX2
// We only hashed 480 of the 512 bytes available - hash the remaining 32 here
polyAdd(0*8(oup))
polyMulAVX2
LEAQ (4*8)(oup), oup
VPERM2I128 $0x02, AA0, BB0, CC3; VPERM2I128 $0x13, AA0, BB0, BB0; VPERM2I128 $0x02, CC0, DD0, AA0; VPERM2I128 $0x13, CC0, DD0, CC0
VPXOR (0*32)(inp), CC3, CC3; VPXOR (1*32)(inp), AA0, AA0; VPXOR (2*32)(inp), BB0, BB0; VPXOR (3*32)(inp), CC0, CC0
VMOVDQU CC3, (0*32)(oup); VMOVDQU AA0, (1*32)(oup); VMOVDQU BB0, (2*32)(oup); VMOVDQU CC0, (3*32)(oup)
VPERM2I128 $0x02, AA1, BB1, AA0; VPERM2I128 $0x02, CC1, DD1, BB0; VPERM2I128 $0x13, AA1, BB1, CC0; VPERM2I128 $0x13, CC1, DD1, DD0
VPXOR (4*32)(inp), AA0, AA0; VPXOR (5*32)(inp), BB0, BB0; VPXOR (6*32)(inp), CC0, CC0; VPXOR (7*32)(inp), DD0, DD0
VMOVDQU AA0, (4*32)(oup); VMOVDQU BB0, (5*32)(oup); VMOVDQU CC0, (6*32)(oup); VMOVDQU DD0, (7*32)(oup)
// and here
polyAdd(-2*8(oup))
polyMulAVX2
VPERM2I128 $0x02, AA2, BB2, AA0; VPERM2I128 $0x02, CC2, DD2, BB0; VPERM2I128 $0x13, AA2, BB2, CC0; VPERM2I128 $0x13, CC2, DD2, DD0
VPXOR (8*32)(inp), AA0, AA0; VPXOR (9*32)(inp), BB0, BB0; VPXOR (10*32)(inp), CC0, CC0; VPXOR (11*32)(inp), DD0, DD0
VMOVDQU AA0, (8*32)(oup); VMOVDQU BB0, (9*32)(oup); VMOVDQU CC0, (10*32)(oup); VMOVDQU DD0, (11*32)(oup)
VPERM2I128 $0x02, AA3, BB3, AA0; VPERM2I128 $0x02, tmpStoreAVX2, DD3, BB0; VPERM2I128 $0x13, AA3, BB3, CC0; VPERM2I128 $0x13, tmpStoreAVX2, DD3, DD0
VPXOR (12*32)(inp), AA0, AA0; VPXOR (13*32)(inp), BB0, BB0; VPXOR (14*32)(inp), CC0, CC0; VPXOR (15*32)(inp), DD0, DD0
VMOVDQU AA0, (12*32)(oup); VMOVDQU BB0, (13*32)(oup); VMOVDQU CC0, (14*32)(oup); VMOVDQU DD0, (15*32)(oup)
LEAQ (32*16)(inp), inp
SUBQ $(32*16), inl
CMPQ inl, $512
JG sealAVX2MainLoop
// Tail can only hash 480 bytes
polyAdd(0*8(oup))
polyMulAVX2
polyAdd(2*8(oup))
polyMulAVX2
LEAQ 32(oup), oup
MOVQ $10, itr1
MOVQ $0, itr2
CMPQ inl, $128
JBE sealAVX2Tail128
CMPQ inl, $256
JBE sealAVX2Tail256
CMPQ inl, $384
JBE sealAVX2Tail384
JMP sealAVX2Tail512
// ----------------------------------------------------------------------------
// Special optimization for buffers smaller than 193 bytes
seal192AVX2:
// For up to 192 bytes of ciphertext and 64 bytes for the poly key, we process four blocks
VMOVDQA AA0, AA1
VMOVDQA BB0, BB1
VMOVDQA CC0, CC1
VPADDD ·avx2IncMask<>(SB), DD0, DD1
VMOVDQA AA0, AA2
VMOVDQA BB0, BB2
VMOVDQA CC0, CC2
VMOVDQA DD0, DD2
VMOVDQA DD1, TT3
MOVQ $10, itr2
sealAVX2192InnerCipherLoop:
chachaQR_AVX2(AA0, BB0, CC0, DD0, TT0); chachaQR_AVX2(AA1, BB1, CC1, DD1, TT0)
VPALIGNR $4, BB0, BB0, BB0; VPALIGNR $4, BB1, BB1, BB1
VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1
VPALIGNR $12, DD0, DD0, DD0; VPALIGNR $12, DD1, DD1, DD1
chachaQR_AVX2(AA0, BB0, CC0, DD0, TT0); chachaQR_AVX2(AA1, BB1, CC1, DD1, TT0)
VPALIGNR $12, BB0, BB0, BB0; VPALIGNR $12, BB1, BB1, BB1
VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1
VPALIGNR $4, DD0, DD0, DD0; VPALIGNR $4, DD1, DD1, DD1
DECQ itr2
JNE sealAVX2192InnerCipherLoop
VPADDD AA2, AA0, AA0; VPADDD AA2, AA1, AA1
VPADDD BB2, BB0, BB0; VPADDD BB2, BB1, BB1
VPADDD CC2, CC0, CC0; VPADDD CC2, CC1, CC1
VPADDD DD2, DD0, DD0; VPADDD TT3, DD1, DD1
VPERM2I128 $0x02, AA0, BB0, TT0
// Clamp and store poly key
VPAND ·polyClampMask<>(SB), TT0, TT0
VMOVDQA TT0, rsStoreAVX2
// Stream for up to 192 bytes
VPERM2I128 $0x13, AA0, BB0, AA0
VPERM2I128 $0x13, CC0, DD0, BB0
VPERM2I128 $0x02, AA1, BB1, CC0
VPERM2I128 $0x02, CC1, DD1, DD0
VPERM2I128 $0x13, AA1, BB1, AA1
VPERM2I128 $0x13, CC1, DD1, BB1
sealAVX2ShortSeal:
// Hash aad
MOVQ ad_len+80(FP), itr2
CALL polyHashADInternal<>(SB)
XORQ itr1, itr1
sealAVX2SealHash:
// itr1 holds the number of bytes encrypted but not yet hashed
CMPQ itr1, $16
JB sealAVX2ShortSealLoop
polyAdd(0(oup))
polyMul
SUBQ $16, itr1
ADDQ $16, oup
JMP sealAVX2SealHash
sealAVX2ShortSealLoop:
CMPQ inl, $32
JB sealAVX2ShortTail32
SUBQ $32, inl
// Load for encryption
VPXOR (inp), AA0, AA0
VMOVDQU AA0, (oup)
LEAQ (1*32)(inp), inp
// Now can hash
polyAdd(0*8(oup))
polyMulAVX2
polyAdd(2*8(oup))
polyMulAVX2
LEAQ (1*32)(oup), oup
// Shift stream left
VMOVDQA BB0, AA0
VMOVDQA CC0, BB0
VMOVDQA DD0, CC0
VMOVDQA AA1, DD0
VMOVDQA BB1, AA1
VMOVDQA CC1, BB1
VMOVDQA DD1, CC1
VMOVDQA AA2, DD1
VMOVDQA BB2, AA2
JMP sealAVX2ShortSealLoop
sealAVX2ShortTail32:
CMPQ inl, $16
VMOVDQA A0, A1
JB sealAVX2ShortDone
SUBQ $16, inl
// Load for encryption
VPXOR (inp), A0, T0
VMOVDQU T0, (oup)
LEAQ (1*16)(inp), inp
// Hash
polyAdd(0*8(oup))
polyMulAVX2
LEAQ (1*16)(oup), oup
VPERM2I128 $0x11, AA0, AA0, AA0
VMOVDQA A0, A1
sealAVX2ShortDone:
VZEROUPPER
JMP sealSSETail
// ----------------------------------------------------------------------------
// Special optimization for buffers smaller than 321 bytes
seal320AVX2:
// For up to 320 bytes of ciphertext and 64 bytes for the poly key, we process six blocks
VMOVDQA AA0, AA1; VMOVDQA BB0, BB1; VMOVDQA CC0, CC1; VPADDD ·avx2IncMask<>(SB), DD0, DD1
VMOVDQA AA0, AA2; VMOVDQA BB0, BB2; VMOVDQA CC0, CC2; VPADDD ·avx2IncMask<>(SB), DD1, DD2
VMOVDQA BB0, TT1; VMOVDQA CC0, TT2; VMOVDQA DD0, TT3
MOVQ $10, itr2
sealAVX2320InnerCipherLoop:
chachaQR_AVX2(AA0, BB0, CC0, DD0, TT0); chachaQR_AVX2(AA1, BB1, CC1, DD1, TT0); chachaQR_AVX2(AA2, BB2, CC2, DD2, TT0)
VPALIGNR $4, BB0, BB0, BB0; VPALIGNR $4, BB1, BB1, BB1; VPALIGNR $4, BB2, BB2, BB2
VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1; VPALIGNR $8, CC2, CC2, CC2
VPALIGNR $12, DD0, DD0, DD0; VPALIGNR $12, DD1, DD1, DD1; VPALIGNR $12, DD2, DD2, DD2
chachaQR_AVX2(AA0, BB0, CC0, DD0, TT0); chachaQR_AVX2(AA1, BB1, CC1, DD1, TT0); chachaQR_AVX2(AA2, BB2, CC2, DD2, TT0)
VPALIGNR $12, BB0, BB0, BB0; VPALIGNR $12, BB1, BB1, BB1; VPALIGNR $12, BB2, BB2, BB2
VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1; VPALIGNR $8, CC2, CC2, CC2
VPALIGNR $4, DD0, DD0, DD0; VPALIGNR $4, DD1, DD1, DD1; VPALIGNR $4, DD2, DD2, DD2
DECQ itr2
JNE sealAVX2320InnerCipherLoop
VMOVDQA ·chacha20Constants<>(SB), TT0
VPADDD TT0, AA0, AA0; VPADDD TT0, AA1, AA1; VPADDD TT0, AA2, AA2
VPADDD TT1, BB0, BB0; VPADDD TT1, BB1, BB1; VPADDD TT1, BB2, BB2
VPADDD TT2, CC0, CC0; VPADDD TT2, CC1, CC1; VPADDD TT2, CC2, CC2
VMOVDQA ·avx2IncMask<>(SB), TT0
VPADDD TT3, DD0, DD0; VPADDD TT0, TT3, TT3
VPADDD TT3, DD1, DD1; VPADDD TT0, TT3, TT3
VPADDD TT3, DD2, DD2
// Clamp and store poly key
VPERM2I128 $0x02, AA0, BB0, TT0
VPAND ·polyClampMask<>(SB), TT0, TT0
VMOVDQA TT0, rsStoreAVX2
// Stream for up to 320 bytes
VPERM2I128 $0x13, AA0, BB0, AA0
VPERM2I128 $0x13, CC0, DD0, BB0
VPERM2I128 $0x02, AA1, BB1, CC0
VPERM2I128 $0x02, CC1, DD1, DD0
VPERM2I128 $0x13, AA1, BB1, AA1
VPERM2I128 $0x13, CC1, DD1, BB1
VPERM2I128 $0x02, AA2, BB2, CC1
VPERM2I128 $0x02, CC2, DD2, DD1
VPERM2I128 $0x13, AA2, BB2, AA2
VPERM2I128 $0x13, CC2, DD2, BB2
JMP sealAVX2ShortSeal
// ----------------------------------------------------------------------------
// Special optimization for the last 128 bytes of ciphertext
sealAVX2Tail128:
// Need to decrypt up to 128 bytes - prepare two blocks
// If we got here after the main loop - there are 512 encrypted bytes waiting to be hashed
// If we got here before the main loop - there are 448 encrpyred bytes waiting to be hashed
VMOVDQA ·chacha20Constants<>(SB), AA0
VMOVDQA state1StoreAVX2, BB0
VMOVDQA state2StoreAVX2, CC0
VMOVDQA ctr3StoreAVX2, DD0
VPADDD ·avx2IncMask<>(SB), DD0, DD0
VMOVDQA DD0, DD1
sealAVX2Tail128LoopA:
polyAdd(0(oup))
polyMul
LEAQ 16(oup), oup
sealAVX2Tail128LoopB:
chachaQR_AVX2(AA0, BB0, CC0, DD0, TT0)
polyAdd(0(oup))
polyMul
VPALIGNR $4, BB0, BB0, BB0
VPALIGNR $8, CC0, CC0, CC0
VPALIGNR $12, DD0, DD0, DD0
chachaQR_AVX2(AA0, BB0, CC0, DD0, TT0)
polyAdd(16(oup))
polyMul
LEAQ 32(oup), oup
VPALIGNR $12, BB0, BB0, BB0
VPALIGNR $8, CC0, CC0, CC0
VPALIGNR $4, DD0, DD0, DD0
DECQ itr1
JG sealAVX2Tail128LoopA
DECQ itr2
JGE sealAVX2Tail128LoopB
VPADDD ·chacha20Constants<>(SB), AA0, AA1
VPADDD state1StoreAVX2, BB0, BB1
VPADDD state2StoreAVX2, CC0, CC1
VPADDD DD1, DD0, DD1
VPERM2I128 $0x02, AA1, BB1, AA0
VPERM2I128 $0x02, CC1, DD1, BB0
VPERM2I128 $0x13, AA1, BB1, CC0
VPERM2I128 $0x13, CC1, DD1, DD0
JMP sealAVX2ShortSealLoop
// ----------------------------------------------------------------------------
// Special optimization for the last 256 bytes of ciphertext
sealAVX2Tail256:
// Need to decrypt up to 256 bytes - prepare two blocks
// If we got here after the main loop - there are 512 encrypted bytes waiting to be hashed
// If we got here before the main loop - there are 448 encrpyred bytes waiting to be hashed
VMOVDQA ·chacha20Constants<>(SB), AA0; VMOVDQA ·chacha20Constants<>(SB), AA1
VMOVDQA state1StoreAVX2, BB0; VMOVDQA state1StoreAVX2, BB1
VMOVDQA state2StoreAVX2, CC0; VMOVDQA state2StoreAVX2, CC1
VMOVDQA ctr3StoreAVX2, DD0
VPADDD ·avx2IncMask<>(SB), DD0, DD0
VPADDD ·avx2IncMask<>(SB), DD0, DD1
VMOVDQA DD0, TT1
VMOVDQA DD1, TT2
sealAVX2Tail256LoopA:
polyAdd(0(oup))
polyMul
LEAQ 16(oup), oup
sealAVX2Tail256LoopB:
chachaQR_AVX2(AA0, BB0, CC0, DD0, TT0); chachaQR_AVX2(AA1, BB1, CC1, DD1, TT0)
polyAdd(0(oup))
polyMul
VPALIGNR $4, BB0, BB0, BB0; VPALIGNR $4, BB1, BB1, BB1
VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1
VPALIGNR $12, DD0, DD0, DD0; VPALIGNR $12, DD1, DD1, DD1
chachaQR_AVX2(AA0, BB0, CC0, DD0, TT0); chachaQR_AVX2(AA1, BB1, CC1, DD1, TT0)
polyAdd(16(oup))
polyMul
LEAQ 32(oup), oup
VPALIGNR $12, BB0, BB0, BB0; VPALIGNR $12, BB1, BB1, BB1
VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1
VPALIGNR $4, DD0, DD0, DD0; VPALIGNR $4, DD1, DD1, DD1
DECQ itr1
JG sealAVX2Tail256LoopA
DECQ itr2
JGE sealAVX2Tail256LoopB
VPADDD ·chacha20Constants<>(SB), AA0, AA0; VPADDD ·chacha20Constants<>(SB), AA1, AA1
VPADDD state1StoreAVX2, BB0, BB0; VPADDD state1StoreAVX2, BB1, BB1
VPADDD state2StoreAVX2, CC0, CC0; VPADDD state2StoreAVX2, CC1, CC1
VPADDD TT1, DD0, DD0; VPADDD TT2, DD1, DD1
VPERM2I128 $0x02, AA0, BB0, TT0
VPERM2I128 $0x02, CC0, DD0, TT1
VPERM2I128 $0x13, AA0, BB0, TT2
VPERM2I128 $0x13, CC0, DD0, TT3
VPXOR (0*32)(inp), TT0, TT0; VPXOR (1*32)(inp), TT1, TT1; VPXOR (2*32)(inp), TT2, TT2; VPXOR (3*32)(inp), TT3, TT3
VMOVDQU TT0, (0*32)(oup); VMOVDQU TT1, (1*32)(oup); VMOVDQU TT2, (2*32)(oup); VMOVDQU TT3, (3*32)(oup)
MOVQ $128, itr1
LEAQ 128(inp), inp
SUBQ $128, inl
VPERM2I128 $0x02, AA1, BB1, AA0
VPERM2I128 $0x02, CC1, DD1, BB0
VPERM2I128 $0x13, AA1, BB1, CC0
VPERM2I128 $0x13, CC1, DD1, DD0
JMP sealAVX2SealHash
// ----------------------------------------------------------------------------
// Special optimization for the last 384 bytes of ciphertext
sealAVX2Tail384:
// Need to decrypt up to 384 bytes - prepare two blocks
// If we got here after the main loop - there are 512 encrypted bytes waiting to be hashed
// If we got here before the main loop - there are 448 encrpyred bytes waiting to be hashed
VMOVDQA ·chacha20Constants<>(SB), AA0; VMOVDQA AA0, AA1; VMOVDQA AA0, AA2
VMOVDQA state1StoreAVX2, BB0; VMOVDQA BB0, BB1; VMOVDQA BB0, BB2
VMOVDQA state2StoreAVX2, CC0; VMOVDQA CC0, CC1; VMOVDQA CC0, CC2
VMOVDQA ctr3StoreAVX2, DD0
VPADDD ·avx2IncMask<>(SB), DD0, DD0; VPADDD ·avx2IncMask<>(SB), DD0, DD1; VPADDD ·avx2IncMask<>(SB), DD1, DD2
VMOVDQA DD0, TT1; VMOVDQA DD1, TT2; VMOVDQA DD2, TT3
sealAVX2Tail384LoopA:
polyAdd(0(oup))
polyMul
LEAQ 16(oup), oup
sealAVX2Tail384LoopB:
chachaQR_AVX2(AA0, BB0, CC0, DD0, TT0); chachaQR_AVX2(AA1, BB1, CC1, DD1, TT0); chachaQR_AVX2(AA2, BB2, CC2, DD2, TT0)
polyAdd(0(oup))
polyMul
VPALIGNR $4, BB0, BB0, BB0; VPALIGNR $4, BB1, BB1, BB1; VPALIGNR $4, BB2, BB2, BB2
VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1; VPALIGNR $8, CC2, CC2, CC2
VPALIGNR $12, DD0, DD0, DD0; VPALIGNR $12, DD1, DD1, DD1; VPALIGNR $12, DD2, DD2, DD2
chachaQR_AVX2(AA0, BB0, CC0, DD0, TT0); chachaQR_AVX2(AA1, BB1, CC1, DD1, TT0); chachaQR_AVX2(AA2, BB2, CC2, DD2, TT0)
polyAdd(16(oup))
polyMul
LEAQ 32(oup), oup
VPALIGNR $12, BB0, BB0, BB0; VPALIGNR $12, BB1, BB1, BB1; VPALIGNR $12, BB2, BB2, BB2
VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1; VPALIGNR $8, CC2, CC2, CC2
VPALIGNR $4, DD0, DD0, DD0; VPALIGNR $4, DD1, DD1, DD1; VPALIGNR $4, DD2, DD2, DD2
DECQ itr1
JG sealAVX2Tail384LoopA
DECQ itr2
JGE sealAVX2Tail384LoopB
VPADDD ·chacha20Constants<>(SB), AA0, AA0; VPADDD ·chacha20Constants<>(SB), AA1, AA1; VPADDD ·chacha20Constants<>(SB), AA2, AA2
VPADDD state1StoreAVX2, BB0, BB0; VPADDD state1StoreAVX2, BB1, BB1; VPADDD state1StoreAVX2, BB2, BB2
VPADDD state2StoreAVX2, CC0, CC0; VPADDD state2StoreAVX2, CC1, CC1; VPADDD state2StoreAVX2, CC2, CC2
VPADDD TT1, DD0, DD0; VPADDD TT2, DD1, DD1; VPADDD TT3, DD2, DD2
VPERM2I128 $0x02, AA0, BB0, TT0
VPERM2I128 $0x02, CC0, DD0, TT1
VPERM2I128 $0x13, AA0, BB0, TT2
VPERM2I128 $0x13, CC0, DD0, TT3
VPXOR (0*32)(inp), TT0, TT0; VPXOR (1*32)(inp), TT1, TT1; VPXOR (2*32)(inp), TT2, TT2; VPXOR (3*32)(inp), TT3, TT3
VMOVDQU TT0, (0*32)(oup); VMOVDQU TT1, (1*32)(oup); VMOVDQU TT2, (2*32)(oup); VMOVDQU TT3, (3*32)(oup)
VPERM2I128 $0x02, AA1, BB1, TT0
VPERM2I128 $0x02, CC1, DD1, TT1
VPERM2I128 $0x13, AA1, BB1, TT2
VPERM2I128 $0x13, CC1, DD1, TT3
VPXOR (4*32)(inp), TT0, TT0; VPXOR (5*32)(inp), TT1, TT1; VPXOR (6*32)(inp), TT2, TT2; VPXOR (7*32)(inp), TT3, TT3
VMOVDQU TT0, (4*32)(oup); VMOVDQU TT1, (5*32)(oup); VMOVDQU TT2, (6*32)(oup); VMOVDQU TT3, (7*32)(oup)
MOVQ $256, itr1
LEAQ 256(inp), inp
SUBQ $256, inl
VPERM2I128 $0x02, AA2, BB2, AA0
VPERM2I128 $0x02, CC2, DD2, BB0
VPERM2I128 $0x13, AA2, BB2, CC0
VPERM2I128 $0x13, CC2, DD2, DD0
JMP sealAVX2SealHash
// ----------------------------------------------------------------------------
// Special optimization for the last 512 bytes of ciphertext
sealAVX2Tail512:
// Need to decrypt up to 512 bytes - prepare two blocks
// If we got here after the main loop - there are 512 encrypted bytes waiting to be hashed
// If we got here before the main loop - there are 448 encrpyred bytes waiting to be hashed
VMOVDQA ·chacha20Constants<>(SB), AA0; VMOVDQA AA0, AA1; VMOVDQA AA0, AA2; VMOVDQA AA0, AA3
VMOVDQA state1StoreAVX2, BB0; VMOVDQA BB0, BB1; VMOVDQA BB0, BB2; VMOVDQA BB0, BB3
VMOVDQA state2StoreAVX2, CC0; VMOVDQA CC0, CC1; VMOVDQA CC0, CC2; VMOVDQA CC0, CC3
VMOVDQA ctr3StoreAVX2, DD0
VPADDD ·avx2IncMask<>(SB), DD0, DD0; VPADDD ·avx2IncMask<>(SB), DD0, DD1; VPADDD ·avx2IncMask<>(SB), DD1, DD2; VPADDD ·avx2IncMask<>(SB), DD2, DD3
VMOVDQA DD0, ctr0StoreAVX2; VMOVDQA DD1, ctr1StoreAVX2; VMOVDQA DD2, ctr2StoreAVX2; VMOVDQA DD3, ctr3StoreAVX2
sealAVX2Tail512LoopA:
polyAdd(0(oup))
polyMul
LEAQ 16(oup), oup
sealAVX2Tail512LoopB:
VPADDD BB0, AA0, AA0; VPADDD BB1, AA1, AA1; VPADDD BB2, AA2, AA2; VPADDD BB3, AA3, AA3
VPXOR AA0, DD0, DD0; VPXOR AA1, DD1, DD1; VPXOR AA2, DD2, DD2; VPXOR AA3, DD3, DD3
VPSHUFB ·rol16<>(SB), DD0, DD0; VPSHUFB ·rol16<>(SB), DD1, DD1; VPSHUFB ·rol16<>(SB), DD2, DD2; VPSHUFB ·rol16<>(SB), DD3, DD3
VPADDD DD0, CC0, CC0; VPADDD DD1, CC1, CC1; VPADDD DD2, CC2, CC2; VPADDD DD3, CC3, CC3
VPXOR CC0, BB0, BB0; VPXOR CC1, BB1, BB1; VPXOR CC2, BB2, BB2; VPXOR CC3, BB3, BB3
VMOVDQA CC3, tmpStoreAVX2
VPSLLD $12, BB0, CC3; VPSRLD $20, BB0, BB0; VPXOR CC3, BB0, BB0
VPSLLD $12, BB1, CC3; VPSRLD $20, BB1, BB1; VPXOR CC3, BB1, BB1
VPSLLD $12, BB2, CC3; VPSRLD $20, BB2, BB2; VPXOR CC3, BB2, BB2
VPSLLD $12, BB3, CC3; VPSRLD $20, BB3, BB3; VPXOR CC3, BB3, BB3
VMOVDQA tmpStoreAVX2, CC3
polyAdd(0*8(oup))
polyMulAVX2
VPADDD BB0, AA0, AA0; VPADDD BB1, AA1, AA1; VPADDD BB2, AA2, AA2; VPADDD BB3, AA3, AA3
VPXOR AA0, DD0, DD0; VPXOR AA1, DD1, DD1; VPXOR AA2, DD2, DD2; VPXOR AA3, DD3, DD3
VPSHUFB ·rol8<>(SB), DD0, DD0; VPSHUFB ·rol8<>(SB), DD1, DD1; VPSHUFB ·rol8<>(SB), DD2, DD2; VPSHUFB ·rol8<>(SB), DD3, DD3
VPADDD DD0, CC0, CC0; VPADDD DD1, CC1, CC1; VPADDD DD2, CC2, CC2; VPADDD DD3, CC3, CC3
VPXOR CC0, BB0, BB0; VPXOR CC1, BB1, BB1; VPXOR CC2, BB2, BB2; VPXOR CC3, BB3, BB3
VMOVDQA CC3, tmpStoreAVX2
VPSLLD $7, BB0, CC3; VPSRLD $25, BB0, BB0; VPXOR CC3, BB0, BB0
VPSLLD $7, BB1, CC3; VPSRLD $25, BB1, BB1; VPXOR CC3, BB1, BB1
VPSLLD $7, BB2, CC3; VPSRLD $25, BB2, BB2; VPXOR CC3, BB2, BB2
VPSLLD $7, BB3, CC3; VPSRLD $25, BB3, BB3; VPXOR CC3, BB3, BB3
VMOVDQA tmpStoreAVX2, CC3
VPALIGNR $4, BB0, BB0, BB0; VPALIGNR $4, BB1, BB1, BB1; VPALIGNR $4, BB2, BB2, BB2; VPALIGNR $4, BB3, BB3, BB3
VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1; VPALIGNR $8, CC2, CC2, CC2; VPALIGNR $8, CC3, CC3, CC3
VPALIGNR $12, DD0, DD0, DD0; VPALIGNR $12, DD1, DD1, DD1; VPALIGNR $12, DD2, DD2, DD2; VPALIGNR $12, DD3, DD3, DD3
VPADDD BB0, AA0, AA0; VPADDD BB1, AA1, AA1; VPADDD BB2, AA2, AA2; VPADDD BB3, AA3, AA3
VPXOR AA0, DD0, DD0; VPXOR AA1, DD1, DD1; VPXOR AA2, DD2, DD2; VPXOR AA3, DD3, DD3
VPSHUFB ·rol16<>(SB), DD0, DD0; VPSHUFB ·rol16<>(SB), DD1, DD1; VPSHUFB ·rol16<>(SB), DD2, DD2; VPSHUFB ·rol16<>(SB), DD3, DD3
VPADDD DD0, CC0, CC0; VPADDD DD1, CC1, CC1; VPADDD DD2, CC2, CC2; VPADDD DD3, CC3, CC3
VPXOR CC0, BB0, BB0; VPXOR CC1, BB1, BB1; VPXOR CC2, BB2, BB2; VPXOR CC3, BB3, BB3
polyAdd(2*8(oup))
polyMulAVX2
LEAQ (4*8)(oup), oup
VMOVDQA CC3, tmpStoreAVX2
VPSLLD $12, BB0, CC3; VPSRLD $20, BB0, BB0; VPXOR CC3, BB0, BB0
VPSLLD $12, BB1, CC3; VPSRLD $20, BB1, BB1; VPXOR CC3, BB1, BB1
VPSLLD $12, BB2, CC3; VPSRLD $20, BB2, BB2; VPXOR CC3, BB2, BB2
VPSLLD $12, BB3, CC3; VPSRLD $20, BB3, BB3; VPXOR CC3, BB3, BB3
VMOVDQA tmpStoreAVX2, CC3
VPADDD BB0, AA0, AA0; VPADDD BB1, AA1, AA1; VPADDD BB2, AA2, AA2; VPADDD BB3, AA3, AA3
VPXOR AA0, DD0, DD0; VPXOR AA1, DD1, DD1; VPXOR AA2, DD2, DD2; VPXOR AA3, DD3, DD3
VPSHUFB ·rol8<>(SB), DD0, DD0; VPSHUFB ·rol8<>(SB), DD1, DD1; VPSHUFB ·rol8<>(SB), DD2, DD2; VPSHUFB ·rol8<>(SB), DD3, DD3
VPADDD DD0, CC0, CC0; VPADDD DD1, CC1, CC1; VPADDD DD2, CC2, CC2; VPADDD DD3, CC3, CC3
VPXOR CC0, BB0, BB0; VPXOR CC1, BB1, BB1; VPXOR CC2, BB2, BB2; VPXOR CC3, BB3, BB3
VMOVDQA CC3, tmpStoreAVX2
VPSLLD $7, BB0, CC3; VPSRLD $25, BB0, BB0; VPXOR CC3, BB0, BB0
VPSLLD $7, BB1, CC3; VPSRLD $25, BB1, BB1; VPXOR CC3, BB1, BB1
VPSLLD $7, BB2, CC3; VPSRLD $25, BB2, BB2; VPXOR CC3, BB2, BB2
VPSLLD $7, BB3, CC3; VPSRLD $25, BB3, BB3; VPXOR CC3, BB3, BB3
VMOVDQA tmpStoreAVX2, CC3
VPALIGNR $12, BB0, BB0, BB0; VPALIGNR $12, BB1, BB1, BB1; VPALIGNR $12, BB2, BB2, BB2; VPALIGNR $12, BB3, BB3, BB3
VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1; VPALIGNR $8, CC2, CC2, CC2; VPALIGNR $8, CC3, CC3, CC3
VPALIGNR $4, DD0, DD0, DD0; VPALIGNR $4, DD1, DD1, DD1; VPALIGNR $4, DD2, DD2, DD2; VPALIGNR $4, DD3, DD3, DD3
DECQ itr1
JG sealAVX2Tail512LoopA
DECQ itr2
JGE sealAVX2Tail512LoopB
VPADDD ·chacha20Constants<>(SB), AA0, AA0; VPADDD ·chacha20Constants<>(SB), AA1, AA1; VPADDD ·chacha20Constants<>(SB), AA2, AA2; VPADDD ·chacha20Constants<>(SB), AA3, AA3
VPADDD state1StoreAVX2, BB0, BB0; VPADDD state1StoreAVX2, BB1, BB1; VPADDD state1StoreAVX2, BB2, BB2; VPADDD state1StoreAVX2, BB3, BB3
VPADDD state2StoreAVX2, CC0, CC0; VPADDD state2StoreAVX2, CC1, CC1; VPADDD state2StoreAVX2, CC2, CC2; VPADDD state2StoreAVX2, CC3, CC3
VPADDD ctr0StoreAVX2, DD0, DD0; VPADDD ctr1StoreAVX2, DD1, DD1; VPADDD ctr2StoreAVX2, DD2, DD2; VPADDD ctr3StoreAVX2, DD3, DD3
VMOVDQA CC3, tmpStoreAVX2
VPERM2I128 $0x02, AA0, BB0, CC3
VPXOR (0*32)(inp), CC3, CC3
VMOVDQU CC3, (0*32)(oup)
VPERM2I128 $0x02, CC0, DD0, CC3
VPXOR (1*32)(inp), CC3, CC3
VMOVDQU CC3, (1*32)(oup)
VPERM2I128 $0x13, AA0, BB0, CC3
VPXOR (2*32)(inp), CC3, CC3
VMOVDQU CC3, (2*32)(oup)
VPERM2I128 $0x13, CC0, DD0, CC3
VPXOR (3*32)(inp), CC3, CC3
VMOVDQU CC3, (3*32)(oup)
VPERM2I128 $0x02, AA1, BB1, AA0
VPERM2I128 $0x02, CC1, DD1, BB0
VPERM2I128 $0x13, AA1, BB1, CC0
VPERM2I128 $0x13, CC1, DD1, DD0
VPXOR (4*32)(inp), AA0, AA0; VPXOR (5*32)(inp), BB0, BB0; VPXOR (6*32)(inp), CC0, CC0; VPXOR (7*32)(inp), DD0, DD0
VMOVDQU AA0, (4*32)(oup); VMOVDQU BB0, (5*32)(oup); VMOVDQU CC0, (6*32)(oup); VMOVDQU DD0, (7*32)(oup)
VPERM2I128 $0x02, AA2, BB2, AA0
VPERM2I128 $0x02, CC2, DD2, BB0
VPERM2I128 $0x13, AA2, BB2, CC0
VPERM2I128 $0x13, CC2, DD2, DD0
VPXOR (8*32)(inp), AA0, AA0; VPXOR (9*32)(inp), BB0, BB0; VPXOR (10*32)(inp), CC0, CC0; VPXOR (11*32)(inp), DD0, DD0
VMOVDQU AA0, (8*32)(oup); VMOVDQU BB0, (9*32)(oup); VMOVDQU CC0, (10*32)(oup); VMOVDQU DD0, (11*32)(oup)
MOVQ $384, itr1
LEAQ 384(inp), inp
SUBQ $384, inl
VPERM2I128 $0x02, AA3, BB3, AA0
VPERM2I128 $0x02, tmpStoreAVX2, DD3, BB0
VPERM2I128 $0x13, AA3, BB3, CC0
VPERM2I128 $0x13, tmpStoreAVX2, DD3, DD0
JMP sealAVX2SealHash
================================================
FILE: gossh/crypto/chacha20poly1305/chacha20poly1305_generic.go
================================================
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package chacha20poly1305
import (
"encoding/binary"
"gossh/crypto/chacha20"
"gossh/crypto/internal/alias"
"gossh/crypto/internal/poly1305"
)
func writeWithPadding(p *poly1305.MAC, b []byte) {
p.Write(b)
if rem := len(b) % 16; rem != 0 {
var buf [16]byte
padLen := 16 - rem
p.Write(buf[:padLen])
}
}
func writeUint64(p *poly1305.MAC, n int) {
var buf [8]byte
binary.LittleEndian.PutUint64(buf[:], uint64(n))
p.Write(buf[:])
}
func (c *chacha20poly1305) sealGeneric(dst, nonce, plaintext, additionalData []byte) []byte {
ret, out := sliceForAppend(dst, len(plaintext)+poly1305.TagSize)
ciphertext, tag := out[:len(plaintext)], out[len(plaintext):]
if alias.InexactOverlap(out, plaintext) {
panic("chacha20poly1305: invalid buffer overlap")
}
var polyKey [32]byte
s, _ := chacha20.NewUnauthenticatedCipher(c.key[:], nonce)
s.XORKeyStream(polyKey[:], polyKey[:])
s.SetCounter(1) // set the counter to 1, skipping 32 bytes
s.XORKeyStream(ciphertext, plaintext)
p := poly1305.New(&polyKey)
writeWithPadding(p, additionalData)
writeWithPadding(p, ciphertext)
writeUint64(p, len(additionalData))
writeUint64(p, len(plaintext))
p.Sum(tag[:0])
return ret
}
func (c *chacha20poly1305) openGeneric(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) {
tag := ciphertext[len(ciphertext)-16:]
ciphertext = ciphertext[:len(ciphertext)-16]
var polyKey [32]byte
s, _ := chacha20.NewUnauthenticatedCipher(c.key[:], nonce)
s.XORKeyStream(polyKey[:], polyKey[:])
s.SetCounter(1) // set the counter to 1, skipping 32 bytes
p := poly1305.New(&polyKey)
writeWithPadding(p, additionalData)
writeWithPadding(p, ciphertext)
writeUint64(p, len(additionalData))
writeUint64(p, len(ciphertext))
ret, out := sliceForAppend(dst, len(ciphertext))
if alias.InexactOverlap(out, ciphertext) {
panic("chacha20poly1305: invalid buffer overlap")
}
if !p.Verify(tag) {
for i := range out {
out[i] = 0
}
return nil, errOpen
}
s.XORKeyStream(out, ciphertext)
return ret, nil
}
================================================
FILE: gossh/crypto/chacha20poly1305/chacha20poly1305_noasm.go
================================================
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !amd64 || !gc || purego
package chacha20poly1305
func (c *chacha20poly1305) seal(dst, nonce, plaintext, additionalData []byte) []byte {
return c.sealGeneric(dst, nonce, plaintext, additionalData)
}
func (c *chacha20poly1305) open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) {
return c.openGeneric(dst, nonce, ciphertext, additionalData)
}
================================================
FILE: gossh/crypto/chacha20poly1305/xchacha20poly1305.go
================================================
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package chacha20poly1305
import (
"crypto/cipher"
"errors"
"gossh/crypto/chacha20"
)
type xchacha20poly1305 struct {
key [KeySize]byte
}
// NewX returns a XChaCha20-Poly1305 AEAD that uses the given 256-bit key.
//
// XChaCha20-Poly1305 is a ChaCha20-Poly1305 variant that takes a longer nonce,
// suitable to be generated randomly without risk of collisions. It should be
// preferred when nonce uniqueness cannot be trivially ensured, or whenever
// nonces are randomly generated.
func NewX(key []byte) (cipher.AEAD, error) {
if len(key) != KeySize {
return nil, errors.New("chacha20poly1305: bad key length")
}
ret := new(xchacha20poly1305)
copy(ret.key[:], key)
return ret, nil
}
func (*xchacha20poly1305) NonceSize() int {
return NonceSizeX
}
func (*xchacha20poly1305) Overhead() int {
return Overhead
}
func (x *xchacha20poly1305) Seal(dst, nonce, plaintext, additionalData []byte) []byte {
if len(nonce) != NonceSizeX {
panic("chacha20poly1305: bad nonce length passed to Seal")
}
// XChaCha20-Poly1305 technically supports a 64-bit counter, so there is no
// size limit. However, since we reuse the ChaCha20-Poly1305 implementation,
// the second half of the counter is not available. This is unlikely to be
// an issue because the cipher.AEAD API requires the entire message to be in
// memory, and the counter overflows at 256 GB.
if uint64(len(plaintext)) > (1<<38)-64 {
panic("chacha20poly1305: plaintext too large")
}
c := new(chacha20poly1305)
hKey, _ := chacha20.HChaCha20(x.key[:], nonce[0:16])
copy(c.key[:], hKey)
// The first 4 bytes of the final nonce are unused counter space.
cNonce := make([]byte, NonceSize)
copy(cNonce[4:12], nonce[16:24])
return c.seal(dst, cNonce[:], plaintext, additionalData)
}
func (x *xchacha20poly1305) Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) {
if len(nonce) != NonceSizeX {
panic("chacha20poly1305: bad nonce length passed to Open")
}
if len(ciphertext) < 16 {
return nil, errOpen
}
if uint64(len(ciphertext)) > (1<<38)-48 {
panic("chacha20poly1305: ciphertext too large")
}
c := new(chacha20poly1305)
hKey, _ := chacha20.HChaCha20(x.key[:], nonce[0:16])
copy(c.key[:], hKey)
// The first 4 bytes of the final nonce are unused counter space.
cNonce := make([]byte, NonceSize)
copy(cNonce[4:12], nonce[16:24])
return c.open(dst, cNonce[:], ciphertext, additionalData)
}
================================================
FILE: gossh/crypto/curve25519/curve25519.go
================================================
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package curve25519 provides an implementation of the X25519 function, which
// performs scalar multiplication on the elliptic curve known as Curve25519.
// See RFC 7748.
//
// Starting in Go 1.20, this package is a wrapper for the X25519 implementation
// in the crypto/ecdh package.
package curve25519 // import "golang.org/x/crypto/curve25519"
// ScalarMult sets dst to the product scalar * point.
//
// Deprecated: when provided a low-order point, ScalarMult will set dst to all
// zeroes, irrespective of the scalar. Instead, use the X25519 function, which
// will return an error.
func ScalarMult(dst, scalar, point *[32]byte) {
scalarMult(dst, scalar, point)
}
// ScalarBaseMult sets dst to the product scalar * base where base is the
// standard generator.
//
// It is recommended to use the X25519 function with Basepoint instead, as
// copying into fixed size arrays can lead to unexpected bugs.
func ScalarBaseMult(dst, scalar *[32]byte) {
scalarBaseMult(dst, scalar)
}
const (
// ScalarSize is the size of the scalar input to X25519.
ScalarSize = 32
// PointSize is the size of the point input to X25519.
PointSize = 32
)
// Basepoint is the canonical Curve25519 generator.
var Basepoint []byte
var basePoint = [32]byte{9}
func init() { Basepoint = basePoint[:] }
// X25519 returns the result of the scalar multiplication (scalar * point),
// according to RFC 7748, Section 5. scalar, point and the return value are
// slices of 32 bytes.
//
// scalar can be generated at random, for example with crypto/rand. point should
// be either Basepoint or the output of another X25519 call.
//
// If point is Basepoint (but not if it's a different slice with the same
// contents) a precomputed implementation might be used for performance.
func X25519(scalar, point []byte) ([]byte, error) {
// Outline the body of function, to let the allocation be inlined in the
// caller, and possibly avoid escaping to the heap.
var dst [32]byte
return x25519(&dst, scalar, point)
}
================================================
FILE: gossh/crypto/curve25519/curve25519_compat.go
================================================
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !go1.20
package curve25519
import (
"crypto/subtle"
"errors"
"strconv"
"gossh/crypto/curve25519/internal/field"
)
func scalarMult(dst, scalar, point *[32]byte) {
var e [32]byte
copy(e[:], scalar[:])
e[0] &= 248
e[31] &= 127
e[31] |= 64
var x1, x2, z2, x3, z3, tmp0, tmp1 field.Element
x1.SetBytes(point[:])
x2.One()
x3.Set(&x1)
z3.One()
swap := 0
for pos := 254; pos >= 0; pos-- {
b := e[pos/8] >> uint(pos&7)
b &= 1
swap ^= int(b)
x2.Swap(&x3, swap)
z2.Swap(&z3, swap)
swap = int(b)
tmp0.Subtract(&x3, &z3)
tmp1.Subtract(&x2, &z2)
x2.Add(&x2, &z2)
z2.Add(&x3, &z3)
z3.Multiply(&tmp0, &x2)
z2.Multiply(&z2, &tmp1)
tmp0.Square(&tmp1)
tmp1.Square(&x2)
x3.Add(&z3, &z2)
z2.Subtract(&z3, &z2)
x2.Multiply(&tmp1, &tmp0)
tmp1.Subtract(&tmp1, &tmp0)
z2.Square(&z2)
z3.Mult32(&tmp1, 121666)
x3.Square(&x3)
tmp0.Add(&tmp0, &z3)
z3.Multiply(&x1, &z2)
z2.Multiply(&tmp1, &tmp0)
}
x2.Swap(&x3, swap)
z2.Swap(&z3, swap)
z2.Invert(&z2)
x2.Multiply(&x2, &z2)
copy(dst[:], x2.Bytes())
}
func scalarBaseMult(dst, scalar *[32]byte) {
checkBasepoint()
scalarMult(dst, scalar, &basePoint)
}
func x25519(dst *[32]byte, scalar, point []byte) ([]byte, error) {
var in [32]byte
if l := len(scalar); l != 32 {
return nil, errors.New("bad scalar length: " + strconv.Itoa(l) + ", expected 32")
}
if l := len(point); l != 32 {
return nil, errors.New("bad point length: " + strconv.Itoa(l) + ", expected 32")
}
copy(in[:], scalar)
if &point[0] == &Basepoint[0] {
scalarBaseMult(dst, &in)
} else {
var base, zero [32]byte
copy(base[:], point)
scalarMult(dst, &in, &base)
if subtle.ConstantTimeCompare(dst[:], zero[:]) == 1 {
return nil, errors.New("bad input point: low order point")
}
}
return dst[:], nil
}
func checkBasepoint() {
if subtle.ConstantTimeCompare(Basepoint, []byte{
0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
}) != 1 {
panic("curve25519: global Basepoint value was modified")
}
}
================================================
FILE: gossh/crypto/curve25519/curve25519_go120.go
================================================
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.20
package curve25519
import "crypto/ecdh"
func x25519(dst *[32]byte, scalar, point []byte) ([]byte, error) {
curve := ecdh.X25519()
pub, err := curve.NewPublicKey(point)
if err != nil {
return nil, err
}
priv, err := curve.NewPrivateKey(scalar)
if err != nil {
return nil, err
}
out, err := priv.ECDH(pub)
if err != nil {
return nil, err
}
copy(dst[:], out)
return dst[:], nil
}
func scalarMult(dst, scalar, point *[32]byte) {
if _, err := x25519(dst, scalar[:], point[:]); err != nil {
// The only error condition for x25519 when the inputs are 32 bytes long
// is if the output would have been the all-zero value.
for i := range dst {
dst[i] = 0
}
}
}
func scalarBaseMult(dst, scalar *[32]byte) {
curve := ecdh.X25519()
priv, err := curve.NewPrivateKey(scalar[:])
if err != nil {
panic("curve25519: internal error: scalarBaseMult was not 32 bytes")
}
copy(dst[:], priv.PublicKey().Bytes())
}
================================================
FILE: gossh/crypto/curve25519/internal/field/fe.go
================================================
// Copyright (c) 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package field implements fast arithmetic modulo 2^255-19.
package field
import (
"crypto/subtle"
"encoding/binary"
"math/bits"
)
// Element represents an element of the field GF(2^255-19). Note that this
// is not a cryptographically secure group, and should only be used to interact
// with edwards25519.Point coordinates.
//
// This type works similarly to math/big.Int, and all arguments and receivers
// are allowed to alias.
//
// The zero value is a valid zero element.
type Element struct {
// An element t represents the integer
// t.l0 + t.l1*2^51 + t.l2*2^102 + t.l3*2^153 + t.l4*2^204
//
// Between operations, all limbs are expected to be lower than 2^52.
l0 uint64
l1 uint64
l2 uint64
l3 uint64
l4 uint64
}
const maskLow51Bits uint64 = (1 << 51) - 1
var feZero = &Element{0, 0, 0, 0, 0}
// Zero sets v = 0, and returns v.
func (v *Element) Zero() *Element {
*v = *feZero
return v
}
var feOne = &Element{1, 0, 0, 0, 0}
// One sets v = 1, and returns v.
func (v *Element) One() *Element {
*v = *feOne
return v
}
// reduce reduces v modulo 2^255 - 19 and returns it.
func (v *Element) reduce() *Element {
v.carryPropagate()
// After the light reduction we now have a field element representation
// v < 2^255 + 2^13 * 19, but need v < 2^255 - 19.
// If v >= 2^255 - 19, then v + 19 >= 2^255, which would overflow 2^255 - 1,
// generating a carry. That is, c will be 0 if v < 2^255 - 19, and 1 otherwise.
c := (v.l0 + 19) >> 51
c = (v.l1 + c) >> 51
c = (v.l2 + c) >> 51
c = (v.l3 + c) >> 51
c = (v.l4 + c) >> 51
// If v < 2^255 - 19 and c = 0, this will be a no-op. Otherwise, it's
// effectively applying the reduction identity to the carry.
v.l0 += 19 * c
v.l1 += v.l0 >> 51
v.l0 = v.l0 & maskLow51Bits
v.l2 += v.l1 >> 51
v.l1 = v.l1 & maskLow51Bits
v.l3 += v.l2 >> 51
v.l2 = v.l2 & maskLow51Bits
v.l4 += v.l3 >> 51
v.l3 = v.l3 & maskLow51Bits
// no additional carry
v.l4 = v.l4 & maskLow51Bits
return v
}
// Add sets v = a + b, and returns v.
func (v *Element) Add(a, b *Element) *Element {
v.l0 = a.l0 + b.l0
v.l1 = a.l1 + b.l1
v.l2 = a.l2 + b.l2
v.l3 = a.l3 + b.l3
v.l4 = a.l4 + b.l4
// Using the generic implementation here is actually faster than the
// assembly. Probably because the body of this function is so simple that
// the compiler can figure out better optimizations by inlining the carry
// propagation. TODO
return v.carryPropagateGeneric()
}
// Subtract sets v = a - b, and returns v.
func (v *Element) Subtract(a, b *Element) *Element {
// We first add 2 * p, to guarantee the subtraction won't underflow, and
// then subtract b (which can be up to 2^255 + 2^13 * 19).
v.l0 = (a.l0 + 0xFFFFFFFFFFFDA) - b.l0
v.l1 = (a.l1 + 0xFFFFFFFFFFFFE) - b.l1
v.l2 = (a.l2 + 0xFFFFFFFFFFFFE) - b.l2
v.l3 = (a.l3 + 0xFFFFFFFFFFFFE) - b.l3
v.l4 = (a.l4 + 0xFFFFFFFFFFFFE) - b.l4
return v.carryPropagate()
}
// Negate sets v = -a, and returns v.
func (v *Element) Negate(a *Element) *Element {
return v.Subtract(feZero, a)
}
// Invert sets v = 1/z mod p, and returns v.
//
// If z == 0, Invert returns v = 0.
func (v *Element) Invert(z *Element) *Element {
// Inversion is implemented as exponentiation with exponent p − 2. It uses the
// same sequence of 255 squarings and 11 multiplications as [Curve25519].
var z2, z9, z11, z2_5_0, z2_10_0, z2_20_0, z2_50_0, z2_100_0, t Element
z2.Square(z) // 2
t.Square(&z2) // 4
t.Square(&t) // 8
z9.Multiply(&t, z) // 9
z11.Multiply(&z9, &z2) // 11
t.Square(&z11) // 22
z2_5_0.Multiply(&t, &z9) // 31 = 2^5 - 2^0
t.Square(&z2_5_0) // 2^6 - 2^1
for i := 0; i < 4; i++ {
t.Square(&t) // 2^10 - 2^5
}
z2_10_0.Multiply(&t, &z2_5_0) // 2^10 - 2^0
t.Square(&z2_10_0) // 2^11 - 2^1
for i := 0; i < 9; i++ {
t.Square(&t) // 2^20 - 2^10
}
z2_20_0.Multiply(&t, &z2_10_0) // 2^20 - 2^0
t.Square(&z2_20_0) // 2^21 - 2^1
for i := 0; i < 19; i++ {
t.Square(&t) // 2^40 - 2^20
}
t.Multiply(&t, &z2_20_0) // 2^40 - 2^0
t.Square(&t) // 2^41 - 2^1
for i := 0; i < 9; i++ {
t.Square(&t) // 2^50 - 2^10
}
z2_50_0.Multiply(&t, &z2_10_0) // 2^50 - 2^0
t.Square(&z2_50_0) // 2^51 - 2^1
for i := 0; i < 49; i++ {
t.Square(&t) // 2^100 - 2^50
}
z2_100_0.Multiply(&t, &z2_50_0) // 2^100 - 2^0
t.Square(&z2_100_0) // 2^101 - 2^1
for i := 0; i < 99; i++ {
t.Square(&t) // 2^200 - 2^100
}
t.Multiply(&t, &z2_100_0) // 2^200 - 2^0
t.Square(&t) // 2^201 - 2^1
for i := 0; i < 49; i++ {
t.Square(&t) // 2^250 - 2^50
}
t.Multiply(&t, &z2_50_0) // 2^250 - 2^0
t.Square(&t) // 2^251 - 2^1
t.Square(&t) // 2^252 - 2^2
t.Square(&t) // 2^253 - 2^3
t.Square(&t) // 2^254 - 2^4
t.Square(&t) // 2^255 - 2^5
return v.Multiply(&t, &z11) // 2^255 - 21
}
// Set sets v = a, and returns v.
func (v *Element) Set(a *Element) *Element {
*v = *a
return v
}
// SetBytes sets v to x, which must be a 32-byte little-endian encoding.
//
// Consistent with RFC 7748, the most significant bit (the high bit of the
// last byte) is ignored, and non-canonical values (2^255-19 through 2^255-1)
// are accepted. Note that this is laxer than specified by RFC 8032.
func (v *Element) SetBytes(x []byte) *Element {
if len(x) != 32 {
panic("edwards25519: invalid field element input size")
}
// Bits 0:51 (bytes 0:8, bits 0:64, shift 0, mask 51).
v.l0 = binary.LittleEndian.Uint64(x[0:8])
v.l0 &= maskLow51Bits
// Bits 51:102 (bytes 6:14, bits 48:112, shift 3, mask 51).
v.l1 = binary.LittleEndian.Uint64(x[6:14]) >> 3
v.l1 &= maskLow51Bits
// Bits 102:153 (bytes 12:20, bits 96:160, shift 6, mask 51).
v.l2 = binary.LittleEndian.Uint64(x[12:20]) >> 6
v.l2 &= maskLow51Bits
// Bits 153:204 (bytes 19:27, bits 152:216, shift 1, mask 51).
v.l3 = binary.LittleEndian.Uint64(x[19:27]) >> 1
v.l3 &= maskLow51Bits
// Bits 204:251 (bytes 24:32, bits 192:256, shift 12, mask 51).
// Note: not bytes 25:33, shift 4, to avoid overread.
v.l4 = binary.LittleEndian.Uint64(x[24:32]) >> 12
v.l4 &= maskLow51Bits
return v
}
// Bytes returns the canonical 32-byte little-endian encoding of v.
func (v *Element) Bytes() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var out [32]byte
return v.bytes(&out)
}
func (v *Element) bytes(out *[32]byte) []byte {
t := *v
t.reduce()
var buf [8]byte
for i, l := range [5]uint64{t.l0, t.l1, t.l2, t.l3, t.l4} {
bitsOffset := i * 51
binary.LittleEndian.PutUint64(buf[:], l<= len(out) {
break
}
out[off] |= bb
}
}
return out[:]
}
// Equal returns 1 if v and u are equal, and 0 otherwise.
func (v *Element) Equal(u *Element) int {
sa, sv := u.Bytes(), v.Bytes()
return subtle.ConstantTimeCompare(sa, sv)
}
// mask64Bits returns 0xffffffff if cond is 1, and 0 otherwise.
func mask64Bits(cond int) uint64 { return ^(uint64(cond) - 1) }
// Select sets v to a if cond == 1, and to b if cond == 0.
func (v *Element) Select(a, b *Element, cond int) *Element {
m := mask64Bits(cond)
v.l0 = (m & a.l0) | (^m & b.l0)
v.l1 = (m & a.l1) | (^m & b.l1)
v.l2 = (m & a.l2) | (^m & b.l2)
v.l3 = (m & a.l3) | (^m & b.l3)
v.l4 = (m & a.l4) | (^m & b.l4)
return v
}
// Swap swaps v and u if cond == 1 or leaves them unchanged if cond == 0, and returns v.
func (v *Element) Swap(u *Element, cond int) {
m := mask64Bits(cond)
t := m & (v.l0 ^ u.l0)
v.l0 ^= t
u.l0 ^= t
t = m & (v.l1 ^ u.l1)
v.l1 ^= t
u.l1 ^= t
t = m & (v.l2 ^ u.l2)
v.l2 ^= t
u.l2 ^= t
t = m & (v.l3 ^ u.l3)
v.l3 ^= t
u.l3 ^= t
t = m & (v.l4 ^ u.l4)
v.l4 ^= t
u.l4 ^= t
}
// IsNegative returns 1 if v is negative, and 0 otherwise.
func (v *Element) IsNegative() int {
return int(v.Bytes()[0] & 1)
}
// Absolute sets v to |u|, and returns v.
func (v *Element) Absolute(u *Element) *Element {
return v.Select(new(Element).Negate(u), u, u.IsNegative())
}
// Multiply sets v = x * y, and returns v.
func (v *Element) Multiply(x, y *Element) *Element {
feMul(v, x, y)
return v
}
// Square sets v = x * x, and returns v.
func (v *Element) Square(x *Element) *Element {
feSquare(v, x)
return v
}
// Mult32 sets v = x * y, and returns v.
func (v *Element) Mult32(x *Element, y uint32) *Element {
x0lo, x0hi := mul51(x.l0, y)
x1lo, x1hi := mul51(x.l1, y)
x2lo, x2hi := mul51(x.l2, y)
x3lo, x3hi := mul51(x.l3, y)
x4lo, x4hi := mul51(x.l4, y)
v.l0 = x0lo + 19*x4hi // carried over per the reduction identity
v.l1 = x1lo + x0hi
v.l2 = x2lo + x1hi
v.l3 = x3lo + x2hi
v.l4 = x4lo + x3hi
// The hi portions are going to be only 32 bits, plus any previous excess,
// so we can skip the carry propagation.
return v
}
// mul51 returns lo + hi * 2⁵¹ = a * b.
func mul51(a uint64, b uint32) (lo uint64, hi uint64) {
mh, ml := bits.Mul64(a, uint64(b))
lo = ml & maskLow51Bits
hi = (mh << 13) | (ml >> 51)
return
}
// Pow22523 set v = x^((p-5)/8), and returns v. (p-5)/8 is 2^252-3.
func (v *Element) Pow22523(x *Element) *Element {
var t0, t1, t2 Element
t0.Square(x) // x^2
t1.Square(&t0) // x^4
t1.Square(&t1) // x^8
t1.Multiply(x, &t1) // x^9
t0.Multiply(&t0, &t1) // x^11
t0.Square(&t0) // x^22
t0.Multiply(&t1, &t0) // x^31
t1.Square(&t0) // x^62
for i := 1; i < 5; i++ { // x^992
t1.Square(&t1)
}
t0.Multiply(&t1, &t0) // x^1023 -> 1023 = 2^10 - 1
t1.Square(&t0) // 2^11 - 2
for i := 1; i < 10; i++ { // 2^20 - 2^10
t1.Square(&t1)
}
t1.Multiply(&t1, &t0) // 2^20 - 1
t2.Square(&t1) // 2^21 - 2
for i := 1; i < 20; i++ { // 2^40 - 2^20
t2.Square(&t2)
}
t1.Multiply(&t2, &t1) // 2^40 - 1
t1.Square(&t1) // 2^41 - 2
for i := 1; i < 10; i++ { // 2^50 - 2^10
t1.Square(&t1)
}
t0.Multiply(&t1, &t0) // 2^50 - 1
t1.Square(&t0) // 2^51 - 2
for i := 1; i < 50; i++ { // 2^100 - 2^50
t1.Square(&t1)
}
t1.Multiply(&t1, &t0) // 2^100 - 1
t2.Square(&t1) // 2^101 - 2
for i := 1; i < 100; i++ { // 2^200 - 2^100
t2.Square(&t2)
}
t1.Multiply(&t2, &t1) // 2^200 - 1
t1.Square(&t1) // 2^201 - 2
for i := 1; i < 50; i++ { // 2^250 - 2^50
t1.Square(&t1)
}
t0.Multiply(&t1, &t0) // 2^250 - 1
t0.Square(&t0) // 2^251 - 2
t0.Square(&t0) // 2^252 - 4
return v.Multiply(&t0, x) // 2^252 - 3 -> x^(2^252-3)
}
// sqrtM1 is 2^((p-1)/4), which squared is equal to -1 by Euler's Criterion.
var sqrtM1 = &Element{1718705420411056, 234908883556509,
2233514472574048, 2117202627021982, 765476049583133}
// SqrtRatio sets r to the non-negative square root of the ratio of u and v.
//
// If u/v is square, SqrtRatio returns r and 1. If u/v is not square, SqrtRatio
// sets r according to Section 4.3 of draft-irtf-cfrg-ristretto255-decaf448-00,
// and returns r and 0.
func (r *Element) SqrtRatio(u, v *Element) (rr *Element, wasSquare int) {
var a, b Element
// r = (u * v3) * (u * v7)^((p-5)/8)
v2 := a.Square(v)
uv3 := b.Multiply(u, b.Multiply(v2, v))
uv7 := a.Multiply(uv3, a.Square(v2))
r.Multiply(uv3, r.Pow22523(uv7))
check := a.Multiply(v, a.Square(r)) // check = v * r^2
uNeg := b.Negate(u)
correctSignSqrt := check.Equal(u)
flippedSignSqrt := check.Equal(uNeg)
flippedSignSqrtI := check.Equal(uNeg.Multiply(uNeg, sqrtM1))
rPrime := b.Multiply(r, sqrtM1) // r_prime = SQRT_M1 * r
// r = CT_SELECT(r_prime IF flipped_sign_sqrt | flipped_sign_sqrt_i ELSE r)
r.Select(rPrime, r, flippedSignSqrt|flippedSignSqrtI)
r.Absolute(r) // Choose the nonnegative square root.
return r, correctSignSqrt | flippedSignSqrt
}
================================================
FILE: gossh/crypto/curve25519/internal/field/fe_amd64.go
================================================
// Code generated by command: go run fe_amd64_asm.go -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg field. DO NOT EDIT.
//go:build amd64 && gc && !purego
package field
// feMul sets out = a * b. It works like feMulGeneric.
//
//go:noescape
func feMul(out *Element, a *Element, b *Element)
// feSquare sets out = a * a. It works like feSquareGeneric.
//
//go:noescape
func feSquare(out *Element, a *Element)
================================================
FILE: gossh/crypto/curve25519/internal/field/fe_amd64.s
================================================
// Code generated by command: go run fe_amd64_asm.go -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg field. DO NOT EDIT.
//go:build amd64 && gc && !purego
#include "textflag.h"
// func feMul(out *Element, a *Element, b *Element)
TEXT ·feMul(SB), NOSPLIT, $0-24
MOVQ a+8(FP), CX
MOVQ b+16(FP), BX
// r0 = a0×b0
MOVQ (CX), AX
MULQ (BX)
MOVQ AX, DI
MOVQ DX, SI
// r0 += 19×a1×b4
MOVQ 8(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 32(BX)
ADDQ AX, DI
ADCQ DX, SI
// r0 += 19×a2×b3
MOVQ 16(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 24(BX)
ADDQ AX, DI
ADCQ DX, SI
// r0 += 19×a3×b2
MOVQ 24(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 16(BX)
ADDQ AX, DI
ADCQ DX, SI
// r0 += 19×a4×b1
MOVQ 32(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 8(BX)
ADDQ AX, DI
ADCQ DX, SI
// r1 = a0×b1
MOVQ (CX), AX
MULQ 8(BX)
MOVQ AX, R9
MOVQ DX, R8
// r1 += a1×b0
MOVQ 8(CX), AX
MULQ (BX)
ADDQ AX, R9
ADCQ DX, R8
// r1 += 19×a2×b4
MOVQ 16(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 32(BX)
ADDQ AX, R9
ADCQ DX, R8
// r1 += 19×a3×b3
MOVQ 24(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 24(BX)
ADDQ AX, R9
ADCQ DX, R8
// r1 += 19×a4×b2
MOVQ 32(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 16(BX)
ADDQ AX, R9
ADCQ DX, R8
// r2 = a0×b2
MOVQ (CX), AX
MULQ 16(BX)
MOVQ AX, R11
MOVQ DX, R10
// r2 += a1×b1
MOVQ 8(CX), AX
MULQ 8(BX)
ADDQ AX, R11
ADCQ DX, R10
// r2 += a2×b0
MOVQ 16(CX), AX
MULQ (BX)
ADDQ AX, R11
ADCQ DX, R10
// r2 += 19×a3×b4
MOVQ 24(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 32(BX)
ADDQ AX, R11
ADCQ DX, R10
// r2 += 19×a4×b3
MOVQ 32(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 24(BX)
ADDQ AX, R11
ADCQ DX, R10
// r3 = a0×b3
MOVQ (CX), AX
MULQ 24(BX)
MOVQ AX, R13
MOVQ DX, R12
// r3 += a1×b2
MOVQ 8(CX), AX
MULQ 16(BX)
ADDQ AX, R13
ADCQ DX, R12
// r3 += a2×b1
MOVQ 16(CX), AX
MULQ 8(BX)
ADDQ AX, R13
ADCQ DX, R12
// r3 += a3×b0
MOVQ 24(CX), AX
MULQ (BX)
ADDQ AX, R13
ADCQ DX, R12
// r3 += 19×a4×b4
MOVQ 32(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 32(BX)
ADDQ AX, R13
ADCQ DX, R12
// r4 = a0×b4
MOVQ (CX), AX
MULQ 32(BX)
MOVQ AX, R15
MOVQ DX, R14
// r4 += a1×b3
MOVQ 8(CX), AX
MULQ 24(BX)
ADDQ AX, R15
ADCQ DX, R14
// r4 += a2×b2
MOVQ 16(CX), AX
MULQ 16(BX)
ADDQ AX, R15
ADCQ DX, R14
// r4 += a3×b1
MOVQ 24(CX), AX
MULQ 8(BX)
ADDQ AX, R15
ADCQ DX, R14
// r4 += a4×b0
MOVQ 32(CX), AX
MULQ (BX)
ADDQ AX, R15
ADCQ DX, R14
// First reduction chain
MOVQ $0x0007ffffffffffff, AX
SHLQ $0x0d, DI, SI
SHLQ $0x0d, R9, R8
SHLQ $0x0d, R11, R10
SHLQ $0x0d, R13, R12
SHLQ $0x0d, R15, R14
ANDQ AX, DI
IMUL3Q $0x13, R14, R14
ADDQ R14, DI
ANDQ AX, R9
ADDQ SI, R9
ANDQ AX, R11
ADDQ R8, R11
ANDQ AX, R13
ADDQ R10, R13
ANDQ AX, R15
ADDQ R12, R15
// Second reduction chain (carryPropagate)
MOVQ DI, SI
SHRQ $0x33, SI
MOVQ R9, R8
SHRQ $0x33, R8
MOVQ R11, R10
SHRQ $0x33, R10
MOVQ R13, R12
SHRQ $0x33, R12
MOVQ R15, R14
SHRQ $0x33, R14
ANDQ AX, DI
IMUL3Q $0x13, R14, R14
ADDQ R14, DI
ANDQ AX, R9
ADDQ SI, R9
ANDQ AX, R11
ADDQ R8, R11
ANDQ AX, R13
ADDQ R10, R13
ANDQ AX, R15
ADDQ R12, R15
// Store output
MOVQ out+0(FP), AX
MOVQ DI, (AX)
MOVQ R9, 8(AX)
MOVQ R11, 16(AX)
MOVQ R13, 24(AX)
MOVQ R15, 32(AX)
RET
// func feSquare(out *Element, a *Element)
TEXT ·feSquare(SB), NOSPLIT, $0-16
MOVQ a+8(FP), CX
// r0 = l0×l0
MOVQ (CX), AX
MULQ (CX)
MOVQ AX, SI
MOVQ DX, BX
// r0 += 38×l1×l4
MOVQ 8(CX), AX
IMUL3Q $0x26, AX, AX
MULQ 32(CX)
ADDQ AX, SI
ADCQ DX, BX
// r0 += 38×l2×l3
MOVQ 16(CX), AX
IMUL3Q $0x26, AX, AX
MULQ 24(CX)
ADDQ AX, SI
ADCQ DX, BX
// r1 = 2×l0×l1
MOVQ (CX), AX
SHLQ $0x01, AX
MULQ 8(CX)
MOVQ AX, R8
MOVQ DX, DI
// r1 += 38×l2×l4
MOVQ 16(CX), AX
IMUL3Q $0x26, AX, AX
MULQ 32(CX)
ADDQ AX, R8
ADCQ DX, DI
// r1 += 19×l3×l3
MOVQ 24(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 24(CX)
ADDQ AX, R8
ADCQ DX, DI
// r2 = 2×l0×l2
MOVQ (CX), AX
SHLQ $0x01, AX
MULQ 16(CX)
MOVQ AX, R10
MOVQ DX, R9
// r2 += l1×l1
MOVQ 8(CX), AX
MULQ 8(CX)
ADDQ AX, R10
ADCQ DX, R9
// r2 += 38×l3×l4
MOVQ 24(CX), AX
IMUL3Q $0x26, AX, AX
MULQ 32(CX)
ADDQ AX, R10
ADCQ DX, R9
// r3 = 2×l0×l3
MOVQ (CX), AX
SHLQ $0x01, AX
MULQ 24(CX)
MOVQ AX, R12
MOVQ DX, R11
// r3 += 2×l1×l2
MOVQ 8(CX), AX
IMUL3Q $0x02, AX, AX
MULQ 16(CX)
ADDQ AX, R12
ADCQ DX, R11
// r3 += 19×l4×l4
MOVQ 32(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 32(CX)
ADDQ AX, R12
ADCQ DX, R11
// r4 = 2×l0×l4
MOVQ (CX), AX
SHLQ $0x01, AX
MULQ 32(CX)
MOVQ AX, R14
MOVQ DX, R13
// r4 += 2×l1×l3
MOVQ 8(CX), AX
IMUL3Q $0x02, AX, AX
MULQ 24(CX)
ADDQ AX, R14
ADCQ DX, R13
// r4 += l2×l2
MOVQ 16(CX), AX
MULQ 16(CX)
ADDQ AX, R14
ADCQ DX, R13
// First reduction chain
MOVQ $0x0007ffffffffffff, AX
SHLQ $0x0d, SI, BX
SHLQ $0x0d, R8, DI
SHLQ $0x0d, R10, R9
SHLQ $0x0d, R12, R11
SHLQ $0x0d, R14, R13
ANDQ AX, SI
IMUL3Q $0x13, R13, R13
ADDQ R13, SI
ANDQ AX, R8
ADDQ BX, R8
ANDQ AX, R10
ADDQ DI, R10
ANDQ AX, R12
ADDQ R9, R12
ANDQ AX, R14
ADDQ R11, R14
// Second reduction chain (carryPropagate)
MOVQ SI, BX
SHRQ $0x33, BX
MOVQ R8, DI
SHRQ $0x33, DI
MOVQ R10, R9
SHRQ $0x33, R9
MOVQ R12, R11
SHRQ $0x33, R11
MOVQ R14, R13
SHRQ $0x33, R13
ANDQ AX, SI
IMUL3Q $0x13, R13, R13
ADDQ R13, SI
ANDQ AX, R8
ADDQ BX, R8
ANDQ AX, R10
ADDQ DI, R10
ANDQ AX, R12
ADDQ R9, R12
ANDQ AX, R14
ADDQ R11, R14
// Store output
MOVQ out+0(FP), AX
MOVQ SI, (AX)
MOVQ R8, 8(AX)
MOVQ R10, 16(AX)
MOVQ R12, 24(AX)
MOVQ R14, 32(AX)
RET
================================================
FILE: gossh/crypto/curve25519/internal/field/fe_amd64_noasm.go
================================================
// Copyright (c) 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !amd64 || !gc || purego
package field
func feMul(v, x, y *Element) { feMulGeneric(v, x, y) }
func feSquare(v, x *Element) { feSquareGeneric(v, x) }
================================================
FILE: gossh/crypto/curve25519/internal/field/fe_arm64.go
================================================
// Copyright (c) 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build arm64 && gc && !purego
package field
//go:noescape
func carryPropagate(v *Element)
func (v *Element) carryPropagate() *Element {
carryPropagate(v)
return v
}
================================================
FILE: gossh/crypto/curve25519/internal/field/fe_arm64.s
================================================
// Copyright (c) 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build arm64 && gc && !purego
#include "textflag.h"
// carryPropagate works exactly like carryPropagateGeneric and uses the
// same AND, ADD, and LSR+MADD instructions emitted by the compiler, but
// avoids loading R0-R4 twice and uses LDP and STP.
//
// See https://golang.org/issues/43145 for the main compiler issue.
//
// func carryPropagate(v *Element)
TEXT ·carryPropagate(SB),NOFRAME|NOSPLIT,$0-8
MOVD v+0(FP), R20
LDP 0(R20), (R0, R1)
LDP 16(R20), (R2, R3)
MOVD 32(R20), R4
AND $0x7ffffffffffff, R0, R10
AND $0x7ffffffffffff, R1, R11
AND $0x7ffffffffffff, R2, R12
AND $0x7ffffffffffff, R3, R13
AND $0x7ffffffffffff, R4, R14
ADD R0>>51, R11, R11
ADD R1>>51, R12, R12
ADD R2>>51, R13, R13
ADD R3>>51, R14, R14
// R4>>51 * 19 + R10 -> R10
LSR $51, R4, R21
MOVD $19, R22
MADD R22, R10, R21, R10
STP (R10, R11), 0(R20)
STP (R12, R13), 16(R20)
MOVD R14, 32(R20)
RET
================================================
FILE: gossh/crypto/curve25519/internal/field/fe_arm64_noasm.go
================================================
// Copyright (c) 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !arm64 || !gc || purego
package field
func (v *Element) carryPropagate() *Element {
return v.carryPropagateGeneric()
}
================================================
FILE: gossh/crypto/curve25519/internal/field/fe_generic.go
================================================
// Copyright (c) 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package field
import "math/bits"
// uint128 holds a 128-bit number as two 64-bit limbs, for use with the
// bits.Mul64 and bits.Add64 intrinsics.
type uint128 struct {
lo, hi uint64
}
// mul64 returns a * b.
func mul64(a, b uint64) uint128 {
hi, lo := bits.Mul64(a, b)
return uint128{lo, hi}
}
// addMul64 returns v + a * b.
func addMul64(v uint128, a, b uint64) uint128 {
hi, lo := bits.Mul64(a, b)
lo, c := bits.Add64(lo, v.lo, 0)
hi, _ = bits.Add64(hi, v.hi, c)
return uint128{lo, hi}
}
// shiftRightBy51 returns a >> 51. a is assumed to be at most 115 bits.
func shiftRightBy51(a uint128) uint64 {
return (a.hi << (64 - 51)) | (a.lo >> 51)
}
func feMulGeneric(v, a, b *Element) {
a0 := a.l0
a1 := a.l1
a2 := a.l2
a3 := a.l3
a4 := a.l4
b0 := b.l0
b1 := b.l1
b2 := b.l2
b3 := b.l3
b4 := b.l4
// Limb multiplication works like pen-and-paper columnar multiplication, but
// with 51-bit limbs instead of digits.
//
// a4 a3 a2 a1 a0 x
// b4 b3 b2 b1 b0 =
// ------------------------
// a4b0 a3b0 a2b0 a1b0 a0b0 +
// a4b1 a3b1 a2b1 a1b1 a0b1 +
// a4b2 a3b2 a2b2 a1b2 a0b2 +
// a4b3 a3b3 a2b3 a1b3 a0b3 +
// a4b4 a3b4 a2b4 a1b4 a0b4 =
// ----------------------------------------------
// r8 r7 r6 r5 r4 r3 r2 r1 r0
//
// We can then use the reduction identity (a * 2²⁵⁵ + b = a * 19 + b) to
// reduce the limbs that would overflow 255 bits. r5 * 2²⁵⁵ becomes 19 * r5,
// r6 * 2³⁰⁶ becomes 19 * r6 * 2⁵¹, etc.
//
// Reduction can be carried out simultaneously to multiplication. For
// example, we do not compute r5: whenever the result of a multiplication
// belongs to r5, like a1b4, we multiply it by 19 and add the result to r0.
//
// a4b0 a3b0 a2b0 a1b0 a0b0 +
// a3b1 a2b1 a1b1 a0b1 19×a4b1 +
// a2b2 a1b2 a0b2 19×a4b2 19×a3b2 +
// a1b3 a0b3 19×a4b3 19×a3b3 19×a2b3 +
// a0b4 19×a4b4 19×a3b4 19×a2b4 19×a1b4 =
// --------------------------------------
// r4 r3 r2 r1 r0
//
// Finally we add up the columns into wide, overlapping limbs.
a1_19 := a1 * 19
a2_19 := a2 * 19
a3_19 := a3 * 19
a4_19 := a4 * 19
// r0 = a0×b0 + 19×(a1×b4 + a2×b3 + a3×b2 + a4×b1)
r0 := mul64(a0, b0)
r0 = addMul64(r0, a1_19, b4)
r0 = addMul64(r0, a2_19, b3)
r0 = addMul64(r0, a3_19, b2)
r0 = addMul64(r0, a4_19, b1)
// r1 = a0×b1 + a1×b0 + 19×(a2×b4 + a3×b3 + a4×b2)
r1 := mul64(a0, b1)
r1 = addMul64(r1, a1, b0)
r1 = addMul64(r1, a2_19, b4)
r1 = addMul64(r1, a3_19, b3)
r1 = addMul64(r1, a4_19, b2)
// r2 = a0×b2 + a1×b1 + a2×b0 + 19×(a3×b4 + a4×b3)
r2 := mul64(a0, b2)
r2 = addMul64(r2, a1, b1)
r2 = addMul64(r2, a2, b0)
r2 = addMul64(r2, a3_19, b4)
r2 = addMul64(r2, a4_19, b3)
// r3 = a0×b3 + a1×b2 + a2×b1 + a3×b0 + 19×a4×b4
r3 := mul64(a0, b3)
r3 = addMul64(r3, a1, b2)
r3 = addMul64(r3, a2, b1)
r3 = addMul64(r3, a3, b0)
r3 = addMul64(r3, a4_19, b4)
// r4 = a0×b4 + a1×b3 + a2×b2 + a3×b1 + a4×b0
r4 := mul64(a0, b4)
r4 = addMul64(r4, a1, b3)
r4 = addMul64(r4, a2, b2)
r4 = addMul64(r4, a3, b1)
r4 = addMul64(r4, a4, b0)
// After the multiplication, we need to reduce (carry) the five coefficients
// to obtain a result with limbs that are at most slightly larger than 2⁵¹,
// to respect the Element invariant.
//
// Overall, the reduction works the same as carryPropagate, except with
// wider inputs: we take the carry for each coefficient by shifting it right
// by 51, and add it to the limb above it. The top carry is multiplied by 19
// according to the reduction identity and added to the lowest limb.
//
// The largest coefficient (r0) will be at most 111 bits, which guarantees
// that all carries are at most 111 - 51 = 60 bits, which fits in a uint64.
//
// r0 = a0×b0 + 19×(a1×b4 + a2×b3 + a3×b2 + a4×b1)
// r0 < 2⁵²×2⁵² + 19×(2⁵²×2⁵² + 2⁵²×2⁵² + 2⁵²×2⁵² + 2⁵²×2⁵²)
// r0 < (1 + 19 × 4) × 2⁵² × 2⁵²
// r0 < 2⁷ × 2⁵² × 2⁵²
// r0 < 2¹¹¹
//
// Moreover, the top coefficient (r4) is at most 107 bits, so c4 is at most
// 56 bits, and c4 * 19 is at most 61 bits, which again fits in a uint64 and
// allows us to easily apply the reduction identity.
//
// r4 = a0×b4 + a1×b3 + a2×b2 + a3×b1 + a4×b0
// r4 < 5 × 2⁵² × 2⁵²
// r4 < 2¹⁰⁷
//
c0 := shiftRightBy51(r0)
c1 := shiftRightBy51(r1)
c2 := shiftRightBy51(r2)
c3 := shiftRightBy51(r3)
c4 := shiftRightBy51(r4)
rr0 := r0.lo&maskLow51Bits + c4*19
rr1 := r1.lo&maskLow51Bits + c0
rr2 := r2.lo&maskLow51Bits + c1
rr3 := r3.lo&maskLow51Bits + c2
rr4 := r4.lo&maskLow51Bits + c3
// Now all coefficients fit into 64-bit registers but are still too large to
// be passed around as a Element. We therefore do one last carry chain,
// where the carries will be small enough to fit in the wiggle room above 2⁵¹.
*v = Element{rr0, rr1, rr2, rr3, rr4}
v.carryPropagate()
}
func feSquareGeneric(v, a *Element) {
l0 := a.l0
l1 := a.l1
l2 := a.l2
l3 := a.l3
l4 := a.l4
// Squaring works precisely like multiplication above, but thanks to its
// symmetry we get to group a few terms together.
//
// l4 l3 l2 l1 l0 x
// l4 l3 l2 l1 l0 =
// ------------------------
// l4l0 l3l0 l2l0 l1l0 l0l0 +
// l4l1 l3l1 l2l1 l1l1 l0l1 +
// l4l2 l3l2 l2l2 l1l2 l0l2 +
// l4l3 l3l3 l2l3 l1l3 l0l3 +
// l4l4 l3l4 l2l4 l1l4 l0l4 =
// ----------------------------------------------
// r8 r7 r6 r5 r4 r3 r2 r1 r0
//
// l4l0 l3l0 l2l0 l1l0 l0l0 +
// l3l1 l2l1 l1l1 l0l1 19×l4l1 +
// l2l2 l1l2 l0l2 19×l4l2 19×l3l2 +
// l1l3 l0l3 19×l4l3 19×l3l3 19×l2l3 +
// l0l4 19×l4l4 19×l3l4 19×l2l4 19×l1l4 =
// --------------------------------------
// r4 r3 r2 r1 r0
//
// With precomputed 2×, 19×, and 2×19× terms, we can compute each limb with
// only three Mul64 and four Add64, instead of five and eight.
l0_2 := l0 * 2
l1_2 := l1 * 2
l1_38 := l1 * 38
l2_38 := l2 * 38
l3_38 := l3 * 38
l3_19 := l3 * 19
l4_19 := l4 * 19
// r0 = l0×l0 + 19×(l1×l4 + l2×l3 + l3×l2 + l4×l1) = l0×l0 + 19×2×(l1×l4 + l2×l3)
r0 := mul64(l0, l0)
r0 = addMul64(r0, l1_38, l4)
r0 = addMul64(r0, l2_38, l3)
// r1 = l0×l1 + l1×l0 + 19×(l2×l4 + l3×l3 + l4×l2) = 2×l0×l1 + 19×2×l2×l4 + 19×l3×l3
r1 := mul64(l0_2, l1)
r1 = addMul64(r1, l2_38, l4)
r1 = addMul64(r1, l3_19, l3)
// r2 = l0×l2 + l1×l1 + l2×l0 + 19×(l3×l4 + l4×l3) = 2×l0×l2 + l1×l1 + 19×2×l3×l4
r2 := mul64(l0_2, l2)
r2 = addMul64(r2, l1, l1)
r2 = addMul64(r2, l3_38, l4)
// r3 = l0×l3 + l1×l2 + l2×l1 + l3×l0 + 19×l4×l4 = 2×l0×l3 + 2×l1×l2 + 19×l4×l4
r3 := mul64(l0_2, l3)
r3 = addMul64(r3, l1_2, l2)
r3 = addMul64(r3, l4_19, l4)
// r4 = l0×l4 + l1×l3 + l2×l2 + l3×l1 + l4×l0 = 2×l0×l4 + 2×l1×l3 + l2×l2
r4 := mul64(l0_2, l4)
r4 = addMul64(r4, l1_2, l3)
r4 = addMul64(r4, l2, l2)
c0 := shiftRightBy51(r0)
c1 := shiftRightBy51(r1)
c2 := shiftRightBy51(r2)
c3 := shiftRightBy51(r3)
c4 := shiftRightBy51(r4)
rr0 := r0.lo&maskLow51Bits + c4*19
rr1 := r1.lo&maskLow51Bits + c0
rr2 := r2.lo&maskLow51Bits + c1
rr3 := r3.lo&maskLow51Bits + c2
rr4 := r4.lo&maskLow51Bits + c3
*v = Element{rr0, rr1, rr2, rr3, rr4}
v.carryPropagate()
}
// carryPropagateGeneric brings the limbs below 52 bits by applying the reduction
// identity (a * 2²⁵⁵ + b = a * 19 + b) to the l4 carry. TODO inline
func (v *Element) carryPropagateGeneric() *Element {
c0 := v.l0 >> 51
c1 := v.l1 >> 51
c2 := v.l2 >> 51
c3 := v.l3 >> 51
c4 := v.l4 >> 51
v.l0 = v.l0&maskLow51Bits + c4*19
v.l1 = v.l1&maskLow51Bits + c0
v.l2 = v.l2&maskLow51Bits + c1
v.l3 = v.l3&maskLow51Bits + c2
v.l4 = v.l4&maskLow51Bits + c3
return v
}
================================================
FILE: gossh/crypto/curve25519/internal/field/sync.checkpoint
================================================
b0c49ae9f59d233526f8934262c5bbbe14d4358d
================================================
FILE: gossh/crypto/curve25519/internal/field/sync.sh
================================================
#! /bin/bash
set -euo pipefail
cd "$(git rev-parse --show-toplevel)"
STD_PATH=src/crypto/ed25519/internal/edwards25519/field
LOCAL_PATH=curve25519/internal/field
LAST_SYNC_REF=$(cat $LOCAL_PATH/sync.checkpoint)
git fetch https://go.googlesource.com/go master
if git diff --quiet $LAST_SYNC_REF:$STD_PATH FETCH_HEAD:$STD_PATH; then
echo "No changes."
else
NEW_REF=$(git rev-parse FETCH_HEAD | tee $LOCAL_PATH/sync.checkpoint)
echo "Applying changes from $LAST_SYNC_REF to $NEW_REF..."
git diff $LAST_SYNC_REF:$STD_PATH FETCH_HEAD:$STD_PATH | \
git apply -3 --directory=$LOCAL_PATH
fi
================================================
FILE: gossh/crypto/internal/alias/alias.go
================================================
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !purego
// Package alias implements memory aliasing tests.
package alias
import "unsafe"
// AnyOverlap reports whether x and y share memory at any (not necessarily
// corresponding) index. The memory beyond the slice length is ignored.
func AnyOverlap(x, y []byte) bool {
return len(x) > 0 && len(y) > 0 &&
uintptr(unsafe.Pointer(&x[0])) <= uintptr(unsafe.Pointer(&y[len(y)-1])) &&
uintptr(unsafe.Pointer(&y[0])) <= uintptr(unsafe.Pointer(&x[len(x)-1]))
}
// InexactOverlap reports whether x and y share memory at any non-corresponding
// index. The memory beyond the slice length is ignored. Note that x and y can
// have different lengths and still not have any inexact overlap.
//
// InexactOverlap can be used to implement the requirements of the crypto/cipher
// AEAD, Block, BlockMode and Stream interfaces.
func InexactOverlap(x, y []byte) bool {
if len(x) == 0 || len(y) == 0 || &x[0] == &y[0] {
return false
}
return AnyOverlap(x, y)
}
================================================
FILE: gossh/crypto/internal/alias/alias_purego.go
================================================
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build purego
// Package alias implements memory aliasing tests.
package alias
// This is the Google App Engine standard variant based on reflect
// because the unsafe package and cgo are disallowed.
import "reflect"
// AnyOverlap reports whether x and y share memory at any (not necessarily
// corresponding) index. The memory beyond the slice length is ignored.
func AnyOverlap(x, y []byte) bool {
return len(x) > 0 && len(y) > 0 &&
reflect.ValueOf(&x[0]).Pointer() <= reflect.ValueOf(&y[len(y)-1]).Pointer() &&
reflect.ValueOf(&y[0]).Pointer() <= reflect.ValueOf(&x[len(x)-1]).Pointer()
}
// InexactOverlap reports whether x and y share memory at any non-corresponding
// index. The memory beyond the slice length is ignored. Note that x and y can
// have different lengths and still not have any inexact overlap.
//
// InexactOverlap can be used to implement the requirements of the crypto/cipher
// AEAD, Block, BlockMode and Stream interfaces.
func InexactOverlap(x, y []byte) bool {
if len(x) == 0 || len(y) == 0 || &x[0] == &y[0] {
return false
}
return AnyOverlap(x, y)
}
================================================
FILE: gossh/crypto/internal/poly1305/mac_noasm.go
================================================
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build (!amd64 && !ppc64le && !s390x) || !gc || purego
package poly1305
type mac struct{ macGeneric }
================================================
FILE: gossh/crypto/internal/poly1305/poly1305.go
================================================
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package poly1305 implements Poly1305 one-time message authentication code as
// specified in https://cr.yp.to/mac/poly1305-20050329.pdf.
//
// Poly1305 is a fast, one-time authentication function. It is infeasible for an
// attacker to generate an authenticator for a message without the key. However, a
// key must only be used for a single message. Authenticating two different
// messages with the same key allows an attacker to forge authenticators for other
// messages with the same key.
//
// Poly1305 was originally coupled with AES in order to make Poly1305-AES. AES was
// used with a fixed key in order to generate one-time keys from an nonce.
// However, in this package AES isn't used and the one-time key is specified
// directly.
package poly1305
import "crypto/subtle"
// TagSize is the size, in bytes, of a poly1305 authenticator.
const TagSize = 16
// Sum generates an authenticator for msg using a one-time key and puts the
// 16-byte result into out. Authenticating two different messages with the same
// key allows an attacker to forge messages at will.
func Sum(out *[16]byte, m []byte, key *[32]byte) {
h := New(key)
h.Write(m)
h.Sum(out[:0])
}
// Verify returns true if mac is a valid authenticator for m with the given key.
func Verify(mac *[16]byte, m []byte, key *[32]byte) bool {
var tmp [16]byte
Sum(&tmp, m, key)
return subtle.ConstantTimeCompare(tmp[:], mac[:]) == 1
}
// New returns a new MAC computing an authentication
// tag of all data written to it with the given key.
// This allows writing the message progressively instead
// of passing it as a single slice. Common users should use
// the Sum function instead.
//
// The key must be unique for each message, as authenticating
// two different messages with the same key allows an attacker
// to forge messages at will.
func New(key *[32]byte) *MAC {
m := &MAC{}
initialize(key, &m.macState)
return m
}
// MAC is an io.Writer computing an authentication tag
// of the data written to it.
//
// MAC cannot be used like common hash.Hash implementations,
// because using a poly1305 key twice breaks its security.
// Therefore writing data to a running MAC after calling
// Sum or Verify causes it to panic.
type MAC struct {
mac // platform-dependent implementation
finalized bool
}
// Size returns the number of bytes Sum will return.
func (h *MAC) Size() int { return TagSize }
// Write adds more data to the running message authentication code.
// It never returns an error.
//
// It must not be called after the first call of Sum or Verify.
func (h *MAC) Write(p []byte) (n int, err error) {
if h.finalized {
panic("poly1305: write to MAC after Sum or Verify")
}
return h.mac.Write(p)
}
// Sum computes the authenticator of all data written to the
// message authentication code.
func (h *MAC) Sum(b []byte) []byte {
var mac [TagSize]byte
h.mac.Sum(&mac)
h.finalized = true
return append(b, mac[:]...)
}
// Verify returns whether the authenticator of all data written to
// the message authentication code matches the expected value.
func (h *MAC) Verify(expected []byte) bool {
var mac [TagSize]byte
h.mac.Sum(&mac)
h.finalized = true
return subtle.ConstantTimeCompare(expected, mac[:]) == 1
}
================================================
FILE: gossh/crypto/internal/poly1305/sum_amd64.go
================================================
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gc && !purego
package poly1305
//go:noescape
func update(state *macState, msg []byte)
// mac is a wrapper for macGeneric that redirects calls that would have gone to
// updateGeneric to update.
//
// Its Write and Sum methods are otherwise identical to the macGeneric ones, but
// using function pointers would carry a major performance cost.
type mac struct{ macGeneric }
func (h *mac) Write(p []byte) (int, error) {
nn := len(p)
if h.offset > 0 {
n := copy(h.buffer[h.offset:], p)
if h.offset+n < TagSize {
h.offset += n
return nn, nil
}
p = p[n:]
h.offset = 0
update(&h.macState, h.buffer[:])
}
if n := len(p) - (len(p) % TagSize); n > 0 {
update(&h.macState, p[:n])
p = p[n:]
}
if len(p) > 0 {
h.offset += copy(h.buffer[h.offset:], p)
}
return nn, nil
}
func (h *mac) Sum(out *[16]byte) {
state := h.macState
if h.offset > 0 {
update(&state, h.buffer[:h.offset])
}
finalize(out, &state.h, &state.s)
}
================================================
FILE: gossh/crypto/internal/poly1305/sum_amd64.s
================================================
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gc && !purego
#include "textflag.h"
#define POLY1305_ADD(msg, h0, h1, h2) \
ADDQ 0(msg), h0; \
ADCQ 8(msg), h1; \
ADCQ $1, h2; \
LEAQ 16(msg), msg
#define POLY1305_MUL(h0, h1, h2, r0, r1, t0, t1, t2, t3) \
MOVQ r0, AX; \
MULQ h0; \
MOVQ AX, t0; \
MOVQ DX, t1; \
MOVQ r0, AX; \
MULQ h1; \
ADDQ AX, t1; \
ADCQ $0, DX; \
MOVQ r0, t2; \
IMULQ h2, t2; \
ADDQ DX, t2; \
\
MOVQ r1, AX; \
MULQ h0; \
ADDQ AX, t1; \
ADCQ $0, DX; \
MOVQ DX, h0; \
MOVQ r1, t3; \
IMULQ h2, t3; \
MOVQ r1, AX; \
MULQ h1; \
ADDQ AX, t2; \
ADCQ DX, t3; \
ADDQ h0, t2; \
ADCQ $0, t3; \
\
MOVQ t0, h0; \
MOVQ t1, h1; \
MOVQ t2, h2; \
ANDQ $3, h2; \
MOVQ t2, t0; \
ANDQ $0xFFFFFFFFFFFFFFFC, t0; \
ADDQ t0, h0; \
ADCQ t3, h1; \
ADCQ $0, h2; \
SHRQ $2, t3, t2; \
SHRQ $2, t3; \
ADDQ t2, h0; \
ADCQ t3, h1; \
ADCQ $0, h2
// func update(state *[7]uint64, msg []byte)
TEXT ·update(SB), $0-32
MOVQ state+0(FP), DI
MOVQ msg_base+8(FP), SI
MOVQ msg_len+16(FP), R15
MOVQ 0(DI), R8 // h0
MOVQ 8(DI), R9 // h1
MOVQ 16(DI), R10 // h2
MOVQ 24(DI), R11 // r0
MOVQ 32(DI), R12 // r1
CMPQ R15, $16
JB bytes_between_0_and_15
loop:
POLY1305_ADD(SI, R8, R9, R10)
multiply:
POLY1305_MUL(R8, R9, R10, R11, R12, BX, CX, R13, R14)
SUBQ $16, R15
CMPQ R15, $16
JAE loop
bytes_between_0_and_15:
TESTQ R15, R15
JZ done
MOVQ $1, BX
XORQ CX, CX
XORQ R13, R13
ADDQ R15, SI
flush_buffer:
SHLQ $8, BX, CX
SHLQ $8, BX
MOVB -1(SI), R13
XORQ R13, BX
DECQ SI
DECQ R15
JNZ flush_buffer
ADDQ BX, R8
ADCQ CX, R9
ADCQ $0, R10
MOVQ $16, R15
JMP multiply
done:
MOVQ R8, 0(DI)
MOVQ R9, 8(DI)
MOVQ R10, 16(DI)
RET
================================================
FILE: gossh/crypto/internal/poly1305/sum_generic.go
================================================
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file provides the generic implementation of Sum and MAC. Other files
// might provide optimized assembly implementations of some of this code.
package poly1305
import (
"encoding/binary"
"math/bits"
)
// Poly1305 [RFC 7539] is a relatively simple algorithm: the authentication tag
// for a 64 bytes message is approximately
//
// s + m[0:16] * r⁴ + m[16:32] * r³ + m[32:48] * r² + m[48:64] * r mod 2¹³⁰ - 5
//
// for some secret r and s. It can be computed sequentially like
//
// for len(msg) > 0:
// h += read(msg, 16)
// h *= r
// h %= 2¹³⁰ - 5
// return h + s
//
// All the complexity is about doing performant constant-time math on numbers
// larger than any available numeric type.
func sumGeneric(out *[TagSize]byte, msg []byte, key *[32]byte) {
h := newMACGeneric(key)
h.Write(msg)
h.Sum(out)
}
func newMACGeneric(key *[32]byte) macGeneric {
m := macGeneric{}
initialize(key, &m.macState)
return m
}
// macState holds numbers in saturated 64-bit little-endian limbs. That is,
// the value of [x0, x1, x2] is x[0] + x[1] * 2⁶⁴ + x[2] * 2¹²⁸.
type macState struct {
// h is the main accumulator. It is to be interpreted modulo 2¹³⁰ - 5, but
// can grow larger during and after rounds. It must, however, remain below
// 2 * (2¹³⁰ - 5).
h [3]uint64
// r and s are the private key components.
r [2]uint64
s [2]uint64
}
type macGeneric struct {
macState
buffer [TagSize]byte
offset int
}
// Write splits the incoming message into TagSize chunks, and passes them to
// update. It buffers incomplete chunks.
func (h *macGeneric) Write(p []byte) (int, error) {
nn := len(p)
if h.offset > 0 {
n := copy(h.buffer[h.offset:], p)
if h.offset+n < TagSize {
h.offset += n
return nn, nil
}
p = p[n:]
h.offset = 0
updateGeneric(&h.macState, h.buffer[:])
}
if n := len(p) - (len(p) % TagSize); n > 0 {
updateGeneric(&h.macState, p[:n])
p = p[n:]
}
if len(p) > 0 {
h.offset += copy(h.buffer[h.offset:], p)
}
return nn, nil
}
// Sum flushes the last incomplete chunk from the buffer, if any, and generates
// the MAC output. It does not modify its state, in order to allow for multiple
// calls to Sum, even if no Write is allowed after Sum.
func (h *macGeneric) Sum(out *[TagSize]byte) {
state := h.macState
if h.offset > 0 {
updateGeneric(&state, h.buffer[:h.offset])
}
finalize(out, &state.h, &state.s)
}
// [rMask0, rMask1] is the specified Poly1305 clamping mask in little-endian. It
// clears some bits of the secret coefficient to make it possible to implement
// multiplication more efficiently.
const (
rMask0 = 0x0FFFFFFC0FFFFFFF
rMask1 = 0x0FFFFFFC0FFFFFFC
)
// initialize loads the 256-bit key into the two 128-bit secret values r and s.
func initialize(key *[32]byte, m *macState) {
m.r[0] = binary.LittleEndian.Uint64(key[0:8]) & rMask0
m.r[1] = binary.LittleEndian.Uint64(key[8:16]) & rMask1
m.s[0] = binary.LittleEndian.Uint64(key[16:24])
m.s[1] = binary.LittleEndian.Uint64(key[24:32])
}
// uint128 holds a 128-bit number as two 64-bit limbs, for use with the
// bits.Mul64 and bits.Add64 intrinsics.
type uint128 struct {
lo, hi uint64
}
func mul64(a, b uint64) uint128 {
hi, lo := bits.Mul64(a, b)
return uint128{lo, hi}
}
func add128(a, b uint128) uint128 {
lo, c := bits.Add64(a.lo, b.lo, 0)
hi, c := bits.Add64(a.hi, b.hi, c)
if c != 0 {
panic("poly1305: unexpected overflow")
}
return uint128{lo, hi}
}
func shiftRightBy2(a uint128) uint128 {
a.lo = a.lo>>2 | (a.hi&3)<<62
a.hi = a.hi >> 2
return a
}
// updateGeneric absorbs msg into the state.h accumulator. For each chunk m of
// 128 bits of message, it computes
//
// h₊ = (h + m) * r mod 2¹³⁰ - 5
//
// If the msg length is not a multiple of TagSize, it assumes the last
// incomplete chunk is the final one.
func updateGeneric(state *macState, msg []byte) {
h0, h1, h2 := state.h[0], state.h[1], state.h[2]
r0, r1 := state.r[0], state.r[1]
for len(msg) > 0 {
var c uint64
// For the first step, h + m, we use a chain of bits.Add64 intrinsics.
// The resulting value of h might exceed 2¹³⁰ - 5, but will be partially
// reduced at the end of the multiplication below.
//
// The spec requires us to set a bit just above the message size, not to
// hide leading zeroes. For full chunks, that's 1 << 128, so we can just
// add 1 to the most significant (2¹²⁸) limb, h2.
if len(msg) >= TagSize {
h0, c = bits.Add64(h0, binary.LittleEndian.Uint64(msg[0:8]), 0)
h1, c = bits.Add64(h1, binary.LittleEndian.Uint64(msg[8:16]), c)
h2 += c + 1
msg = msg[TagSize:]
} else {
var buf [TagSize]byte
copy(buf[:], msg)
buf[len(msg)] = 1
h0, c = bits.Add64(h0, binary.LittleEndian.Uint64(buf[0:8]), 0)
h1, c = bits.Add64(h1, binary.LittleEndian.Uint64(buf[8:16]), c)
h2 += c
msg = nil
}
// Multiplication of big number limbs is similar to elementary school
// columnar multiplication. Instead of digits, there are 64-bit limbs.
//
// We are multiplying a 3 limbs number, h, by a 2 limbs number, r.
//
// h2 h1 h0 x
// r1 r0 =
// ----------------
// h2r0 h1r0 h0r0 <-- individual 128-bit products
// + h2r1 h1r1 h0r1
// ------------------------
// m3 m2 m1 m0 <-- result in 128-bit overlapping limbs
// ------------------------
// m3.hi m2.hi m1.hi m0.hi <-- carry propagation
// + m3.lo m2.lo m1.lo m0.lo
// -------------------------------
// t4 t3 t2 t1 t0 <-- final result in 64-bit limbs
//
// The main difference from pen-and-paper multiplication is that we do
// carry propagation in a separate step, as if we wrote two digit sums
// at first (the 128-bit limbs), and then carried the tens all at once.
h0r0 := mul64(h0, r0)
h1r0 := mul64(h1, r0)
h2r0 := mul64(h2, r0)
h0r1 := mul64(h0, r1)
h1r1 := mul64(h1, r1)
h2r1 := mul64(h2, r1)
// Since h2 is known to be at most 7 (5 + 1 + 1), and r0 and r1 have their
// top 4 bits cleared by rMask{0,1}, we know that their product is not going
// to overflow 64 bits, so we can ignore the high part of the products.
//
// This also means that the product doesn't have a fifth limb (t4).
if h2r0.hi != 0 {
panic("poly1305: unexpected overflow")
}
if h2r1.hi != 0 {
panic("poly1305: unexpected overflow")
}
m0 := h0r0
m1 := add128(h1r0, h0r1) // These two additions don't overflow thanks again
m2 := add128(h2r0, h1r1) // to the 4 masked bits at the top of r0 and r1.
m3 := h2r1
t0 := m0.lo
t1, c := bits.Add64(m1.lo, m0.hi, 0)
t2, c := bits.Add64(m2.lo, m1.hi, c)
t3, _ := bits.Add64(m3.lo, m2.hi, c)
// Now we have the result as 4 64-bit limbs, and we need to reduce it
// modulo 2¹³⁰ - 5. The special shape of this Crandall prime lets us do
// a cheap partial reduction according to the reduction identity
//
// c * 2¹³⁰ + n = c * 5 + n mod 2¹³⁰ - 5
//
// because 2¹³⁰ = 5 mod 2¹³⁰ - 5. Partial reduction since the result is
// likely to be larger than 2¹³⁰ - 5, but still small enough to fit the
// assumptions we make about h in the rest of the code.
//
// See also https://speakerdeck.com/gtank/engineering-prime-numbers?slide=23
// We split the final result at the 2¹³⁰ mark into h and cc, the carry.
// Note that the carry bits are effectively shifted left by 2, in other
// words, cc = c * 4 for the c in the reduction identity.
h0, h1, h2 = t0, t1, t2&maskLow2Bits
cc := uint128{t2 & maskNotLow2Bits, t3}
// To add c * 5 to h, we first add cc = c * 4, and then add (cc >> 2) = c.
h0, c = bits.Add64(h0, cc.lo, 0)
h1, c = bits.Add64(h1, cc.hi, c)
h2 += c
cc = shiftRightBy2(cc)
h0, c = bits.Add64(h0, cc.lo, 0)
h1, c = bits.Add64(h1, cc.hi, c)
h2 += c
// h2 is at most 3 + 1 + 1 = 5, making the whole of h at most
//
// 5 * 2¹²⁸ + (2¹²⁸ - 1) = 6 * 2¹²⁸ - 1
}
state.h[0], state.h[1], state.h[2] = h0, h1, h2
}
const (
maskLow2Bits uint64 = 0x0000000000000003
maskNotLow2Bits uint64 = ^maskLow2Bits
)
// select64 returns x if v == 1 and y if v == 0, in constant time.
func select64(v, x, y uint64) uint64 { return ^(v-1)&x | (v-1)&y }
// [p0, p1, p2] is 2¹³⁰ - 5 in little endian order.
const (
p0 = 0xFFFFFFFFFFFFFFFB
p1 = 0xFFFFFFFFFFFFFFFF
p2 = 0x0000000000000003
)
// finalize completes the modular reduction of h and computes
//
// out = h + s mod 2¹²⁸
func finalize(out *[TagSize]byte, h *[3]uint64, s *[2]uint64) {
h0, h1, h2 := h[0], h[1], h[2]
// After the partial reduction in updateGeneric, h might be more than
// 2¹³⁰ - 5, but will be less than 2 * (2¹³⁰ - 5). To complete the reduction
// in constant time, we compute t = h - (2¹³⁰ - 5), and select h as the
// result if the subtraction underflows, and t otherwise.
hMinusP0, b := bits.Sub64(h0, p0, 0)
hMinusP1, b := bits.Sub64(h1, p1, b)
_, b = bits.Sub64(h2, p2, b)
// h = h if h < p else h - p
h0 = select64(b, h0, hMinusP0)
h1 = select64(b, h1, hMinusP1)
// Finally, we compute the last Poly1305 step
//
// tag = h + s mod 2¹²⁸
//
// by just doing a wide addition with the 128 low bits of h and discarding
// the overflow.
h0, c := bits.Add64(h0, s[0], 0)
h1, _ = bits.Add64(h1, s[1], c)
binary.LittleEndian.PutUint64(out[0:8], h0)
binary.LittleEndian.PutUint64(out[8:16], h1)
}
================================================
FILE: gossh/crypto/internal/poly1305/sum_ppc64le.go
================================================
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gc && !purego
package poly1305
//go:noescape
func update(state *macState, msg []byte)
// mac is a wrapper for macGeneric that redirects calls that would have gone to
// updateGeneric to update.
//
// Its Write and Sum methods are otherwise identical to the macGeneric ones, but
// using function pointers would carry a major performance cost.
type mac struct{ macGeneric }
func (h *mac) Write(p []byte) (int, error) {
nn := len(p)
if h.offset > 0 {
n := copy(h.buffer[h.offset:], p)
if h.offset+n < TagSize {
h.offset += n
return nn, nil
}
p = p[n:]
h.offset = 0
update(&h.macState, h.buffer[:])
}
if n := len(p) - (len(p) % TagSize); n > 0 {
update(&h.macState, p[:n])
p = p[n:]
}
if len(p) > 0 {
h.offset += copy(h.buffer[h.offset:], p)
}
return nn, nil
}
func (h *mac) Sum(out *[16]byte) {
state := h.macState
if h.offset > 0 {
update(&state, h.buffer[:h.offset])
}
finalize(out, &state.h, &state.s)
}
================================================
FILE: gossh/crypto/internal/poly1305/sum_ppc64le.s
================================================
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gc && !purego
#include "textflag.h"
// This was ported from the amd64 implementation.
#define POLY1305_ADD(msg, h0, h1, h2, t0, t1, t2) \
MOVD (msg), t0; \
MOVD 8(msg), t1; \
MOVD $1, t2; \
ADDC t0, h0, h0; \
ADDE t1, h1, h1; \
ADDE t2, h2; \
ADD $16, msg
#define POLY1305_MUL(h0, h1, h2, r0, r1, t0, t1, t2, t3, t4, t5) \
MULLD r0, h0, t0; \
MULHDU r0, h0, t1; \
MULLD r0, h1, t4; \
MULHDU r0, h1, t5; \
ADDC t4, t1, t1; \
MULLD r0, h2, t2; \
MULHDU r1, h0, t4; \
MULLD r1, h0, h0; \
ADDE t5, t2, t2; \
ADDC h0, t1, t1; \
MULLD h2, r1, t3; \
ADDZE t4, h0; \
MULHDU r1, h1, t5; \
MULLD r1, h1, t4; \
ADDC t4, t2, t2; \
ADDE t5, t3, t3; \
ADDC h0, t2, t2; \
MOVD $-4, t4; \
ADDZE t3; \
RLDICL $0, t2, $62, h2; \
AND t2, t4, h0; \
ADDC t0, h0, h0; \
ADDE t3, t1, h1; \
SLD $62, t3, t4; \
SRD $2, t2; \
ADDZE h2; \
OR t4, t2, t2; \
SRD $2, t3; \
ADDC t2, h0, h0; \
ADDE t3, h1, h1; \
ADDZE h2
DATA ·poly1305Mask<>+0x00(SB)/8, $0x0FFFFFFC0FFFFFFF
DATA ·poly1305Mask<>+0x08(SB)/8, $0x0FFFFFFC0FFFFFFC
GLOBL ·poly1305Mask<>(SB), RODATA, $16
// func update(state *[7]uint64, msg []byte)
TEXT ·update(SB), $0-32
MOVD state+0(FP), R3
MOVD msg_base+8(FP), R4
MOVD msg_len+16(FP), R5
MOVD 0(R3), R8 // h0
MOVD 8(R3), R9 // h1
MOVD 16(R3), R10 // h2
MOVD 24(R3), R11 // r0
MOVD 32(R3), R12 // r1
CMP R5, $16
BLT bytes_between_0_and_15
loop:
POLY1305_ADD(R4, R8, R9, R10, R20, R21, R22)
PCALIGN $16
multiply:
POLY1305_MUL(R8, R9, R10, R11, R12, R16, R17, R18, R14, R20, R21)
ADD $-16, R5
CMP R5, $16
BGE loop
bytes_between_0_and_15:
CMP R5, $0
BEQ done
MOVD $0, R16 // h0
MOVD $0, R17 // h1
flush_buffer:
CMP R5, $8
BLE just1
MOVD $8, R21
SUB R21, R5, R21
// Greater than 8 -- load the rightmost remaining bytes in msg
// and put into R17 (h1)
MOVD (R4)(R21), R17
MOVD $16, R22
// Find the offset to those bytes
SUB R5, R22, R22
SLD $3, R22
// Shift to get only the bytes in msg
SRD R22, R17, R17
// Put 1 at high end
MOVD $1, R23
SLD $3, R21
SLD R21, R23, R23
OR R23, R17, R17
// Remainder is 8
MOVD $8, R5
just1:
CMP R5, $8
BLT less8
// Exactly 8
MOVD (R4), R16
CMP R17, $0
// Check if we've already set R17; if not
// set 1 to indicate end of msg.
BNE carry
MOVD $1, R17
BR carry
less8:
MOVD $0, R16 // h0
MOVD $0, R22 // shift count
CMP R5, $4
BLT less4
MOVWZ (R4), R16
ADD $4, R4
ADD $-4, R5
MOVD $32, R22
less4:
CMP R5, $2
BLT less2
MOVHZ (R4), R21
SLD R22, R21, R21
OR R16, R21, R16
ADD $16, R22
ADD $-2, R5
ADD $2, R4
less2:
CMP R5, $0
BEQ insert1
MOVBZ (R4), R21
SLD R22, R21, R21
OR R16, R21, R16
ADD $8, R22
insert1:
// Insert 1 at end of msg
MOVD $1, R21
SLD R22, R21, R21
OR R16, R21, R16
carry:
// Add new values to h0, h1, h2
ADDC R16, R8
ADDE R17, R9
ADDZE R10, R10
MOVD $16, R5
ADD R5, R4
BR multiply
done:
// Save h0, h1, h2 in state
MOVD R8, 0(R3)
MOVD R9, 8(R3)
MOVD R10, 16(R3)
RET
================================================
FILE: gossh/crypto/internal/poly1305/sum_s390x.go
================================================
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gc && !purego
package poly1305
import (
"gossh/sys/cpu"
)
// updateVX is an assembly implementation of Poly1305 that uses vector
// instructions. It must only be called if the vector facility (vx) is
// available.
//
//go:noescape
func updateVX(state *macState, msg []byte)
// mac is a replacement for macGeneric that uses a larger buffer and redirects
// calls that would have gone to updateGeneric to updateVX if the vector
// facility is installed.
//
// A larger buffer is required for good performance because the vector
// implementation has a higher fixed cost per call than the generic
// implementation.
type mac struct {
macState
buffer [16 * TagSize]byte // size must be a multiple of block size (16)
offset int
}
func (h *mac) Write(p []byte) (int, error) {
nn := len(p)
if h.offset > 0 {
n := copy(h.buffer[h.offset:], p)
if h.offset+n < len(h.buffer) {
h.offset += n
return nn, nil
}
p = p[n:]
h.offset = 0
if cpu.S390X.HasVX {
updateVX(&h.macState, h.buffer[:])
} else {
updateGeneric(&h.macState, h.buffer[:])
}
}
tail := len(p) % len(h.buffer) // number of bytes to copy into buffer
body := len(p) - tail // number of bytes to process now
if body > 0 {
if cpu.S390X.HasVX {
updateVX(&h.macState, p[:body])
} else {
updateGeneric(&h.macState, p[:body])
}
}
h.offset = copy(h.buffer[:], p[body:]) // copy tail bytes - can be 0
return nn, nil
}
func (h *mac) Sum(out *[TagSize]byte) {
state := h.macState
remainder := h.buffer[:h.offset]
// Use the generic implementation if we have 2 or fewer blocks left
// to sum. The vector implementation has a higher startup time.
if cpu.S390X.HasVX && len(remainder) > 2*TagSize {
updateVX(&state, remainder)
} else if len(remainder) > 0 {
updateGeneric(&state, remainder)
}
finalize(out, &state.h, &state.s)
}
================================================
FILE: gossh/crypto/internal/poly1305/sum_s390x.s
================================================
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gc && !purego
#include "textflag.h"
// This implementation of Poly1305 uses the vector facility (vx)
// to process up to 2 blocks (32 bytes) per iteration using an
// algorithm based on the one described in:
//
// NEON crypto, Daniel J. Bernstein & Peter Schwabe
// https://cryptojedi.org/papers/neoncrypto-20120320.pdf
//
// This algorithm uses 5 26-bit limbs to represent a 130-bit
// value. These limbs are, for the most part, zero extended and
// placed into 64-bit vector register elements. Each vector
// register is 128-bits wide and so holds 2 of these elements.
// Using 26-bit limbs allows us plenty of headroom to accommodate
// accumulations before and after multiplication without
// overflowing either 32-bits (before multiplication) or 64-bits
// (after multiplication).
//
// In order to parallelise the operations required to calculate
// the sum we use two separate accumulators and then sum those
// in an extra final step. For compatibility with the generic
// implementation we perform this summation at the end of every
// updateVX call.
//
// To use two accumulators we must multiply the message blocks
// by r² rather than r. Only the final message block should be
// multiplied by r.
//
// Example:
//
// We want to calculate the sum (h) for a 64 byte message (m):
//
// h = m[0:16]r⁴ + m[16:32]r³ + m[32:48]r² + m[48:64]r
//
// To do this we split the calculation into the even indices
// and odd indices of the message. These form our SIMD 'lanes':
//
// h = m[ 0:16]r⁴ + m[32:48]r² + <- lane 0
// m[16:32]r³ + m[48:64]r <- lane 1
//
// To calculate this iteratively we refactor so that both lanes
// are written in terms of r² and r:
//
// h = (m[ 0:16]r² + m[32:48])r² + <- lane 0
// (m[16:32]r² + m[48:64])r <- lane 1
// ^ ^
// | coefficients for second iteration
// coefficients for first iteration
//
// So in this case we would have two iterations. In the first
// both lanes are multiplied by r². In the second only the
// first lane is multiplied by r² and the second lane is
// instead multiplied by r. This gives use the odd and even
// powers of r that we need from the original equation.
//
// Notation:
//
// h - accumulator
// r - key
// m - message
//
// [a, b] - SIMD register holding two 64-bit values
// [a, b, c, d] - SIMD register holding four 32-bit values
// xᵢ[n] - limb n of variable x with bit width i
//
// Limbs are expressed in little endian order, so for 26-bit
// limbs x₂₆[4] will be the most significant limb and x₂₆[0]
// will be the least significant limb.
// masking constants
#define MOD24 V0 // [0x0000000000ffffff, 0x0000000000ffffff] - mask low 24-bits
#define MOD26 V1 // [0x0000000003ffffff, 0x0000000003ffffff] - mask low 26-bits
// expansion constants (see EXPAND macro)
#define EX0 V2
#define EX1 V3
#define EX2 V4
// key (r², r or 1 depending on context)
#define R_0 V5
#define R_1 V6
#define R_2 V7
#define R_3 V8
#define R_4 V9
// precalculated coefficients (5r², 5r or 0 depending on context)
#define R5_1 V10
#define R5_2 V11
#define R5_3 V12
#define R5_4 V13
// message block (m)
#define M_0 V14
#define M_1 V15
#define M_2 V16
#define M_3 V17
#define M_4 V18
// accumulator (h)
#define H_0 V19
#define H_1 V20
#define H_2 V21
#define H_3 V22
#define H_4 V23
// temporary registers (for short-lived values)
#define T_0 V24
#define T_1 V25
#define T_2 V26
#define T_3 V27
#define T_4 V28
GLOBL ·constants<>(SB), RODATA, $0x30
// EX0
DATA ·constants<>+0x00(SB)/8, $0x0006050403020100
DATA ·constants<>+0x08(SB)/8, $0x1016151413121110
// EX1
DATA ·constants<>+0x10(SB)/8, $0x060c0b0a09080706
DATA ·constants<>+0x18(SB)/8, $0x161c1b1a19181716
// EX2
DATA ·constants<>+0x20(SB)/8, $0x0d0d0d0d0d0f0e0d
DATA ·constants<>+0x28(SB)/8, $0x1d1d1d1d1d1f1e1d
// MULTIPLY multiplies each lane of f and g, partially reduced
// modulo 2¹³⁰ - 5. The result, h, consists of partial products
// in each lane that need to be reduced further to produce the
// final result.
//
// h₁₃₀ = (f₁₃₀g₁₃₀) % 2¹³⁰ + (5f₁₃₀g₁₃₀) / 2¹³⁰
//
// Note that the multiplication by 5 of the high bits is
// achieved by precalculating the multiplication of four of the
// g coefficients by 5. These are g51-g54.
#define MULTIPLY(f0, f1, f2, f3, f4, g0, g1, g2, g3, g4, g51, g52, g53, g54, h0, h1, h2, h3, h4) \
VMLOF f0, g0, h0 \
VMLOF f0, g3, h3 \
VMLOF f0, g1, h1 \
VMLOF f0, g4, h4 \
VMLOF f0, g2, h2 \
VMLOF f1, g54, T_0 \
VMLOF f1, g2, T_3 \
VMLOF f1, g0, T_1 \
VMLOF f1, g3, T_4 \
VMLOF f1, g1, T_2 \
VMALOF f2, g53, h0, h0 \
VMALOF f2, g1, h3, h3 \
VMALOF f2, g54, h1, h1 \
VMALOF f2, g2, h4, h4 \
VMALOF f2, g0, h2, h2 \
VMALOF f3, g52, T_0, T_0 \
VMALOF f3, g0, T_3, T_3 \
VMALOF f3, g53, T_1, T_1 \
VMALOF f3, g1, T_4, T_4 \
VMALOF f3, g54, T_2, T_2 \
VMALOF f4, g51, h0, h0 \
VMALOF f4, g54, h3, h3 \
VMALOF f4, g52, h1, h1 \
VMALOF f4, g0, h4, h4 \
VMALOF f4, g53, h2, h2 \
VAG T_0, h0, h0 \
VAG T_3, h3, h3 \
VAG T_1, h1, h1 \
VAG T_4, h4, h4 \
VAG T_2, h2, h2
// REDUCE performs the following carry operations in four
// stages, as specified in Bernstein & Schwabe:
//
// 1: h₂₆[0]->h₂₆[1] h₂₆[3]->h₂₆[4]
// 2: h₂₆[1]->h₂₆[2] h₂₆[4]->h₂₆[0]
// 3: h₂₆[0]->h₂₆[1] h₂₆[2]->h₂₆[3]
// 4: h₂₆[3]->h₂₆[4]
//
// The result is that all of the limbs are limited to 26-bits
// except for h₂₆[1] and h₂₆[4] which are limited to 27-bits.
//
// Note that although each limb is aligned at 26-bit intervals
// they may contain values that exceed 2²⁶ - 1, hence the need
// to carry the excess bits in each limb.
#define REDUCE(h0, h1, h2, h3, h4) \
VESRLG $26, h0, T_0 \
VESRLG $26, h3, T_1 \
VN MOD26, h0, h0 \
VN MOD26, h3, h3 \
VAG T_0, h1, h1 \
VAG T_1, h4, h4 \
VESRLG $26, h1, T_2 \
VESRLG $26, h4, T_3 \
VN MOD26, h1, h1 \
VN MOD26, h4, h4 \
VESLG $2, T_3, T_4 \
VAG T_3, T_4, T_4 \
VAG T_2, h2, h2 \
VAG T_4, h0, h0 \
VESRLG $26, h2, T_0 \
VESRLG $26, h0, T_1 \
VN MOD26, h2, h2 \
VN MOD26, h0, h0 \
VAG T_0, h3, h3 \
VAG T_1, h1, h1 \
VESRLG $26, h3, T_2 \
VN MOD26, h3, h3 \
VAG T_2, h4, h4
// EXPAND splits the 128-bit little-endian values in0 and in1
// into 26-bit big-endian limbs and places the results into
// the first and second lane of d₂₆[0:4] respectively.
//
// The EX0, EX1 and EX2 constants are arrays of byte indices
// for permutation. The permutation both reverses the bytes
// in the input and ensures the bytes are copied into the
// destination limb ready to be shifted into their final
// position.
#define EXPAND(in0, in1, d0, d1, d2, d3, d4) \
VPERM in0, in1, EX0, d0 \
VPERM in0, in1, EX1, d2 \
VPERM in0, in1, EX2, d4 \
VESRLG $26, d0, d1 \
VESRLG $30, d2, d3 \
VESRLG $4, d2, d2 \
VN MOD26, d0, d0 \ // [in0₂₆[0], in1₂₆[0]]
VN MOD26, d3, d3 \ // [in0₂₆[3], in1₂₆[3]]
VN MOD26, d1, d1 \ // [in0₂₆[1], in1₂₆[1]]
VN MOD24, d4, d4 \ // [in0₂₆[4], in1₂₆[4]]
VN MOD26, d2, d2 // [in0₂₆[2], in1₂₆[2]]
// func updateVX(state *macState, msg []byte)
TEXT ·updateVX(SB), NOSPLIT, $0
MOVD state+0(FP), R1
LMG msg+8(FP), R2, R3 // R2=msg_base, R3=msg_len
// load EX0, EX1 and EX2
MOVD $·constants<>(SB), R5
VLM (R5), EX0, EX2
// generate masks
VGMG $(64-24), $63, MOD24 // [0x00ffffff, 0x00ffffff]
VGMG $(64-26), $63, MOD26 // [0x03ffffff, 0x03ffffff]
// load h (accumulator) and r (key) from state
VZERO T_1 // [0, 0]
VL 0(R1), T_0 // [h₆₄[0], h₆₄[1]]
VLEG $0, 16(R1), T_1 // [h₆₄[2], 0]
VL 24(R1), T_2 // [r₆₄[0], r₆₄[1]]
VPDI $0, T_0, T_2, T_3 // [h₆₄[0], r₆₄[0]]
VPDI $5, T_0, T_2, T_4 // [h₆₄[1], r₆₄[1]]
// unpack h and r into 26-bit limbs
// note: h₆₄[2] may have the low 3 bits set, so h₂₆[4] is a 27-bit value
VN MOD26, T_3, H_0 // [h₂₆[0], r₂₆[0]]
VZERO H_1 // [0, 0]
VZERO H_3 // [0, 0]
VGMG $(64-12-14), $(63-12), T_0 // [0x03fff000, 0x03fff000] - 26-bit mask with low 12 bits masked out
VESLG $24, T_1, T_1 // [h₆₄[2]<<24, 0]
VERIMG $-26&63, T_3, MOD26, H_1 // [h₂₆[1], r₂₆[1]]
VESRLG $+52&63, T_3, H_2 // [h₂₆[2], r₂₆[2]] - low 12 bits only
VERIMG $-14&63, T_4, MOD26, H_3 // [h₂₆[1], r₂₆[1]]
VESRLG $40, T_4, H_4 // [h₂₆[4], r₂₆[4]] - low 24 bits only
VERIMG $+12&63, T_4, T_0, H_2 // [h₂₆[2], r₂₆[2]] - complete
VO T_1, H_4, H_4 // [h₂₆[4], r₂₆[4]] - complete
// replicate r across all 4 vector elements
VREPF $3, H_0, R_0 // [r₂₆[0], r₂₆[0], r₂₆[0], r₂₆[0]]
VREPF $3, H_1, R_1 // [r₂₆[1], r₂₆[1], r₂₆[1], r₂₆[1]]
VREPF $3, H_2, R_2 // [r₂₆[2], r₂₆[2], r₂₆[2], r₂₆[2]]
VREPF $3, H_3, R_3 // [r₂₆[3], r₂₆[3], r₂₆[3], r₂₆[3]]
VREPF $3, H_4, R_4 // [r₂₆[4], r₂₆[4], r₂₆[4], r₂₆[4]]
// zero out lane 1 of h
VLEIG $1, $0, H_0 // [h₂₆[0], 0]
VLEIG $1, $0, H_1 // [h₂₆[1], 0]
VLEIG $1, $0, H_2 // [h₂₆[2], 0]
VLEIG $1, $0, H_3 // [h₂₆[3], 0]
VLEIG $1, $0, H_4 // [h₂₆[4], 0]
// calculate 5r (ignore least significant limb)
VREPIF $5, T_0
VMLF T_0, R_1, R5_1 // [5r₂₆[1], 5r₂₆[1], 5r₂₆[1], 5r₂₆[1]]
VMLF T_0, R_2, R5_2 // [5r₂₆[2], 5r₂₆[2], 5r₂₆[2], 5r₂₆[2]]
VMLF T_0, R_3, R5_3 // [5r₂₆[3], 5r₂₆[3], 5r₂₆[3], 5r₂₆[3]]
VMLF T_0, R_4, R5_4 // [5r₂₆[4], 5r₂₆[4], 5r₂₆[4], 5r₂₆[4]]
// skip r² calculation if we are only calculating one block
CMPBLE R3, $16, skip
// calculate r²
MULTIPLY(R_0, R_1, R_2, R_3, R_4, R_0, R_1, R_2, R_3, R_4, R5_1, R5_2, R5_3, R5_4, M_0, M_1, M_2, M_3, M_4)
REDUCE(M_0, M_1, M_2, M_3, M_4)
VGBM $0x0f0f, T_0
VERIMG $0, M_0, T_0, R_0 // [r₂₆[0], r²₂₆[0], r₂₆[0], r²₂₆[0]]
VERIMG $0, M_1, T_0, R_1 // [r₂₆[1], r²₂₆[1], r₂₆[1], r²₂₆[1]]
VERIMG $0, M_2, T_0, R_2 // [r₂₆[2], r²₂₆[2], r₂₆[2], r²₂₆[2]]
VERIMG $0, M_3, T_0, R_3 // [r₂₆[3], r²₂₆[3], r₂₆[3], r²₂₆[3]]
VERIMG $0, M_4, T_0, R_4 // [r₂₆[4], r²₂₆[4], r₂₆[4], r²₂₆[4]]
// calculate 5r² (ignore least significant limb)
VREPIF $5, T_0
VMLF T_0, R_1, R5_1 // [5r₂₆[1], 5r²₂₆[1], 5r₂₆[1], 5r²₂₆[1]]
VMLF T_0, R_2, R5_2 // [5r₂₆[2], 5r²₂₆[2], 5r₂₆[2], 5r²₂₆[2]]
VMLF T_0, R_3, R5_3 // [5r₂₆[3], 5r²₂₆[3], 5r₂₆[3], 5r²₂₆[3]]
VMLF T_0, R_4, R5_4 // [5r₂₆[4], 5r²₂₆[4], 5r₂₆[4], 5r²₂₆[4]]
loop:
CMPBLE R3, $32, b2 // 2 or fewer blocks remaining, need to change key coefficients
// load next 2 blocks from message
VLM (R2), T_0, T_1
// update message slice
SUB $32, R3
MOVD $32(R2), R2
// unpack message blocks into 26-bit big-endian limbs
EXPAND(T_0, T_1, M_0, M_1, M_2, M_3, M_4)
// add 2¹²⁸ to each message block value
VLEIB $4, $1, M_4
VLEIB $12, $1, M_4
multiply:
// accumulate the incoming message
VAG H_0, M_0, M_0
VAG H_3, M_3, M_3
VAG H_1, M_1, M_1
VAG H_4, M_4, M_4
VAG H_2, M_2, M_2
// multiply the accumulator by the key coefficient
MULTIPLY(M_0, M_1, M_2, M_3, M_4, R_0, R_1, R_2, R_3, R_4, R5_1, R5_2, R5_3, R5_4, H_0, H_1, H_2, H_3, H_4)
// carry and partially reduce the partial products
REDUCE(H_0, H_1, H_2, H_3, H_4)
CMPBNE R3, $0, loop
finish:
// sum lane 0 and lane 1 and put the result in lane 1
VZERO T_0
VSUMQG H_0, T_0, H_0
VSUMQG H_3, T_0, H_3
VSUMQG H_1, T_0, H_1
VSUMQG H_4, T_0, H_4
VSUMQG H_2, T_0, H_2
// reduce again after summation
// TODO(mundaym): there might be a more efficient way to do this
// now that we only have 1 active lane. For example, we could
// simultaneously pack the values as we reduce them.
REDUCE(H_0, H_1, H_2, H_3, H_4)
// carry h[1] through to h[4] so that only h[4] can exceed 2²⁶ - 1
// TODO(mundaym): in testing this final carry was unnecessary.
// Needs a proof before it can be removed though.
VESRLG $26, H_1, T_1
VN MOD26, H_1, H_1
VAQ T_1, H_2, H_2
VESRLG $26, H_2, T_2
VN MOD26, H_2, H_2
VAQ T_2, H_3, H_3
VESRLG $26, H_3, T_3
VN MOD26, H_3, H_3
VAQ T_3, H_4, H_4
// h is now < 2(2¹³⁰ - 5)
// Pack each lane in h₂₆[0:4] into h₁₂₈[0:1].
VESLG $26, H_1, H_1
VESLG $26, H_3, H_3
VO H_0, H_1, H_0
VO H_2, H_3, H_2
VESLG $4, H_2, H_2
VLEIB $7, $48, H_1
VSLB H_1, H_2, H_2
VO H_0, H_2, H_0
VLEIB $7, $104, H_1
VSLB H_1, H_4, H_3
VO H_3, H_0, H_0
VLEIB $7, $24, H_1
VSRLB H_1, H_4, H_1
// update state
VSTEG $1, H_0, 0(R1)
VSTEG $0, H_0, 8(R1)
VSTEG $1, H_1, 16(R1)
RET
b2: // 2 or fewer blocks remaining
CMPBLE R3, $16, b1
// Load the 2 remaining blocks (17-32 bytes remaining).
MOVD $-17(R3), R0 // index of final byte to load modulo 16
VL (R2), T_0 // load full 16 byte block
VLL R0, 16(R2), T_1 // load final (possibly partial) block and pad with zeros to 16 bytes
// The Poly1305 algorithm requires that a 1 bit be appended to
// each message block. If the final block is less than 16 bytes
// long then it is easiest to insert the 1 before the message
// block is split into 26-bit limbs. If, on the other hand, the
// final message block is 16 bytes long then we append the 1 bit
// after expansion as normal.
MOVBZ $1, R0
MOVD $-16(R3), R3 // index of byte in last block to insert 1 at (could be 16)
CMPBEQ R3, $16, 2(PC) // skip the insertion if the final block is 16 bytes long
VLVGB R3, R0, T_1 // insert 1 into the byte at index R3
// Split both blocks into 26-bit limbs in the appropriate lanes.
EXPAND(T_0, T_1, M_0, M_1, M_2, M_3, M_4)
// Append a 1 byte to the end of the second to last block.
VLEIB $4, $1, M_4
// Append a 1 byte to the end of the last block only if it is a
// full 16 byte block.
CMPBNE R3, $16, 2(PC)
VLEIB $12, $1, M_4
// Finally, set up the coefficients for the final multiplication.
// We have previously saved r and 5r in the 32-bit even indexes
// of the R_[0-4] and R5_[1-4] coefficient registers.
//
// We want lane 0 to be multiplied by r² so that can be kept the
// same. We want lane 1 to be multiplied by r so we need to move
// the saved r value into the 32-bit odd index in lane 1 by
// rotating the 64-bit lane by 32.
VGBM $0x00ff, T_0 // [0, 0xffffffffffffffff] - mask lane 1 only
VERIMG $32, R_0, T_0, R_0 // [_, r²₂₆[0], _, r₂₆[0]]
VERIMG $32, R_1, T_0, R_1 // [_, r²₂₆[1], _, r₂₆[1]]
VERIMG $32, R_2, T_0, R_2 // [_, r²₂₆[2], _, r₂₆[2]]
VERIMG $32, R_3, T_0, R_3 // [_, r²₂₆[3], _, r₂₆[3]]
VERIMG $32, R_4, T_0, R_4 // [_, r²₂₆[4], _, r₂₆[4]]
VERIMG $32, R5_1, T_0, R5_1 // [_, 5r²₂₆[1], _, 5r₂₆[1]]
VERIMG $32, R5_2, T_0, R5_2 // [_, 5r²₂₆[2], _, 5r₂₆[2]]
VERIMG $32, R5_3, T_0, R5_3 // [_, 5r²₂₆[3], _, 5r₂₆[3]]
VERIMG $32, R5_4, T_0, R5_4 // [_, 5r²₂₆[4], _, 5r₂₆[4]]
MOVD $0, R3
BR multiply
skip:
CMPBEQ R3, $0, finish
b1: // 1 block remaining
// Load the final block (1-16 bytes). This will be placed into
// lane 0.
MOVD $-1(R3), R0
VLL R0, (R2), T_0 // pad to 16 bytes with zeros
// The Poly1305 algorithm requires that a 1 bit be appended to
// each message block. If the final block is less than 16 bytes
// long then it is easiest to insert the 1 before the message
// block is split into 26-bit limbs. If, on the other hand, the
// final message block is 16 bytes long then we append the 1 bit
// after expansion as normal.
MOVBZ $1, R0
CMPBEQ R3, $16, 2(PC)
VLVGB R3, R0, T_0
// Set the message block in lane 1 to the value 0 so that it
// can be accumulated without affecting the final result.
VZERO T_1
// Split the final message block into 26-bit limbs in lane 0.
// Lane 1 will be contain 0.
EXPAND(T_0, T_1, M_0, M_1, M_2, M_3, M_4)
// Append a 1 byte to the end of the last block only if it is a
// full 16 byte block.
CMPBNE R3, $16, 2(PC)
VLEIB $4, $1, M_4
// We have previously saved r and 5r in the 32-bit even indexes
// of the R_[0-4] and R5_[1-4] coefficient registers.
//
// We want lane 0 to be multiplied by r so we need to move the
// saved r value into the 32-bit odd index in lane 0. We want
// lane 1 to be set to the value 1. This makes multiplication
// a no-op. We do this by setting lane 1 in every register to 0
// and then just setting the 32-bit index 3 in R_0 to 1.
VZERO T_0
MOVD $0, R0
MOVD $0x10111213, R12
VLVGP R12, R0, T_1 // [_, 0x10111213, _, 0x00000000]
VPERM T_0, R_0, T_1, R_0 // [_, r₂₆[0], _, 0]
VPERM T_0, R_1, T_1, R_1 // [_, r₂₆[1], _, 0]
VPERM T_0, R_2, T_1, R_2 // [_, r₂₆[2], _, 0]
VPERM T_0, R_3, T_1, R_3 // [_, r₂₆[3], _, 0]
VPERM T_0, R_4, T_1, R_4 // [_, r₂₆[4], _, 0]
VPERM T_0, R5_1, T_1, R5_1 // [_, 5r₂₆[1], _, 0]
VPERM T_0, R5_2, T_1, R5_2 // [_, 5r₂₆[2], _, 0]
VPERM T_0, R5_3, T_1, R5_3 // [_, 5r₂₆[3], _, 0]
VPERM T_0, R5_4, T_1, R5_4 // [_, 5r₂₆[4], _, 0]
// Set the value of lane 1 to be 1.
VLEIF $3, $1, R_0 // [_, r₂₆[0], _, 1]
MOVD $0, R3
BR multiply
================================================
FILE: gossh/crypto/internal/testenv/exec.go
================================================
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package testenv
import (
"context"
"os"
"os/exec"
"reflect"
"strconv"
"testing"
"time"
)
// CommandContext is like exec.CommandContext, but:
// - skips t if the platform does not support os/exec,
// - sends SIGQUIT (if supported by the platform) instead of SIGKILL
// in its Cancel function
// - if the test has a deadline, adds a Context timeout and WaitDelay
// for an arbitrary grace period before the test's deadline expires,
// - fails the test if the command does not complete before the test's deadline, and
// - sets a Cleanup function that verifies that the test did not leak a subprocess.
func CommandContext(t testing.TB, ctx context.Context, name string, args ...string) *exec.Cmd {
t.Helper()
var (
cancelCtx context.CancelFunc
gracePeriod time.Duration // unlimited unless the test has a deadline (to allow for interactive debugging)
)
if t, ok := t.(interface {
testing.TB
Deadline() (time.Time, bool)
}); ok {
if td, ok := t.Deadline(); ok {
// Start with a minimum grace period, just long enough to consume the
// output of a reasonable program after it terminates.
gracePeriod = 100 * time.Millisecond
if s := os.Getenv("GO_TEST_TIMEOUT_SCALE"); s != "" {
scale, err := strconv.Atoi(s)
if err != nil {
t.Fatalf("invalid GO_TEST_TIMEOUT_SCALE: %v", err)
}
gracePeriod *= time.Duration(scale)
}
// If time allows, increase the termination grace period to 5% of the
// test's remaining time.
testTimeout := time.Until(td)
if gp := testTimeout / 20; gp > gracePeriod {
gracePeriod = gp
}
// When we run commands that execute subprocesses, we want to reserve two
// grace periods to clean up: one for the delay between the first
// termination signal being sent (via the Cancel callback when the Context
// expires) and the process being forcibly terminated (via the WaitDelay
// field), and a second one for the delay becween the process being
// terminated and and the test logging its output for debugging.
//
// (We want to ensure that the test process itself has enough time to
// log the output before it is also terminated.)
cmdTimeout := testTimeout - 2*gracePeriod
if cd, ok := ctx.Deadline(); !ok || time.Until(cd) > cmdTimeout {
// Either ctx doesn't have a deadline, or its deadline would expire
// after (or too close before) the test has already timed out.
// Add a shorter timeout so that the test will produce useful output.
ctx, cancelCtx = context.WithTimeout(ctx, cmdTimeout)
}
}
}
cmd := exec.CommandContext(ctx, name, args...)
// Set the Cancel and WaitDelay fields only if present (go 1.20 and later).
// TODO: When Go 1.19 is no longer supported, remove this use of reflection
// and instead set the fields directly.
if cmdCancel := reflect.ValueOf(cmd).Elem().FieldByName("Cancel"); cmdCancel.IsValid() {
cmdCancel.Set(reflect.ValueOf(func() error {
if cancelCtx != nil && ctx.Err() == context.DeadlineExceeded {
// The command timed out due to running too close to the test's deadline.
// There is no way the test did that intentionally — it's too close to the
// wire! — so mark it as a test failure. That way, if the test expects the
// command to fail for some other reason, it doesn't have to distinguish
// between that reason and a timeout.
t.Errorf("test timed out while running command: %v", cmd)
} else {
// The command is being terminated due to ctx being canceled, but
// apparently not due to an explicit test deadline that we added.
// Log that information in case it is useful for diagnosing a failure,
// but don't actually fail the test because of it.
t.Logf("%v: terminating command: %v", ctx.Err(), cmd)
}
return cmd.Process.Signal(Sigquit)
}))
}
if cmdWaitDelay := reflect.ValueOf(cmd).Elem().FieldByName("WaitDelay"); cmdWaitDelay.IsValid() {
cmdWaitDelay.Set(reflect.ValueOf(gracePeriod))
}
t.Cleanup(func() {
if cancelCtx != nil {
cancelCtx()
}
if cmd.Process != nil && cmd.ProcessState == nil {
t.Errorf("command was started, but test did not wait for it to complete: %v", cmd)
}
})
return cmd
}
// Command is like exec.Command, but applies the same changes as
// testenv.CommandContext (with a default Context).
func Command(t testing.TB, name string, args ...string) *exec.Cmd {
t.Helper()
return CommandContext(t, context.Background(), name, args...)
}
================================================
FILE: gossh/crypto/internal/testenv/testenv_notunix.go
================================================
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build windows || plan9 || (js && wasm) || wasip1
package testenv
import (
"os"
)
// Sigquit is the signal to send to kill a hanging subprocess.
// On Unix we send SIGQUIT, but on non-Unix we only have os.Kill.
var Sigquit = os.Kill
================================================
FILE: gossh/crypto/internal/testenv/testenv_unix.go
================================================
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix
package testenv
import (
"syscall"
)
// Sigquit is the signal to send to kill a hanging subprocess.
// Send SIGQUIT to get a stack trace.
var Sigquit = syscall.SIGQUIT
================================================
FILE: gossh/crypto/ssh/agent/client.go
================================================
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package agent implements the ssh-agent protocol, and provides both
// a client and a server. The client can talk to a standard ssh-agent
// that uses UNIX sockets, and one could implement an alternative
// ssh-agent process using the sample server.
//
// References:
//
// [PROTOCOL.agent]: https://tools.ietf.org/html/draft-miller-ssh-agent-00
package agent // import "golang.org/x/crypto/ssh/agent"
import (
"bytes"
"crypto/dsa"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rsa"
"encoding/base64"
"encoding/binary"
"errors"
"fmt"
"io"
"math/big"
"sync"
"gossh/crypto/ssh"
)
// SignatureFlags represent additional flags that can be passed to the signature
// requests an defined in [PROTOCOL.agent] section 4.5.1.
type SignatureFlags uint32
// SignatureFlag values as defined in [PROTOCOL.agent] section 5.3.
const (
SignatureFlagReserved SignatureFlags = 1 << iota
SignatureFlagRsaSha256
SignatureFlagRsaSha512
)
// Agent represents the capabilities of an ssh-agent.
type Agent interface {
// List returns the identities known to the agent.
List() ([]*Key, error)
// Sign has the agent sign the data using a protocol 2 key as defined
// in [PROTOCOL.agent] section 2.6.2.
Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error)
// Add adds a private key to the agent.
Add(key AddedKey) error
// Remove removes all identities with the given public key.
Remove(key ssh.PublicKey) error
// RemoveAll removes all identities.
RemoveAll() error
// Lock locks the agent. Sign and Remove will fail, and List will empty an empty list.
Lock(passphrase []byte) error
// Unlock undoes the effect of Lock
Unlock(passphrase []byte) error
// Signers returns signers for all the known keys.
Signers() ([]ssh.Signer, error)
}
type ExtendedAgent interface {
Agent
// SignWithFlags signs like Sign, but allows for additional flags to be sent/received
SignWithFlags(key ssh.PublicKey, data []byte, flags SignatureFlags) (*ssh.Signature, error)
// Extension processes a custom extension request. Standard-compliant agents are not
// required to support any extensions, but this method allows agents to implement
// vendor-specific methods or add experimental features. See [PROTOCOL.agent] section 4.7.
// If agent extensions are unsupported entirely this method MUST return an
// ErrExtensionUnsupported error. Similarly, if just the specific extensionType in
// the request is unsupported by the agent then ErrExtensionUnsupported MUST be
// returned.
//
// In the case of success, since [PROTOCOL.agent] section 4.7 specifies that the contents
// of the response are unspecified (including the type of the message), the complete
// response will be returned as a []byte slice, including the "type" byte of the message.
Extension(extensionType string, contents []byte) ([]byte, error)
}
// ConstraintExtension describes an optional constraint defined by users.
type ConstraintExtension struct {
// ExtensionName consist of a UTF-8 string suffixed by the
// implementation domain following the naming scheme defined
// in Section 4.2 of RFC 4251, e.g. "foo@example.com".
ExtensionName string
// ExtensionDetails contains the actual content of the extended
// constraint.
ExtensionDetails []byte
}
// AddedKey describes an SSH key to be added to an Agent.
type AddedKey struct {
// PrivateKey must be a *rsa.PrivateKey, *dsa.PrivateKey,
// ed25519.PrivateKey or *ecdsa.PrivateKey, which will be inserted into the
// agent.
PrivateKey any
// Certificate, if not nil, is communicated to the agent and will be
// stored with the key.
Certificate *ssh.Certificate
// Comment is an optional, free-form string.
Comment string
// LifetimeSecs, if not zero, is the number of seconds that the
// agent will store the key for.
LifetimeSecs uint32
// ConfirmBeforeUse, if true, requests that the agent confirm with the
// user before each use of this key.
ConfirmBeforeUse bool
// ConstraintExtensions are the experimental or private-use constraints
// defined by users.
ConstraintExtensions []ConstraintExtension
}
// See [PROTOCOL.agent], section 3.
const (
agentRequestV1Identities = 1
agentRemoveAllV1Identities = 9
// 3.2 Requests from client to agent for protocol 2 key operations
agentAddIdentity = 17
agentRemoveIdentity = 18
agentRemoveAllIdentities = 19
agentAddIDConstrained = 25
// 3.3 Key-type independent requests from client to agent
agentAddSmartcardKey = 20
agentRemoveSmartcardKey = 21
agentLock = 22
agentUnlock = 23
agentAddSmartcardKeyConstrained = 26
// 3.7 Key constraint identifiers
agentConstrainLifetime = 1
agentConstrainConfirm = 2
// Constraint extension identifier up to version 2 of the protocol. A
// backward incompatible change will be required if we want to add support
// for SSH_AGENT_CONSTRAIN_MAXSIGN which uses the same ID.
agentConstrainExtensionV00 = 3
// Constraint extension identifier in version 3 and later of the protocol.
agentConstrainExtension = 255
)
// maxAgentResponseBytes is the maximum agent reply size that is accepted. This
// is a sanity check, not a limit in the spec.
const maxAgentResponseBytes = 16 << 20
// Agent messages:
// These structures mirror the wire format of the corresponding ssh agent
// messages found in [PROTOCOL.agent].
// 3.4 Generic replies from agent to client
const agentFailure = 5
type failureAgentMsg struct{}
const agentSuccess = 6
type successAgentMsg struct{}
// See [PROTOCOL.agent], section 2.5.2.
const agentRequestIdentities = 11
type requestIdentitiesAgentMsg struct{}
// See [PROTOCOL.agent], section 2.5.2.
const agentIdentitiesAnswer = 12
type identitiesAnswerAgentMsg struct {
NumKeys uint32 `sshtype:"12"`
Keys []byte `ssh:"rest"`
}
// See [PROTOCOL.agent], section 2.6.2.
const agentSignRequest = 13
type signRequestAgentMsg struct {
KeyBlob []byte `sshtype:"13"`
Data []byte
Flags uint32
}
// See [PROTOCOL.agent], section 2.6.2.
// 3.6 Replies from agent to client for protocol 2 key operations
const agentSignResponse = 14
type signResponseAgentMsg struct {
SigBlob []byte `sshtype:"14"`
}
type publicKey struct {
Format string
Rest []byte `ssh:"rest"`
}
// 3.7 Key constraint identifiers
type constrainLifetimeAgentMsg struct {
LifetimeSecs uint32 `sshtype:"1"`
}
type constrainExtensionAgentMsg struct {
ExtensionName string `sshtype:"255|3"`
ExtensionDetails []byte
// Rest is a field used for parsing, not part of message
Rest []byte `ssh:"rest"`
}
// See [PROTOCOL.agent], section 4.7
const agentExtension = 27
const agentExtensionFailure = 28
// ErrExtensionUnsupported indicates that an extension defined in
// [PROTOCOL.agent] section 4.7 is unsupported by the agent. Specifically this
// error indicates that the agent returned a standard SSH_AGENT_FAILURE message
// as the result of a SSH_AGENTC_EXTENSION request. Note that the protocol
// specification (and therefore this error) does not distinguish between a
// specific extension being unsupported and extensions being unsupported entirely.
var ErrExtensionUnsupported = errors.New("agent: extension unsupported")
type extensionAgentMsg struct {
ExtensionType string `sshtype:"27"`
// NOTE: this matches OpenSSH's PROTOCOL.agent, not the IETF draft [PROTOCOL.agent],
// so that it matches what OpenSSH actually implements in the wild.
Contents []byte `ssh:"rest"`
}
// Key represents a protocol 2 public key as defined in
// [PROTOCOL.agent], section 2.5.2.
type Key struct {
Format string
Blob []byte
Comment string
}
func clientErr(err error) error {
return fmt.Errorf("agent: client error: %v", err)
}
// String returns the storage form of an agent key with the format, base64
// encoded serialized key, and the comment if it is not empty.
func (k *Key) String() string {
s := string(k.Format) + " " + base64.StdEncoding.EncodeToString(k.Blob)
if k.Comment != "" {
s += " " + k.Comment
}
return s
}
// Type returns the public key type.
func (k *Key) Type() string {
return k.Format
}
// Marshal returns key blob to satisfy the ssh.PublicKey interface.
func (k *Key) Marshal() []byte {
return k.Blob
}
// Verify satisfies the ssh.PublicKey interface.
func (k *Key) Verify(data []byte, sig *ssh.Signature) error {
pubKey, err := ssh.ParsePublicKey(k.Blob)
if err != nil {
return fmt.Errorf("agent: bad public key: %v", err)
}
return pubKey.Verify(data, sig)
}
type wireKey struct {
Format string
Rest []byte `ssh:"rest"`
}
func parseKey(in []byte) (out *Key, rest []byte, err error) {
var record struct {
Blob []byte
Comment string
Rest []byte `ssh:"rest"`
}
if err := ssh.Unmarshal(in, &record); err != nil {
return nil, nil, err
}
var wk wireKey
if err := ssh.Unmarshal(record.Blob, &wk); err != nil {
return nil, nil, err
}
return &Key{
Format: wk.Format,
Blob: record.Blob,
Comment: record.Comment,
}, record.Rest, nil
}
// client is a client for an ssh-agent process.
type client struct {
// conn is typically a *net.UnixConn
conn io.ReadWriter
// mu is used to prevent concurrent access to the agent
mu sync.Mutex
}
// NewClient returns an Agent that talks to an ssh-agent process over
// the given connection.
func NewClient(rw io.ReadWriter) ExtendedAgent {
return &client{conn: rw}
}
// call sends an RPC to the agent. On success, the reply is
// unmarshaled into reply and replyType is set to the first byte of
// the reply, which contains the type of the message.
func (c *client) call(req []byte) (reply any, err error) {
buf, err := c.callRaw(req)
if err != nil {
return nil, err
}
reply, err = unmarshal(buf)
if err != nil {
return nil, clientErr(err)
}
return reply, nil
}
// callRaw sends an RPC to the agent. On success, the raw
// bytes of the response are returned; no unmarshalling is
// performed on the response.
func (c *client) callRaw(req []byte) (reply []byte, err error) {
c.mu.Lock()
defer c.mu.Unlock()
msg := make([]byte, 4+len(req))
binary.BigEndian.PutUint32(msg, uint32(len(req)))
copy(msg[4:], req)
if _, err = c.conn.Write(msg); err != nil {
return nil, clientErr(err)
}
var respSizeBuf [4]byte
if _, err = io.ReadFull(c.conn, respSizeBuf[:]); err != nil {
return nil, clientErr(err)
}
respSize := binary.BigEndian.Uint32(respSizeBuf[:])
if respSize > maxAgentResponseBytes {
return nil, clientErr(errors.New("response too large"))
}
buf := make([]byte, respSize)
if _, err = io.ReadFull(c.conn, buf); err != nil {
return nil, clientErr(err)
}
return buf, nil
}
func (c *client) simpleCall(req []byte) error {
resp, err := c.call(req)
if err != nil {
return err
}
if _, ok := resp.(*successAgentMsg); ok {
return nil
}
return errors.New("agent: failure")
}
func (c *client) RemoveAll() error {
return c.simpleCall([]byte{agentRemoveAllIdentities})
}
func (c *client) Remove(key ssh.PublicKey) error {
req := ssh.Marshal(&agentRemoveIdentityMsg{
KeyBlob: key.Marshal(),
})
return c.simpleCall(req)
}
func (c *client) Lock(passphrase []byte) error {
req := ssh.Marshal(&agentLockMsg{
Passphrase: passphrase,
})
return c.simpleCall(req)
}
func (c *client) Unlock(passphrase []byte) error {
req := ssh.Marshal(&agentUnlockMsg{
Passphrase: passphrase,
})
return c.simpleCall(req)
}
// List returns the identities known to the agent.
func (c *client) List() ([]*Key, error) {
// see [PROTOCOL.agent] section 2.5.2.
req := []byte{agentRequestIdentities}
msg, err := c.call(req)
if err != nil {
return nil, err
}
switch msg := msg.(type) {
case *identitiesAnswerAgentMsg:
if msg.NumKeys > maxAgentResponseBytes/8 {
return nil, errors.New("agent: too many keys in agent reply")
}
keys := make([]*Key, msg.NumKeys)
data := msg.Keys
for i := uint32(0); i < msg.NumKeys; i++ {
var key *Key
var err error
if key, data, err = parseKey(data); err != nil {
return nil, err
}
keys[i] = key
}
return keys, nil
case *failureAgentMsg:
return nil, errors.New("agent: failed to list keys")
}
panic("unreachable")
}
// Sign has the agent sign the data using a protocol 2 key as defined
// in [PROTOCOL.agent] section 2.6.2.
func (c *client) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) {
return c.SignWithFlags(key, data, 0)
}
func (c *client) SignWithFlags(key ssh.PublicKey, data []byte, flags SignatureFlags) (*ssh.Signature, error) {
req := ssh.Marshal(signRequestAgentMsg{
KeyBlob: key.Marshal(),
Data: data,
Flags: uint32(flags),
})
msg, err := c.call(req)
if err != nil {
return nil, err
}
switch msg := msg.(type) {
case *signResponseAgentMsg:
var sig ssh.Signature
if err := ssh.Unmarshal(msg.SigBlob, &sig); err != nil {
return nil, err
}
return &sig, nil
case *failureAgentMsg:
return nil, errors.New("agent: failed to sign challenge")
}
panic("unreachable")
}
// unmarshal parses an agent message in packet, returning the parsed
// form and the message type of packet.
func unmarshal(packet []byte) (any, error) {
if len(packet) < 1 {
return nil, errors.New("agent: empty packet")
}
var msg any
switch packet[0] {
case agentFailure:
return new(failureAgentMsg), nil
case agentSuccess:
return new(successAgentMsg), nil
case agentIdentitiesAnswer:
msg = new(identitiesAnswerAgentMsg)
case agentSignResponse:
msg = new(signResponseAgentMsg)
case agentV1IdentitiesAnswer:
msg = new(agentV1IdentityMsg)
default:
return nil, fmt.Errorf("agent: unknown type tag %d", packet[0])
}
if err := ssh.Unmarshal(packet, msg); err != nil {
return nil, err
}
return msg, nil
}
type rsaKeyMsg struct {
Type string `sshtype:"17|25"`
N *big.Int
E *big.Int
D *big.Int
Iqmp *big.Int // IQMP = Inverse Q Mod P
P *big.Int
Q *big.Int
Comments string
Constraints []byte `ssh:"rest"`
}
type dsaKeyMsg struct {
Type string `sshtype:"17|25"`
P *big.Int
Q *big.Int
G *big.Int
Y *big.Int
X *big.Int
Comments string
Constraints []byte `ssh:"rest"`
}
type ecdsaKeyMsg struct {
Type string `sshtype:"17|25"`
Curve string
KeyBytes []byte
D *big.Int
Comments string
Constraints []byte `ssh:"rest"`
}
type ed25519KeyMsg struct {
Type string `sshtype:"17|25"`
Pub []byte
Priv []byte
Comments string
Constraints []byte `ssh:"rest"`
}
// Insert adds a private key to the agent.
func (c *client) insertKey(s any, comment string, constraints []byte) error {
var req []byte
switch k := s.(type) {
case *rsa.PrivateKey:
if len(k.Primes) != 2 {
return fmt.Errorf("agent: unsupported RSA key with %d primes", len(k.Primes))
}
k.Precompute()
req = ssh.Marshal(rsaKeyMsg{
Type: ssh.KeyAlgoRSA,
N: k.N,
E: big.NewInt(int64(k.E)),
D: k.D,
Iqmp: k.Precomputed.Qinv,
P: k.Primes[0],
Q: k.Primes[1],
Comments: comment,
Constraints: constraints,
})
case *dsa.PrivateKey:
req = ssh.Marshal(dsaKeyMsg{
Type: ssh.KeyAlgoDSA,
P: k.P,
Q: k.Q,
G: k.G,
Y: k.Y,
X: k.X,
Comments: comment,
Constraints: constraints,
})
case *ecdsa.PrivateKey:
nistID := fmt.Sprintf("nistp%d", k.Params().BitSize)
req = ssh.Marshal(ecdsaKeyMsg{
Type: "ecdsa-sha2-" + nistID,
Curve: nistID,
KeyBytes: elliptic.Marshal(k.Curve, k.X, k.Y),
D: k.D,
Comments: comment,
Constraints: constraints,
})
case ed25519.PrivateKey:
req = ssh.Marshal(ed25519KeyMsg{
Type: ssh.KeyAlgoED25519,
Pub: []byte(k)[32:],
Priv: []byte(k),
Comments: comment,
Constraints: constraints,
})
// This function originally supported only *ed25519.PrivateKey, however the
// general idiom is to pass ed25519.PrivateKey by value, not by pointer.
// We still support the pointer variant for backwards compatibility.
case *ed25519.PrivateKey:
req = ssh.Marshal(ed25519KeyMsg{
Type: ssh.KeyAlgoED25519,
Pub: []byte(*k)[32:],
Priv: []byte(*k),
Comments: comment,
Constraints: constraints,
})
default:
return fmt.Errorf("agent: unsupported key type %T", s)
}
// if constraints are present then the message type needs to be changed.
if len(constraints) != 0 {
req[0] = agentAddIDConstrained
}
resp, err := c.call(req)
if err != nil {
return err
}
if _, ok := resp.(*successAgentMsg); ok {
return nil
}
return errors.New("agent: failure")
}
type rsaCertMsg struct {
Type string `sshtype:"17|25"`
CertBytes []byte
D *big.Int
Iqmp *big.Int // IQMP = Inverse Q Mod P
P *big.Int
Q *big.Int
Comments string
Constraints []byte `ssh:"rest"`
}
type dsaCertMsg struct {
Type string `sshtype:"17|25"`
CertBytes []byte
X *big.Int
Comments string
Constraints []byte `ssh:"rest"`
}
type ecdsaCertMsg struct {
Type string `sshtype:"17|25"`
CertBytes []byte
D *big.Int
Comments string
Constraints []byte `ssh:"rest"`
}
type ed25519CertMsg struct {
Type string `sshtype:"17|25"`
CertBytes []byte
Pub []byte
Priv []byte
Comments string
Constraints []byte `ssh:"rest"`
}
// Add adds a private key to the agent. If a certificate is given,
// that certificate is added instead as public key.
func (c *client) Add(key AddedKey) error {
var constraints []byte
if secs := key.LifetimeSecs; secs != 0 {
constraints = append(constraints, ssh.Marshal(constrainLifetimeAgentMsg{secs})...)
}
if key.ConfirmBeforeUse {
constraints = append(constraints, agentConstrainConfirm)
}
cert := key.Certificate
if cert == nil {
return c.insertKey(key.PrivateKey, key.Comment, constraints)
}
return c.insertCert(key.PrivateKey, cert, key.Comment, constraints)
}
func (c *client) insertCert(s any, cert *ssh.Certificate, comment string, constraints []byte) error {
var req []byte
switch k := s.(type) {
case *rsa.PrivateKey:
if len(k.Primes) != 2 {
return fmt.Errorf("agent: unsupported RSA key with %d primes", len(k.Primes))
}
k.Precompute()
req = ssh.Marshal(rsaCertMsg{
Type: cert.Type(),
CertBytes: cert.Marshal(),
D: k.D,
Iqmp: k.Precomputed.Qinv,
P: k.Primes[0],
Q: k.Primes[1],
Comments: comment,
Constraints: constraints,
})
case *dsa.PrivateKey:
req = ssh.Marshal(dsaCertMsg{
Type: cert.Type(),
CertBytes: cert.Marshal(),
X: k.X,
Comments: comment,
Constraints: constraints,
})
case *ecdsa.PrivateKey:
req = ssh.Marshal(ecdsaCertMsg{
Type: cert.Type(),
CertBytes: cert.Marshal(),
D: k.D,
Comments: comment,
Constraints: constraints,
})
case ed25519.PrivateKey:
req = ssh.Marshal(ed25519CertMsg{
Type: cert.Type(),
CertBytes: cert.Marshal(),
Pub: []byte(k)[32:],
Priv: []byte(k),
Comments: comment,
Constraints: constraints,
})
// This function originally supported only *ed25519.PrivateKey, however the
// general idiom is to pass ed25519.PrivateKey by value, not by pointer.
// We still support the pointer variant for backwards compatibility.
case *ed25519.PrivateKey:
req = ssh.Marshal(ed25519CertMsg{
Type: cert.Type(),
CertBytes: cert.Marshal(),
Pub: []byte(*k)[32:],
Priv: []byte(*k),
Comments: comment,
Constraints: constraints,
})
default:
return fmt.Errorf("agent: unsupported key type %T", s)
}
// if constraints are present then the message type needs to be changed.
if len(constraints) != 0 {
req[0] = agentAddIDConstrained
}
signer, err := ssh.NewSignerFromKey(s)
if err != nil {
return err
}
if !bytes.Equal(cert.Key.Marshal(), signer.PublicKey().Marshal()) {
return errors.New("agent: signer and cert have different public key")
}
resp, err := c.call(req)
if err != nil {
return err
}
if _, ok := resp.(*successAgentMsg); ok {
return nil
}
return errors.New("agent: failure")
}
// Signers provides a callback for client authentication.
func (c *client) Signers() ([]ssh.Signer, error) {
keys, err := c.List()
if err != nil {
return nil, err
}
var result []ssh.Signer
for _, k := range keys {
result = append(result, &agentKeyringSigner{c, k})
}
return result, nil
}
type agentKeyringSigner struct {
agent *client
pub ssh.PublicKey
}
func (s *agentKeyringSigner) PublicKey() ssh.PublicKey {
return s.pub
}
func (s *agentKeyringSigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) {
// The agent has its own entropy source, so the rand argument is ignored.
return s.agent.Sign(s.pub, data)
}
func (s *agentKeyringSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*ssh.Signature, error) {
if algorithm == "" || algorithm == underlyingAlgo(s.pub.Type()) {
return s.Sign(rand, data)
}
var flags SignatureFlags
switch algorithm {
case ssh.KeyAlgoRSASHA256:
flags = SignatureFlagRsaSha256
case ssh.KeyAlgoRSASHA512:
flags = SignatureFlagRsaSha512
default:
return nil, fmt.Errorf("agent: unsupported algorithm %q", algorithm)
}
return s.agent.SignWithFlags(s.pub, data, flags)
}
var _ ssh.AlgorithmSigner = &agentKeyringSigner{}
// certKeyAlgoNames is a mapping from known certificate algorithm names to the
// corresponding public key signature algorithm.
//
// This map must be kept in sync with the one in certs.go.
var certKeyAlgoNames = map[string]string{
ssh.CertAlgoRSAv01: ssh.KeyAlgoRSA,
ssh.CertAlgoRSASHA256v01: ssh.KeyAlgoRSASHA256,
ssh.CertAlgoRSASHA512v01: ssh.KeyAlgoRSASHA512,
ssh.CertAlgoDSAv01: ssh.KeyAlgoDSA,
ssh.CertAlgoECDSA256v01: ssh.KeyAlgoECDSA256,
ssh.CertAlgoECDSA384v01: ssh.KeyAlgoECDSA384,
ssh.CertAlgoECDSA521v01: ssh.KeyAlgoECDSA521,
ssh.CertAlgoSKECDSA256v01: ssh.KeyAlgoSKECDSA256,
ssh.CertAlgoED25519v01: ssh.KeyAlgoED25519,
ssh.CertAlgoSKED25519v01: ssh.KeyAlgoSKED25519,
}
// underlyingAlgo returns the signature algorithm associated with algo (which is
// an advertised or negotiated public key or host key algorithm). These are
// usually the same, except for certificate algorithms.
func underlyingAlgo(algo string) string {
if a, ok := certKeyAlgoNames[algo]; ok {
return a
}
return algo
}
// Calls an extension method. It is up to the agent implementation as to whether or not
// any particular extension is supported and may always return an error. Because the
// type of the response is up to the implementation, this returns the bytes of the
// response and does not attempt any type of unmarshalling.
func (c *client) Extension(extensionType string, contents []byte) ([]byte, error) {
req := ssh.Marshal(extensionAgentMsg{
ExtensionType: extensionType,
Contents: contents,
})
buf, err := c.callRaw(req)
if err != nil {
return nil, err
}
if len(buf) == 0 {
return nil, errors.New("agent: failure; empty response")
}
// [PROTOCOL.agent] section 4.7 indicates that an SSH_AGENT_FAILURE message
// represents an agent that does not support the extension
if buf[0] == agentFailure {
return nil, ErrExtensionUnsupported
}
if buf[0] == agentExtensionFailure {
return nil, errors.New("agent: generic extension failure")
}
return buf, nil
}
================================================
FILE: gossh/crypto/ssh/agent/forward.go
================================================
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package agent
import (
"errors"
"io"
"net"
"sync"
"gossh/crypto/ssh"
)
// RequestAgentForwarding sets up agent forwarding for the session.
// ForwardToAgent or ForwardToRemote should be called to route
// the authentication requests.
func RequestAgentForwarding(session *ssh.Session) error {
ok, err := session.SendRequest("auth-agent-req@openssh.com", true, nil)
if err != nil {
return err
}
if !ok {
return errors.New("forwarding request denied")
}
return nil
}
// ForwardToAgent routes authentication requests to the given keyring.
func ForwardToAgent(client *ssh.Client, keyring Agent) error {
channels := client.HandleChannelOpen(channelType)
if channels == nil {
return errors.New("agent: already have handler for " + channelType)
}
go func() {
for ch := range channels {
channel, reqs, err := ch.Accept()
if err != nil {
continue
}
go ssh.DiscardRequests(reqs)
go func() {
ServeAgent(keyring, channel)
channel.Close()
}()
}
}()
return nil
}
const channelType = "auth-agent@openssh.com"
// ForwardToRemote routes authentication requests to the ssh-agent
// process serving on the given unix socket.
func ForwardToRemote(client *ssh.Client, addr string) error {
channels := client.HandleChannelOpen(channelType)
if channels == nil {
return errors.New("agent: already have handler for " + channelType)
}
conn, err := net.Dial("unix", addr)
if err != nil {
return err
}
conn.Close()
go func() {
for ch := range channels {
channel, reqs, err := ch.Accept()
if err != nil {
continue
}
go ssh.DiscardRequests(reqs)
go forwardUnixSocket(channel, addr)
}
}()
return nil
}
func forwardUnixSocket(channel ssh.Channel, addr string) {
conn, err := net.Dial("unix", addr)
if err != nil {
return
}
var wg sync.WaitGroup
wg.Add(2)
go func() {
io.Copy(conn, channel)
conn.(*net.UnixConn).CloseWrite()
wg.Done()
}()
go func() {
io.Copy(channel, conn)
channel.CloseWrite()
wg.Done()
}()
wg.Wait()
conn.Close()
channel.Close()
}
================================================
FILE: gossh/crypto/ssh/agent/keyring.go
================================================
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package agent
import (
"bytes"
"crypto/rand"
"crypto/subtle"
"errors"
"fmt"
"sync"
"time"
"gossh/crypto/ssh"
)
type privKey struct {
signer ssh.Signer
comment string
expire *time.Time
}
type keyring struct {
mu sync.Mutex
keys []privKey
locked bool
passphrase []byte
}
var errLocked = errors.New("agent: locked")
// NewKeyring returns an Agent that holds keys in memory. It is safe
// for concurrent use by multiple goroutines.
func NewKeyring() Agent {
return &keyring{}
}
// RemoveAll removes all identities.
func (r *keyring) RemoveAll() error {
r.mu.Lock()
defer r.mu.Unlock()
if r.locked {
return errLocked
}
r.keys = nil
return nil
}
// removeLocked does the actual key removal. The caller must already be holding the
// keyring mutex.
func (r *keyring) removeLocked(want []byte) error {
found := false
for i := 0; i < len(r.keys); {
if bytes.Equal(r.keys[i].signer.PublicKey().Marshal(), want) {
found = true
r.keys[i] = r.keys[len(r.keys)-1]
r.keys = r.keys[:len(r.keys)-1]
continue
} else {
i++
}
}
if !found {
return errors.New("agent: key not found")
}
return nil
}
// Remove removes all identities with the given public key.
func (r *keyring) Remove(key ssh.PublicKey) error {
r.mu.Lock()
defer r.mu.Unlock()
if r.locked {
return errLocked
}
return r.removeLocked(key.Marshal())
}
// Lock locks the agent. Sign and Remove will fail, and List will return an empty list.
func (r *keyring) Lock(passphrase []byte) error {
r.mu.Lock()
defer r.mu.Unlock()
if r.locked {
return errLocked
}
r.locked = true
r.passphrase = passphrase
return nil
}
// Unlock undoes the effect of Lock
func (r *keyring) Unlock(passphrase []byte) error {
r.mu.Lock()
defer r.mu.Unlock()
if !r.locked {
return errors.New("agent: not locked")
}
if 1 != subtle.ConstantTimeCompare(passphrase, r.passphrase) {
return fmt.Errorf("agent: incorrect passphrase")
}
r.locked = false
r.passphrase = nil
return nil
}
// expireKeysLocked removes expired keys from the keyring. If a key was added
// with a lifetimesecs contraint and seconds >= lifetimesecs seconds have
// elapsed, it is removed. The caller *must* be holding the keyring mutex.
func (r *keyring) expireKeysLocked() {
for _, k := range r.keys {
if k.expire != nil && time.Now().After(*k.expire) {
r.removeLocked(k.signer.PublicKey().Marshal())
}
}
}
// List returns the identities known to the agent.
func (r *keyring) List() ([]*Key, error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.locked {
// section 2.7: locked agents return empty.
return nil, nil
}
r.expireKeysLocked()
var ids []*Key
for _, k := range r.keys {
pub := k.signer.PublicKey()
ids = append(ids, &Key{
Format: pub.Type(),
Blob: pub.Marshal(),
Comment: k.comment})
}
return ids, nil
}
// Insert adds a private key to the keyring. If a certificate
// is given, that certificate is added as public key. Note that
// any constraints given are ignored.
func (r *keyring) Add(key AddedKey) error {
r.mu.Lock()
defer r.mu.Unlock()
if r.locked {
return errLocked
}
signer, err := ssh.NewSignerFromKey(key.PrivateKey)
if err != nil {
return err
}
if cert := key.Certificate; cert != nil {
signer, err = ssh.NewCertSigner(cert, signer)
if err != nil {
return err
}
}
p := privKey{
signer: signer,
comment: key.Comment,
}
if key.LifetimeSecs > 0 {
t := time.Now().Add(time.Duration(key.LifetimeSecs) * time.Second)
p.expire = &t
}
r.keys = append(r.keys, p)
return nil
}
// Sign returns a signature for the data.
func (r *keyring) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) {
return r.SignWithFlags(key, data, 0)
}
func (r *keyring) SignWithFlags(key ssh.PublicKey, data []byte, flags SignatureFlags) (*ssh.Signature, error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.locked {
return nil, errLocked
}
r.expireKeysLocked()
wanted := key.Marshal()
for _, k := range r.keys {
if bytes.Equal(k.signer.PublicKey().Marshal(), wanted) {
if flags == 0 {
return k.signer.Sign(rand.Reader, data)
} else {
if algorithmSigner, ok := k.signer.(ssh.AlgorithmSigner); !ok {
return nil, fmt.Errorf("agent: signature does not support non-default signature algorithm: %T", k.signer)
} else {
var algorithm string
switch flags {
case SignatureFlagRsaSha256:
algorithm = ssh.KeyAlgoRSASHA256
case SignatureFlagRsaSha512:
algorithm = ssh.KeyAlgoRSASHA512
default:
return nil, fmt.Errorf("agent: unsupported signature flags: %d", flags)
}
return algorithmSigner.SignWithAlgorithm(rand.Reader, data, algorithm)
}
}
}
}
return nil, errors.New("not found")
}
// Signers returns signers for all the known keys.
func (r *keyring) Signers() ([]ssh.Signer, error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.locked {
return nil, errLocked
}
r.expireKeysLocked()
s := make([]ssh.Signer, 0, len(r.keys))
for _, k := range r.keys {
s = append(s, k.signer)
}
return s, nil
}
// The keyring does not support any extensions
func (r *keyring) Extension(extensionType string, contents []byte) ([]byte, error) {
return nil, ErrExtensionUnsupported
}
================================================
FILE: gossh/crypto/ssh/agent/server.go
================================================
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package agent
import (
"crypto/dsa"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rsa"
"encoding/binary"
"errors"
"fmt"
"io"
"log"
"math/big"
"gossh/crypto/ssh"
)
// server wraps an Agent and uses it to implement the agent side of
// the SSH-agent, wire protocol.
type server struct {
agent Agent
}
func (s *server) processRequestBytes(reqData []byte) []byte {
rep, err := s.processRequest(reqData)
if err != nil {
if err != errLocked {
// TODO(hanwen): provide better logging interface?
log.Printf("agent %d: %v", reqData[0], err)
}
return []byte{agentFailure}
}
if err == nil && rep == nil {
return []byte{agentSuccess}
}
return ssh.Marshal(rep)
}
func marshalKey(k *Key) []byte {
var record struct {
Blob []byte
Comment string
}
record.Blob = k.Marshal()
record.Comment = k.Comment
return ssh.Marshal(&record)
}
// See [PROTOCOL.agent], section 2.5.1.
const agentV1IdentitiesAnswer = 2
type agentV1IdentityMsg struct {
Numkeys uint32 `sshtype:"2"`
}
type agentRemoveIdentityMsg struct {
KeyBlob []byte `sshtype:"18"`
}
type agentLockMsg struct {
Passphrase []byte `sshtype:"22"`
}
type agentUnlockMsg struct {
Passphrase []byte `sshtype:"23"`
}
func (s *server) processRequest(data []byte) (any, error) {
switch data[0] {
case agentRequestV1Identities:
return &agentV1IdentityMsg{0}, nil
case agentRemoveAllV1Identities:
return nil, nil
case agentRemoveIdentity:
var req agentRemoveIdentityMsg
if err := ssh.Unmarshal(data, &req); err != nil {
return nil, err
}
var wk wireKey
if err := ssh.Unmarshal(req.KeyBlob, &wk); err != nil {
return nil, err
}
return nil, s.agent.Remove(&Key{Format: wk.Format, Blob: req.KeyBlob})
case agentRemoveAllIdentities:
return nil, s.agent.RemoveAll()
case agentLock:
var req agentLockMsg
if err := ssh.Unmarshal(data, &req); err != nil {
return nil, err
}
return nil, s.agent.Lock(req.Passphrase)
case agentUnlock:
var req agentUnlockMsg
if err := ssh.Unmarshal(data, &req); err != nil {
return nil, err
}
return nil, s.agent.Unlock(req.Passphrase)
case agentSignRequest:
var req signRequestAgentMsg
if err := ssh.Unmarshal(data, &req); err != nil {
return nil, err
}
var wk wireKey
if err := ssh.Unmarshal(req.KeyBlob, &wk); err != nil {
return nil, err
}
k := &Key{
Format: wk.Format,
Blob: req.KeyBlob,
}
var sig *ssh.Signature
var err error
if extendedAgent, ok := s.agent.(ExtendedAgent); ok {
sig, err = extendedAgent.SignWithFlags(k, req.Data, SignatureFlags(req.Flags))
} else {
sig, err = s.agent.Sign(k, req.Data)
}
if err != nil {
return nil, err
}
return &signResponseAgentMsg{SigBlob: ssh.Marshal(sig)}, nil
case agentRequestIdentities:
keys, err := s.agent.List()
if err != nil {
return nil, err
}
rep := identitiesAnswerAgentMsg{
NumKeys: uint32(len(keys)),
}
for _, k := range keys {
rep.Keys = append(rep.Keys, marshalKey(k)...)
}
return rep, nil
case agentAddIDConstrained, agentAddIdentity:
return nil, s.insertIdentity(data)
case agentExtension:
// Return a stub object where the whole contents of the response gets marshaled.
var responseStub struct {
Rest []byte `ssh:"rest"`
}
if extendedAgent, ok := s.agent.(ExtendedAgent); !ok {
// If this agent doesn't implement extensions, [PROTOCOL.agent] section 4.7
// requires that we return a standard SSH_AGENT_FAILURE message.
responseStub.Rest = []byte{agentFailure}
} else {
var req extensionAgentMsg
if err := ssh.Unmarshal(data, &req); err != nil {
return nil, err
}
res, err := extendedAgent.Extension(req.ExtensionType, req.Contents)
if err != nil {
// If agent extensions are unsupported, return a standard SSH_AGENT_FAILURE
// message as required by [PROTOCOL.agent] section 4.7.
if err == ErrExtensionUnsupported {
responseStub.Rest = []byte{agentFailure}
} else {
// As the result of any other error processing an extension request,
// [PROTOCOL.agent] section 4.7 requires that we return a
// SSH_AGENT_EXTENSION_FAILURE code.
responseStub.Rest = []byte{agentExtensionFailure}
}
} else {
if len(res) == 0 {
return nil, nil
}
responseStub.Rest = res
}
}
return responseStub, nil
}
return nil, fmt.Errorf("unknown opcode %d", data[0])
}
func parseConstraints(constraints []byte) (lifetimeSecs uint32, confirmBeforeUse bool, extensions []ConstraintExtension, err error) {
for len(constraints) != 0 {
switch constraints[0] {
case agentConstrainLifetime:
lifetimeSecs = binary.BigEndian.Uint32(constraints[1:5])
constraints = constraints[5:]
case agentConstrainConfirm:
confirmBeforeUse = true
constraints = constraints[1:]
case agentConstrainExtension, agentConstrainExtensionV00:
var msg constrainExtensionAgentMsg
if err = ssh.Unmarshal(constraints, &msg); err != nil {
return 0, false, nil, err
}
extensions = append(extensions, ConstraintExtension{
ExtensionName: msg.ExtensionName,
ExtensionDetails: msg.ExtensionDetails,
})
constraints = msg.Rest
default:
return 0, false, nil, fmt.Errorf("unknown constraint type: %d", constraints[0])
}
}
return
}
func setConstraints(key *AddedKey, constraintBytes []byte) error {
lifetimeSecs, confirmBeforeUse, constraintExtensions, err := parseConstraints(constraintBytes)
if err != nil {
return err
}
key.LifetimeSecs = lifetimeSecs
key.ConfirmBeforeUse = confirmBeforeUse
key.ConstraintExtensions = constraintExtensions
return nil
}
func parseRSAKey(req []byte) (*AddedKey, error) {
var k rsaKeyMsg
if err := ssh.Unmarshal(req, &k); err != nil {
return nil, err
}
if k.E.BitLen() > 30 {
return nil, errors.New("agent: RSA public exponent too large")
}
priv := &rsa.PrivateKey{
PublicKey: rsa.PublicKey{
E: int(k.E.Int64()),
N: k.N,
},
D: k.D,
Primes: []*big.Int{k.P, k.Q},
}
priv.Precompute()
addedKey := &AddedKey{PrivateKey: priv, Comment: k.Comments}
if err := setConstraints(addedKey, k.Constraints); err != nil {
return nil, err
}
return addedKey, nil
}
func parseEd25519Key(req []byte) (*AddedKey, error) {
var k ed25519KeyMsg
if err := ssh.Unmarshal(req, &k); err != nil {
return nil, err
}
priv := ed25519.PrivateKey(k.Priv)
addedKey := &AddedKey{PrivateKey: &priv, Comment: k.Comments}
if err := setConstraints(addedKey, k.Constraints); err != nil {
return nil, err
}
return addedKey, nil
}
func parseDSAKey(req []byte) (*AddedKey, error) {
var k dsaKeyMsg
if err := ssh.Unmarshal(req, &k); err != nil {
return nil, err
}
priv := &dsa.PrivateKey{
PublicKey: dsa.PublicKey{
Parameters: dsa.Parameters{
P: k.P,
Q: k.Q,
G: k.G,
},
Y: k.Y,
},
X: k.X,
}
addedKey := &AddedKey{PrivateKey: priv, Comment: k.Comments}
if err := setConstraints(addedKey, k.Constraints); err != nil {
return nil, err
}
return addedKey, nil
}
func unmarshalECDSA(curveName string, keyBytes []byte, privScalar *big.Int) (priv *ecdsa.PrivateKey, err error) {
priv = &ecdsa.PrivateKey{
D: privScalar,
}
switch curveName {
case "nistp256":
priv.Curve = elliptic.P256()
case "nistp384":
priv.Curve = elliptic.P384()
case "nistp521":
priv.Curve = elliptic.P521()
default:
return nil, fmt.Errorf("agent: unknown curve %q", curveName)
}
priv.X, priv.Y = elliptic.Unmarshal(priv.Curve, keyBytes)
if priv.X == nil || priv.Y == nil {
return nil, errors.New("agent: point not on curve")
}
return priv, nil
}
func parseEd25519Cert(req []byte) (*AddedKey, error) {
var k ed25519CertMsg
if err := ssh.Unmarshal(req, &k); err != nil {
return nil, err
}
pubKey, err := ssh.ParsePublicKey(k.CertBytes)
if err != nil {
return nil, err
}
priv := ed25519.PrivateKey(k.Priv)
cert, ok := pubKey.(*ssh.Certificate)
if !ok {
return nil, errors.New("agent: bad ED25519 certificate")
}
addedKey := &AddedKey{PrivateKey: &priv, Certificate: cert, Comment: k.Comments}
if err := setConstraints(addedKey, k.Constraints); err != nil {
return nil, err
}
return addedKey, nil
}
func parseECDSAKey(req []byte) (*AddedKey, error) {
var k ecdsaKeyMsg
if err := ssh.Unmarshal(req, &k); err != nil {
return nil, err
}
priv, err := unmarshalECDSA(k.Curve, k.KeyBytes, k.D)
if err != nil {
return nil, err
}
addedKey := &AddedKey{PrivateKey: priv, Comment: k.Comments}
if err := setConstraints(addedKey, k.Constraints); err != nil {
return nil, err
}
return addedKey, nil
}
func parseRSACert(req []byte) (*AddedKey, error) {
var k rsaCertMsg
if err := ssh.Unmarshal(req, &k); err != nil {
return nil, err
}
pubKey, err := ssh.ParsePublicKey(k.CertBytes)
if err != nil {
return nil, err
}
cert, ok := pubKey.(*ssh.Certificate)
if !ok {
return nil, errors.New("agent: bad RSA certificate")
}
// An RSA publickey as marshaled by rsaPublicKey.Marshal() in keys.go
var rsaPub struct {
Name string
E *big.Int
N *big.Int
}
if err := ssh.Unmarshal(cert.Key.Marshal(), &rsaPub); err != nil {
return nil, fmt.Errorf("agent: Unmarshal failed to parse public key: %v", err)
}
if rsaPub.E.BitLen() > 30 {
return nil, errors.New("agent: RSA public exponent too large")
}
priv := rsa.PrivateKey{
PublicKey: rsa.PublicKey{
E: int(rsaPub.E.Int64()),
N: rsaPub.N,
},
D: k.D,
Primes: []*big.Int{k.Q, k.P},
}
priv.Precompute()
addedKey := &AddedKey{PrivateKey: &priv, Certificate: cert, Comment: k.Comments}
if err := setConstraints(addedKey, k.Constraints); err != nil {
return nil, err
}
return addedKey, nil
}
func parseDSACert(req []byte) (*AddedKey, error) {
var k dsaCertMsg
if err := ssh.Unmarshal(req, &k); err != nil {
return nil, err
}
pubKey, err := ssh.ParsePublicKey(k.CertBytes)
if err != nil {
return nil, err
}
cert, ok := pubKey.(*ssh.Certificate)
if !ok {
return nil, errors.New("agent: bad DSA certificate")
}
// A DSA publickey as marshaled by dsaPublicKey.Marshal() in keys.go
var w struct {
Name string
P, Q, G, Y *big.Int
}
if err := ssh.Unmarshal(cert.Key.Marshal(), &w); err != nil {
return nil, fmt.Errorf("agent: Unmarshal failed to parse public key: %v", err)
}
priv := &dsa.PrivateKey{
PublicKey: dsa.PublicKey{
Parameters: dsa.Parameters{
P: w.P,
Q: w.Q,
G: w.G,
},
Y: w.Y,
},
X: k.X,
}
addedKey := &AddedKey{PrivateKey: priv, Certificate: cert, Comment: k.Comments}
if err := setConstraints(addedKey, k.Constraints); err != nil {
return nil, err
}
return addedKey, nil
}
func parseECDSACert(req []byte) (*AddedKey, error) {
var k ecdsaCertMsg
if err := ssh.Unmarshal(req, &k); err != nil {
return nil, err
}
pubKey, err := ssh.ParsePublicKey(k.CertBytes)
if err != nil {
return nil, err
}
cert, ok := pubKey.(*ssh.Certificate)
if !ok {
return nil, errors.New("agent: bad ECDSA certificate")
}
// An ECDSA publickey as marshaled by ecdsaPublicKey.Marshal() in keys.go
var ecdsaPub struct {
Name string
ID string
Key []byte
}
if err := ssh.Unmarshal(cert.Key.Marshal(), &ecdsaPub); err != nil {
return nil, err
}
priv, err := unmarshalECDSA(ecdsaPub.ID, ecdsaPub.Key, k.D)
if err != nil {
return nil, err
}
addedKey := &AddedKey{PrivateKey: priv, Certificate: cert, Comment: k.Comments}
if err := setConstraints(addedKey, k.Constraints); err != nil {
return nil, err
}
return addedKey, nil
}
func (s *server) insertIdentity(req []byte) error {
var record struct {
Type string `sshtype:"17|25"`
Rest []byte `ssh:"rest"`
}
if err := ssh.Unmarshal(req, &record); err != nil {
return err
}
var addedKey *AddedKey
var err error
switch record.Type {
case ssh.KeyAlgoRSA:
addedKey, err = parseRSAKey(req)
case ssh.KeyAlgoDSA:
addedKey, err = parseDSAKey(req)
case ssh.KeyAlgoECDSA256, ssh.KeyAlgoECDSA384, ssh.KeyAlgoECDSA521:
addedKey, err = parseECDSAKey(req)
case ssh.KeyAlgoED25519:
addedKey, err = parseEd25519Key(req)
case ssh.CertAlgoRSAv01:
addedKey, err = parseRSACert(req)
case ssh.CertAlgoDSAv01:
addedKey, err = parseDSACert(req)
case ssh.CertAlgoECDSA256v01, ssh.CertAlgoECDSA384v01, ssh.CertAlgoECDSA521v01:
addedKey, err = parseECDSACert(req)
case ssh.CertAlgoED25519v01:
addedKey, err = parseEd25519Cert(req)
default:
return fmt.Errorf("agent: not implemented: %q", record.Type)
}
if err != nil {
return err
}
return s.agent.Add(*addedKey)
}
// ServeAgent serves the agent protocol on the given connection. It
// returns when an I/O error occurs.
func ServeAgent(agent Agent, c io.ReadWriter) error {
s := &server{agent}
var length [4]byte
for {
if _, err := io.ReadFull(c, length[:]); err != nil {
return err
}
l := binary.BigEndian.Uint32(length[:])
if l == 0 {
return fmt.Errorf("agent: request size is 0")
}
if l > maxAgentResponseBytes {
// We also cap requests.
return fmt.Errorf("agent: request too large: %d", l)
}
req := make([]byte, l)
if _, err := io.ReadFull(c, req); err != nil {
return err
}
repData := s.processRequestBytes(req)
if len(repData) > maxAgentResponseBytes {
return fmt.Errorf("agent: reply too large: %d bytes", len(repData))
}
binary.BigEndian.PutUint32(length[:], uint32(len(repData)))
if _, err := c.Write(length[:]); err != nil {
return err
}
if _, err := c.Write(repData); err != nil {
return err
}
}
}
================================================
FILE: gossh/crypto/ssh/buffer.go
================================================
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"io"
"sync"
)
// buffer provides a linked list buffer for data exchange
// between producer and consumer. Theoretically the buffer is
// of unlimited capacity as it does no allocation of its own.
type buffer struct {
// protects concurrent access to head, tail and closed
*sync.Cond
head *element // the buffer that will be read first
tail *element // the buffer that will be read last
closed bool
}
// An element represents a single link in a linked list.
type element struct {
buf []byte
next *element
}
// newBuffer returns an empty buffer that is not closed.
func newBuffer() *buffer {
e := new(element)
b := &buffer{
Cond: newCond(),
head: e,
tail: e,
}
return b
}
// write makes buf available for Read to receive.
// buf must not be modified after the call to write.
func (b *buffer) write(buf []byte) {
b.Cond.L.Lock()
e := &element{buf: buf}
b.tail.next = e
b.tail = e
b.Cond.Signal()
b.Cond.L.Unlock()
}
// eof closes the buffer. Reads from the buffer once all
// the data has been consumed will receive io.EOF.
func (b *buffer) eof() {
b.Cond.L.Lock()
b.closed = true
b.Cond.Signal()
b.Cond.L.Unlock()
}
// Read reads data from the internal buffer in buf. Reads will block
// if no data is available, or until the buffer is closed.
func (b *buffer) Read(buf []byte) (n int, err error) {
b.Cond.L.Lock()
defer b.Cond.L.Unlock()
for len(buf) > 0 {
// if there is data in b.head, copy it
if len(b.head.buf) > 0 {
r := copy(buf, b.head.buf)
buf, b.head.buf = buf[r:], b.head.buf[r:]
n += r
continue
}
// if there is a next buffer, make it the head
if len(b.head.buf) == 0 && b.head != b.tail {
b.head = b.head.next
continue
}
// if at least one byte has been copied, return
if n > 0 {
break
}
// if nothing was read, and there is nothing outstanding
// check to see if the buffer is closed.
if b.closed {
err = io.EOF
break
}
// out of buffers, wait for producer
b.Cond.Wait()
}
return
}
================================================
FILE: gossh/crypto/ssh/certs.go
================================================
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"errors"
"fmt"
"io"
"net"
"sort"
"time"
)
// Certificate algorithm names from [PROTOCOL.certkeys]. These values can appear
// in Certificate.Type, PublicKey.Type, and ClientConfig.HostKeyAlgorithms.
// Unlike key algorithm names, these are not passed to AlgorithmSigner nor
// returned by MultiAlgorithmSigner and don't appear in the Signature.Format
// field.
const (
CertAlgoRSAv01 = "ssh-rsa-cert-v01@openssh.com"
CertAlgoDSAv01 = "ssh-dss-cert-v01@openssh.com"
CertAlgoECDSA256v01 = "ecdsa-sha2-nistp256-cert-v01@openssh.com"
CertAlgoECDSA384v01 = "ecdsa-sha2-nistp384-cert-v01@openssh.com"
CertAlgoECDSA521v01 = "ecdsa-sha2-nistp521-cert-v01@openssh.com"
CertAlgoSKECDSA256v01 = "sk-ecdsa-sha2-nistp256-cert-v01@openssh.com"
CertAlgoED25519v01 = "ssh-ed25519-cert-v01@openssh.com"
CertAlgoSKED25519v01 = "sk-ssh-ed25519-cert-v01@openssh.com"
// CertAlgoRSASHA256v01 and CertAlgoRSASHA512v01 can't appear as a
// Certificate.Type (or PublicKey.Type), but only in
// ClientConfig.HostKeyAlgorithms.
CertAlgoRSASHA256v01 = "rsa-sha2-256-cert-v01@openssh.com"
CertAlgoRSASHA512v01 = "rsa-sha2-512-cert-v01@openssh.com"
)
const (
// Deprecated: use CertAlgoRSAv01.
CertSigAlgoRSAv01 = CertAlgoRSAv01
// Deprecated: use CertAlgoRSASHA256v01.
CertSigAlgoRSASHA2256v01 = CertAlgoRSASHA256v01
// Deprecated: use CertAlgoRSASHA512v01.
CertSigAlgoRSASHA2512v01 = CertAlgoRSASHA512v01
)
// Certificate types distinguish between host and user
// certificates. The values can be set in the CertType field of
// Certificate.
const (
UserCert = 1
HostCert = 2
)
// Signature represents a cryptographic signature.
type Signature struct {
Format string
Blob []byte
Rest []byte `ssh:"rest"`
}
// CertTimeInfinity can be used for OpenSSHCertV01.ValidBefore to indicate that
// a certificate does not expire.
const CertTimeInfinity = 1<<64 - 1
// An Certificate represents an OpenSSH certificate as defined in
// [PROTOCOL.certkeys]?rev=1.8. The Certificate type implements the
// PublicKey interface, so it can be unmarshaled using
// ParsePublicKey.
type Certificate struct {
Nonce []byte
Key PublicKey
Serial uint64
CertType uint32
KeyId string
ValidPrincipals []string
ValidAfter uint64
ValidBefore uint64
Permissions
Reserved []byte
SignatureKey PublicKey
Signature *Signature
}
// genericCertData holds the key-independent part of the certificate data.
// Overall, certificates contain an nonce, public key fields and
// key-independent fields.
type genericCertData struct {
Serial uint64
CertType uint32
KeyId string
ValidPrincipals []byte
ValidAfter uint64
ValidBefore uint64
CriticalOptions []byte
Extensions []byte
Reserved []byte
SignatureKey []byte
Signature []byte
}
func marshalStringList(namelist []string) []byte {
var to []byte
for _, name := range namelist {
s := struct{ N string }{name}
to = append(to, Marshal(&s)...)
}
return to
}
type optionsTuple struct {
Key string
Value []byte
}
type optionsTupleValue struct {
Value string
}
// serialize a map of critical options or extensions
// issue #10569 - per [PROTOCOL.certkeys] and SSH implementation,
// we need two length prefixes for a non-empty string value
func marshalTuples(tups map[string]string) []byte {
keys := make([]string, 0, len(tups))
for key := range tups {
keys = append(keys, key)
}
sort.Strings(keys)
var ret []byte
for _, key := range keys {
s := optionsTuple{Key: key}
if value := tups[key]; len(value) > 0 {
s.Value = Marshal(&optionsTupleValue{value})
}
ret = append(ret, Marshal(&s)...)
}
return ret
}
// issue #10569 - per [PROTOCOL.certkeys] and SSH implementation,
// we need two length prefixes for a non-empty option value
func parseTuples(in []byte) (map[string]string, error) {
tups := map[string]string{}
var lastKey string
var haveLastKey bool
for len(in) > 0 {
var key, val, extra []byte
var ok bool
if key, in, ok = parseString(in); !ok {
return nil, errShortRead
}
keyStr := string(key)
// according to [PROTOCOL.certkeys], the names must be in
// lexical order.
if haveLastKey && keyStr <= lastKey {
return nil, fmt.Errorf("ssh: certificate options are not in lexical order")
}
lastKey, haveLastKey = keyStr, true
// the next field is a data field, which if non-empty has a string embedded
if val, in, ok = parseString(in); !ok {
return nil, errShortRead
}
if len(val) > 0 {
val, extra, ok = parseString(val)
if !ok {
return nil, errShortRead
}
if len(extra) > 0 {
return nil, fmt.Errorf("ssh: unexpected trailing data after certificate option value")
}
tups[keyStr] = string(val)
} else {
tups[keyStr] = ""
}
}
return tups, nil
}
func parseCert(in []byte, privAlgo string) (*Certificate, error) {
nonce, rest, ok := parseString(in)
if !ok {
return nil, errShortRead
}
key, rest, err := parsePubKey(rest, privAlgo)
if err != nil {
return nil, err
}
var g genericCertData
if err := Unmarshal(rest, &g); err != nil {
return nil, err
}
c := &Certificate{
Nonce: nonce,
Key: key,
Serial: g.Serial,
CertType: g.CertType,
KeyId: g.KeyId,
ValidAfter: g.ValidAfter,
ValidBefore: g.ValidBefore,
}
for principals := g.ValidPrincipals; len(principals) > 0; {
principal, rest, ok := parseString(principals)
if !ok {
return nil, errShortRead
}
c.ValidPrincipals = append(c.ValidPrincipals, string(principal))
principals = rest
}
c.CriticalOptions, err = parseTuples(g.CriticalOptions)
if err != nil {
return nil, err
}
c.Extensions, err = parseTuples(g.Extensions)
if err != nil {
return nil, err
}
c.Reserved = g.Reserved
k, err := ParsePublicKey(g.SignatureKey)
if err != nil {
return nil, err
}
c.SignatureKey = k
c.Signature, rest, ok = parseSignatureBody(g.Signature)
if !ok || len(rest) > 0 {
return nil, errors.New("ssh: signature parse error")
}
return c, nil
}
type openSSHCertSigner struct {
pub *Certificate
signer Signer
}
type algorithmOpenSSHCertSigner struct {
*openSSHCertSigner
algorithmSigner AlgorithmSigner
}
// NewCertSigner returns a Signer that signs with the given Certificate, whose
// private key is held by signer. It returns an error if the public key in cert
// doesn't match the key used by signer.
func NewCertSigner(cert *Certificate, signer Signer) (Signer, error) {
if !bytes.Equal(cert.Key.Marshal(), signer.PublicKey().Marshal()) {
return nil, errors.New("ssh: signer and cert have different public key")
}
switch s := signer.(type) {
case MultiAlgorithmSigner:
return &multiAlgorithmSigner{
AlgorithmSigner: &algorithmOpenSSHCertSigner{
&openSSHCertSigner{cert, signer}, s},
supportedAlgorithms: s.Algorithms(),
}, nil
case AlgorithmSigner:
return &algorithmOpenSSHCertSigner{
&openSSHCertSigner{cert, signer}, s}, nil
default:
return &openSSHCertSigner{cert, signer}, nil
}
}
func (s *openSSHCertSigner) Sign(rand io.Reader, data []byte) (*Signature, error) {
return s.signer.Sign(rand, data)
}
func (s *openSSHCertSigner) PublicKey() PublicKey {
return s.pub
}
func (s *algorithmOpenSSHCertSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) {
return s.algorithmSigner.SignWithAlgorithm(rand, data, algorithm)
}
const sourceAddressCriticalOption = "source-address"
// CertChecker does the work of verifying a certificate. Its methods
// can be plugged into ClientConfig.HostKeyCallback and
// ServerConfig.PublicKeyCallback. For the CertChecker to work,
// minimally, the IsAuthority callback should be set.
type CertChecker struct {
// SupportedCriticalOptions lists the CriticalOptions that the
// server application layer understands. These are only used
// for user certificates.
SupportedCriticalOptions []string
// IsUserAuthority should return true if the key is recognized as an
// authority for the given user certificate. This allows for
// certificates to be signed by other certificates. This must be set
// if this CertChecker will be checking user certificates.
IsUserAuthority func(auth PublicKey) bool
// IsHostAuthority should report whether the key is recognized as
// an authority for this host. This allows for certificates to be
// signed by other keys, and for those other keys to only be valid
// signers for particular hostnames. This must be set if this
// CertChecker will be checking host certificates.
IsHostAuthority func(auth PublicKey, address string) bool
// Clock is used for verifying time stamps. If nil, time.Now
// is used.
Clock func() time.Time
// UserKeyFallback is called when CertChecker.Authenticate encounters a
// public key that is not a certificate. It must implement validation
// of user keys or else, if nil, all such keys are rejected.
UserKeyFallback func(conn ConnMetadata, key PublicKey) (*Permissions, error)
// HostKeyFallback is called when CertChecker.CheckHostKey encounters a
// public key that is not a certificate. It must implement host key
// validation or else, if nil, all such keys are rejected.
HostKeyFallback HostKeyCallback
// IsRevoked is called for each certificate so that revocation checking
// can be implemented. It should return true if the given certificate
// is revoked and false otherwise. If nil, no certificates are
// considered to have been revoked.
IsRevoked func(cert *Certificate) bool
}
// CheckHostKey checks a host key certificate. This method can be
// plugged into ClientConfig.HostKeyCallback.
func (c *CertChecker) CheckHostKey(addr string, remote net.Addr, key PublicKey) error {
cert, ok := key.(*Certificate)
if !ok {
if c.HostKeyFallback != nil {
return c.HostKeyFallback(addr, remote, key)
}
return errors.New("ssh: non-certificate host key")
}
if cert.CertType != HostCert {
return fmt.Errorf("ssh: certificate presented as a host key has type %d", cert.CertType)
}
if !c.IsHostAuthority(cert.SignatureKey, addr) {
return fmt.Errorf("ssh: no authorities for hostname: %v", addr)
}
hostname, _, err := net.SplitHostPort(addr)
if err != nil {
return err
}
// Pass hostname only as principal for host certificates (consistent with OpenSSH)
return c.CheckCert(hostname, cert)
}
// Authenticate checks a user certificate. Authenticate can be used as
// a value for ServerConfig.PublicKeyCallback.
func (c *CertChecker) Authenticate(conn ConnMetadata, pubKey PublicKey) (*Permissions, error) {
cert, ok := pubKey.(*Certificate)
if !ok {
if c.UserKeyFallback != nil {
return c.UserKeyFallback(conn, pubKey)
}
return nil, errors.New("ssh: normal key pairs not accepted")
}
if cert.CertType != UserCert {
return nil, fmt.Errorf("ssh: cert has type %d", cert.CertType)
}
if !c.IsUserAuthority(cert.SignatureKey) {
return nil, fmt.Errorf("ssh: certificate signed by unrecognized authority")
}
if err := c.CheckCert(conn.User(), cert); err != nil {
return nil, err
}
return &cert.Permissions, nil
}
// CheckCert checks CriticalOptions, ValidPrincipals, revocation, timestamp and
// the signature of the certificate.
func (c *CertChecker) CheckCert(principal string, cert *Certificate) error {
if c.IsRevoked != nil && c.IsRevoked(cert) {
return fmt.Errorf("ssh: certificate serial %d revoked", cert.Serial)
}
for opt := range cert.CriticalOptions {
// sourceAddressCriticalOption will be enforced by
// serverAuthenticate
if opt == sourceAddressCriticalOption {
continue
}
found := false
for _, supp := range c.SupportedCriticalOptions {
if supp == opt {
found = true
break
}
}
if !found {
return fmt.Errorf("ssh: unsupported critical option %q in certificate", opt)
}
}
if len(cert.ValidPrincipals) > 0 {
// By default, certs are valid for all users/hosts.
found := false
for _, p := range cert.ValidPrincipals {
if p == principal {
found = true
break
}
}
if !found {
return fmt.Errorf("ssh: principal %q not in the set of valid principals for given certificate: %q", principal, cert.ValidPrincipals)
}
}
clock := c.Clock
if clock == nil {
clock = time.Now
}
unixNow := clock().Unix()
if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) {
return fmt.Errorf("ssh: cert is not yet valid")
}
if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(CertTimeInfinity) && (unixNow >= before || before < 0) {
return fmt.Errorf("ssh: cert has expired")
}
if err := cert.SignatureKey.Verify(cert.bytesForSigning(), cert.Signature); err != nil {
return fmt.Errorf("ssh: certificate signature does not verify")
}
return nil
}
// SignCert signs the certificate with an authority, setting the Nonce,
// SignatureKey, and Signature fields. If the authority implements the
// MultiAlgorithmSigner interface the first algorithm in the list is used. This
// is useful if you want to sign with a specific algorithm.
func (c *Certificate) SignCert(rand io.Reader, authority Signer) error {
c.Nonce = make([]byte, 32)
if _, err := io.ReadFull(rand, c.Nonce); err != nil {
return err
}
c.SignatureKey = authority.PublicKey()
if v, ok := authority.(MultiAlgorithmSigner); ok {
if len(v.Algorithms()) == 0 {
return errors.New("the provided authority has no signature algorithm")
}
// Use the first algorithm in the list.
sig, err := v.SignWithAlgorithm(rand, c.bytesForSigning(), v.Algorithms()[0])
if err != nil {
return err
}
c.Signature = sig
return nil
} else if v, ok := authority.(AlgorithmSigner); ok && v.PublicKey().Type() == KeyAlgoRSA {
// Default to KeyAlgoRSASHA512 for ssh-rsa signers.
// TODO: consider using KeyAlgoRSASHA256 as default.
sig, err := v.SignWithAlgorithm(rand, c.bytesForSigning(), KeyAlgoRSASHA512)
if err != nil {
return err
}
c.Signature = sig
return nil
}
sig, err := authority.Sign(rand, c.bytesForSigning())
if err != nil {
return err
}
c.Signature = sig
return nil
}
// certKeyAlgoNames is a mapping from known certificate algorithm names to the
// corresponding public key signature algorithm.
//
// This map must be kept in sync with the one in agent/client.go.
var certKeyAlgoNames = map[string]string{
CertAlgoRSAv01: KeyAlgoRSA,
CertAlgoRSASHA256v01: KeyAlgoRSASHA256,
CertAlgoRSASHA512v01: KeyAlgoRSASHA512,
CertAlgoDSAv01: KeyAlgoDSA,
CertAlgoECDSA256v01: KeyAlgoECDSA256,
CertAlgoECDSA384v01: KeyAlgoECDSA384,
CertAlgoECDSA521v01: KeyAlgoECDSA521,
CertAlgoSKECDSA256v01: KeyAlgoSKECDSA256,
CertAlgoED25519v01: KeyAlgoED25519,
CertAlgoSKED25519v01: KeyAlgoSKED25519,
}
// underlyingAlgo returns the signature algorithm associated with algo (which is
// an advertised or negotiated public key or host key algorithm). These are
// usually the same, except for certificate algorithms.
func underlyingAlgo(algo string) string {
if a, ok := certKeyAlgoNames[algo]; ok {
return a
}
return algo
}
// certificateAlgo returns the certificate algorithms that uses the provided
// underlying signature algorithm.
func certificateAlgo(algo string) (certAlgo string, ok bool) {
for certName, algoName := range certKeyAlgoNames {
if algoName == algo {
return certName, true
}
}
return "", false
}
func (cert *Certificate) bytesForSigning() []byte {
c2 := *cert
c2.Signature = nil
out := c2.Marshal()
// Drop trailing signature length.
return out[:len(out)-4]
}
// Marshal serializes c into OpenSSH's wire format. It is part of the
// PublicKey interface.
func (c *Certificate) Marshal() []byte {
generic := genericCertData{
Serial: c.Serial,
CertType: c.CertType,
KeyId: c.KeyId,
ValidPrincipals: marshalStringList(c.ValidPrincipals),
ValidAfter: uint64(c.ValidAfter),
ValidBefore: uint64(c.ValidBefore),
CriticalOptions: marshalTuples(c.CriticalOptions),
Extensions: marshalTuples(c.Extensions),
Reserved: c.Reserved,
SignatureKey: c.SignatureKey.Marshal(),
}
if c.Signature != nil {
generic.Signature = Marshal(c.Signature)
}
genericBytes := Marshal(&generic)
keyBytes := c.Key.Marshal()
_, keyBytes, _ = parseString(keyBytes)
prefix := Marshal(&struct {
Name string
Nonce []byte
Key []byte `ssh:"rest"`
}{c.Type(), c.Nonce, keyBytes})
result := make([]byte, 0, len(prefix)+len(genericBytes))
result = append(result, prefix...)
result = append(result, genericBytes...)
return result
}
// Type returns the certificate algorithm name. It is part of the PublicKey interface.
func (c *Certificate) Type() string {
certName, ok := certificateAlgo(c.Key.Type())
if !ok {
panic("unknown certificate type for key type " + c.Key.Type())
}
return certName
}
// Verify verifies a signature against the certificate's public
// key. It is part of the PublicKey interface.
func (c *Certificate) Verify(data []byte, sig *Signature) error {
return c.Key.Verify(data, sig)
}
func parseSignatureBody(in []byte) (out *Signature, rest []byte, ok bool) {
format, in, ok := parseString(in)
if !ok {
return
}
out = &Signature{
Format: string(format),
}
if out.Blob, in, ok = parseString(in); !ok {
return
}
switch out.Format {
case KeyAlgoSKECDSA256, CertAlgoSKECDSA256v01, KeyAlgoSKED25519, CertAlgoSKED25519v01:
out.Rest = in
return out, nil, ok
}
return out, in, ok
}
func parseSignature(in []byte) (out *Signature, rest []byte, ok bool) {
sigBytes, rest, ok := parseString(in)
if !ok {
return
}
out, trailing, ok := parseSignatureBody(sigBytes)
if !ok || len(trailing) > 0 {
return nil, nil, false
}
return
}
================================================
FILE: gossh/crypto/ssh/channel.go
================================================
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"encoding/binary"
"errors"
"fmt"
"io"
"log"
"sync"
)
const (
minPacketLength = 9
// channelMaxPacket contains the maximum number of bytes that will be
// sent in a single packet. As per RFC 4253, section 6.1, 32k is also
// the minimum.
channelMaxPacket = 1 << 15
// We follow OpenSSH here.
channelWindowSize = 64 * channelMaxPacket
)
// NewChannel represents an incoming request to a channel. It must either be
// accepted for use by calling Accept, or rejected by calling Reject.
type NewChannel interface {
// Accept accepts the channel creation request. It returns the Channel
// and a Go channel containing SSH requests. The Go channel must be
// serviced otherwise the Channel will hang.
Accept() (Channel, <-chan *Request, error)
// Reject rejects the channel creation request. After calling
// this, no other methods on the Channel may be called.
Reject(reason RejectionReason, message string) error
// ChannelType returns the type of the channel, as supplied by the
// client.
ChannelType() string
// ExtraData returns the arbitrary payload for this channel, as supplied
// by the client. This data is specific to the channel type.
ExtraData() []byte
}
// A Channel is an ordered, reliable, flow-controlled, duplex stream
// that is multiplexed over an SSH connection.
type Channel interface {
// Read reads up to len(data) bytes from the channel.
Read(data []byte) (int, error)
// Write writes len(data) bytes to the channel.
Write(data []byte) (int, error)
// Close signals end of channel use. No data may be sent after this
// call.
Close() error
// CloseWrite signals the end of sending in-band
// data. Requests may still be sent, and the other side may
// still send data
CloseWrite() error
// SendRequest sends a channel request. If wantReply is true,
// it will wait for a reply and return the result as a
// boolean, otherwise the return value will be false. Channel
// requests are out-of-band messages so they may be sent even
// if the data stream is closed or blocked by flow control.
// If the channel is closed before a reply is returned, io.EOF
// is returned.
SendRequest(name string, wantReply bool, payload []byte) (bool, error)
// Stderr returns an io.ReadWriter that writes to this channel
// with the extended data type set to stderr. Stderr may
// safely be read and written from a different goroutine than
// Read and Write respectively.
Stderr() io.ReadWriter
}
// Request is a request sent outside of the normal stream of
// data. Requests can either be specific to an SSH channel, or they
// can be global.
type Request struct {
Type string
WantReply bool
Payload []byte
ch *channel
mux *mux
}
// Reply sends a response to a request. It must be called for all requests
// where WantReply is true and is a no-op otherwise. The payload argument is
// ignored for replies to channel-specific requests.
func (r *Request) Reply(ok bool, payload []byte) error {
if !r.WantReply {
return nil
}
if r.ch == nil {
return r.mux.ackRequest(ok, payload)
}
return r.ch.ackRequest(ok)
}
// RejectionReason is an enumeration used when rejecting channel creation
// requests. See RFC 4254, section 5.1.
type RejectionReason uint32
const (
Prohibited RejectionReason = iota + 1
ConnectionFailed
UnknownChannelType
ResourceShortage
)
// String converts the rejection reason to human readable form.
func (r RejectionReason) String() string {
switch r {
case Prohibited:
return "administratively prohibited"
case ConnectionFailed:
return "connect failed"
case UnknownChannelType:
return "unknown channel type"
case ResourceShortage:
return "resource shortage"
}
return fmt.Sprintf("unknown reason %d", int(r))
}
func min(a uint32, b int) uint32 {
if a < uint32(b) {
return a
}
return uint32(b)
}
type channelDirection uint8
const (
channelInbound channelDirection = iota
channelOutbound
)
// channel is an implementation of the Channel interface that works
// with the mux class.
type channel struct {
// R/O after creation
chanType string
extraData []byte
localId, remoteId uint32
// maxIncomingPayload and maxRemotePayload are the maximum
// payload sizes of normal and extended data packets for
// receiving and sending, respectively. The wire packet will
// be 9 or 13 bytes larger (excluding encryption overhead).
maxIncomingPayload uint32
maxRemotePayload uint32
mux *mux
// decided is set to true if an accept or reject message has been sent
// (for outbound channels) or received (for inbound channels).
decided bool
// direction contains either channelOutbound, for channels created
// locally, or channelInbound, for channels created by the peer.
direction channelDirection
// Pending internal channel messages.
msg chan any
// Since requests have no ID, there can be only one request
// with WantReply=true outstanding. This lock is held by a
// goroutine that has such an outgoing request pending.
sentRequestMu sync.Mutex
incomingRequests chan *Request
sentEOF bool
// thread-safe data
remoteWin window
pending *buffer
extPending *buffer
// windowMu protects myWindow, the flow-control window, and myConsumed,
// the number of bytes consumed since we last increased myWindow
windowMu sync.Mutex
myWindow uint32
myConsumed uint32
// writeMu serializes calls to mux.conn.writePacket() and
// protects sentClose and packetPool. This mutex must be
// different from windowMu, as writePacket can block if there
// is a key exchange pending.
writeMu sync.Mutex
sentClose bool
// packetPool has a buffer for each extended channel ID to
// save allocations during writes.
packetPool map[uint32][]byte
}
// writePacket sends a packet. If the packet is a channel close, it updates
// sentClose. This method takes the lock c.writeMu.
func (ch *channel) writePacket(packet []byte) error {
ch.writeMu.Lock()
if ch.sentClose {
ch.writeMu.Unlock()
return io.EOF
}
ch.sentClose = (packet[0] == msgChannelClose)
err := ch.mux.conn.writePacket(packet)
ch.writeMu.Unlock()
return err
}
func (ch *channel) sendMessage(msg any) error {
if debugMux {
log.Printf("send(%d): %#v", ch.mux.chanList.offset, msg)
}
p := Marshal(msg)
binary.BigEndian.PutUint32(p[1:], ch.remoteId)
return ch.writePacket(p)
}
// WriteExtended writes data to a specific extended stream. These streams are
// used, for example, for stderr.
func (ch *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) {
if ch.sentEOF {
return 0, io.EOF
}
// 1 byte message type, 4 bytes remoteId, 4 bytes data length
opCode := byte(msgChannelData)
headerLength := uint32(9)
if extendedCode > 0 {
headerLength += 4
opCode = msgChannelExtendedData
}
ch.writeMu.Lock()
packet := ch.packetPool[extendedCode]
// We don't remove the buffer from packetPool, so
// WriteExtended calls from different goroutines will be
// flagged as errors by the race detector.
ch.writeMu.Unlock()
for len(data) > 0 {
space := min(ch.maxRemotePayload, len(data))
if space, err = ch.remoteWin.reserve(space); err != nil {
return n, err
}
if want := headerLength + space; uint32(cap(packet)) < want {
packet = make([]byte, want)
} else {
packet = packet[:want]
}
todo := data[:space]
packet[0] = opCode
binary.BigEndian.PutUint32(packet[1:], ch.remoteId)
if extendedCode > 0 {
binary.BigEndian.PutUint32(packet[5:], uint32(extendedCode))
}
binary.BigEndian.PutUint32(packet[headerLength-4:], uint32(len(todo)))
copy(packet[headerLength:], todo)
if err = ch.writePacket(packet); err != nil {
return n, err
}
n += len(todo)
data = data[len(todo):]
}
ch.writeMu.Lock()
ch.packetPool[extendedCode] = packet
ch.writeMu.Unlock()
return n, err
}
func (ch *channel) handleData(packet []byte) error {
headerLen := 9
isExtendedData := packet[0] == msgChannelExtendedData
if isExtendedData {
headerLen = 13
}
if len(packet) < headerLen {
// malformed data packet
return parseError(packet[0])
}
var extended uint32
if isExtendedData {
extended = binary.BigEndian.Uint32(packet[5:])
}
length := binary.BigEndian.Uint32(packet[headerLen-4 : headerLen])
if length == 0 {
return nil
}
if length > ch.maxIncomingPayload {
// TODO(hanwen): should send Disconnect?
return errors.New("ssh: incoming packet exceeds maximum payload size")
}
data := packet[headerLen:]
if length != uint32(len(data)) {
return errors.New("ssh: wrong packet length")
}
ch.windowMu.Lock()
if ch.myWindow < length {
ch.windowMu.Unlock()
// TODO(hanwen): should send Disconnect with reason?
return errors.New("ssh: remote side wrote too much")
}
ch.myWindow -= length
ch.windowMu.Unlock()
if extended == 1 {
ch.extPending.write(data)
} else if extended > 0 {
// discard other extended data.
} else {
ch.pending.write(data)
}
return nil
}
func (c *channel) adjustWindow(adj uint32) error {
c.windowMu.Lock()
// Since myConsumed and myWindow are managed on our side, and can never
// exceed the initial window setting, we don't worry about overflow.
c.myConsumed += adj
var sendAdj uint32
if (channelWindowSize-c.myWindow > 3*c.maxIncomingPayload) ||
(c.myWindow < channelWindowSize/2) {
sendAdj = c.myConsumed
c.myConsumed = 0
c.myWindow += sendAdj
}
c.windowMu.Unlock()
if sendAdj == 0 {
return nil
}
return c.sendMessage(windowAdjustMsg{
AdditionalBytes: sendAdj,
})
}
func (c *channel) ReadExtended(data []byte, extended uint32) (n int, err error) {
switch extended {
case 1:
n, err = c.extPending.Read(data)
case 0:
n, err = c.pending.Read(data)
default:
return 0, fmt.Errorf("ssh: extended code %d unimplemented", extended)
}
if n > 0 {
err = c.adjustWindow(uint32(n))
// sendWindowAdjust can return io.EOF if the remote
// peer has closed the connection, however we want to
// defer forwarding io.EOF to the caller of Read until
// the buffer has been drained.
if n > 0 && err == io.EOF {
err = nil
}
}
return n, err
}
func (c *channel) close() {
c.pending.eof()
c.extPending.eof()
close(c.msg)
close(c.incomingRequests)
c.writeMu.Lock()
// This is not necessary for a normal channel teardown, but if
// there was another error, it is.
c.sentClose = true
c.writeMu.Unlock()
// Unblock writers.
c.remoteWin.close()
}
// responseMessageReceived is called when a success or failure message is
// received on a channel to check that such a message is reasonable for the
// given channel.
func (ch *channel) responseMessageReceived() error {
if ch.direction == channelInbound {
return errors.New("ssh: channel response message received on inbound channel")
}
if ch.decided {
return errors.New("ssh: duplicate response received for channel")
}
ch.decided = true
return nil
}
func (ch *channel) handlePacket(packet []byte) error {
switch packet[0] {
case msgChannelData, msgChannelExtendedData:
return ch.handleData(packet)
case msgChannelClose:
ch.sendMessage(channelCloseMsg{PeersID: ch.remoteId})
ch.mux.chanList.remove(ch.localId)
ch.close()
return nil
case msgChannelEOF:
// RFC 4254 is mute on how EOF affects dataExt messages but
// it is logical to signal EOF at the same time.
ch.extPending.eof()
ch.pending.eof()
return nil
}
decoded, err := decode(packet)
if err != nil {
return err
}
switch msg := decoded.(type) {
case *channelOpenFailureMsg:
if err := ch.responseMessageReceived(); err != nil {
return err
}
ch.mux.chanList.remove(msg.PeersID)
ch.msg <- msg
case *channelOpenConfirmMsg:
if err := ch.responseMessageReceived(); err != nil {
return err
}
if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
return fmt.Errorf("ssh: invalid MaxPacketSize %d from peer", msg.MaxPacketSize)
}
ch.remoteId = msg.MyID
ch.maxRemotePayload = msg.MaxPacketSize
ch.remoteWin.add(msg.MyWindow)
ch.msg <- msg
case *windowAdjustMsg:
if !ch.remoteWin.add(msg.AdditionalBytes) {
return fmt.Errorf("ssh: invalid window update for %d bytes", msg.AdditionalBytes)
}
case *channelRequestMsg:
req := Request{
Type: msg.Request,
WantReply: msg.WantReply,
Payload: msg.RequestSpecificData,
ch: ch,
}
ch.incomingRequests <- &req
default:
ch.msg <- msg
}
return nil
}
func (m *mux) newChannel(chanType string, direction channelDirection, extraData []byte) *channel {
ch := &channel{
remoteWin: window{Cond: newCond()},
myWindow: channelWindowSize,
pending: newBuffer(),
extPending: newBuffer(),
direction: direction,
incomingRequests: make(chan *Request, chanSize),
msg: make(chan any, chanSize),
chanType: chanType,
extraData: extraData,
mux: m,
packetPool: make(map[uint32][]byte),
}
ch.localId = m.chanList.add(ch)
return ch
}
var errUndecided = errors.New("ssh: must Accept or Reject channel")
var errDecidedAlready = errors.New("ssh: can call Accept or Reject only once")
type extChannel struct {
code uint32
ch *channel
}
func (e *extChannel) Write(data []byte) (n int, err error) {
return e.ch.WriteExtended(data, e.code)
}
func (e *extChannel) Read(data []byte) (n int, err error) {
return e.ch.ReadExtended(data, e.code)
}
func (ch *channel) Accept() (Channel, <-chan *Request, error) {
if ch.decided {
return nil, nil, errDecidedAlready
}
ch.maxIncomingPayload = channelMaxPacket
confirm := channelOpenConfirmMsg{
PeersID: ch.remoteId,
MyID: ch.localId,
MyWindow: ch.myWindow,
MaxPacketSize: ch.maxIncomingPayload,
}
ch.decided = true
if err := ch.sendMessage(confirm); err != nil {
return nil, nil, err
}
return ch, ch.incomingRequests, nil
}
func (ch *channel) Reject(reason RejectionReason, message string) error {
if ch.decided {
return errDecidedAlready
}
reject := channelOpenFailureMsg{
PeersID: ch.remoteId,
Reason: reason,
Message: message,
Language: "en",
}
ch.decided = true
return ch.sendMessage(reject)
}
func (ch *channel) Read(data []byte) (int, error) {
if !ch.decided {
return 0, errUndecided
}
return ch.ReadExtended(data, 0)
}
func (ch *channel) Write(data []byte) (int, error) {
if !ch.decided {
return 0, errUndecided
}
return ch.WriteExtended(data, 0)
}
func (ch *channel) CloseWrite() error {
if !ch.decided {
return errUndecided
}
ch.sentEOF = true
return ch.sendMessage(channelEOFMsg{
PeersID: ch.remoteId})
}
func (ch *channel) Close() error {
if !ch.decided {
return errUndecided
}
return ch.sendMessage(channelCloseMsg{
PeersID: ch.remoteId})
}
// Extended returns an io.ReadWriter that sends and receives data on the given,
// SSH extended stream. Such streams are used, for example, for stderr.
func (ch *channel) Extended(code uint32) io.ReadWriter {
if !ch.decided {
return nil
}
return &extChannel{code, ch}
}
func (ch *channel) Stderr() io.ReadWriter {
return ch.Extended(1)
}
func (ch *channel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
if !ch.decided {
return false, errUndecided
}
if wantReply {
ch.sentRequestMu.Lock()
defer ch.sentRequestMu.Unlock()
}
msg := channelRequestMsg{
PeersID: ch.remoteId,
Request: name,
WantReply: wantReply,
RequestSpecificData: payload,
}
if err := ch.sendMessage(msg); err != nil {
return false, err
}
if wantReply {
m, ok := (<-ch.msg)
if !ok {
return false, io.EOF
}
switch m.(type) {
case *channelRequestFailureMsg:
return false, nil
case *channelRequestSuccessMsg:
return true, nil
default:
return false, fmt.Errorf("ssh: unexpected response to channel request: %#v", m)
}
}
return false, nil
}
// ackRequest either sends an ack or nack to the channel request.
func (ch *channel) ackRequest(ok bool) error {
if !ch.decided {
return errUndecided
}
var msg any
if !ok {
msg = channelRequestFailureMsg{
PeersID: ch.remoteId,
}
} else {
msg = channelRequestSuccessMsg{
PeersID: ch.remoteId,
}
}
return ch.sendMessage(msg)
}
func (ch *channel) ChannelType() string {
return ch.chanType
}
func (ch *channel) ExtraData() []byte {
return ch.extraData
}
================================================
FILE: gossh/crypto/ssh/cipher.go
================================================
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"crypto/aes"
"crypto/cipher"
"crypto/des"
"crypto/rc4"
"crypto/subtle"
"encoding/binary"
"errors"
"fmt"
"gossh/crypto/chacha20"
"gossh/crypto/internal/poly1305"
"hash"
"io"
)
const (
packetSizeMultiple = 16 // TODO(huin) this should be determined by the cipher.
// RFC 4253 section 6.1 defines a minimum packet size of 32768 that implementations
// MUST be able to process (plus a few more kilobytes for padding and mac). The RFC
// indicates implementations SHOULD be able to handle larger packet sizes, but then
// waffles on about reasonable limits.
//
// OpenSSH caps their maxPacket at 256kB so we choose to do
// the same. maxPacket is also used to ensure that uint32
// length fields do not overflow, so it should remain well
// below 4G.
maxPacket = 256 * 1024
)
// noneCipher implements cipher.Stream and provides no encryption. It is used
// by the transport before the first key-exchange.
type noneCipher struct{}
func (c noneCipher) XORKeyStream(dst, src []byte) {
copy(dst, src)
}
func newAESCTR(key, iv []byte) (cipher.Stream, error) {
c, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
return cipher.NewCTR(c, iv), nil
}
func newRC4(key, iv []byte) (cipher.Stream, error) {
return rc4.NewCipher(key)
}
type cipherMode struct {
keySize int
ivSize int
create func(key, iv []byte, macKey []byte, algs directionAlgorithms) (packetCipher, error)
}
func streamCipherMode(skip int, createFunc func(key, iv []byte) (cipher.Stream, error)) func(key, iv []byte, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
return func(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
stream, err := createFunc(key, iv)
if err != nil {
return nil, err
}
var streamDump []byte
if skip > 0 {
streamDump = make([]byte, 512)
}
for remainingToDump := skip; remainingToDump > 0; {
dumpThisTime := remainingToDump
if dumpThisTime > len(streamDump) {
dumpThisTime = len(streamDump)
}
stream.XORKeyStream(streamDump[:dumpThisTime], streamDump[:dumpThisTime])
remainingToDump -= dumpThisTime
}
mac := macModes[algs.MAC].new(macKey)
return &streamPacketCipher{
mac: mac,
etm: macModes[algs.MAC].etm,
macResult: make([]byte, mac.Size()),
cipher: stream,
}, nil
}
}
// cipherModes documents properties of supported ciphers. Ciphers not included
// are not supported and will not be negotiated, even if explicitly requested in
// ClientConfig.Crypto.Ciphers.
var cipherModes = map[string]*cipherMode{
// Ciphers from RFC 4344, which introduced many CTR-based ciphers. Algorithms
// are defined in the order specified in the RFC.
"aes128-ctr": {16, aes.BlockSize, streamCipherMode(0, newAESCTR)},
"aes192-ctr": {24, aes.BlockSize, streamCipherMode(0, newAESCTR)},
"aes256-ctr": {32, aes.BlockSize, streamCipherMode(0, newAESCTR)},
// Ciphers from RFC 4345, which introduces security-improved arcfour ciphers.
// They are defined in the order specified in the RFC.
"arcfour128": {16, 0, streamCipherMode(1536, newRC4)},
"arcfour256": {32, 0, streamCipherMode(1536, newRC4)},
// Cipher defined in RFC 4253, which describes SSH Transport Layer Protocol.
// Note that this cipher is not safe, as stated in RFC 4253: "Arcfour (and
// RC4) has problems with weak keys, and should be used with caution."
// RFC 4345 introduces improved versions of Arcfour.
"arcfour": {16, 0, streamCipherMode(0, newRC4)},
// AEAD ciphers
gcm128CipherID: {16, 12, newGCMCipher},
gcm256CipherID: {32, 12, newGCMCipher},
chacha20Poly1305ID: {64, 0, newChaCha20Cipher},
// CBC mode is insecure and so is not included in the default config.
// (See https://www.ieee-security.org/TC/SP2013/papers/4977a526.pdf). If absolutely
// needed, it's possible to specify a custom Config to enable it.
// You should expect that an active attacker can recover plaintext if
// you do.
aes128cbcID: {16, aes.BlockSize, newAESCBCCipher},
// 3des-cbc is insecure and is not included in the default
// config.
tripledescbcID: {24, des.BlockSize, newTripleDESCBCCipher},
}
// prefixLen is the length of the packet prefix that contains the packet length
// and number of padding bytes.
const prefixLen = 5
// streamPacketCipher is a packetCipher using a stream cipher.
type streamPacketCipher struct {
mac hash.Hash
cipher cipher.Stream
etm bool
// The following members are to avoid per-packet allocations.
prefix [prefixLen]byte
seqNumBytes [4]byte
padding [2 * packetSizeMultiple]byte
packetData []byte
macResult []byte
}
// readCipherPacket reads and decrypt a single packet from the reader argument.
func (s *streamPacketCipher) readCipherPacket(seqNum uint32, r io.Reader) ([]byte, error) {
if _, err := io.ReadFull(r, s.prefix[:]); err != nil {
return nil, err
}
var encryptedPaddingLength [1]byte
if s.mac != nil && s.etm {
copy(encryptedPaddingLength[:], s.prefix[4:5])
s.cipher.XORKeyStream(s.prefix[4:5], s.prefix[4:5])
} else {
s.cipher.XORKeyStream(s.prefix[:], s.prefix[:])
}
length := binary.BigEndian.Uint32(s.prefix[0:4])
paddingLength := uint32(s.prefix[4])
var macSize uint32
if s.mac != nil {
s.mac.Reset()
binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum)
s.mac.Write(s.seqNumBytes[:])
if s.etm {
s.mac.Write(s.prefix[:4])
s.mac.Write(encryptedPaddingLength[:])
} else {
s.mac.Write(s.prefix[:])
}
macSize = uint32(s.mac.Size())
}
if length <= paddingLength+1 {
return nil, errors.New("ssh: invalid packet length, packet too small")
}
if length > maxPacket {
return nil, errors.New("ssh: invalid packet length, packet too large")
}
// the maxPacket check above ensures that length-1+macSize
// does not overflow.
if uint32(cap(s.packetData)) < length-1+macSize {
s.packetData = make([]byte, length-1+macSize)
} else {
s.packetData = s.packetData[:length-1+macSize]
}
if _, err := io.ReadFull(r, s.packetData); err != nil {
return nil, err
}
mac := s.packetData[length-1:]
data := s.packetData[:length-1]
if s.mac != nil && s.etm {
s.mac.Write(data)
}
s.cipher.XORKeyStream(data, data)
if s.mac != nil {
if !s.etm {
s.mac.Write(data)
}
s.macResult = s.mac.Sum(s.macResult[:0])
if subtle.ConstantTimeCompare(s.macResult, mac) != 1 {
return nil, errors.New("ssh: MAC failure")
}
}
return s.packetData[:length-paddingLength-1], nil
}
// writeCipherPacket encrypts and sends a packet of data to the writer argument
func (s *streamPacketCipher) writeCipherPacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error {
if len(packet) > maxPacket {
return errors.New("ssh: packet too large")
}
aadlen := 0
if s.mac != nil && s.etm {
// packet length is not encrypted for EtM modes
aadlen = 4
}
paddingLength := packetSizeMultiple - (prefixLen+len(packet)-aadlen)%packetSizeMultiple
if paddingLength < 4 {
paddingLength += packetSizeMultiple
}
length := len(packet) + 1 + paddingLength
binary.BigEndian.PutUint32(s.prefix[:], uint32(length))
s.prefix[4] = byte(paddingLength)
padding := s.padding[:paddingLength]
if _, err := io.ReadFull(rand, padding); err != nil {
return err
}
if s.mac != nil {
s.mac.Reset()
binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum)
s.mac.Write(s.seqNumBytes[:])
if s.etm {
// For EtM algorithms, the packet length must stay unencrypted,
// but the following data (padding length) must be encrypted
s.cipher.XORKeyStream(s.prefix[4:5], s.prefix[4:5])
}
s.mac.Write(s.prefix[:])
if !s.etm {
// For non-EtM algorithms, the algorithm is applied on unencrypted data
s.mac.Write(packet)
s.mac.Write(padding)
}
}
if !(s.mac != nil && s.etm) {
// For EtM algorithms, the padding length has already been encrypted
// and the packet length must remain unencrypted
s.cipher.XORKeyStream(s.prefix[:], s.prefix[:])
}
s.cipher.XORKeyStream(packet, packet)
s.cipher.XORKeyStream(padding, padding)
if s.mac != nil && s.etm {
// For EtM algorithms, packet and padding must be encrypted
s.mac.Write(packet)
s.mac.Write(padding)
}
if _, err := w.Write(s.prefix[:]); err != nil {
return err
}
if _, err := w.Write(packet); err != nil {
return err
}
if _, err := w.Write(padding); err != nil {
return err
}
if s.mac != nil {
s.macResult = s.mac.Sum(s.macResult[:0])
if _, err := w.Write(s.macResult); err != nil {
return err
}
}
return nil
}
type gcmCipher struct {
aead cipher.AEAD
prefix [4]byte
iv []byte
buf []byte
}
func newGCMCipher(key, iv, unusedMacKey []byte, unusedAlgs directionAlgorithms) (packetCipher, error) {
c, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
aead, err := cipher.NewGCM(c)
if err != nil {
return nil, err
}
return &gcmCipher{
aead: aead,
iv: iv,
}, nil
}
const gcmTagSize = 16
func (c *gcmCipher) writeCipherPacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error {
// Pad out to multiple of 16 bytes. This is different from the
// stream cipher because that encrypts the length too.
padding := byte(packetSizeMultiple - (1+len(packet))%packetSizeMultiple)
if padding < 4 {
padding += packetSizeMultiple
}
length := uint32(len(packet) + int(padding) + 1)
binary.BigEndian.PutUint32(c.prefix[:], length)
if _, err := w.Write(c.prefix[:]); err != nil {
return err
}
if cap(c.buf) < int(length) {
c.buf = make([]byte, length)
} else {
c.buf = c.buf[:length]
}
c.buf[0] = padding
copy(c.buf[1:], packet)
if _, err := io.ReadFull(rand, c.buf[1+len(packet):]); err != nil {
return err
}
c.buf = c.aead.Seal(c.buf[:0], c.iv, c.buf, c.prefix[:])
if _, err := w.Write(c.buf); err != nil {
return err
}
c.incIV()
return nil
}
func (c *gcmCipher) incIV() {
for i := 4 + 7; i >= 4; i-- {
c.iv[i]++
if c.iv[i] != 0 {
break
}
}
}
func (c *gcmCipher) readCipherPacket(seqNum uint32, r io.Reader) ([]byte, error) {
if _, err := io.ReadFull(r, c.prefix[:]); err != nil {
return nil, err
}
length := binary.BigEndian.Uint32(c.prefix[:])
if length > maxPacket {
return nil, errors.New("ssh: max packet length exceeded")
}
if cap(c.buf) < int(length+gcmTagSize) {
c.buf = make([]byte, length+gcmTagSize)
} else {
c.buf = c.buf[:length+gcmTagSize]
}
if _, err := io.ReadFull(r, c.buf); err != nil {
return nil, err
}
plain, err := c.aead.Open(c.buf[:0], c.iv, c.buf, c.prefix[:])
if err != nil {
return nil, err
}
c.incIV()
if len(plain) == 0 {
return nil, errors.New("ssh: empty packet")
}
padding := plain[0]
if padding < 4 {
// padding is a byte, so it automatically satisfies
// the maximum size, which is 255.
return nil, fmt.Errorf("ssh: illegal padding %d", padding)
}
if int(padding+1) >= len(plain) {
return nil, fmt.Errorf("ssh: padding %d too large", padding)
}
plain = plain[1 : length-uint32(padding)]
return plain, nil
}
// cbcCipher implements aes128-cbc cipher defined in RFC 4253 section 6.1
type cbcCipher struct {
mac hash.Hash
macSize uint32
decrypter cipher.BlockMode
encrypter cipher.BlockMode
// The following members are to avoid per-packet allocations.
seqNumBytes [4]byte
packetData []byte
macResult []byte
// Amount of data we should still read to hide which
// verification error triggered.
oracleCamouflage uint32
}
func newCBCCipher(c cipher.Block, key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
cbc := &cbcCipher{
mac: macModes[algs.MAC].new(macKey),
decrypter: cipher.NewCBCDecrypter(c, iv),
encrypter: cipher.NewCBCEncrypter(c, iv),
packetData: make([]byte, 1024),
}
if cbc.mac != nil {
cbc.macSize = uint32(cbc.mac.Size())
}
return cbc, nil
}
func newAESCBCCipher(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
c, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
cbc, err := newCBCCipher(c, key, iv, macKey, algs)
if err != nil {
return nil, err
}
return cbc, nil
}
func newTripleDESCBCCipher(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
c, err := des.NewTripleDESCipher(key)
if err != nil {
return nil, err
}
cbc, err := newCBCCipher(c, key, iv, macKey, algs)
if err != nil {
return nil, err
}
return cbc, nil
}
func maxUInt32(a, b int) uint32 {
if a > b {
return uint32(a)
}
return uint32(b)
}
const (
cbcMinPacketSizeMultiple = 8
cbcMinPacketSize = 16
cbcMinPaddingSize = 4
)
// cbcError represents a verification error that may leak information.
type cbcError string
func (e cbcError) Error() string { return string(e) }
func (c *cbcCipher) readCipherPacket(seqNum uint32, r io.Reader) ([]byte, error) {
p, err := c.readCipherPacketLeaky(seqNum, r)
if err != nil {
if _, ok := err.(cbcError); ok {
// Verification error: read a fixed amount of
// data, to make distinguishing between
// failing MAC and failing length check more
// difficult.
io.CopyN(io.Discard, r, int64(c.oracleCamouflage))
}
}
return p, err
}
func (c *cbcCipher) readCipherPacketLeaky(seqNum uint32, r io.Reader) ([]byte, error) {
blockSize := c.decrypter.BlockSize()
// Read the header, which will include some of the subsequent data in the
// case of block ciphers - this is copied back to the payload later.
// How many bytes of payload/padding will be read with this first read.
firstBlockLength := uint32((prefixLen + blockSize - 1) / blockSize * blockSize)
firstBlock := c.packetData[:firstBlockLength]
if _, err := io.ReadFull(r, firstBlock); err != nil {
return nil, err
}
c.oracleCamouflage = maxPacket + 4 + c.macSize - firstBlockLength
c.decrypter.CryptBlocks(firstBlock, firstBlock)
length := binary.BigEndian.Uint32(firstBlock[:4])
if length > maxPacket {
return nil, cbcError("ssh: packet too large")
}
if length+4 < maxUInt32(cbcMinPacketSize, blockSize) {
// The minimum size of a packet is 16 (or the cipher block size, whichever
// is larger) bytes.
return nil, cbcError("ssh: packet too small")
}
// The length of the packet (including the length field but not the MAC) must
// be a multiple of the block size or 8, whichever is larger.
if (length+4)%maxUInt32(cbcMinPacketSizeMultiple, blockSize) != 0 {
return nil, cbcError("ssh: invalid packet length multiple")
}
paddingLength := uint32(firstBlock[4])
if paddingLength < cbcMinPaddingSize || length <= paddingLength+1 {
return nil, cbcError("ssh: invalid packet length")
}
// Positions within the c.packetData buffer:
macStart := 4 + length
paddingStart := macStart - paddingLength
// Entire packet size, starting before length, ending at end of mac.
entirePacketSize := macStart + c.macSize
// Ensure c.packetData is large enough for the entire packet data.
if uint32(cap(c.packetData)) < entirePacketSize {
// Still need to upsize and copy, but this should be rare at runtime, only
// on upsizing the packetData buffer.
c.packetData = make([]byte, entirePacketSize)
copy(c.packetData, firstBlock)
} else {
c.packetData = c.packetData[:entirePacketSize]
}
n, err := io.ReadFull(r, c.packetData[firstBlockLength:])
if err != nil {
return nil, err
}
c.oracleCamouflage -= uint32(n)
remainingCrypted := c.packetData[firstBlockLength:macStart]
c.decrypter.CryptBlocks(remainingCrypted, remainingCrypted)
mac := c.packetData[macStart:]
if c.mac != nil {
c.mac.Reset()
binary.BigEndian.PutUint32(c.seqNumBytes[:], seqNum)
c.mac.Write(c.seqNumBytes[:])
c.mac.Write(c.packetData[:macStart])
c.macResult = c.mac.Sum(c.macResult[:0])
if subtle.ConstantTimeCompare(c.macResult, mac) != 1 {
return nil, cbcError("ssh: MAC failure")
}
}
return c.packetData[prefixLen:paddingStart], nil
}
func (c *cbcCipher) writeCipherPacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error {
effectiveBlockSize := maxUInt32(cbcMinPacketSizeMultiple, c.encrypter.BlockSize())
// Length of encrypted portion of the packet (header, payload, padding).
// Enforce minimum padding and packet size.
encLength := maxUInt32(prefixLen+len(packet)+cbcMinPaddingSize, cbcMinPaddingSize)
// Enforce block size.
encLength = (encLength + effectiveBlockSize - 1) / effectiveBlockSize * effectiveBlockSize
length := encLength - 4
paddingLength := int(length) - (1 + len(packet))
// Overall buffer contains: header, payload, padding, mac.
// Space for the MAC is reserved in the capacity but not the slice length.
bufferSize := encLength + c.macSize
if uint32(cap(c.packetData)) < bufferSize {
c.packetData = make([]byte, encLength, bufferSize)
} else {
c.packetData = c.packetData[:encLength]
}
p := c.packetData
// Packet header.
binary.BigEndian.PutUint32(p, length)
p = p[4:]
p[0] = byte(paddingLength)
// Payload.
p = p[1:]
copy(p, packet)
// Padding.
p = p[len(packet):]
if _, err := io.ReadFull(rand, p); err != nil {
return err
}
if c.mac != nil {
c.mac.Reset()
binary.BigEndian.PutUint32(c.seqNumBytes[:], seqNum)
c.mac.Write(c.seqNumBytes[:])
c.mac.Write(c.packetData)
// The MAC is now appended into the capacity reserved for it earlier.
c.packetData = c.mac.Sum(c.packetData)
}
c.encrypter.CryptBlocks(c.packetData[:encLength], c.packetData[:encLength])
if _, err := w.Write(c.packetData); err != nil {
return err
}
return nil
}
const chacha20Poly1305ID = "chacha20-poly1305@openssh.com"
// chacha20Poly1305Cipher implements the chacha20-poly1305@openssh.com
// AEAD, which is described here:
//
// https://tools.ietf.org/html/draft-josefsson-ssh-chacha20-poly1305-openssh-00
//
// the methods here also implement padding, which RFC 4253 Section 6
// also requires of stream ciphers.
type chacha20Poly1305Cipher struct {
lengthKey [32]byte
contentKey [32]byte
buf []byte
}
func newChaCha20Cipher(key, unusedIV, unusedMACKey []byte, unusedAlgs directionAlgorithms) (packetCipher, error) {
if len(key) != 64 {
panic(len(key))
}
c := &chacha20Poly1305Cipher{
buf: make([]byte, 256),
}
copy(c.contentKey[:], key[:32])
copy(c.lengthKey[:], key[32:])
return c, nil
}
func (c *chacha20Poly1305Cipher) readCipherPacket(seqNum uint32, r io.Reader) ([]byte, error) {
nonce := make([]byte, 12)
binary.BigEndian.PutUint32(nonce[8:], seqNum)
s, err := chacha20.NewUnauthenticatedCipher(c.contentKey[:], nonce)
if err != nil {
return nil, err
}
var polyKey, discardBuf [32]byte
s.XORKeyStream(polyKey[:], polyKey[:])
s.XORKeyStream(discardBuf[:], discardBuf[:]) // skip the next 32 bytes
encryptedLength := c.buf[:4]
if _, err := io.ReadFull(r, encryptedLength); err != nil {
return nil, err
}
var lenBytes [4]byte
ls, err := chacha20.NewUnauthenticatedCipher(c.lengthKey[:], nonce)
if err != nil {
return nil, err
}
ls.XORKeyStream(lenBytes[:], encryptedLength)
length := binary.BigEndian.Uint32(lenBytes[:])
if length > maxPacket {
return nil, errors.New("ssh: invalid packet length, packet too large")
}
contentEnd := 4 + length
packetEnd := contentEnd + poly1305.TagSize
if uint32(cap(c.buf)) < packetEnd {
c.buf = make([]byte, packetEnd)
copy(c.buf[:], encryptedLength)
} else {
c.buf = c.buf[:packetEnd]
}
if _, err := io.ReadFull(r, c.buf[4:packetEnd]); err != nil {
return nil, err
}
var mac [poly1305.TagSize]byte
copy(mac[:], c.buf[contentEnd:packetEnd])
if !poly1305.Verify(&mac, c.buf[:contentEnd], &polyKey) {
return nil, errors.New("ssh: MAC failure")
}
plain := c.buf[4:contentEnd]
s.XORKeyStream(plain, plain)
if len(plain) == 0 {
return nil, errors.New("ssh: empty packet")
}
padding := plain[0]
if padding < 4 {
// padding is a byte, so it automatically satisfies
// the maximum size, which is 255.
return nil, fmt.Errorf("ssh: illegal padding %d", padding)
}
if int(padding)+1 >= len(plain) {
return nil, fmt.Errorf("ssh: padding %d too large", padding)
}
plain = plain[1 : len(plain)-int(padding)]
return plain, nil
}
func (c *chacha20Poly1305Cipher) writeCipherPacket(seqNum uint32, w io.Writer, rand io.Reader, payload []byte) error {
nonce := make([]byte, 12)
binary.BigEndian.PutUint32(nonce[8:], seqNum)
s, err := chacha20.NewUnauthenticatedCipher(c.contentKey[:], nonce)
if err != nil {
return err
}
var polyKey, discardBuf [32]byte
s.XORKeyStream(polyKey[:], polyKey[:])
s.XORKeyStream(discardBuf[:], discardBuf[:]) // skip the next 32 bytes
// There is no blocksize, so fall back to multiple of 8 byte
// padding, as described in RFC 4253, Sec 6.
const packetSizeMultiple = 8
padding := packetSizeMultiple - (1+len(payload))%packetSizeMultiple
if padding < 4 {
padding += packetSizeMultiple
}
// size (4 bytes), padding (1), payload, padding, tag.
totalLength := 4 + 1 + len(payload) + padding + poly1305.TagSize
if cap(c.buf) < totalLength {
c.buf = make([]byte, totalLength)
} else {
c.buf = c.buf[:totalLength]
}
binary.BigEndian.PutUint32(c.buf, uint32(1+len(payload)+padding))
ls, err := chacha20.NewUnauthenticatedCipher(c.lengthKey[:], nonce)
if err != nil {
return err
}
ls.XORKeyStream(c.buf, c.buf[:4])
c.buf[4] = byte(padding)
copy(c.buf[5:], payload)
packetEnd := 5 + len(payload) + padding
if _, err := io.ReadFull(rand, c.buf[5+len(payload):packetEnd]); err != nil {
return err
}
s.XORKeyStream(c.buf[4:], c.buf[4:packetEnd])
var mac [poly1305.TagSize]byte
poly1305.Sum(&mac, c.buf[:packetEnd], &polyKey)
copy(c.buf[packetEnd:], mac[:])
if _, err := w.Write(c.buf); err != nil {
return err
}
return nil
}
================================================
FILE: gossh/crypto/ssh/client.go
================================================
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"errors"
"fmt"
"net"
"os"
"sync"
"time"
)
// Client implements a traditional SSH client that supports shells,
// subprocesses, TCP port/streamlocal forwarding and tunneled dialing.
type Client struct {
Conn
handleForwardsOnce sync.Once // guards calling (*Client).handleForwards
forwards forwardList // forwarded tcpip connections from the remote side
mu sync.Mutex
channelHandlers map[string]chan NewChannel
}
// HandleChannelOpen returns a channel on which NewChannel requests
// for the given type are sent. If the type already is being handled,
// nil is returned. The channel is closed when the connection is closed.
func (c *Client) HandleChannelOpen(channelType string) <-chan NewChannel {
c.mu.Lock()
defer c.mu.Unlock()
if c.channelHandlers == nil {
// The SSH channel has been closed.
c := make(chan NewChannel)
close(c)
return c
}
ch := c.channelHandlers[channelType]
if ch != nil {
return nil
}
ch = make(chan NewChannel, chanSize)
c.channelHandlers[channelType] = ch
return ch
}
// NewClient creates a Client on top of the given connection.
func NewClient(c Conn, chans <-chan NewChannel, reqs <-chan *Request) *Client {
conn := &Client{
Conn: c,
channelHandlers: make(map[string]chan NewChannel, 1),
}
go conn.handleGlobalRequests(reqs)
go conn.handleChannelOpens(chans)
go func() {
conn.Wait()
conn.forwards.closeAll()
}()
return conn
}
// NewClientConn establishes an authenticated SSH connection using c
// as the underlying transport. The Request and NewChannel channels
// must be serviced or the connection will hang.
func NewClientConn(c net.Conn, addr string, config *ClientConfig) (Conn, <-chan NewChannel, <-chan *Request, error) {
fullConf := *config
fullConf.SetDefaults()
if fullConf.HostKeyCallback == nil {
c.Close()
return nil, nil, nil, errors.New("ssh: must specify HostKeyCallback")
}
conn := &connection{
sshConn: sshConn{conn: c, user: fullConf.User},
}
if err := conn.clientHandshake(addr, &fullConf); err != nil {
c.Close()
return nil, nil, nil, fmt.Errorf("ssh: handshake failed: %w", err)
}
conn.mux = newMux(conn.transport)
return conn, conn.mux.incomingChannels, conn.mux.incomingRequests, nil
}
// clientHandshake performs the client side key exchange. See RFC 4253 Section
// 7.
func (c *connection) clientHandshake(dialAddress string, config *ClientConfig) error {
if config.ClientVersion != "" {
c.clientVersion = []byte(config.ClientVersion)
} else {
c.clientVersion = []byte(packageVersion)
}
var err error
c.serverVersion, err = exchangeVersions(c.sshConn.conn, c.clientVersion)
if err != nil {
return err
}
c.transport = newClientTransport(
newTransport(c.sshConn.conn, config.Rand, true /* is client */),
c.clientVersion, c.serverVersion, config, dialAddress, c.sshConn.RemoteAddr())
if err := c.transport.waitSession(); err != nil {
return err
}
c.sessionID = c.transport.getSessionID()
return c.clientAuthenticate(config)
}
// verifyHostKeySignature verifies the host key obtained in the key exchange.
// algo is the negotiated algorithm, and may be a certificate type.
func verifyHostKeySignature(hostKey PublicKey, algo string, result *kexResult) error {
sig, rest, ok := parseSignatureBody(result.Signature)
if len(rest) > 0 || !ok {
return errors.New("ssh: signature parse error")
}
if a := underlyingAlgo(algo); sig.Format != a {
return fmt.Errorf("ssh: invalid signature algorithm %q, expected %q", sig.Format, a)
}
return hostKey.Verify(result.H, sig)
}
// NewSession opens a new Session for this client. (A session is a remote
// execution of a program.)
func (c *Client) NewSession() (*Session, error) {
ch, in, err := c.OpenChannel("session", nil)
if err != nil {
return nil, err
}
return newSession(ch, in)
}
func (c *Client) handleGlobalRequests(incoming <-chan *Request) {
for r := range incoming {
// This handles keepalive messages and matches
// the behaviour of OpenSSH.
r.Reply(false, nil)
}
}
// handleChannelOpens channel open messages from the remote side.
func (c *Client) handleChannelOpens(in <-chan NewChannel) {
for ch := range in {
c.mu.Lock()
handler := c.channelHandlers[ch.ChannelType()]
c.mu.Unlock()
if handler != nil {
handler <- ch
} else {
ch.Reject(UnknownChannelType, fmt.Sprintf("unknown channel type: %v", ch.ChannelType()))
}
}
c.mu.Lock()
for _, ch := range c.channelHandlers {
close(ch)
}
c.channelHandlers = nil
c.mu.Unlock()
}
// Dial starts a client connection to the given SSH server. It is a
// convenience function that connects to the given network address,
// initiates the SSH handshake, and then sets up a Client. For access
// to incoming channels and requests, use net.Dial with NewClientConn
// instead.
func Dial(network, addr string, config *ClientConfig) (*Client, error) {
conn, err := net.DialTimeout(network, addr, config.Timeout)
if err != nil {
return nil, err
}
c, chans, reqs, err := NewClientConn(conn, addr, config)
if err != nil {
return nil, err
}
return NewClient(c, chans, reqs), nil
}
// HostKeyCallback is the function type used for verifying server
// keys. A HostKeyCallback must return nil if the host key is OK, or
// an error to reject it. It receives the hostname as passed to Dial
// or NewClientConn. The remote address is the RemoteAddr of the
// net.Conn underlying the SSH connection.
type HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error
// BannerCallback is the function type used for treat the banner sent by
// the server. A BannerCallback receives the message sent by the remote server.
type BannerCallback func(message string) error
// A ClientConfig structure is used to configure a Client. It must not be
// modified after having been passed to an SSH function.
type ClientConfig struct {
// Config contains configuration that is shared between clients and
// servers.
Config
// User contains the username to authenticate as.
User string
// Auth contains possible authentication methods to use with the
// server. Only the first instance of a particular RFC 4252 method will
// be used during authentication.
Auth []AuthMethod
// HostKeyCallback is called during the cryptographic
// handshake to validate the server's host key. The client
// configuration must supply this callback for the connection
// to succeed. The functions InsecureIgnoreHostKey or
// FixedHostKey can be used for simplistic host key checks.
HostKeyCallback HostKeyCallback
// BannerCallback is called during the SSH dance to display a custom
// server's message. The client configuration can supply this callback to
// handle it as wished. The function BannerDisplayStderr can be used for
// simplistic display on Stderr.
BannerCallback BannerCallback
// ClientVersion contains the version identification string that will
// be used for the connection. If empty, a reasonable default is used.
ClientVersion string
// HostKeyAlgorithms lists the public key algorithms that the client will
// accept from the server for host key authentication, in order of
// preference. If empty, a reasonable default is used. Any
// string returned from a PublicKey.Type method may be used, or
// any of the CertAlgo and KeyAlgo constants.
HostKeyAlgorithms []string
// Timeout is the maximum amount of time for the TCP connection to establish.
//
// A Timeout of zero means no timeout.
Timeout time.Duration
}
// InsecureIgnoreHostKey returns a function that can be used for
// ClientConfig.HostKeyCallback to accept any host key. It should
// not be used for production code.
func InsecureIgnoreHostKey() HostKeyCallback {
return func(hostname string, remote net.Addr, key PublicKey) error {
return nil
}
}
type fixedHostKey struct {
key PublicKey
}
func (f *fixedHostKey) check(hostname string, remote net.Addr, key PublicKey) error {
if f.key == nil {
return fmt.Errorf("ssh: required host key was nil")
}
if !bytes.Equal(key.Marshal(), f.key.Marshal()) {
return fmt.Errorf("ssh: host key mismatch")
}
return nil
}
// FixedHostKey returns a function for use in
// ClientConfig.HostKeyCallback to accept only a specific host key.
func FixedHostKey(key PublicKey) HostKeyCallback {
hk := &fixedHostKey{key}
return hk.check
}
// BannerDisplayStderr returns a function that can be used for
// ClientConfig.BannerCallback to display banners on os.Stderr.
func BannerDisplayStderr() BannerCallback {
return func(banner string) error {
_, err := os.Stderr.WriteString(banner)
return err
}
}
================================================
FILE: gossh/crypto/ssh/client_auth.go
================================================
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"errors"
"fmt"
"io"
"strings"
)
type authResult int
const (
authFailure authResult = iota
authPartialSuccess
authSuccess
)
// clientAuthenticate authenticates with the remote server. See RFC 4252.
func (c *connection) clientAuthenticate(config *ClientConfig) error {
// initiate user auth session
if err := c.transport.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})); err != nil {
return err
}
packet, err := c.transport.readPacket()
if err != nil {
return err
}
// The server may choose to send a SSH_MSG_EXT_INFO at this point (if we
// advertised willingness to receive one, which we always do) or not. See
// RFC 8308, Section 2.4.
extensions := make(map[string][]byte)
if len(packet) > 0 && packet[0] == msgExtInfo {
var extInfo extInfoMsg
if err := Unmarshal(packet, &extInfo); err != nil {
return err
}
payload := extInfo.Payload
for i := uint32(0); i < extInfo.NumExtensions; i++ {
name, rest, ok := parseString(payload)
if !ok {
return parseError(msgExtInfo)
}
value, rest, ok := parseString(rest)
if !ok {
return parseError(msgExtInfo)
}
extensions[string(name)] = value
payload = rest
}
packet, err = c.transport.readPacket()
if err != nil {
return err
}
}
var serviceAccept serviceAcceptMsg
if err := Unmarshal(packet, &serviceAccept); err != nil {
return err
}
// during the authentication phase the client first attempts the "none" method
// then any untried methods suggested by the server.
var tried []string
var lastMethods []string
sessionID := c.transport.getSessionID()
for auth := AuthMethod(new(noneAuth)); auth != nil; {
ok, methods, err := auth.auth(sessionID, config.User, c.transport, config.Rand, extensions)
if err != nil {
// We return the error later if there is no other method left to
// try.
ok = authFailure
}
if ok == authSuccess {
// success
return nil
} else if ok == authFailure {
if m := auth.method(); !contains(tried, m) {
tried = append(tried, m)
}
}
if methods == nil {
methods = lastMethods
}
lastMethods = methods
auth = nil
findNext:
for _, a := range config.Auth {
candidateMethod := a.method()
if contains(tried, candidateMethod) {
continue
}
for _, meth := range methods {
if meth == candidateMethod {
auth = a
break findNext
}
}
}
if auth == nil && err != nil {
// We have an error and there are no other authentication methods to
// try, so we return it.
return err
}
}
return fmt.Errorf("ssh: unable to authenticate, attempted methods %v, no supported methods remain", tried)
}
func contains(list []string, e string) bool {
for _, s := range list {
if s == e {
return true
}
}
return false
}
// An AuthMethod represents an instance of an RFC 4252 authentication method.
type AuthMethod interface {
// auth authenticates user over transport t.
// Returns true if authentication is successful.
// If authentication is not successful, a []string of alternative
// method names is returned. If the slice is nil, it will be ignored
// and the previous set of possible methods will be reused.
auth(session []byte, user string, p packetConn, rand io.Reader, extensions map[string][]byte) (authResult, []string, error)
// method returns the RFC 4252 method name.
method() string
}
// "none" authentication, RFC 4252 section 5.2.
type noneAuth int
func (n *noneAuth) auth(session []byte, user string, c packetConn, rand io.Reader, _ map[string][]byte) (authResult, []string, error) {
if err := c.writePacket(Marshal(&userAuthRequestMsg{
User: user,
Service: serviceSSH,
Method: "none",
})); err != nil {
return authFailure, nil, err
}
return handleAuthResponse(c)
}
func (n *noneAuth) method() string {
return "none"
}
// passwordCallback is an AuthMethod that fetches the password through
// a function call, e.g. by prompting the user.
type passwordCallback func() (password string, err error)
func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand io.Reader, _ map[string][]byte) (authResult, []string, error) {
type passwordAuthMsg struct {
User string `sshtype:"50"`
Service string
Method string
Reply bool
Password string
}
pw, err := cb()
// REVIEW NOTE: is there a need to support skipping a password attempt?
// The program may only find out that the user doesn't have a password
// when prompting.
if err != nil {
return authFailure, nil, err
}
if err := c.writePacket(Marshal(&passwordAuthMsg{
User: user,
Service: serviceSSH,
Method: cb.method(),
Reply: false,
Password: pw,
})); err != nil {
return authFailure, nil, err
}
return handleAuthResponse(c)
}
func (cb passwordCallback) method() string {
return "password"
}
// Password returns an AuthMethod using the given password.
func Password(secret string) AuthMethod {
return passwordCallback(func() (string, error) { return secret, nil })
}
// PasswordCallback returns an AuthMethod that uses a callback for
// fetching a password.
func PasswordCallback(prompt func() (secret string, err error)) AuthMethod {
return passwordCallback(prompt)
}
type publickeyAuthMsg struct {
User string `sshtype:"50"`
Service string
Method string
// HasSig indicates to the receiver packet that the auth request is signed and
// should be used for authentication of the request.
HasSig bool
Algoname string
PubKey []byte
// Sig is tagged with "rest" so Marshal will exclude it during
// validateKey
Sig []byte `ssh:"rest"`
}
// publicKeyCallback is an AuthMethod that uses a set of key
// pairs for authentication.
type publicKeyCallback func() ([]Signer, error)
func (cb publicKeyCallback) method() string {
return "publickey"
}
func pickSignatureAlgorithm(signer Signer, extensions map[string][]byte) (MultiAlgorithmSigner, string, error) {
var as MultiAlgorithmSigner
keyFormat := signer.PublicKey().Type()
// If the signer implements MultiAlgorithmSigner we use the algorithms it
// support, if it implements AlgorithmSigner we assume it supports all
// algorithms, otherwise only the key format one.
switch s := signer.(type) {
case MultiAlgorithmSigner:
as = s
case AlgorithmSigner:
as = &multiAlgorithmSigner{
AlgorithmSigner: s,
supportedAlgorithms: algorithmsForKeyFormat(underlyingAlgo(keyFormat)),
}
default:
as = &multiAlgorithmSigner{
AlgorithmSigner: algorithmSignerWrapper{signer},
supportedAlgorithms: []string{underlyingAlgo(keyFormat)},
}
}
getFallbackAlgo := func() (string, error) {
// Fallback to use if there is no "server-sig-algs" extension or a
// common algorithm cannot be found. We use the public key format if the
// MultiAlgorithmSigner supports it, otherwise we return an error.
if !contains(as.Algorithms(), underlyingAlgo(keyFormat)) {
return "", fmt.Errorf("ssh: no common public key signature algorithm, server only supports %q for key type %q, signer only supports %v",
underlyingAlgo(keyFormat), keyFormat, as.Algorithms())
}
return keyFormat, nil
}
extPayload, ok := extensions["server-sig-algs"]
if !ok {
// If there is no "server-sig-algs" extension use the fallback
// algorithm.
algo, err := getFallbackAlgo()
return as, algo, err
}
// The server-sig-algs extension only carries underlying signature
// algorithm, but we are trying to select a protocol-level public key
// algorithm, which might be a certificate type. Extend the list of server
// supported algorithms to include the corresponding certificate algorithms.
serverAlgos := strings.Split(string(extPayload), ",")
for _, algo := range serverAlgos {
if certAlgo, ok := certificateAlgo(algo); ok {
serverAlgos = append(serverAlgos, certAlgo)
}
}
// Filter algorithms based on those supported by MultiAlgorithmSigner.
var keyAlgos []string
for _, algo := range algorithmsForKeyFormat(keyFormat) {
if contains(as.Algorithms(), underlyingAlgo(algo)) {
keyAlgos = append(keyAlgos, algo)
}
}
algo, err := findCommon("public key signature algorithm", keyAlgos, serverAlgos)
if err != nil {
// If there is no overlap, return the fallback algorithm to support
// servers that fail to list all supported algorithms.
algo, err := getFallbackAlgo()
return as, algo, err
}
return as, algo, nil
}
func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader, extensions map[string][]byte) (authResult, []string, error) {
// Authentication is performed by sending an enquiry to test if a key is
// acceptable to the remote. If the key is acceptable, the client will
// attempt to authenticate with the valid key. If not the client will repeat
// the process with the remaining keys.
signers, err := cb()
if err != nil {
return authFailure, nil, err
}
var methods []string
var errSigAlgo error
origSignersLen := len(signers)
for idx := 0; idx < len(signers); idx++ {
signer := signers[idx]
pub := signer.PublicKey()
as, algo, err := pickSignatureAlgorithm(signer, extensions)
if err != nil && errSigAlgo == nil {
// If we cannot negotiate a signature algorithm store the first
// error so we can return it to provide a more meaningful message if
// no other signers work.
errSigAlgo = err
continue
}
ok, err := validateKey(pub, algo, user, c)
if err != nil {
return authFailure, nil, err
}
// OpenSSH 7.2-7.7 advertises support for rsa-sha2-256 and rsa-sha2-512
// in the "server-sig-algs" extension but doesn't support these
// algorithms for certificate authentication, so if the server rejects
// the key try to use the obtained algorithm as if "server-sig-algs" had
// not been implemented if supported from the algorithm signer.
if !ok && idx < origSignersLen && isRSACert(algo) && algo != CertAlgoRSAv01 {
if contains(as.Algorithms(), KeyAlgoRSA) {
// We retry using the compat algorithm after all signers have
// been tried normally.
signers = append(signers, &multiAlgorithmSigner{
AlgorithmSigner: as,
supportedAlgorithms: []string{KeyAlgoRSA},
})
}
}
if !ok {
continue
}
pubKey := pub.Marshal()
data := buildDataSignedForAuth(session, userAuthRequestMsg{
User: user,
Service: serviceSSH,
Method: cb.method(),
}, algo, pubKey)
sign, err := as.SignWithAlgorithm(rand, data, underlyingAlgo(algo))
if err != nil {
return authFailure, nil, err
}
// manually wrap the serialized signature in a string
s := Marshal(sign)
sig := make([]byte, stringLength(len(s)))
marshalString(sig, s)
msg := publickeyAuthMsg{
User: user,
Service: serviceSSH,
Method: cb.method(),
HasSig: true,
Algoname: algo,
PubKey: pubKey,
Sig: sig,
}
p := Marshal(&msg)
if err := c.writePacket(p); err != nil {
return authFailure, nil, err
}
var success authResult
success, methods, err = handleAuthResponse(c)
if err != nil {
return authFailure, nil, err
}
// If authentication succeeds or the list of available methods does not
// contain the "publickey" method, do not attempt to authenticate with any
// other keys. According to RFC 4252 Section 7, the latter can occur when
// additional authentication methods are required.
if success == authSuccess || !contains(methods, cb.method()) {
return success, methods, err
}
}
return authFailure, methods, errSigAlgo
}
// validateKey validates the key provided is acceptable to the server.
func validateKey(key PublicKey, algo string, user string, c packetConn) (bool, error) {
pubKey := key.Marshal()
msg := publickeyAuthMsg{
User: user,
Service: serviceSSH,
Method: "publickey",
HasSig: false,
Algoname: algo,
PubKey: pubKey,
}
if err := c.writePacket(Marshal(&msg)); err != nil {
return false, err
}
return confirmKeyAck(key, algo, c)
}
func confirmKeyAck(key PublicKey, algo string, c packetConn) (bool, error) {
pubKey := key.Marshal()
for {
packet, err := c.readPacket()
if err != nil {
return false, err
}
switch packet[0] {
case msgUserAuthBanner:
if err := handleBannerResponse(c, packet); err != nil {
return false, err
}
case msgUserAuthPubKeyOk:
var msg userAuthPubKeyOkMsg
if err := Unmarshal(packet, &msg); err != nil {
return false, err
}
if msg.Algo != algo || !bytes.Equal(msg.PubKey, pubKey) {
return false, nil
}
return true, nil
case msgUserAuthFailure:
return false, nil
default:
return false, unexpectedMessageError(msgUserAuthPubKeyOk, packet[0])
}
}
}
// PublicKeys returns an AuthMethod that uses the given key
// pairs.
func PublicKeys(signers ...Signer) AuthMethod {
return publicKeyCallback(func() ([]Signer, error) { return signers, nil })
}
// PublicKeysCallback returns an AuthMethod that runs the given
// function to obtain a list of key pairs.
func PublicKeysCallback(getSigners func() (signers []Signer, err error)) AuthMethod {
return publicKeyCallback(getSigners)
}
// handleAuthResponse returns whether the preceding authentication request succeeded
// along with a list of remaining authentication methods to try next and
// an error if an unexpected response was received.
func handleAuthResponse(c packetConn) (authResult, []string, error) {
gotMsgExtInfo := false
for {
packet, err := c.readPacket()
if err != nil {
return authFailure, nil, err
}
switch packet[0] {
case msgUserAuthBanner:
if err := handleBannerResponse(c, packet); err != nil {
return authFailure, nil, err
}
case msgExtInfo:
// Ignore post-authentication RFC 8308 extensions, once.
if gotMsgExtInfo {
return authFailure, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0])
}
gotMsgExtInfo = true
case msgUserAuthFailure:
var msg userAuthFailureMsg
if err := Unmarshal(packet, &msg); err != nil {
return authFailure, nil, err
}
if msg.PartialSuccess {
return authPartialSuccess, msg.Methods, nil
}
return authFailure, msg.Methods, nil
case msgUserAuthSuccess:
return authSuccess, nil, nil
default:
return authFailure, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0])
}
}
}
func handleBannerResponse(c packetConn, packet []byte) error {
var msg userAuthBannerMsg
if err := Unmarshal(packet, &msg); err != nil {
return err
}
transport, ok := c.(*handshakeTransport)
if !ok {
return nil
}
if transport.bannerCallback != nil {
return transport.bannerCallback(msg.Message)
}
return nil
}
// KeyboardInteractiveChallenge should print questions, optionally
// disabling echoing (e.g. for passwords), and return all the answers.
// Challenge may be called multiple times in a single session. After
// successful authentication, the server may send a challenge with no
// questions, for which the name and instruction messages should be
// printed. RFC 4256 section 3.3 details how the UI should behave for
// both CLI and GUI environments.
type KeyboardInteractiveChallenge func(name, instruction string, questions []string, echos []bool) (answers []string, err error)
// KeyboardInteractive returns an AuthMethod using a prompt/response
// sequence controlled by the server.
func KeyboardInteractive(challenge KeyboardInteractiveChallenge) AuthMethod {
return challenge
}
func (cb KeyboardInteractiveChallenge) method() string {
return "keyboard-interactive"
}
func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packetConn, rand io.Reader, _ map[string][]byte) (authResult, []string, error) {
type initiateMsg struct {
User string `sshtype:"50"`
Service string
Method string
Language string
Submethods string
}
if err := c.writePacket(Marshal(&initiateMsg{
User: user,
Service: serviceSSH,
Method: "keyboard-interactive",
})); err != nil {
return authFailure, nil, err
}
gotMsgExtInfo := false
for {
packet, err := c.readPacket()
if err != nil {
return authFailure, nil, err
}
// like handleAuthResponse, but with less options.
switch packet[0] {
case msgUserAuthBanner:
if err := handleBannerResponse(c, packet); err != nil {
return authFailure, nil, err
}
continue
case msgExtInfo:
// Ignore post-authentication RFC 8308 extensions, once.
if gotMsgExtInfo {
return authFailure, nil, unexpectedMessageError(msgUserAuthInfoRequest, packet[0])
}
gotMsgExtInfo = true
continue
case msgUserAuthInfoRequest:
// OK
case msgUserAuthFailure:
var msg userAuthFailureMsg
if err := Unmarshal(packet, &msg); err != nil {
return authFailure, nil, err
}
if msg.PartialSuccess {
return authPartialSuccess, msg.Methods, nil
}
return authFailure, msg.Methods, nil
case msgUserAuthSuccess:
return authSuccess, nil, nil
default:
return authFailure, nil, unexpectedMessageError(msgUserAuthInfoRequest, packet[0])
}
var msg userAuthInfoRequestMsg
if err := Unmarshal(packet, &msg); err != nil {
return authFailure, nil, err
}
// Manually unpack the prompt/echo pairs.
rest := msg.Prompts
var prompts []string
var echos []bool
for i := 0; i < int(msg.NumPrompts); i++ {
prompt, r, ok := parseString(rest)
if !ok || len(r) == 0 {
return authFailure, nil, errors.New("ssh: prompt format error")
}
prompts = append(prompts, string(prompt))
echos = append(echos, r[0] != 0)
rest = r[1:]
}
if len(rest) != 0 {
return authFailure, nil, errors.New("ssh: extra data following keyboard-interactive pairs")
}
answers, err := cb(msg.Name, msg.Instruction, prompts, echos)
if err != nil {
return authFailure, nil, err
}
if len(answers) != len(prompts) {
return authFailure, nil, fmt.Errorf("ssh: incorrect number of answers from keyboard-interactive callback %d (expected %d)", len(answers), len(prompts))
}
responseLength := 1 + 4
for _, a := range answers {
responseLength += stringLength(len(a))
}
serialized := make([]byte, responseLength)
p := serialized
p[0] = msgUserAuthInfoResponse
p = p[1:]
p = marshalUint32(p, uint32(len(answers)))
for _, a := range answers {
p = marshalString(p, []byte(a))
}
if err := c.writePacket(serialized); err != nil {
return authFailure, nil, err
}
}
}
type retryableAuthMethod struct {
authMethod AuthMethod
maxTries int
}
func (r *retryableAuthMethod) auth(session []byte, user string, c packetConn, rand io.Reader, extensions map[string][]byte) (ok authResult, methods []string, err error) {
for i := 0; r.maxTries <= 0 || i < r.maxTries; i++ {
ok, methods, err = r.authMethod.auth(session, user, c, rand, extensions)
if ok != authFailure || err != nil { // either success, partial success or error terminate
return ok, methods, err
}
}
return ok, methods, err
}
func (r *retryableAuthMethod) method() string {
return r.authMethod.method()
}
// RetryableAuthMethod is a decorator for other auth methods enabling them to
// be retried up to maxTries before considering that AuthMethod itself failed.
// If maxTries is <= 0, will retry indefinitely
//
// This is useful for interactive clients using challenge/response type
// authentication (e.g. Keyboard-Interactive, Password, etc) where the user
// could mistype their response resulting in the server issuing a
// SSH_MSG_USERAUTH_FAILURE (rfc4252 #8 [password] and rfc4256 #3.4
// [keyboard-interactive]); Without this decorator, the non-retryable
// AuthMethod would be removed from future consideration, and never tried again
// (and so the user would never be able to retry their entry).
func RetryableAuthMethod(auth AuthMethod, maxTries int) AuthMethod {
return &retryableAuthMethod{authMethod: auth, maxTries: maxTries}
}
// GSSAPIWithMICAuthMethod is an AuthMethod with "gssapi-with-mic" authentication.
// See RFC 4462 section 3
// gssAPIClient is implementation of the GSSAPIClient interface, see the definition of the interface for details.
// target is the server host you want to log in to.
func GSSAPIWithMICAuthMethod(gssAPIClient GSSAPIClient, target string) AuthMethod {
if gssAPIClient == nil {
panic("gss-api client must be not nil with enable gssapi-with-mic")
}
return &gssAPIWithMICCallback{gssAPIClient: gssAPIClient, target: target}
}
type gssAPIWithMICCallback struct {
gssAPIClient GSSAPIClient
target string
}
func (g *gssAPIWithMICCallback) auth(session []byte, user string, c packetConn, rand io.Reader, _ map[string][]byte) (authResult, []string, error) {
m := &userAuthRequestMsg{
User: user,
Service: serviceSSH,
Method: g.method(),
}
// The GSS-API authentication method is initiated when the client sends an SSH_MSG_USERAUTH_REQUEST.
// See RFC 4462 section 3.2.
m.Payload = appendU32(m.Payload, 1)
m.Payload = appendString(m.Payload, string(krb5OID))
if err := c.writePacket(Marshal(m)); err != nil {
return authFailure, nil, err
}
// The server responds to the SSH_MSG_USERAUTH_REQUEST with either an
// SSH_MSG_USERAUTH_FAILURE if none of the mechanisms are supported or
// with an SSH_MSG_USERAUTH_GSSAPI_RESPONSE.
// See RFC 4462 section 3.3.
// OpenSSH supports Kerberos V5 mechanism only for GSS-API authentication,so I don't want to check
// selected mech if it is valid.
packet, err := c.readPacket()
if err != nil {
return authFailure, nil, err
}
userAuthGSSAPIResp := &userAuthGSSAPIResponse{}
if err := Unmarshal(packet, userAuthGSSAPIResp); err != nil {
return authFailure, nil, err
}
// Start the loop into the exchange token.
// See RFC 4462 section 3.4.
var token []byte
defer g.gssAPIClient.DeleteSecContext()
for {
// Initiates the establishment of a security context between the application and a remote peer.
nextToken, needContinue, err := g.gssAPIClient.InitSecContext("host@"+g.target, token, false)
if err != nil {
return authFailure, nil, err
}
if len(nextToken) > 0 {
if err := c.writePacket(Marshal(&userAuthGSSAPIToken{
Token: nextToken,
})); err != nil {
return authFailure, nil, err
}
}
if !needContinue {
break
}
packet, err = c.readPacket()
if err != nil {
return authFailure, nil, err
}
switch packet[0] {
case msgUserAuthFailure:
var msg userAuthFailureMsg
if err := Unmarshal(packet, &msg); err != nil {
return authFailure, nil, err
}
if msg.PartialSuccess {
return authPartialSuccess, msg.Methods, nil
}
return authFailure, msg.Methods, nil
case msgUserAuthGSSAPIError:
userAuthGSSAPIErrorResp := &userAuthGSSAPIError{}
if err := Unmarshal(packet, userAuthGSSAPIErrorResp); err != nil {
return authFailure, nil, err
}
return authFailure, nil, fmt.Errorf("GSS-API Error:\n"+
"Major Status: %d\n"+
"Minor Status: %d\n"+
"Error Message: %s\n", userAuthGSSAPIErrorResp.MajorStatus, userAuthGSSAPIErrorResp.MinorStatus,
userAuthGSSAPIErrorResp.Message)
case msgUserAuthGSSAPIToken:
userAuthGSSAPITokenReq := &userAuthGSSAPIToken{}
if err := Unmarshal(packet, userAuthGSSAPITokenReq); err != nil {
return authFailure, nil, err
}
token = userAuthGSSAPITokenReq.Token
}
}
// Binding Encryption Keys.
// See RFC 4462 section 3.5.
micField := buildMIC(string(session), user, "ssh-connection", "gssapi-with-mic")
micToken, err := g.gssAPIClient.GetMIC(micField)
if err != nil {
return authFailure, nil, err
}
if err := c.writePacket(Marshal(&userAuthGSSAPIMIC{
MIC: micToken,
})); err != nil {
return authFailure, nil, err
}
return handleAuthResponse(c)
}
func (g *gssAPIWithMICCallback) method() string {
return "gssapi-with-mic"
}
================================================
FILE: gossh/crypto/ssh/common.go
================================================
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"crypto"
"crypto/rand"
"fmt"
"io"
"math"
"sync"
_ "crypto/sha1"
_ "crypto/sha256"
_ "crypto/sha512"
)
// These are string constants in the SSH protocol.
const (
compressionNone = "none"
serviceUserAuth = "ssh-userauth"
serviceSSH = "ssh-connection"
)
// supportedCiphers lists ciphers we support but might not recommend.
var supportedCiphers = []string{
"aes128-ctr", "aes192-ctr", "aes256-ctr",
"aes128-gcm@openssh.com", gcm256CipherID,
chacha20Poly1305ID,
"arcfour256", "arcfour128", "arcfour",
aes128cbcID,
tripledescbcID,
}
// preferredCiphers specifies the default preference for ciphers.
var preferredCiphers = []string{
"aes128-gcm@openssh.com", gcm256CipherID,
chacha20Poly1305ID,
"aes128-ctr", "aes192-ctr", "aes256-ctr",
}
// supportedKexAlgos specifies the supported key-exchange algorithms in
// preference order.
var supportedKexAlgos = []string{
kexAlgoCurve25519SHA256, kexAlgoCurve25519SHA256LibSSH,
// P384 and P521 are not constant-time yet, but since we don't
// reuse ephemeral keys, using them for ECDH should be OK.
kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521,
kexAlgoDH14SHA256, kexAlgoDH16SHA512, kexAlgoDH14SHA1,
kexAlgoDH1SHA1,
}
// serverForbiddenKexAlgos contains key exchange algorithms, that are forbidden
// for the server half.
var serverForbiddenKexAlgos = map[string]struct{}{
kexAlgoDHGEXSHA1: {}, // server half implementation is only minimal to satisfy the automated tests
kexAlgoDHGEXSHA256: {}, // server half implementation is only minimal to satisfy the automated tests
}
// preferredKexAlgos specifies the default preference for key-exchange
// algorithms in preference order. The diffie-hellman-group16-sha512 algorithm
// is disabled by default because it is a bit slower than the others.
var preferredKexAlgos = []string{
kexAlgoCurve25519SHA256, kexAlgoCurve25519SHA256LibSSH,
kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521,
kexAlgoDH14SHA256, kexAlgoDH14SHA1,
}
// supportedHostKeyAlgos specifies the supported host-key algorithms (i.e. methods
// of authenticating servers) in preference order.
var supportedHostKeyAlgos = []string{
CertAlgoRSASHA256v01, CertAlgoRSASHA512v01,
CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01,
CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoED25519v01,
KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521,
KeyAlgoRSASHA256, KeyAlgoRSASHA512,
KeyAlgoRSA, KeyAlgoDSA,
KeyAlgoED25519,
}
// supportedMACs specifies a default set of MAC algorithms in preference order.
// This is based on RFC 4253, section 6.4, but with hmac-md5 variants removed
// because they have reached the end of their useful life.
var supportedMACs = []string{
"hmac-sha2-256-etm@openssh.com", "hmac-sha2-512-etm@openssh.com", "hmac-sha2-256", "hmac-sha2-512", "hmac-sha1", "hmac-sha1-96",
}
var supportedCompressions = []string{compressionNone}
// hashFuncs keeps the mapping of supported signature algorithms to their
// respective hashes needed for signing and verification.
var hashFuncs = map[string]crypto.Hash{
KeyAlgoRSA: crypto.SHA1,
KeyAlgoRSASHA256: crypto.SHA256,
KeyAlgoRSASHA512: crypto.SHA512,
KeyAlgoDSA: crypto.SHA1,
KeyAlgoECDSA256: crypto.SHA256,
KeyAlgoECDSA384: crypto.SHA384,
KeyAlgoECDSA521: crypto.SHA512,
// KeyAlgoED25519 doesn't pre-hash.
KeyAlgoSKECDSA256: crypto.SHA256,
KeyAlgoSKED25519: crypto.SHA256,
}
// algorithmsForKeyFormat returns the supported signature algorithms for a given
// public key format (PublicKey.Type), in order of preference. See RFC 8332,
// Section 2. See also the note in sendKexInit on backwards compatibility.
func algorithmsForKeyFormat(keyFormat string) []string {
switch keyFormat {
case KeyAlgoRSA:
return []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512, KeyAlgoRSA}
case CertAlgoRSAv01:
return []string{CertAlgoRSASHA256v01, CertAlgoRSASHA512v01, CertAlgoRSAv01}
default:
return []string{keyFormat}
}
}
// isRSA returns whether algo is a supported RSA algorithm, including certificate
// algorithms.
func isRSA(algo string) bool {
algos := algorithmsForKeyFormat(KeyAlgoRSA)
return contains(algos, underlyingAlgo(algo))
}
func isRSACert(algo string) bool {
_, ok := certKeyAlgoNames[algo]
if !ok {
return false
}
return isRSA(algo)
}
// supportedPubKeyAuthAlgos specifies the supported client public key
// authentication algorithms. Note that this doesn't include certificate types
// since those use the underlying algorithm. This list is sent to the client if
// it supports the server-sig-algs extension. Order is irrelevant.
var supportedPubKeyAuthAlgos = []string{
KeyAlgoED25519,
KeyAlgoSKED25519, KeyAlgoSKECDSA256,
KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521,
KeyAlgoRSASHA256, KeyAlgoRSASHA512, KeyAlgoRSA,
KeyAlgoDSA,
}
// unexpectedMessageError results when the SSH message that we received didn't
// match what we wanted.
func unexpectedMessageError(expected, got uint8) error {
return fmt.Errorf("ssh: unexpected message type %d (expected %d)", got, expected)
}
// parseError results from a malformed SSH message.
func parseError(tag uint8) error {
return fmt.Errorf("ssh: parse error in message type %d", tag)
}
func findCommon(what string, client []string, server []string) (common string, err error) {
for _, c := range client {
for _, s := range server {
if c == s {
return c, nil
}
}
}
return "", fmt.Errorf("ssh: no common algorithm for %s; client offered: %v, server offered: %v", what, client, server)
}
// directionAlgorithms records algorithm choices in one direction (either read or write)
type directionAlgorithms struct {
Cipher string
MAC string
Compression string
}
// rekeyBytes returns a rekeying intervals in bytes.
func (a *directionAlgorithms) rekeyBytes() int64 {
// According to RFC 4344 block ciphers should rekey after
// 2^(BLOCKSIZE/4) blocks. For all AES flavors BLOCKSIZE is
// 128.
switch a.Cipher {
case "aes128-ctr", "aes192-ctr", "aes256-ctr", gcm128CipherID, gcm256CipherID, aes128cbcID:
return 16 * (1 << 32)
}
// For others, stick with RFC 4253 recommendation to rekey after 1 Gb of data.
return 1 << 30
}
var aeadCiphers = map[string]bool{
gcm128CipherID: true,
gcm256CipherID: true,
chacha20Poly1305ID: true,
}
type algorithms struct {
kex string
hostKey string
w directionAlgorithms
r directionAlgorithms
}
func findAgreedAlgorithms(isClient bool, clientKexInit, serverKexInit *kexInitMsg) (algs *algorithms, err error) {
result := &algorithms{}
result.kex, err = findCommon("key exchange", clientKexInit.KexAlgos, serverKexInit.KexAlgos)
if err != nil {
return
}
result.hostKey, err = findCommon("host key", clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos)
if err != nil {
return
}
stoc, ctos := &result.w, &result.r
if isClient {
ctos, stoc = stoc, ctos
}
ctos.Cipher, err = findCommon("client to server cipher", clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer)
if err != nil {
return
}
stoc.Cipher, err = findCommon("server to client cipher", clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient)
if err != nil {
return
}
if !aeadCiphers[ctos.Cipher] {
ctos.MAC, err = findCommon("client to server MAC", clientKexInit.MACsClientServer, serverKexInit.MACsClientServer)
if err != nil {
return
}
}
if !aeadCiphers[stoc.Cipher] {
stoc.MAC, err = findCommon("server to client MAC", clientKexInit.MACsServerClient, serverKexInit.MACsServerClient)
if err != nil {
return
}
}
ctos.Compression, err = findCommon("client to server compression", clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer)
if err != nil {
return
}
stoc.Compression, err = findCommon("server to client compression", clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient)
if err != nil {
return
}
return result, nil
}
// If rekeythreshold is too small, we can't make any progress sending
// stuff.
const minRekeyThreshold uint64 = 256
// Config contains configuration data common to both ServerConfig and
// ClientConfig.
type Config struct {
// Rand provides the source of entropy for cryptographic
// primitives. If Rand is nil, the cryptographic random reader
// in package crypto/rand will be used.
Rand io.Reader
// The maximum number of bytes sent or received after which a
// new key is negotiated. It must be at least 256. If
// unspecified, a size suitable for the chosen cipher is used.
RekeyThreshold uint64
// The allowed key exchanges algorithms. If unspecified then a default set
// of algorithms is used. Unsupported values are silently ignored.
KeyExchanges []string
// The allowed cipher algorithms. If unspecified then a sensible default is
// used. Unsupported values are silently ignored.
Ciphers []string
// The allowed MAC algorithms. If unspecified then a sensible default is
// used. Unsupported values are silently ignored.
MACs []string
}
// SetDefaults sets sensible values for unset fields in config. This is
// exported for testing: Configs passed to SSH functions are copied and have
// default values set automatically.
func (c *Config) SetDefaults() {
if c.Rand == nil {
c.Rand = rand.Reader
}
if c.Ciphers == nil {
c.Ciphers = preferredCiphers
}
var ciphers []string
for _, c := range c.Ciphers {
if cipherModes[c] != nil {
// Ignore the cipher if we have no cipherModes definition.
ciphers = append(ciphers, c)
}
}
c.Ciphers = ciphers
if c.KeyExchanges == nil {
c.KeyExchanges = preferredKexAlgos
}
var kexs []string
for _, k := range c.KeyExchanges {
if kexAlgoMap[k] != nil {
// Ignore the KEX if we have no kexAlgoMap definition.
kexs = append(kexs, k)
}
}
c.KeyExchanges = kexs
if c.MACs == nil {
c.MACs = supportedMACs
}
var macs []string
for _, m := range c.MACs {
if macModes[m] != nil {
// Ignore the MAC if we have no macModes definition.
macs = append(macs, m)
}
}
c.MACs = macs
if c.RekeyThreshold == 0 {
// cipher specific default
} else if c.RekeyThreshold < minRekeyThreshold {
c.RekeyThreshold = minRekeyThreshold
} else if c.RekeyThreshold >= math.MaxInt64 {
// Avoid weirdness if somebody uses -1 as a threshold.
c.RekeyThreshold = math.MaxInt64
}
}
// buildDataSignedForAuth returns the data that is signed in order to prove
// possession of a private key. See RFC 4252, section 7. algo is the advertised
// algorithm, and may be a certificate type.
func buildDataSignedForAuth(sessionID []byte, req userAuthRequestMsg, algo string, pubKey []byte) []byte {
data := struct {
Session []byte
Type byte
User string
Service string
Method string
Sign bool
Algo string
PubKey []byte
}{
sessionID,
msgUserAuthRequest,
req.User,
req.Service,
req.Method,
true,
algo,
pubKey,
}
return Marshal(data)
}
func appendU16(buf []byte, n uint16) []byte {
return append(buf, byte(n>>8), byte(n))
}
func appendU32(buf []byte, n uint32) []byte {
return append(buf, byte(n>>24), byte(n>>16), byte(n>>8), byte(n))
}
func appendU64(buf []byte, n uint64) []byte {
return append(buf,
byte(n>>56), byte(n>>48), byte(n>>40), byte(n>>32),
byte(n>>24), byte(n>>16), byte(n>>8), byte(n))
}
func appendInt(buf []byte, n int) []byte {
return appendU32(buf, uint32(n))
}
func appendString(buf []byte, s string) []byte {
buf = appendU32(buf, uint32(len(s)))
buf = append(buf, s...)
return buf
}
func appendBool(buf []byte, b bool) []byte {
if b {
return append(buf, 1)
}
return append(buf, 0)
}
// newCond is a helper to hide the fact that there is no usable zero
// value for sync.Cond.
func newCond() *sync.Cond { return sync.NewCond(new(sync.Mutex)) }
// window represents the buffer available to clients
// wishing to write to a channel.
type window struct {
*sync.Cond
win uint32 // RFC 4254 5.2 says the window size can grow to 2^32-1
writeWaiters int
closed bool
}
// add adds win to the amount of window available
// for consumers.
func (w *window) add(win uint32) bool {
// a zero sized window adjust is a noop.
if win == 0 {
return true
}
w.L.Lock()
if w.win+win < win {
w.L.Unlock()
return false
}
w.win += win
// It is unusual that multiple goroutines would be attempting to reserve
// window space, but not guaranteed. Use broadcast to notify all waiters
// that additional window is available.
w.Broadcast()
w.L.Unlock()
return true
}
// close sets the window to closed, so all reservations fail
// immediately.
func (w *window) close() {
w.L.Lock()
w.closed = true
w.Broadcast()
w.L.Unlock()
}
// reserve reserves win from the available window capacity.
// If no capacity remains, reserve will block. reserve may
// return less than requested.
func (w *window) reserve(win uint32) (uint32, error) {
var err error
w.L.Lock()
w.writeWaiters++
w.Broadcast()
for w.win == 0 && !w.closed {
w.Wait()
}
w.writeWaiters--
if w.win < win {
win = w.win
}
w.win -= win
if w.closed {
err = io.EOF
}
w.L.Unlock()
return win, err
}
// waitWriterBlocked waits until some goroutine is blocked for further
// writes. It is used in tests only.
func (w *window) waitWriterBlocked() {
w.Cond.L.Lock()
for w.writeWaiters == 0 {
w.Cond.Wait()
}
w.Cond.L.Unlock()
}
================================================
FILE: gossh/crypto/ssh/connection.go
================================================
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"fmt"
"net"
)
// OpenChannelError is returned if the other side rejects an
// OpenChannel request.
type OpenChannelError struct {
Reason RejectionReason
Message string
}
func (e *OpenChannelError) Error() string {
return fmt.Sprintf("ssh: rejected: %s (%s)", e.Reason, e.Message)
}
// ConnMetadata holds metadata for the connection.
type ConnMetadata interface {
// User returns the user ID for this connection.
User() string
// SessionID returns the session hash, also denoted by H.
SessionID() []byte
// ClientVersion returns the client's version string as hashed
// into the session ID.
ClientVersion() []byte
// ServerVersion returns the server's version string as hashed
// into the session ID.
ServerVersion() []byte
// RemoteAddr returns the remote address for this connection.
RemoteAddr() net.Addr
// LocalAddr returns the local address for this connection.
LocalAddr() net.Addr
}
// Conn represents an SSH connection for both server and client roles.
// Conn is the basis for implementing an application layer, such
// as ClientConn, which implements the traditional shell access for
// clients.
type Conn interface {
ConnMetadata
// SendRequest sends a global request, and returns the
// reply. If wantReply is true, it returns the response status
// and payload. See also RFC 4254, section 4.
SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error)
// OpenChannel tries to open an channel. If the request is
// rejected, it returns *OpenChannelError. On success it returns
// the SSH Channel and a Go channel for incoming, out-of-band
// requests. The Go channel must be serviced, or the
// connection will hang.
OpenChannel(name string, data []byte) (Channel, <-chan *Request, error)
// Close closes the underlying network connection
Close() error
// Wait blocks until the connection has shut down, and returns the
// error causing the shutdown.
Wait() error
// TODO(hanwen): consider exposing:
// RequestKeyChange
// Disconnect
}
// DiscardRequests consumes and rejects all requests from the
// passed-in channel.
func DiscardRequests(in <-chan *Request) {
for req := range in {
if req.WantReply {
req.Reply(false, nil)
}
}
}
// A connection represents an incoming connection.
type connection struct {
transport *handshakeTransport
sshConn
// The connection protocol.
*mux
}
func (c *connection) Close() error {
return c.sshConn.conn.Close()
}
// sshConn provides net.Conn metadata, but disallows direct reads and
// writes.
type sshConn struct {
conn net.Conn
user string
sessionID []byte
clientVersion []byte
serverVersion []byte
}
func dup(src []byte) []byte {
dst := make([]byte, len(src))
copy(dst, src)
return dst
}
func (c *sshConn) User() string {
return c.user
}
func (c *sshConn) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
}
func (c *sshConn) Close() error {
return c.conn.Close()
}
func (c *sshConn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
func (c *sshConn) SessionID() []byte {
return dup(c.sessionID)
}
func (c *sshConn) ClientVersion() []byte {
return dup(c.clientVersion)
}
func (c *sshConn) ServerVersion() []byte {
return dup(c.serverVersion)
}
================================================
FILE: gossh/crypto/ssh/doc.go
================================================
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package ssh implements an SSH client and server.
SSH is a transport security protocol, an authentication protocol and a
family of application protocols. The most typical application level
protocol is a remote shell and this is specifically implemented. However,
the multiplexed nature of SSH is exposed to users that wish to support
others.
References:
[PROTOCOL]: https://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL?rev=HEAD
[PROTOCOL.certkeys]: http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.certkeys?rev=HEAD
[SSH-PARAMETERS]: http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xml#ssh-parameters-1
This package does not fall under the stability promise of the Go language itself,
so its API may be changed when pressing needs arise.
*/
package ssh // import "golang.org/x/crypto/ssh"
================================================
FILE: gossh/crypto/ssh/handshake.go
================================================
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"crypto/rand"
"errors"
"fmt"
"io"
"log"
"net"
"strings"
"sync"
)
// debugHandshake, if set, prints messages sent and received. Key
// exchange messages are printed as if DH were used, so the debug
// messages are wrong when using ECDH.
const debugHandshake = false
// chanSize sets the amount of buffering SSH connections. This is
// primarily for testing: setting chanSize=0 uncovers deadlocks more
// quickly.
const chanSize = 16
// keyingTransport is a packet based transport that supports key
// changes. It need not be thread-safe. It should pass through
// msgNewKeys in both directions.
type keyingTransport interface {
packetConn
// prepareKeyChange sets up a key change. The key change for a
// direction will be effected if a msgNewKeys message is sent
// or received.
prepareKeyChange(*algorithms, *kexResult) error
// setStrictMode sets the strict KEX mode, notably triggering
// sequence number resets on sending or receiving msgNewKeys.
// If the sequence number is already > 1 when setStrictMode
// is called, an error is returned.
setStrictMode() error
// setInitialKEXDone indicates to the transport that the initial key exchange
// was completed
setInitialKEXDone()
}
// handshakeTransport implements rekeying on top of a keyingTransport
// and offers a thread-safe writePacket() interface.
type handshakeTransport struct {
conn keyingTransport
config *Config
serverVersion []byte
clientVersion []byte
// hostKeys is non-empty if we are the server. In that case,
// it contains all host keys that can be used to sign the
// connection.
hostKeys []Signer
// publicKeyAuthAlgorithms is non-empty if we are the server. In that case,
// it contains the supported client public key authentication algorithms.
publicKeyAuthAlgorithms []string
// hostKeyAlgorithms is non-empty if we are the client. In that case,
// we accept these key types from the server as host key.
hostKeyAlgorithms []string
// On read error, incoming is closed, and readError is set.
incoming chan []byte
readError error
mu sync.Mutex
writeError error
sentInitPacket []byte
sentInitMsg *kexInitMsg
pendingPackets [][]byte // Used when a key exchange is in progress.
writePacketsLeft uint32
writeBytesLeft int64
// If the read loop wants to schedule a kex, it pings this
// channel, and the write loop will send out a kex
// message.
requestKex chan struct{}
// If the other side requests or confirms a kex, its kexInit
// packet is sent here for the write loop to find it.
startKex chan *pendingKex
kexLoopDone chan struct{} // closed (with writeError non-nil) when kexLoop exits
// data for host key checking
hostKeyCallback HostKeyCallback
dialAddress string
remoteAddr net.Addr
// bannerCallback is non-empty if we are the client and it has been set in
// ClientConfig. In that case it is called during the user authentication
// dance to handle a custom server's message.
bannerCallback BannerCallback
// Algorithms agreed in the last key exchange.
algorithms *algorithms
// Counters exclusively owned by readLoop.
readPacketsLeft uint32
readBytesLeft int64
// The session ID or nil if first kex did not complete yet.
sessionID []byte
// strictMode indicates if the other side of the handshake indicated
// that we should be following the strict KEX protocol restrictions.
strictMode bool
}
type pendingKex struct {
otherInit []byte
done chan error
}
func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport {
t := &handshakeTransport{
conn: conn,
serverVersion: serverVersion,
clientVersion: clientVersion,
incoming: make(chan []byte, chanSize),
requestKex: make(chan struct{}, 1),
startKex: make(chan *pendingKex),
kexLoopDone: make(chan struct{}),
config: config,
}
t.resetReadThresholds()
t.resetWriteThresholds()
// We always start with a mandatory key exchange.
t.requestKex <- struct{}{}
return t
}
func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ClientConfig, dialAddr string, addr net.Addr) *handshakeTransport {
t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
t.dialAddress = dialAddr
t.remoteAddr = addr
t.hostKeyCallback = config.HostKeyCallback
t.bannerCallback = config.BannerCallback
if config.HostKeyAlgorithms != nil {
t.hostKeyAlgorithms = config.HostKeyAlgorithms
} else {
t.hostKeyAlgorithms = supportedHostKeyAlgos
}
go t.readLoop()
go t.kexLoop()
return t
}
func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ServerConfig) *handshakeTransport {
t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
t.hostKeys = config.hostKeys
t.publicKeyAuthAlgorithms = config.PublicKeyAuthAlgorithms
go t.readLoop()
go t.kexLoop()
return t
}
func (t *handshakeTransport) getSessionID() []byte {
return t.sessionID
}
// waitSession waits for the session to be established. This should be
// the first thing to call after instantiating handshakeTransport.
func (t *handshakeTransport) waitSession() error {
p, err := t.readPacket()
if err != nil {
return err
}
if p[0] != msgNewKeys {
return fmt.Errorf("ssh: first packet should be msgNewKeys")
}
return nil
}
func (t *handshakeTransport) id() string {
if len(t.hostKeys) > 0 {
return "server"
}
return "client"
}
func (t *handshakeTransport) printPacket(p []byte, write bool) {
action := "got"
if write {
action = "sent"
}
if p[0] == msgChannelData || p[0] == msgChannelExtendedData {
log.Printf("%s %s data (packet %d bytes)", t.id(), action, len(p))
} else {
msg, err := decode(p)
log.Printf("%s %s %T %v (%v)", t.id(), action, msg, msg, err)
}
}
func (t *handshakeTransport) readPacket() ([]byte, error) {
p, ok := <-t.incoming
if !ok {
return nil, t.readError
}
return p, nil
}
func (t *handshakeTransport) readLoop() {
first := true
for {
p, err := t.readOnePacket(first)
first = false
if err != nil {
t.readError = err
close(t.incoming)
break
}
// If this is the first kex, and strict KEX mode is enabled,
// we don't ignore any messages, as they may be used to manipulate
// the packet sequence numbers.
if !(t.sessionID == nil && t.strictMode) && (p[0] == msgIgnore || p[0] == msgDebug) {
continue
}
t.incoming <- p
}
// Stop writers too.
t.recordWriteError(t.readError)
// Unblock the writer should it wait for this.
close(t.startKex)
// Don't close t.requestKex; it's also written to from writePacket.
}
func (t *handshakeTransport) pushPacket(p []byte) error {
if debugHandshake {
t.printPacket(p, true)
}
return t.conn.writePacket(p)
}
func (t *handshakeTransport) getWriteError() error {
t.mu.Lock()
defer t.mu.Unlock()
return t.writeError
}
func (t *handshakeTransport) recordWriteError(err error) {
t.mu.Lock()
defer t.mu.Unlock()
if t.writeError == nil && err != nil {
t.writeError = err
}
}
func (t *handshakeTransport) requestKeyExchange() {
select {
case t.requestKex <- struct{}{}:
default:
// something already requested a kex, so do nothing.
}
}
func (t *handshakeTransport) resetWriteThresholds() {
t.writePacketsLeft = packetRekeyThreshold
if t.config.RekeyThreshold > 0 {
t.writeBytesLeft = int64(t.config.RekeyThreshold)
} else if t.algorithms != nil {
t.writeBytesLeft = t.algorithms.w.rekeyBytes()
} else {
t.writeBytesLeft = 1 << 30
}
}
func (t *handshakeTransport) kexLoop() {
write:
for t.getWriteError() == nil {
var request *pendingKex
var sent bool
for request == nil || !sent {
var ok bool
select {
case request, ok = <-t.startKex:
if !ok {
break write
}
case <-t.requestKex:
break
}
if !sent {
if err := t.sendKexInit(); err != nil {
t.recordWriteError(err)
break
}
sent = true
}
}
if err := t.getWriteError(); err != nil {
if request != nil {
request.done <- err
}
break
}
// We're not servicing t.requestKex, but that is OK:
// we never block on sending to t.requestKex.
// We're not servicing t.startKex, but the remote end
// has just sent us a kexInitMsg, so it can't send
// another key change request, until we close the done
// channel on the pendingKex request.
err := t.enterKeyExchange(request.otherInit)
t.mu.Lock()
t.writeError = err
t.sentInitPacket = nil
t.sentInitMsg = nil
t.resetWriteThresholds()
// we have completed the key exchange. Since the
// reader is still blocked, it is safe to clear out
// the requestKex channel. This avoids the situation
// where: 1) we consumed our own request for the
// initial kex, and 2) the kex from the remote side
// caused another send on the requestKex channel,
clear:
for {
select {
case <-t.requestKex:
//
default:
break clear
}
}
request.done <- t.writeError
// kex finished. Push packets that we received while
// the kex was in progress. Don't look at t.startKex
// and don't increment writtenSinceKex: if we trigger
// another kex while we are still busy with the last
// one, things will become very confusing.
for _, p := range t.pendingPackets {
t.writeError = t.pushPacket(p)
if t.writeError != nil {
break
}
}
t.pendingPackets = t.pendingPackets[:0]
t.mu.Unlock()
}
// Unblock reader.
t.conn.Close()
// drain startKex channel. We don't service t.requestKex
// because nobody does blocking sends there.
for request := range t.startKex {
request.done <- t.getWriteError()
}
// Mark that the loop is done so that Close can return.
close(t.kexLoopDone)
}
// The protocol uses uint32 for packet counters, so we can't let them
// reach 1<<32. We will actually read and write more packets than
// this, though: the other side may send more packets, and after we
// hit this limit on writing we will send a few more packets for the
// key exchange itself.
const packetRekeyThreshold = (1 << 31)
func (t *handshakeTransport) resetReadThresholds() {
t.readPacketsLeft = packetRekeyThreshold
if t.config.RekeyThreshold > 0 {
t.readBytesLeft = int64(t.config.RekeyThreshold)
} else if t.algorithms != nil {
t.readBytesLeft = t.algorithms.r.rekeyBytes()
} else {
t.readBytesLeft = 1 << 30
}
}
func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) {
p, err := t.conn.readPacket()
if err != nil {
return nil, err
}
if t.readPacketsLeft > 0 {
t.readPacketsLeft--
} else {
t.requestKeyExchange()
}
if t.readBytesLeft > 0 {
t.readBytesLeft -= int64(len(p))
} else {
t.requestKeyExchange()
}
if debugHandshake {
t.printPacket(p, false)
}
if first && p[0] != msgKexInit {
return nil, fmt.Errorf("ssh: first packet should be msgKexInit")
}
if p[0] != msgKexInit {
return p, nil
}
firstKex := t.sessionID == nil
kex := pendingKex{
done: make(chan error, 1),
otherInit: p,
}
t.startKex <- &kex
err = <-kex.done
if debugHandshake {
log.Printf("%s exited key exchange (first %v), err %v", t.id(), firstKex, err)
}
if err != nil {
return nil, err
}
t.resetReadThresholds()
// By default, a key exchange is hidden from higher layers by
// translating it into msgIgnore.
successPacket := []byte{msgIgnore}
if firstKex {
// sendKexInit() for the first kex waits for
// msgNewKeys so the authentication process is
// guaranteed to happen over an encrypted transport.
successPacket = []byte{msgNewKeys}
}
return successPacket, nil
}
const (
kexStrictClient = "kex-strict-c-v00@openssh.com"
kexStrictServer = "kex-strict-s-v00@openssh.com"
)
// sendKexInit sends a key change message.
func (t *handshakeTransport) sendKexInit() error {
t.mu.Lock()
defer t.mu.Unlock()
if t.sentInitMsg != nil {
// kexInits may be sent either in response to the other side,
// or because our side wants to initiate a key change, so we
// may have already sent a kexInit. In that case, don't send a
// second kexInit.
return nil
}
msg := &kexInitMsg{
CiphersClientServer: t.config.Ciphers,
CiphersServerClient: t.config.Ciphers,
MACsClientServer: t.config.MACs,
MACsServerClient: t.config.MACs,
CompressionClientServer: supportedCompressions,
CompressionServerClient: supportedCompressions,
}
io.ReadFull(rand.Reader, msg.Cookie[:])
// We mutate the KexAlgos slice, in order to add the kex-strict extension algorithm,
// and possibly to add the ext-info extension algorithm. Since the slice may be the
// user owned KeyExchanges, we create our own slice in order to avoid using user
// owned memory by mistake.
msg.KexAlgos = make([]string, 0, len(t.config.KeyExchanges)+2) // room for kex-strict and ext-info
msg.KexAlgos = append(msg.KexAlgos, t.config.KeyExchanges...)
isServer := len(t.hostKeys) > 0
if isServer {
for _, k := range t.hostKeys {
// If k is a MultiAlgorithmSigner, we restrict the signature
// algorithms. If k is a AlgorithmSigner, presume it supports all
// signature algorithms associated with the key format. If k is not
// an AlgorithmSigner, we can only assume it only supports the
// algorithms that matches the key format. (This means that Sign
// can't pick a different default).
keyFormat := k.PublicKey().Type()
switch s := k.(type) {
case MultiAlgorithmSigner:
for _, algo := range algorithmsForKeyFormat(keyFormat) {
if contains(s.Algorithms(), underlyingAlgo(algo)) {
msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, algo)
}
}
case AlgorithmSigner:
msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, algorithmsForKeyFormat(keyFormat)...)
default:
msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, keyFormat)
}
}
if t.sessionID == nil {
msg.KexAlgos = append(msg.KexAlgos, kexStrictServer)
}
} else {
msg.ServerHostKeyAlgos = t.hostKeyAlgorithms
// As a client we opt in to receiving SSH_MSG_EXT_INFO so we know what
// algorithms the server supports for public key authentication. See RFC
// 8308, Section 2.1.
//
// We also send the strict KEX mode extension algorithm, in order to opt
// into the strict KEX mode.
if firstKeyExchange := t.sessionID == nil; firstKeyExchange {
msg.KexAlgos = append(msg.KexAlgos, "ext-info-c")
msg.KexAlgos = append(msg.KexAlgos, kexStrictClient)
}
}
packet := Marshal(msg)
// writePacket destroys the contents, so save a copy.
packetCopy := make([]byte, len(packet))
copy(packetCopy, packet)
if err := t.pushPacket(packetCopy); err != nil {
return err
}
t.sentInitMsg = msg
t.sentInitPacket = packet
return nil
}
func (t *handshakeTransport) writePacket(p []byte) error {
switch p[0] {
case msgKexInit:
return errors.New("ssh: only handshakeTransport can send kexInit")
case msgNewKeys:
return errors.New("ssh: only handshakeTransport can send newKeys")
}
t.mu.Lock()
defer t.mu.Unlock()
if t.writeError != nil {
return t.writeError
}
if t.sentInitMsg != nil {
// Copy the packet so the writer can reuse the buffer.
cp := make([]byte, len(p))
copy(cp, p)
t.pendingPackets = append(t.pendingPackets, cp)
return nil
}
if t.writeBytesLeft > 0 {
t.writeBytesLeft -= int64(len(p))
} else {
t.requestKeyExchange()
}
if t.writePacketsLeft > 0 {
t.writePacketsLeft--
} else {
t.requestKeyExchange()
}
if err := t.pushPacket(p); err != nil {
t.writeError = err
}
return nil
}
func (t *handshakeTransport) Close() error {
// Close the connection. This should cause the readLoop goroutine to wake up
// and close t.startKex, which will shut down kexLoop if running.
err := t.conn.Close()
// Wait for the kexLoop goroutine to complete.
// At that point we know that the readLoop goroutine is complete too,
// because kexLoop itself waits for readLoop to close the startKex channel.
<-t.kexLoopDone
return err
}
func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
if debugHandshake {
log.Printf("%s entered key exchange", t.id())
}
otherInit := &kexInitMsg{}
if err := Unmarshal(otherInitPacket, otherInit); err != nil {
return err
}
magics := handshakeMagics{
clientVersion: t.clientVersion,
serverVersion: t.serverVersion,
clientKexInit: otherInitPacket,
serverKexInit: t.sentInitPacket,
}
clientInit := otherInit
serverInit := t.sentInitMsg
isClient := len(t.hostKeys) == 0
if isClient {
clientInit, serverInit = serverInit, clientInit
magics.clientKexInit = t.sentInitPacket
magics.serverKexInit = otherInitPacket
}
var err error
t.algorithms, err = findAgreedAlgorithms(isClient, clientInit, serverInit)
if err != nil {
return err
}
if t.sessionID == nil && ((isClient && contains(serverInit.KexAlgos, kexStrictServer)) || (!isClient && contains(clientInit.KexAlgos, kexStrictClient))) {
t.strictMode = true
if err := t.conn.setStrictMode(); err != nil {
return err
}
}
// We don't send FirstKexFollows, but we handle receiving it.
//
// RFC 4253 section 7 defines the kex and the agreement method for
// first_kex_packet_follows. It states that the guessed packet
// should be ignored if the "kex algorithm and/or the host
// key algorithm is guessed wrong (server and client have
// different preferred algorithm), or if any of the other
// algorithms cannot be agreed upon". The other algorithms have
// already been checked above so the kex algorithm and host key
// algorithm are checked here.
if otherInit.FirstKexFollows && (clientInit.KexAlgos[0] != serverInit.KexAlgos[0] || clientInit.ServerHostKeyAlgos[0] != serverInit.ServerHostKeyAlgos[0]) {
// other side sent a kex message for the wrong algorithm,
// which we have to ignore.
if _, err := t.conn.readPacket(); err != nil {
return err
}
}
kex, ok := kexAlgoMap[t.algorithms.kex]
if !ok {
return fmt.Errorf("ssh: unexpected key exchange algorithm %v", t.algorithms.kex)
}
var result *kexResult
if len(t.hostKeys) > 0 {
result, err = t.server(kex, &magics)
} else {
result, err = t.client(kex, &magics)
}
if err != nil {
return err
}
firstKeyExchange := t.sessionID == nil
if firstKeyExchange {
t.sessionID = result.H
}
result.SessionID = t.sessionID
if err := t.conn.prepareKeyChange(t.algorithms, result); err != nil {
return err
}
if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil {
return err
}
// On the server side, after the first SSH_MSG_NEWKEYS, send a SSH_MSG_EXT_INFO
// message with the server-sig-algs extension if the client supports it. See
// RFC 8308, Sections 2.4 and 3.1, and [PROTOCOL], Section 1.9.
if !isClient && firstKeyExchange && contains(clientInit.KexAlgos, "ext-info-c") {
supportedPubKeyAuthAlgosList := strings.Join(t.publicKeyAuthAlgorithms, ",")
extInfo := &extInfoMsg{
NumExtensions: 2,
Payload: make([]byte, 0, 4+15+4+len(supportedPubKeyAuthAlgosList)+4+16+4+1),
}
extInfo.Payload = appendInt(extInfo.Payload, len("server-sig-algs"))
extInfo.Payload = append(extInfo.Payload, "server-sig-algs"...)
extInfo.Payload = appendInt(extInfo.Payload, len(supportedPubKeyAuthAlgosList))
extInfo.Payload = append(extInfo.Payload, supportedPubKeyAuthAlgosList...)
extInfo.Payload = appendInt(extInfo.Payload, len("ping@openssh.com"))
extInfo.Payload = append(extInfo.Payload, "ping@openssh.com"...)
extInfo.Payload = appendInt(extInfo.Payload, 1)
extInfo.Payload = append(extInfo.Payload, "0"...)
if err := t.conn.writePacket(Marshal(extInfo)); err != nil {
return err
}
}
if packet, err := t.conn.readPacket(); err != nil {
return err
} else if packet[0] != msgNewKeys {
return unexpectedMessageError(msgNewKeys, packet[0])
}
if firstKeyExchange {
// Indicates to the transport that the first key exchange is completed
// after receiving SSH_MSG_NEWKEYS.
t.conn.setInitialKEXDone()
}
return nil
}
// algorithmSignerWrapper is an AlgorithmSigner that only supports the default
// key format algorithm.
//
// This is technically a violation of the AlgorithmSigner interface, but it
// should be unreachable given where we use this. Anyway, at least it returns an
// error instead of panicing or producing an incorrect signature.
type algorithmSignerWrapper struct {
Signer
}
func (a algorithmSignerWrapper) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) {
if algorithm != underlyingAlgo(a.PublicKey().Type()) {
return nil, errors.New("ssh: internal error: algorithmSignerWrapper invoked with non-default algorithm")
}
return a.Sign(rand, data)
}
func pickHostKey(hostKeys []Signer, algo string) AlgorithmSigner {
for _, k := range hostKeys {
if s, ok := k.(MultiAlgorithmSigner); ok {
if !contains(s.Algorithms(), underlyingAlgo(algo)) {
continue
}
}
if algo == k.PublicKey().Type() {
return algorithmSignerWrapper{k}
}
k, ok := k.(AlgorithmSigner)
if !ok {
continue
}
for _, a := range algorithmsForKeyFormat(k.PublicKey().Type()) {
if algo == a {
return k
}
}
}
return nil
}
func (t *handshakeTransport) server(kex kexAlgorithm, magics *handshakeMagics) (*kexResult, error) {
hostKey := pickHostKey(t.hostKeys, t.algorithms.hostKey)
if hostKey == nil {
return nil, errors.New("ssh: internal error: negotiated unsupported signature type")
}
r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey, t.algorithms.hostKey)
return r, err
}
func (t *handshakeTransport) client(kex kexAlgorithm, magics *handshakeMagics) (*kexResult, error) {
result, err := kex.Client(t.conn, t.config.Rand, magics)
if err != nil {
return nil, err
}
hostKey, err := ParsePublicKey(result.HostKey)
if err != nil {
return nil, err
}
if err := verifyHostKeySignature(hostKey, t.algorithms.hostKey, result); err != nil {
return nil, err
}
err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey)
if err != nil {
return nil, err
}
return result, nil
}
================================================
FILE: gossh/crypto/ssh/internal/bcrypt_pbkdf/bcrypt_pbkdf.go
================================================
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package bcrypt_pbkdf implements bcrypt_pbkdf(3) from OpenBSD.
//
// See https://flak.tedunangst.com/post/bcrypt-pbkdf and
// https://cvsweb.openbsd.org/cgi-bin/cvsweb/src/lib/libutil/bcrypt_pbkdf.c.
package bcrypt_pbkdf
import (
"crypto/sha512"
"errors"
"gossh/crypto/blowfish"
)
const blockSize = 32
// Key derives a key from the password, salt and rounds count, returning a
// []byte of length keyLen that can be used as cryptographic key.
func Key(password, salt []byte, rounds, keyLen int) ([]byte, error) {
if rounds < 1 {
return nil, errors.New("bcrypt_pbkdf: number of rounds is too small")
}
if len(password) == 0 {
return nil, errors.New("bcrypt_pbkdf: empty password")
}
if len(salt) == 0 || len(salt) > 1<<20 {
return nil, errors.New("bcrypt_pbkdf: bad salt length")
}
if keyLen > 1024 {
return nil, errors.New("bcrypt_pbkdf: keyLen is too large")
}
numBlocks := (keyLen + blockSize - 1) / blockSize
key := make([]byte, numBlocks*blockSize)
h := sha512.New()
h.Write(password)
shapass := h.Sum(nil)
shasalt := make([]byte, 0, sha512.Size)
cnt, tmp := make([]byte, 4), make([]byte, blockSize)
for block := 1; block <= numBlocks; block++ {
h.Reset()
h.Write(salt)
cnt[0] = byte(block >> 24)
cnt[1] = byte(block >> 16)
cnt[2] = byte(block >> 8)
cnt[3] = byte(block)
h.Write(cnt)
bcryptHash(tmp, shapass, h.Sum(shasalt))
out := make([]byte, blockSize)
copy(out, tmp)
for i := 2; i <= rounds; i++ {
h.Reset()
h.Write(tmp)
bcryptHash(tmp, shapass, h.Sum(shasalt))
for j := 0; j < len(out); j++ {
out[j] ^= tmp[j]
}
}
for i, v := range out {
key[i*numBlocks+(block-1)] = v
}
}
return key[:keyLen], nil
}
var magic = []byte("OxychromaticBlowfishSwatDynamite")
func bcryptHash(out, shapass, shasalt []byte) {
c, err := blowfish.NewSaltedCipher(shapass, shasalt)
if err != nil {
panic(err)
}
for i := 0; i < 64; i++ {
blowfish.ExpandKey(shasalt, c)
blowfish.ExpandKey(shapass, c)
}
copy(out, magic)
for i := 0; i < 32; i += 8 {
for j := 0; j < 64; j++ {
c.Encrypt(out[i:i+8], out[i:i+8])
}
}
// Swap bytes due to different endianness.
for i := 0; i < 32; i += 4 {
out[i+3], out[i+2], out[i+1], out[i] = out[i], out[i+1], out[i+2], out[i+3]
}
}
================================================
FILE: gossh/crypto/ssh/kex.go
================================================
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/subtle"
"encoding/binary"
"errors"
"fmt"
"gossh/crypto/curve25519"
"io"
"math/big"
)
const (
kexAlgoDH1SHA1 = "diffie-hellman-group1-sha1"
kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1"
kexAlgoDH14SHA256 = "diffie-hellman-group14-sha256"
kexAlgoDH16SHA512 = "diffie-hellman-group16-sha512"
kexAlgoECDH256 = "ecdh-sha2-nistp256"
kexAlgoECDH384 = "ecdh-sha2-nistp384"
kexAlgoECDH521 = "ecdh-sha2-nistp521"
kexAlgoCurve25519SHA256LibSSH = "curve25519-sha256@libssh.org"
kexAlgoCurve25519SHA256 = "curve25519-sha256"
// For the following kex only the client half contains a production
// ready implementation. The server half only consists of a minimal
// implementation to satisfy the automated tests.
kexAlgoDHGEXSHA1 = "diffie-hellman-group-exchange-sha1"
kexAlgoDHGEXSHA256 = "diffie-hellman-group-exchange-sha256"
)
// kexResult captures the outcome of a key exchange.
type kexResult struct {
// Session hash. See also RFC 4253, section 8.
H []byte
// Shared secret. See also RFC 4253, section 8.
K []byte
// Host key as hashed into H.
HostKey []byte
// Signature of H.
Signature []byte
// A cryptographic hash function that matches the security
// level of the key exchange algorithm. It is used for
// calculating H, and for deriving keys from H and K.
Hash crypto.Hash
// The session ID, which is the first H computed. This is used
// to derive key material inside the transport.
SessionID []byte
}
// handshakeMagics contains data that is always included in the
// session hash.
type handshakeMagics struct {
clientVersion, serverVersion []byte
clientKexInit, serverKexInit []byte
}
func (m *handshakeMagics) write(w io.Writer) {
writeString(w, m.clientVersion)
writeString(w, m.serverVersion)
writeString(w, m.clientKexInit)
writeString(w, m.serverKexInit)
}
// kexAlgorithm abstracts different key exchange algorithms.
type kexAlgorithm interface {
// Server runs server-side key agreement, signing the result
// with a hostkey. algo is the negotiated algorithm, and may
// be a certificate type.
Server(p packetConn, rand io.Reader, magics *handshakeMagics, s AlgorithmSigner, algo string) (*kexResult, error)
// Client runs the client-side key agreement. Caller is
// responsible for verifying the host key signature.
Client(p packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error)
}
// dhGroup is a multiplicative group suitable for implementing Diffie-Hellman key agreement.
type dhGroup struct {
g, p, pMinus1 *big.Int
hashFunc crypto.Hash
}
func (group *dhGroup) diffieHellman(theirPublic, myPrivate *big.Int) (*big.Int, error) {
if theirPublic.Cmp(bigOne) <= 0 || theirPublic.Cmp(group.pMinus1) >= 0 {
return nil, errors.New("ssh: DH parameter out of bounds")
}
return new(big.Int).Exp(theirPublic, myPrivate, group.p), nil
}
func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handshakeMagics) (*kexResult, error) {
var x *big.Int
for {
var err error
if x, err = rand.Int(randSource, group.pMinus1); err != nil {
return nil, err
}
if x.Sign() > 0 {
break
}
}
X := new(big.Int).Exp(group.g, x, group.p)
kexDHInit := kexDHInitMsg{
X: X,
}
if err := c.writePacket(Marshal(&kexDHInit)); err != nil {
return nil, err
}
packet, err := c.readPacket()
if err != nil {
return nil, err
}
var kexDHReply kexDHReplyMsg
if err = Unmarshal(packet, &kexDHReply); err != nil {
return nil, err
}
ki, err := group.diffieHellman(kexDHReply.Y, x)
if err != nil {
return nil, err
}
h := group.hashFunc.New()
magics.write(h)
writeString(h, kexDHReply.HostKey)
writeInt(h, X)
writeInt(h, kexDHReply.Y)
K := make([]byte, intLength(ki))
marshalInt(K, ki)
h.Write(K)
return &kexResult{
H: h.Sum(nil),
K: K,
HostKey: kexDHReply.HostKey,
Signature: kexDHReply.Signature,
Hash: group.hashFunc,
}, nil
}
func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handshakeMagics, priv AlgorithmSigner, algo string) (result *kexResult, err error) {
packet, err := c.readPacket()
if err != nil {
return
}
var kexDHInit kexDHInitMsg
if err = Unmarshal(packet, &kexDHInit); err != nil {
return
}
var y *big.Int
for {
if y, err = rand.Int(randSource, group.pMinus1); err != nil {
return
}
if y.Sign() > 0 {
break
}
}
Y := new(big.Int).Exp(group.g, y, group.p)
ki, err := group.diffieHellman(kexDHInit.X, y)
if err != nil {
return nil, err
}
hostKeyBytes := priv.PublicKey().Marshal()
h := group.hashFunc.New()
magics.write(h)
writeString(h, hostKeyBytes)
writeInt(h, kexDHInit.X)
writeInt(h, Y)
K := make([]byte, intLength(ki))
marshalInt(K, ki)
h.Write(K)
H := h.Sum(nil)
// H is already a hash, but the hostkey signing will apply its
// own key-specific hash algorithm.
sig, err := signAndMarshal(priv, randSource, H, algo)
if err != nil {
return nil, err
}
kexDHReply := kexDHReplyMsg{
HostKey: hostKeyBytes,
Y: Y,
Signature: sig,
}
packet = Marshal(&kexDHReply)
err = c.writePacket(packet)
return &kexResult{
H: H,
K: K,
HostKey: hostKeyBytes,
Signature: sig,
Hash: group.hashFunc,
}, err
}
// ecdh performs Elliptic Curve Diffie-Hellman key exchange as
// described in RFC 5656, section 4.
type ecdh struct {
curve elliptic.Curve
}
func (kex *ecdh) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) {
ephKey, err := ecdsa.GenerateKey(kex.curve, rand)
if err != nil {
return nil, err
}
kexInit := kexECDHInitMsg{
ClientPubKey: elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y),
}
serialized := Marshal(&kexInit)
if err := c.writePacket(serialized); err != nil {
return nil, err
}
packet, err := c.readPacket()
if err != nil {
return nil, err
}
var reply kexECDHReplyMsg
if err = Unmarshal(packet, &reply); err != nil {
return nil, err
}
x, y, err := unmarshalECKey(kex.curve, reply.EphemeralPubKey)
if err != nil {
return nil, err
}
// generate shared secret
secret, _ := kex.curve.ScalarMult(x, y, ephKey.D.Bytes())
h := ecHash(kex.curve).New()
magics.write(h)
writeString(h, reply.HostKey)
writeString(h, kexInit.ClientPubKey)
writeString(h, reply.EphemeralPubKey)
K := make([]byte, intLength(secret))
marshalInt(K, secret)
h.Write(K)
return &kexResult{
H: h.Sum(nil),
K: K,
HostKey: reply.HostKey,
Signature: reply.Signature,
Hash: ecHash(kex.curve),
}, nil
}
// unmarshalECKey parses and checks an EC key.
func unmarshalECKey(curve elliptic.Curve, pubkey []byte) (x, y *big.Int, err error) {
x, y = elliptic.Unmarshal(curve, pubkey)
if x == nil {
return nil, nil, errors.New("ssh: elliptic.Unmarshal failure")
}
if !validateECPublicKey(curve, x, y) {
return nil, nil, errors.New("ssh: public key not on curve")
}
return x, y, nil
}
// validateECPublicKey checks that the point is a valid public key for
// the given curve. See [SEC1], 3.2.2
func validateECPublicKey(curve elliptic.Curve, x, y *big.Int) bool {
if x.Sign() == 0 && y.Sign() == 0 {
return false
}
if x.Cmp(curve.Params().P) >= 0 {
return false
}
if y.Cmp(curve.Params().P) >= 0 {
return false
}
if !curve.IsOnCurve(x, y) {
return false
}
// We don't check if N * PubKey == 0, since
//
// - the NIST curves have cofactor = 1, so this is implicit.
// (We don't foresee an implementation that supports non NIST
// curves)
//
// - for ephemeral keys, we don't need to worry about small
// subgroup attacks.
return true
}
func (kex *ecdh) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv AlgorithmSigner, algo string) (result *kexResult, err error) {
packet, err := c.readPacket()
if err != nil {
return nil, err
}
var kexECDHInit kexECDHInitMsg
if err = Unmarshal(packet, &kexECDHInit); err != nil {
return nil, err
}
clientX, clientY, err := unmarshalECKey(kex.curve, kexECDHInit.ClientPubKey)
if err != nil {
return nil, err
}
// We could cache this key across multiple users/multiple
// connection attempts, but the benefit is small. OpenSSH
// generates a new key for each incoming connection.
ephKey, err := ecdsa.GenerateKey(kex.curve, rand)
if err != nil {
return nil, err
}
hostKeyBytes := priv.PublicKey().Marshal()
serializedEphKey := elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y)
// generate shared secret
secret, _ := kex.curve.ScalarMult(clientX, clientY, ephKey.D.Bytes())
h := ecHash(kex.curve).New()
magics.write(h)
writeString(h, hostKeyBytes)
writeString(h, kexECDHInit.ClientPubKey)
writeString(h, serializedEphKey)
K := make([]byte, intLength(secret))
marshalInt(K, secret)
h.Write(K)
H := h.Sum(nil)
// H is already a hash, but the hostkey signing will apply its
// own key-specific hash algorithm.
sig, err := signAndMarshal(priv, rand, H, algo)
if err != nil {
return nil, err
}
reply := kexECDHReplyMsg{
EphemeralPubKey: serializedEphKey,
HostKey: hostKeyBytes,
Signature: sig,
}
serialized := Marshal(&reply)
if err := c.writePacket(serialized); err != nil {
return nil, err
}
return &kexResult{
H: H,
K: K,
HostKey: reply.HostKey,
Signature: sig,
Hash: ecHash(kex.curve),
}, nil
}
// ecHash returns the hash to match the given elliptic curve, see RFC
// 5656, section 6.2.1
func ecHash(curve elliptic.Curve) crypto.Hash {
bitSize := curve.Params().BitSize
switch {
case bitSize <= 256:
return crypto.SHA256
case bitSize <= 384:
return crypto.SHA384
}
return crypto.SHA512
}
var kexAlgoMap = map[string]kexAlgorithm{}
func init() {
// This is the group called diffie-hellman-group1-sha1 in
// RFC 4253 and Oakley Group 2 in RFC 2409.
p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF", 16)
kexAlgoMap[kexAlgoDH1SHA1] = &dhGroup{
g: new(big.Int).SetInt64(2),
p: p,
pMinus1: new(big.Int).Sub(p, bigOne),
hashFunc: crypto.SHA1,
}
// This are the groups called diffie-hellman-group14-sha1 and
// diffie-hellman-group14-sha256 in RFC 4253 and RFC 8268,
// and Oakley Group 14 in RFC 3526.
p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16)
group14 := &dhGroup{
g: new(big.Int).SetInt64(2),
p: p,
pMinus1: new(big.Int).Sub(p, bigOne),
}
kexAlgoMap[kexAlgoDH14SHA1] = &dhGroup{
g: group14.g, p: group14.p, pMinus1: group14.pMinus1,
hashFunc: crypto.SHA1,
}
kexAlgoMap[kexAlgoDH14SHA256] = &dhGroup{
g: group14.g, p: group14.p, pMinus1: group14.pMinus1,
hashFunc: crypto.SHA256,
}
// This is the group called diffie-hellman-group16-sha512 in RFC
// 8268 and Oakley Group 16 in RFC 3526.
p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C934063199FFFFFFFFFFFFFFFF", 16)
kexAlgoMap[kexAlgoDH16SHA512] = &dhGroup{
g: new(big.Int).SetInt64(2),
p: p,
pMinus1: new(big.Int).Sub(p, bigOne),
hashFunc: crypto.SHA512,
}
kexAlgoMap[kexAlgoECDH521] = &ecdh{elliptic.P521()}
kexAlgoMap[kexAlgoECDH384] = &ecdh{elliptic.P384()}
kexAlgoMap[kexAlgoECDH256] = &ecdh{elliptic.P256()}
kexAlgoMap[kexAlgoCurve25519SHA256] = &curve25519sha256{}
kexAlgoMap[kexAlgoCurve25519SHA256LibSSH] = &curve25519sha256{}
kexAlgoMap[kexAlgoDHGEXSHA1] = &dhGEXSHA{hashFunc: crypto.SHA1}
kexAlgoMap[kexAlgoDHGEXSHA256] = &dhGEXSHA{hashFunc: crypto.SHA256}
}
// curve25519sha256 implements the curve25519-sha256 (formerly known as
// curve25519-sha256@libssh.org) key exchange method, as described in RFC 8731.
type curve25519sha256 struct{}
type curve25519KeyPair struct {
priv [32]byte
pub [32]byte
}
func (kp *curve25519KeyPair) generate(rand io.Reader) error {
if _, err := io.ReadFull(rand, kp.priv[:]); err != nil {
return err
}
curve25519.ScalarBaseMult(&kp.pub, &kp.priv)
return nil
}
// curve25519Zeros is just an array of 32 zero bytes so that we have something
// convenient to compare against in order to reject curve25519 points with the
// wrong order.
var curve25519Zeros [32]byte
func (kex *curve25519sha256) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) {
var kp curve25519KeyPair
if err := kp.generate(rand); err != nil {
return nil, err
}
if err := c.writePacket(Marshal(&kexECDHInitMsg{kp.pub[:]})); err != nil {
return nil, err
}
packet, err := c.readPacket()
if err != nil {
return nil, err
}
var reply kexECDHReplyMsg
if err = Unmarshal(packet, &reply); err != nil {
return nil, err
}
if len(reply.EphemeralPubKey) != 32 {
return nil, errors.New("ssh: peer's curve25519 public value has wrong length")
}
var servPub, secret [32]byte
copy(servPub[:], reply.EphemeralPubKey)
curve25519.ScalarMult(&secret, &kp.priv, &servPub)
if subtle.ConstantTimeCompare(secret[:], curve25519Zeros[:]) == 1 {
return nil, errors.New("ssh: peer's curve25519 public value has wrong order")
}
h := crypto.SHA256.New()
magics.write(h)
writeString(h, reply.HostKey)
writeString(h, kp.pub[:])
writeString(h, reply.EphemeralPubKey)
ki := new(big.Int).SetBytes(secret[:])
K := make([]byte, intLength(ki))
marshalInt(K, ki)
h.Write(K)
return &kexResult{
H: h.Sum(nil),
K: K,
HostKey: reply.HostKey,
Signature: reply.Signature,
Hash: crypto.SHA256,
}, nil
}
func (kex *curve25519sha256) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv AlgorithmSigner, algo string) (result *kexResult, err error) {
packet, err := c.readPacket()
if err != nil {
return
}
var kexInit kexECDHInitMsg
if err = Unmarshal(packet, &kexInit); err != nil {
return
}
if len(kexInit.ClientPubKey) != 32 {
return nil, errors.New("ssh: peer's curve25519 public value has wrong length")
}
var kp curve25519KeyPair
if err := kp.generate(rand); err != nil {
return nil, err
}
var clientPub, secret [32]byte
copy(clientPub[:], kexInit.ClientPubKey)
curve25519.ScalarMult(&secret, &kp.priv, &clientPub)
if subtle.ConstantTimeCompare(secret[:], curve25519Zeros[:]) == 1 {
return nil, errors.New("ssh: peer's curve25519 public value has wrong order")
}
hostKeyBytes := priv.PublicKey().Marshal()
h := crypto.SHA256.New()
magics.write(h)
writeString(h, hostKeyBytes)
writeString(h, kexInit.ClientPubKey)
writeString(h, kp.pub[:])
ki := new(big.Int).SetBytes(secret[:])
K := make([]byte, intLength(ki))
marshalInt(K, ki)
h.Write(K)
H := h.Sum(nil)
sig, err := signAndMarshal(priv, rand, H, algo)
if err != nil {
return nil, err
}
reply := kexECDHReplyMsg{
EphemeralPubKey: kp.pub[:],
HostKey: hostKeyBytes,
Signature: sig,
}
if err := c.writePacket(Marshal(&reply)); err != nil {
return nil, err
}
return &kexResult{
H: H,
K: K,
HostKey: hostKeyBytes,
Signature: sig,
Hash: crypto.SHA256,
}, nil
}
// dhGEXSHA implements the diffie-hellman-group-exchange-sha1 and
// diffie-hellman-group-exchange-sha256 key agreement protocols,
// as described in RFC 4419
type dhGEXSHA struct {
hashFunc crypto.Hash
}
const (
dhGroupExchangeMinimumBits = 2048
dhGroupExchangePreferredBits = 2048
dhGroupExchangeMaximumBits = 8192
)
func (gex *dhGEXSHA) Client(c packetConn, randSource io.Reader, magics *handshakeMagics) (*kexResult, error) {
// Send GexRequest
kexDHGexRequest := kexDHGexRequestMsg{
MinBits: dhGroupExchangeMinimumBits,
PreferedBits: dhGroupExchangePreferredBits,
MaxBits: dhGroupExchangeMaximumBits,
}
if err := c.writePacket(Marshal(&kexDHGexRequest)); err != nil {
return nil, err
}
// Receive GexGroup
packet, err := c.readPacket()
if err != nil {
return nil, err
}
var msg kexDHGexGroupMsg
if err = Unmarshal(packet, &msg); err != nil {
return nil, err
}
// reject if p's bit length < dhGroupExchangeMinimumBits or > dhGroupExchangeMaximumBits
if msg.P.BitLen() < dhGroupExchangeMinimumBits || msg.P.BitLen() > dhGroupExchangeMaximumBits {
return nil, fmt.Errorf("ssh: server-generated gex p is out of range (%d bits)", msg.P.BitLen())
}
// Check if g is safe by verifying that 1 < g < p-1
pMinusOne := new(big.Int).Sub(msg.P, bigOne)
if msg.G.Cmp(bigOne) <= 0 || msg.G.Cmp(pMinusOne) >= 0 {
return nil, fmt.Errorf("ssh: server provided gex g is not safe")
}
// Send GexInit
pHalf := new(big.Int).Rsh(msg.P, 1)
x, err := rand.Int(randSource, pHalf)
if err != nil {
return nil, err
}
X := new(big.Int).Exp(msg.G, x, msg.P)
kexDHGexInit := kexDHGexInitMsg{
X: X,
}
if err := c.writePacket(Marshal(&kexDHGexInit)); err != nil {
return nil, err
}
// Receive GexReply
packet, err = c.readPacket()
if err != nil {
return nil, err
}
var kexDHGexReply kexDHGexReplyMsg
if err = Unmarshal(packet, &kexDHGexReply); err != nil {
return nil, err
}
if kexDHGexReply.Y.Cmp(bigOne) <= 0 || kexDHGexReply.Y.Cmp(pMinusOne) >= 0 {
return nil, errors.New("ssh: DH parameter out of bounds")
}
kInt := new(big.Int).Exp(kexDHGexReply.Y, x, msg.P)
// Check if k is safe by verifying that k > 1 and k < p - 1
if kInt.Cmp(bigOne) <= 0 || kInt.Cmp(pMinusOne) >= 0 {
return nil, fmt.Errorf("ssh: derived k is not safe")
}
h := gex.hashFunc.New()
magics.write(h)
writeString(h, kexDHGexReply.HostKey)
binary.Write(h, binary.BigEndian, uint32(dhGroupExchangeMinimumBits))
binary.Write(h, binary.BigEndian, uint32(dhGroupExchangePreferredBits))
binary.Write(h, binary.BigEndian, uint32(dhGroupExchangeMaximumBits))
writeInt(h, msg.P)
writeInt(h, msg.G)
writeInt(h, X)
writeInt(h, kexDHGexReply.Y)
K := make([]byte, intLength(kInt))
marshalInt(K, kInt)
h.Write(K)
return &kexResult{
H: h.Sum(nil),
K: K,
HostKey: kexDHGexReply.HostKey,
Signature: kexDHGexReply.Signature,
Hash: gex.hashFunc,
}, nil
}
// Server half implementation of the Diffie Hellman Key Exchange with SHA1 and SHA256.
//
// This is a minimal implementation to satisfy the automated tests.
func (gex dhGEXSHA) Server(c packetConn, randSource io.Reader, magics *handshakeMagics, priv AlgorithmSigner, algo string) (result *kexResult, err error) {
// Receive GexRequest
packet, err := c.readPacket()
if err != nil {
return
}
var kexDHGexRequest kexDHGexRequestMsg
if err = Unmarshal(packet, &kexDHGexRequest); err != nil {
return
}
// Send GexGroup
// This is the group called diffie-hellman-group14-sha1 in RFC
// 4253 and Oakley Group 14 in RFC 3526.
p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16)
g := big.NewInt(2)
msg := &kexDHGexGroupMsg{
P: p,
G: g,
}
if err := c.writePacket(Marshal(msg)); err != nil {
return nil, err
}
// Receive GexInit
packet, err = c.readPacket()
if err != nil {
return
}
var kexDHGexInit kexDHGexInitMsg
if err = Unmarshal(packet, &kexDHGexInit); err != nil {
return
}
pHalf := new(big.Int).Rsh(p, 1)
y, err := rand.Int(randSource, pHalf)
if err != nil {
return
}
Y := new(big.Int).Exp(g, y, p)
pMinusOne := new(big.Int).Sub(p, bigOne)
if kexDHGexInit.X.Cmp(bigOne) <= 0 || kexDHGexInit.X.Cmp(pMinusOne) >= 0 {
return nil, errors.New("ssh: DH parameter out of bounds")
}
kInt := new(big.Int).Exp(kexDHGexInit.X, y, p)
hostKeyBytes := priv.PublicKey().Marshal()
h := gex.hashFunc.New()
magics.write(h)
writeString(h, hostKeyBytes)
binary.Write(h, binary.BigEndian, uint32(dhGroupExchangeMinimumBits))
binary.Write(h, binary.BigEndian, uint32(dhGroupExchangePreferredBits))
binary.Write(h, binary.BigEndian, uint32(dhGroupExchangeMaximumBits))
writeInt(h, p)
writeInt(h, g)
writeInt(h, kexDHGexInit.X)
writeInt(h, Y)
K := make([]byte, intLength(kInt))
marshalInt(K, kInt)
h.Write(K)
H := h.Sum(nil)
// H is already a hash, but the hostkey signing will apply its
// own key-specific hash algorithm.
sig, err := signAndMarshal(priv, randSource, H, algo)
if err != nil {
return nil, err
}
kexDHGexReply := kexDHGexReplyMsg{
HostKey: hostKeyBytes,
Y: Y,
Signature: sig,
}
packet = Marshal(&kexDHGexReply)
err = c.writePacket(packet)
return &kexResult{
H: H,
K: K,
HostKey: hostKeyBytes,
Signature: sig,
Hash: gex.hashFunc,
}, err
}
================================================
FILE: gossh/crypto/ssh/keys.go
================================================
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"crypto"
"crypto/aes"
"crypto/cipher"
"crypto/dsa"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/md5"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/asn1"
"encoding/base64"
"encoding/binary"
"encoding/hex"
"encoding/pem"
"errors"
"fmt"
"io"
"math/big"
"strings"
"gossh/crypto/ssh/internal/bcrypt_pbkdf"
)
// Public key algorithms names. These values can appear in PublicKey.Type,
// ClientConfig.HostKeyAlgorithms, Signature.Format, or as AlgorithmSigner
// arguments.
const (
KeyAlgoRSA = "ssh-rsa"
KeyAlgoDSA = "ssh-dss"
KeyAlgoECDSA256 = "ecdsa-sha2-nistp256"
KeyAlgoSKECDSA256 = "sk-ecdsa-sha2-nistp256@openssh.com"
KeyAlgoECDSA384 = "ecdsa-sha2-nistp384"
KeyAlgoECDSA521 = "ecdsa-sha2-nistp521"
KeyAlgoED25519 = "ssh-ed25519"
KeyAlgoSKED25519 = "sk-ssh-ed25519@openssh.com"
// KeyAlgoRSASHA256 and KeyAlgoRSASHA512 are only public key algorithms, not
// public key formats, so they can't appear as a PublicKey.Type. The
// corresponding PublicKey.Type is KeyAlgoRSA. See RFC 8332, Section 2.
KeyAlgoRSASHA256 = "rsa-sha2-256"
KeyAlgoRSASHA512 = "rsa-sha2-512"
)
const (
// Deprecated: use KeyAlgoRSA.
SigAlgoRSA = KeyAlgoRSA
// Deprecated: use KeyAlgoRSASHA256.
SigAlgoRSASHA2256 = KeyAlgoRSASHA256
// Deprecated: use KeyAlgoRSASHA512.
SigAlgoRSASHA2512 = KeyAlgoRSASHA512
)
// parsePubKey parses a public key of the given algorithm.
// Use ParsePublicKey for keys with prepended algorithm.
func parsePubKey(in []byte, algo string) (pubKey PublicKey, rest []byte, err error) {
switch algo {
case KeyAlgoRSA:
return parseRSA(in)
case KeyAlgoDSA:
return parseDSA(in)
case KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521:
return parseECDSA(in)
case KeyAlgoSKECDSA256:
return parseSKECDSA(in)
case KeyAlgoED25519:
return parseED25519(in)
case KeyAlgoSKED25519:
return parseSKEd25519(in)
case CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoSKECDSA256v01, CertAlgoED25519v01, CertAlgoSKED25519v01:
cert, err := parseCert(in, certKeyAlgoNames[algo])
if err != nil {
return nil, nil, err
}
return cert, nil, nil
}
return nil, nil, fmt.Errorf("ssh: unknown key algorithm: %v", algo)
}
// parseAuthorizedKey parses a public key in OpenSSH authorized_keys format
// (see sshd(8) manual page) once the options and key type fields have been
// removed.
func parseAuthorizedKey(in []byte) (out PublicKey, comment string, err error) {
in = bytes.TrimSpace(in)
i := bytes.IndexAny(in, " \t")
if i == -1 {
i = len(in)
}
base64Key := in[:i]
key := make([]byte, base64.StdEncoding.DecodedLen(len(base64Key)))
n, err := base64.StdEncoding.Decode(key, base64Key)
if err != nil {
return nil, "", err
}
key = key[:n]
out, err = ParsePublicKey(key)
if err != nil {
return nil, "", err
}
comment = string(bytes.TrimSpace(in[i:]))
return out, comment, nil
}
// ParseKnownHosts parses an entry in the format of the known_hosts file.
//
// The known_hosts format is documented in the sshd(8) manual page. This
// function will parse a single entry from in. On successful return, marker
// will contain the optional marker value (i.e. "cert-authority" or "revoked")
// or else be empty, hosts will contain the hosts that this entry matches,
// pubKey will contain the public key and comment will contain any trailing
// comment at the end of the line. See the sshd(8) manual page for the various
// forms that a host string can take.
//
// The unparsed remainder of the input will be returned in rest. This function
// can be called repeatedly to parse multiple entries.
//
// If no entries were found in the input then err will be io.EOF. Otherwise a
// non-nil err value indicates a parse error.
func ParseKnownHosts(in []byte) (marker string, hosts []string, pubKey PublicKey, comment string, rest []byte, err error) {
for len(in) > 0 {
end := bytes.IndexByte(in, '\n')
if end != -1 {
rest = in[end+1:]
in = in[:end]
} else {
rest = nil
}
end = bytes.IndexByte(in, '\r')
if end != -1 {
in = in[:end]
}
in = bytes.TrimSpace(in)
if len(in) == 0 || in[0] == '#' {
in = rest
continue
}
i := bytes.IndexAny(in, " \t")
if i == -1 {
in = rest
continue
}
// Strip out the beginning of the known_host key.
// This is either an optional marker or a (set of) hostname(s).
keyFields := bytes.Fields(in)
if len(keyFields) < 3 || len(keyFields) > 5 {
return "", nil, nil, "", nil, errors.New("ssh: invalid entry in known_hosts data")
}
// keyFields[0] is either "@cert-authority", "@revoked" or a comma separated
// list of hosts
marker := ""
if keyFields[0][0] == '@' {
marker = string(keyFields[0][1:])
keyFields = keyFields[1:]
}
hosts := string(keyFields[0])
// keyFields[1] contains the key type (e.g. “ssh-rsa”).
// However, that information is duplicated inside the
// base64-encoded key and so is ignored here.
key := bytes.Join(keyFields[2:], []byte(" "))
if pubKey, comment, err = parseAuthorizedKey(key); err != nil {
return "", nil, nil, "", nil, err
}
return marker, strings.Split(hosts, ","), pubKey, comment, rest, nil
}
return "", nil, nil, "", nil, io.EOF
}
// ParseAuthorizedKey parses a public key from an authorized_keys
// file used in OpenSSH according to the sshd(8) manual page.
func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []string, rest []byte, err error) {
for len(in) > 0 {
end := bytes.IndexByte(in, '\n')
if end != -1 {
rest = in[end+1:]
in = in[:end]
} else {
rest = nil
}
end = bytes.IndexByte(in, '\r')
if end != -1 {
in = in[:end]
}
in = bytes.TrimSpace(in)
if len(in) == 0 || in[0] == '#' {
in = rest
continue
}
i := bytes.IndexAny(in, " \t")
if i == -1 {
in = rest
continue
}
if out, comment, err = parseAuthorizedKey(in[i:]); err == nil {
return out, comment, options, rest, nil
}
// No key type recognised. Maybe there's an options field at
// the beginning.
var b byte
inQuote := false
var candidateOptions []string
optionStart := 0
for i, b = range in {
isEnd := !inQuote && (b == ' ' || b == '\t')
if (b == ',' && !inQuote) || isEnd {
if i-optionStart > 0 {
candidateOptions = append(candidateOptions, string(in[optionStart:i]))
}
optionStart = i + 1
}
if isEnd {
break
}
if b == '"' && (i == 0 || (i > 0 && in[i-1] != '\\')) {
inQuote = !inQuote
}
}
for i < len(in) && (in[i] == ' ' || in[i] == '\t') {
i++
}
if i == len(in) {
// Invalid line: unmatched quote
in = rest
continue
}
in = in[i:]
i = bytes.IndexAny(in, " \t")
if i == -1 {
in = rest
continue
}
if out, comment, err = parseAuthorizedKey(in[i:]); err == nil {
options = candidateOptions
return out, comment, options, rest, nil
}
in = rest
continue
}
return nil, "", nil, nil, errors.New("ssh: no key found")
}
// ParsePublicKey parses an SSH public key formatted for use in
// the SSH wire protocol according to RFC 4253, section 6.6.
func ParsePublicKey(in []byte) (out PublicKey, err error) {
algo, in, ok := parseString(in)
if !ok {
return nil, errShortRead
}
var rest []byte
out, rest, err = parsePubKey(in, string(algo))
if len(rest) > 0 {
return nil, errors.New("ssh: trailing junk in public key")
}
return out, err
}
// MarshalAuthorizedKey serializes key for inclusion in an OpenSSH
// authorized_keys file. The return value ends with newline.
func MarshalAuthorizedKey(key PublicKey) []byte {
b := &bytes.Buffer{}
b.WriteString(key.Type())
b.WriteByte(' ')
e := base64.NewEncoder(base64.StdEncoding, b)
e.Write(key.Marshal())
e.Close()
b.WriteByte('\n')
return b.Bytes()
}
// MarshalPrivateKey returns a PEM block with the private key serialized in the
// OpenSSH format.
func MarshalPrivateKey(key crypto.PrivateKey, comment string) (*pem.Block, error) {
return marshalOpenSSHPrivateKey(key, comment, unencryptedOpenSSHMarshaler)
}
// MarshalPrivateKeyWithPassphrase returns a PEM block holding the encrypted
// private key serialized in the OpenSSH format.
func MarshalPrivateKeyWithPassphrase(key crypto.PrivateKey, comment string, passphrase []byte) (*pem.Block, error) {
return marshalOpenSSHPrivateKey(key, comment, passphraseProtectedOpenSSHMarshaler(passphrase))
}
// PublicKey represents a public key using an unspecified algorithm.
//
// Some PublicKeys provided by this package also implement CryptoPublicKey.
type PublicKey interface {
// Type returns the key format name, e.g. "ssh-rsa".
Type() string
// Marshal returns the serialized key data in SSH wire format, with the name
// prefix. To unmarshal the returned data, use the ParsePublicKey function.
Marshal() []byte
// Verify that sig is a signature on the given data using this key. This
// method will hash the data appropriately first. sig.Format is allowed to
// be any signature algorithm compatible with the key type, the caller
// should check if it has more stringent requirements.
Verify(data []byte, sig *Signature) error
}
// CryptoPublicKey, if implemented by a PublicKey,
// returns the underlying crypto.PublicKey form of the key.
type CryptoPublicKey interface {
CryptoPublicKey() crypto.PublicKey
}
// A Signer can create signatures that verify against a public key.
//
// Some Signers provided by this package also implement MultiAlgorithmSigner.
type Signer interface {
// PublicKey returns the associated PublicKey.
PublicKey() PublicKey
// Sign returns a signature for the given data. This method will hash the
// data appropriately first. The signature algorithm is expected to match
// the key format returned by the PublicKey.Type method (and not to be any
// alternative algorithm supported by the key format).
Sign(rand io.Reader, data []byte) (*Signature, error)
}
// An AlgorithmSigner is a Signer that also supports specifying an algorithm to
// use for signing.
//
// An AlgorithmSigner can't advertise the algorithms it supports, unless it also
// implements MultiAlgorithmSigner, so it should be prepared to be invoked with
// every algorithm supported by the public key format.
type AlgorithmSigner interface {
Signer
// SignWithAlgorithm is like Signer.Sign, but allows specifying a desired
// signing algorithm. Callers may pass an empty string for the algorithm in
// which case the AlgorithmSigner will use a default algorithm. This default
// doesn't currently control any behavior in this package.
SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error)
}
// MultiAlgorithmSigner is an AlgorithmSigner that also reports the algorithms
// supported by that signer.
type MultiAlgorithmSigner interface {
AlgorithmSigner
// Algorithms returns the available algorithms in preference order. The list
// must not be empty, and it must not include certificate types.
Algorithms() []string
}
// NewSignerWithAlgorithms returns a signer restricted to the specified
// algorithms. The algorithms must be set in preference order. The list must not
// be empty, and it must not include certificate types. An error is returned if
// the specified algorithms are incompatible with the public key type.
func NewSignerWithAlgorithms(signer AlgorithmSigner, algorithms []string) (MultiAlgorithmSigner, error) {
if len(algorithms) == 0 {
return nil, errors.New("ssh: please specify at least one valid signing algorithm")
}
var signerAlgos []string
supportedAlgos := algorithmsForKeyFormat(underlyingAlgo(signer.PublicKey().Type()))
if s, ok := signer.(*multiAlgorithmSigner); ok {
signerAlgos = s.Algorithms()
} else {
signerAlgos = supportedAlgos
}
for _, algo := range algorithms {
if !contains(supportedAlgos, algo) {
return nil, fmt.Errorf("ssh: algorithm %q is not supported for key type %q",
algo, signer.PublicKey().Type())
}
if !contains(signerAlgos, algo) {
return nil, fmt.Errorf("ssh: algorithm %q is restricted for the provided signer", algo)
}
}
return &multiAlgorithmSigner{
AlgorithmSigner: signer,
supportedAlgorithms: algorithms,
}, nil
}
type multiAlgorithmSigner struct {
AlgorithmSigner
supportedAlgorithms []string
}
func (s *multiAlgorithmSigner) Algorithms() []string {
return s.supportedAlgorithms
}
func (s *multiAlgorithmSigner) isAlgorithmSupported(algorithm string) bool {
if algorithm == "" {
algorithm = underlyingAlgo(s.PublicKey().Type())
}
for _, algo := range s.supportedAlgorithms {
if algorithm == algo {
return true
}
}
return false
}
func (s *multiAlgorithmSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) {
if !s.isAlgorithmSupported(algorithm) {
return nil, fmt.Errorf("ssh: algorithm %q is not supported: %v", algorithm, s.supportedAlgorithms)
}
return s.AlgorithmSigner.SignWithAlgorithm(rand, data, algorithm)
}
type rsaPublicKey rsa.PublicKey
func (r *rsaPublicKey) Type() string {
return "ssh-rsa"
}
// parseRSA parses an RSA key according to RFC 4253, section 6.6.
func parseRSA(in []byte) (out PublicKey, rest []byte, err error) {
var w struct {
E *big.Int
N *big.Int
Rest []byte `ssh:"rest"`
}
if err := Unmarshal(in, &w); err != nil {
return nil, nil, err
}
if w.E.BitLen() > 24 {
return nil, nil, errors.New("ssh: exponent too large")
}
e := w.E.Int64()
if e < 3 || e&1 == 0 {
return nil, nil, errors.New("ssh: incorrect exponent")
}
var key rsa.PublicKey
key.E = int(e)
key.N = w.N
return (*rsaPublicKey)(&key), w.Rest, nil
}
func (r *rsaPublicKey) Marshal() []byte {
e := new(big.Int).SetInt64(int64(r.E))
// RSA publickey struct layout should match the struct used by
// parseRSACert in the x/crypto/ssh/agent package.
wirekey := struct {
Name string
E *big.Int
N *big.Int
}{
KeyAlgoRSA,
e,
r.N,
}
return Marshal(&wirekey)
}
func (r *rsaPublicKey) Verify(data []byte, sig *Signature) error {
supportedAlgos := algorithmsForKeyFormat(r.Type())
if !contains(supportedAlgos, sig.Format) {
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, r.Type())
}
hash := hashFuncs[sig.Format]
h := hash.New()
h.Write(data)
digest := h.Sum(nil)
return rsa.VerifyPKCS1v15((*rsa.PublicKey)(r), hash, digest, sig.Blob)
}
func (r *rsaPublicKey) CryptoPublicKey() crypto.PublicKey {
return (*rsa.PublicKey)(r)
}
type dsaPublicKey dsa.PublicKey
func (k *dsaPublicKey) Type() string {
return "ssh-dss"
}
func checkDSAParams(param *dsa.Parameters) error {
// SSH specifies FIPS 186-2, which only provided a single size
// (1024 bits) DSA key. FIPS 186-3 allows for larger key
// sizes, which would confuse SSH.
if l := param.P.BitLen(); l != 1024 {
return fmt.Errorf("ssh: unsupported DSA key size %d", l)
}
return nil
}
// parseDSA parses an DSA key according to RFC 4253, section 6.6.
func parseDSA(in []byte) (out PublicKey, rest []byte, err error) {
var w struct {
P, Q, G, Y *big.Int
Rest []byte `ssh:"rest"`
}
if err := Unmarshal(in, &w); err != nil {
return nil, nil, err
}
param := dsa.Parameters{
P: w.P,
Q: w.Q,
G: w.G,
}
if err := checkDSAParams(¶m); err != nil {
return nil, nil, err
}
key := &dsaPublicKey{
Parameters: param,
Y: w.Y,
}
return key, w.Rest, nil
}
func (k *dsaPublicKey) Marshal() []byte {
// DSA publickey struct layout should match the struct used by
// parseDSACert in the x/crypto/ssh/agent package.
w := struct {
Name string
P, Q, G, Y *big.Int
}{
k.Type(),
k.P,
k.Q,
k.G,
k.Y,
}
return Marshal(&w)
}
func (k *dsaPublicKey) Verify(data []byte, sig *Signature) error {
if sig.Format != k.Type() {
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type())
}
h := hashFuncs[sig.Format].New()
h.Write(data)
digest := h.Sum(nil)
// Per RFC 4253, section 6.6,
// The value for 'dss_signature_blob' is encoded as a string containing
// r, followed by s (which are 160-bit integers, without lengths or
// padding, unsigned, and in network byte order).
// For DSS purposes, sig.Blob should be exactly 40 bytes in length.
if len(sig.Blob) != 40 {
return errors.New("ssh: DSA signature parse error")
}
r := new(big.Int).SetBytes(sig.Blob[:20])
s := new(big.Int).SetBytes(sig.Blob[20:])
if dsa.Verify((*dsa.PublicKey)(k), digest, r, s) {
return nil
}
return errors.New("ssh: signature did not verify")
}
func (k *dsaPublicKey) CryptoPublicKey() crypto.PublicKey {
return (*dsa.PublicKey)(k)
}
type dsaPrivateKey struct {
*dsa.PrivateKey
}
func (k *dsaPrivateKey) PublicKey() PublicKey {
return (*dsaPublicKey)(&k.PrivateKey.PublicKey)
}
func (k *dsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) {
return k.SignWithAlgorithm(rand, data, k.PublicKey().Type())
}
func (k *dsaPrivateKey) Algorithms() []string {
return []string{k.PublicKey().Type()}
}
func (k *dsaPrivateKey) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) {
if algorithm != "" && algorithm != k.PublicKey().Type() {
return nil, fmt.Errorf("ssh: unsupported signature algorithm %s", algorithm)
}
h := hashFuncs[k.PublicKey().Type()].New()
h.Write(data)
digest := h.Sum(nil)
r, s, err := dsa.Sign(rand, k.PrivateKey, digest)
if err != nil {
return nil, err
}
sig := make([]byte, 40)
rb := r.Bytes()
sb := s.Bytes()
copy(sig[20-len(rb):20], rb)
copy(sig[40-len(sb):], sb)
return &Signature{
Format: k.PublicKey().Type(),
Blob: sig,
}, nil
}
type ecdsaPublicKey ecdsa.PublicKey
func (k *ecdsaPublicKey) Type() string {
return "ecdsa-sha2-" + k.nistID()
}
func (k *ecdsaPublicKey) nistID() string {
switch k.Params().BitSize {
case 256:
return "nistp256"
case 384:
return "nistp384"
case 521:
return "nistp521"
}
panic("ssh: unsupported ecdsa key size")
}
type ed25519PublicKey ed25519.PublicKey
func (k ed25519PublicKey) Type() string {
return KeyAlgoED25519
}
func parseED25519(in []byte) (out PublicKey, rest []byte, err error) {
var w struct {
KeyBytes []byte
Rest []byte `ssh:"rest"`
}
if err := Unmarshal(in, &w); err != nil {
return nil, nil, err
}
if l := len(w.KeyBytes); l != ed25519.PublicKeySize {
return nil, nil, fmt.Errorf("invalid size %d for Ed25519 public key", l)
}
return ed25519PublicKey(w.KeyBytes), w.Rest, nil
}
func (k ed25519PublicKey) Marshal() []byte {
w := struct {
Name string
KeyBytes []byte
}{
KeyAlgoED25519,
[]byte(k),
}
return Marshal(&w)
}
func (k ed25519PublicKey) Verify(b []byte, sig *Signature) error {
if sig.Format != k.Type() {
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type())
}
if l := len(k); l != ed25519.PublicKeySize {
return fmt.Errorf("ssh: invalid size %d for Ed25519 public key", l)
}
if ok := ed25519.Verify(ed25519.PublicKey(k), b, sig.Blob); !ok {
return errors.New("ssh: signature did not verify")
}
return nil
}
func (k ed25519PublicKey) CryptoPublicKey() crypto.PublicKey {
return ed25519.PublicKey(k)
}
func supportedEllipticCurve(curve elliptic.Curve) bool {
return curve == elliptic.P256() || curve == elliptic.P384() || curve == elliptic.P521()
}
// parseECDSA parses an ECDSA key according to RFC 5656, section 3.1.
func parseECDSA(in []byte) (out PublicKey, rest []byte, err error) {
var w struct {
Curve string
KeyBytes []byte
Rest []byte `ssh:"rest"`
}
if err := Unmarshal(in, &w); err != nil {
return nil, nil, err
}
key := new(ecdsa.PublicKey)
switch w.Curve {
case "nistp256":
key.Curve = elliptic.P256()
case "nistp384":
key.Curve = elliptic.P384()
case "nistp521":
key.Curve = elliptic.P521()
default:
return nil, nil, errors.New("ssh: unsupported curve")
}
key.X, key.Y = elliptic.Unmarshal(key.Curve, w.KeyBytes)
if key.X == nil || key.Y == nil {
return nil, nil, errors.New("ssh: invalid curve point")
}
return (*ecdsaPublicKey)(key), w.Rest, nil
}
func (k *ecdsaPublicKey) Marshal() []byte {
// See RFC 5656, section 3.1.
keyBytes := elliptic.Marshal(k.Curve, k.X, k.Y)
// ECDSA publickey struct layout should match the struct used by
// parseECDSACert in the x/crypto/ssh/agent package.
w := struct {
Name string
ID string
Key []byte
}{
k.Type(),
k.nistID(),
keyBytes,
}
return Marshal(&w)
}
func (k *ecdsaPublicKey) Verify(data []byte, sig *Signature) error {
if sig.Format != k.Type() {
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type())
}
h := hashFuncs[sig.Format].New()
h.Write(data)
digest := h.Sum(nil)
// Per RFC 5656, section 3.1.2,
// The ecdsa_signature_blob value has the following specific encoding:
// mpint r
// mpint s
var ecSig struct {
R *big.Int
S *big.Int
}
if err := Unmarshal(sig.Blob, &ecSig); err != nil {
return err
}
if ecdsa.Verify((*ecdsa.PublicKey)(k), digest, ecSig.R, ecSig.S) {
return nil
}
return errors.New("ssh: signature did not verify")
}
func (k *ecdsaPublicKey) CryptoPublicKey() crypto.PublicKey {
return (*ecdsa.PublicKey)(k)
}
// skFields holds the additional fields present in U2F/FIDO2 signatures.
// See openssh/PROTOCOL.u2f 'SSH U2F Signatures' for details.
type skFields struct {
// Flags contains U2F/FIDO2 flags such as 'user present'
Flags byte
// Counter is a monotonic signature counter which can be
// used to detect concurrent use of a private key, should
// it be extracted from hardware.
Counter uint32
}
type skECDSAPublicKey struct {
// application is a URL-like string, typically "ssh:" for SSH.
// see openssh/PROTOCOL.u2f for details.
application string
ecdsa.PublicKey
}
func (k *skECDSAPublicKey) Type() string {
return KeyAlgoSKECDSA256
}
func (k *skECDSAPublicKey) nistID() string {
return "nistp256"
}
func parseSKECDSA(in []byte) (out PublicKey, rest []byte, err error) {
var w struct {
Curve string
KeyBytes []byte
Application string
Rest []byte `ssh:"rest"`
}
if err := Unmarshal(in, &w); err != nil {
return nil, nil, err
}
key := new(skECDSAPublicKey)
key.application = w.Application
if w.Curve != "nistp256" {
return nil, nil, errors.New("ssh: unsupported curve")
}
key.Curve = elliptic.P256()
key.X, key.Y = elliptic.Unmarshal(key.Curve, w.KeyBytes)
if key.X == nil || key.Y == nil {
return nil, nil, errors.New("ssh: invalid curve point")
}
return key, w.Rest, nil
}
func (k *skECDSAPublicKey) Marshal() []byte {
// See RFC 5656, section 3.1.
keyBytes := elliptic.Marshal(k.Curve, k.X, k.Y)
w := struct {
Name string
ID string
Key []byte
Application string
}{
k.Type(),
k.nistID(),
keyBytes,
k.application,
}
return Marshal(&w)
}
func (k *skECDSAPublicKey) Verify(data []byte, sig *Signature) error {
if sig.Format != k.Type() {
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type())
}
h := hashFuncs[sig.Format].New()
h.Write([]byte(k.application))
appDigest := h.Sum(nil)
h.Reset()
h.Write(data)
dataDigest := h.Sum(nil)
var ecSig struct {
R *big.Int
S *big.Int
}
if err := Unmarshal(sig.Blob, &ecSig); err != nil {
return err
}
var skf skFields
if err := Unmarshal(sig.Rest, &skf); err != nil {
return err
}
blob := struct {
ApplicationDigest []byte `ssh:"rest"`
Flags byte
Counter uint32
MessageDigest []byte `ssh:"rest"`
}{
appDigest,
skf.Flags,
skf.Counter,
dataDigest,
}
original := Marshal(blob)
h.Reset()
h.Write(original)
digest := h.Sum(nil)
if ecdsa.Verify((*ecdsa.PublicKey)(&k.PublicKey), digest, ecSig.R, ecSig.S) {
return nil
}
return errors.New("ssh: signature did not verify")
}
type skEd25519PublicKey struct {
// application is a URL-like string, typically "ssh:" for SSH.
// see openssh/PROTOCOL.u2f for details.
application string
ed25519.PublicKey
}
func (k *skEd25519PublicKey) Type() string {
return KeyAlgoSKED25519
}
func parseSKEd25519(in []byte) (out PublicKey, rest []byte, err error) {
var w struct {
KeyBytes []byte
Application string
Rest []byte `ssh:"rest"`
}
if err := Unmarshal(in, &w); err != nil {
return nil, nil, err
}
if l := len(w.KeyBytes); l != ed25519.PublicKeySize {
return nil, nil, fmt.Errorf("invalid size %d for Ed25519 public key", l)
}
key := new(skEd25519PublicKey)
key.application = w.Application
key.PublicKey = ed25519.PublicKey(w.KeyBytes)
return key, w.Rest, nil
}
func (k *skEd25519PublicKey) Marshal() []byte {
w := struct {
Name string
KeyBytes []byte
Application string
}{
KeyAlgoSKED25519,
[]byte(k.PublicKey),
k.application,
}
return Marshal(&w)
}
func (k *skEd25519PublicKey) Verify(data []byte, sig *Signature) error {
if sig.Format != k.Type() {
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type())
}
if l := len(k.PublicKey); l != ed25519.PublicKeySize {
return fmt.Errorf("invalid size %d for Ed25519 public key", l)
}
h := hashFuncs[sig.Format].New()
h.Write([]byte(k.application))
appDigest := h.Sum(nil)
h.Reset()
h.Write(data)
dataDigest := h.Sum(nil)
var edSig struct {
Signature []byte `ssh:"rest"`
}
if err := Unmarshal(sig.Blob, &edSig); err != nil {
return err
}
var skf skFields
if err := Unmarshal(sig.Rest, &skf); err != nil {
return err
}
blob := struct {
ApplicationDigest []byte `ssh:"rest"`
Flags byte
Counter uint32
MessageDigest []byte `ssh:"rest"`
}{
appDigest,
skf.Flags,
skf.Counter,
dataDigest,
}
original := Marshal(blob)
if ok := ed25519.Verify(k.PublicKey, original, edSig.Signature); !ok {
return errors.New("ssh: signature did not verify")
}
return nil
}
// NewSignerFromKey takes an *rsa.PrivateKey, *dsa.PrivateKey,
// *ecdsa.PrivateKey or any other crypto.Signer and returns a
// corresponding Signer instance. ECDSA keys must use P-256, P-384 or
// P-521. DSA keys must use parameter size L1024N160.
func NewSignerFromKey(key any) (Signer, error) {
switch key := key.(type) {
case crypto.Signer:
return NewSignerFromSigner(key)
case *dsa.PrivateKey:
return newDSAPrivateKey(key)
default:
return nil, fmt.Errorf("ssh: unsupported key type %T", key)
}
}
func newDSAPrivateKey(key *dsa.PrivateKey) (Signer, error) {
if err := checkDSAParams(&key.PublicKey.Parameters); err != nil {
return nil, err
}
return &dsaPrivateKey{key}, nil
}
type wrappedSigner struct {
signer crypto.Signer
pubKey PublicKey
}
// NewSignerFromSigner takes any crypto.Signer implementation and
// returns a corresponding Signer interface. This can be used, for
// example, with keys kept in hardware modules.
func NewSignerFromSigner(signer crypto.Signer) (Signer, error) {
pubKey, err := NewPublicKey(signer.Public())
if err != nil {
return nil, err
}
return &wrappedSigner{signer, pubKey}, nil
}
func (s *wrappedSigner) PublicKey() PublicKey {
return s.pubKey
}
func (s *wrappedSigner) Sign(rand io.Reader, data []byte) (*Signature, error) {
return s.SignWithAlgorithm(rand, data, s.pubKey.Type())
}
func (s *wrappedSigner) Algorithms() []string {
return algorithmsForKeyFormat(s.pubKey.Type())
}
func (s *wrappedSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) {
if algorithm == "" {
algorithm = s.pubKey.Type()
}
if !contains(s.Algorithms(), algorithm) {
return nil, fmt.Errorf("ssh: unsupported signature algorithm %q for key format %q", algorithm, s.pubKey.Type())
}
hashFunc := hashFuncs[algorithm]
var digest []byte
if hashFunc != 0 {
h := hashFunc.New()
h.Write(data)
digest = h.Sum(nil)
} else {
digest = data
}
signature, err := s.signer.Sign(rand, digest, hashFunc)
if err != nil {
return nil, err
}
// crypto.Signer.Sign is expected to return an ASN.1-encoded signature
// for ECDSA and DSA, but that's not the encoding expected by SSH, so
// re-encode.
switch s.pubKey.(type) {
case *ecdsaPublicKey, *dsaPublicKey:
type asn1Signature struct {
R, S *big.Int
}
asn1Sig := new(asn1Signature)
_, err := asn1.Unmarshal(signature, asn1Sig)
if err != nil {
return nil, err
}
switch s.pubKey.(type) {
case *ecdsaPublicKey:
signature = Marshal(asn1Sig)
case *dsaPublicKey:
signature = make([]byte, 40)
r := asn1Sig.R.Bytes()
s := asn1Sig.S.Bytes()
copy(signature[20-len(r):20], r)
copy(signature[40-len(s):40], s)
}
}
return &Signature{
Format: algorithm,
Blob: signature,
}, nil
}
// NewPublicKey takes an *rsa.PublicKey, *dsa.PublicKey, *ecdsa.PublicKey,
// or ed25519.PublicKey returns a corresponding PublicKey instance.
// ECDSA keys must use P-256, P-384 or P-521.
func NewPublicKey(key any) (PublicKey, error) {
switch key := key.(type) {
case *rsa.PublicKey:
return (*rsaPublicKey)(key), nil
case *ecdsa.PublicKey:
if !supportedEllipticCurve(key.Curve) {
return nil, errors.New("ssh: only P-256, P-384 and P-521 EC keys are supported")
}
return (*ecdsaPublicKey)(key), nil
case *dsa.PublicKey:
return (*dsaPublicKey)(key), nil
case ed25519.PublicKey:
if l := len(key); l != ed25519.PublicKeySize {
return nil, fmt.Errorf("ssh: invalid size %d for Ed25519 public key", l)
}
return ed25519PublicKey(key), nil
default:
return nil, fmt.Errorf("ssh: unsupported key type %T", key)
}
}
// ParsePrivateKey returns a Signer from a PEM encoded private key. It supports
// the same keys as ParseRawPrivateKey. If the private key is encrypted, it
// will return a PassphraseMissingError.
func ParsePrivateKey(pemBytes []byte) (Signer, error) {
key, err := ParseRawPrivateKey(pemBytes)
if err != nil {
return nil, err
}
return NewSignerFromKey(key)
}
// ParsePrivateKeyWithPassphrase returns a Signer from a PEM encoded private
// key and passphrase. It supports the same keys as
// ParseRawPrivateKeyWithPassphrase.
func ParsePrivateKeyWithPassphrase(pemBytes, passphrase []byte) (Signer, error) {
key, err := ParseRawPrivateKeyWithPassphrase(pemBytes, passphrase)
if err != nil {
return nil, err
}
return NewSignerFromKey(key)
}
// encryptedBlock tells whether a private key is
// encrypted by examining its Proc-Type header
// for a mention of ENCRYPTED
// according to RFC 1421 Section 4.6.1.1.
func encryptedBlock(block *pem.Block) bool {
return strings.Contains(block.Headers["Proc-Type"], "ENCRYPTED")
}
// A PassphraseMissingError indicates that parsing this private key requires a
// passphrase. Use ParsePrivateKeyWithPassphrase.
type PassphraseMissingError struct {
// PublicKey will be set if the private key format includes an unencrypted
// public key along with the encrypted private key.
PublicKey PublicKey
}
func (*PassphraseMissingError) Error() string {
return "ssh: this private key is passphrase protected"
}
// ParseRawPrivateKey returns a private key from a PEM encoded private key. It supports
// RSA, DSA, ECDSA, and Ed25519 private keys in PKCS#1, PKCS#8, OpenSSL, and OpenSSH
// formats. If the private key is encrypted, it will return a PassphraseMissingError.
func ParseRawPrivateKey(pemBytes []byte) (any, error) {
block, _ := pem.Decode(pemBytes)
if block == nil {
return nil, errors.New("ssh: no key found")
}
if encryptedBlock(block) {
return nil, &PassphraseMissingError{}
}
switch block.Type {
case "RSA PRIVATE KEY":
return x509.ParsePKCS1PrivateKey(block.Bytes)
// RFC5208 - https://tools.ietf.org/html/rfc5208
case "PRIVATE KEY":
return x509.ParsePKCS8PrivateKey(block.Bytes)
case "EC PRIVATE KEY":
return x509.ParseECPrivateKey(block.Bytes)
case "DSA PRIVATE KEY":
return ParseDSAPrivateKey(block.Bytes)
case "OPENSSH PRIVATE KEY":
return parseOpenSSHPrivateKey(block.Bytes, unencryptedOpenSSHKey)
default:
return nil, fmt.Errorf("ssh: unsupported key type %q", block.Type)
}
}
// ParseRawPrivateKeyWithPassphrase returns a private key decrypted with
// passphrase from a PEM encoded private key. If the passphrase is wrong, it
// will return x509.IncorrectPasswordError.
func ParseRawPrivateKeyWithPassphrase(pemBytes, passphrase []byte) (any, error) {
block, _ := pem.Decode(pemBytes)
if block == nil {
return nil, errors.New("ssh: no key found")
}
if block.Type == "OPENSSH PRIVATE KEY" {
return parseOpenSSHPrivateKey(block.Bytes, passphraseProtectedOpenSSHKey(passphrase))
}
if !encryptedBlock(block) || !x509.IsEncryptedPEMBlock(block) {
return nil, errors.New("ssh: not an encrypted key")
}
buf, err := x509.DecryptPEMBlock(block, passphrase)
if err != nil {
if err == x509.IncorrectPasswordError {
return nil, err
}
return nil, fmt.Errorf("ssh: cannot decode encrypted private keys: %v", err)
}
var result any
switch block.Type {
case "RSA PRIVATE KEY":
result, err = x509.ParsePKCS1PrivateKey(buf)
case "EC PRIVATE KEY":
result, err = x509.ParseECPrivateKey(buf)
case "DSA PRIVATE KEY":
result, err = ParseDSAPrivateKey(buf)
default:
err = fmt.Errorf("ssh: unsupported key type %q", block.Type)
}
// Because of deficiencies in the format, DecryptPEMBlock does not always
// detect an incorrect password. In these cases decrypted DER bytes is
// random noise. If the parsing of the key returns an asn1.StructuralError
// we return x509.IncorrectPasswordError.
if _, ok := err.(asn1.StructuralError); ok {
return nil, x509.IncorrectPasswordError
}
return result, err
}
// ParseDSAPrivateKey returns a DSA private key from its ASN.1 DER encoding, as
// specified by the OpenSSL DSA man page.
func ParseDSAPrivateKey(der []byte) (*dsa.PrivateKey, error) {
var k struct {
Version int
P *big.Int
Q *big.Int
G *big.Int
Pub *big.Int
Priv *big.Int
}
rest, err := asn1.Unmarshal(der, &k)
if err != nil {
return nil, errors.New("ssh: failed to parse DSA key: " + err.Error())
}
if len(rest) > 0 {
return nil, errors.New("ssh: garbage after DSA key")
}
return &dsa.PrivateKey{
PublicKey: dsa.PublicKey{
Parameters: dsa.Parameters{
P: k.P,
Q: k.Q,
G: k.G,
},
Y: k.Pub,
},
X: k.Priv,
}, nil
}
func unencryptedOpenSSHKey(cipherName, kdfName, kdfOpts string, privKeyBlock []byte) ([]byte, error) {
if kdfName != "none" || cipherName != "none" {
return nil, &PassphraseMissingError{}
}
if kdfOpts != "" {
return nil, errors.New("ssh: invalid openssh private key")
}
return privKeyBlock, nil
}
func passphraseProtectedOpenSSHKey(passphrase []byte) openSSHDecryptFunc {
return func(cipherName, kdfName, kdfOpts string, privKeyBlock []byte) ([]byte, error) {
if kdfName == "none" || cipherName == "none" {
return nil, errors.New("ssh: key is not password protected")
}
if kdfName != "bcrypt" {
return nil, fmt.Errorf("ssh: unknown KDF %q, only supports %q", kdfName, "bcrypt")
}
var opts struct {
Salt string
Rounds uint32
}
if err := Unmarshal([]byte(kdfOpts), &opts); err != nil {
return nil, err
}
k, err := bcrypt_pbkdf.Key(passphrase, []byte(opts.Salt), int(opts.Rounds), 32+16)
if err != nil {
return nil, err
}
key, iv := k[:32], k[32:]
c, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
switch cipherName {
case "aes256-ctr":
ctr := cipher.NewCTR(c, iv)
ctr.XORKeyStream(privKeyBlock, privKeyBlock)
case "aes256-cbc":
if len(privKeyBlock)%c.BlockSize() != 0 {
return nil, fmt.Errorf("ssh: invalid encrypted private key length, not a multiple of the block size")
}
cbc := cipher.NewCBCDecrypter(c, iv)
cbc.CryptBlocks(privKeyBlock, privKeyBlock)
default:
return nil, fmt.Errorf("ssh: unknown cipher %q, only supports %q or %q", cipherName, "aes256-ctr", "aes256-cbc")
}
return privKeyBlock, nil
}
}
func unencryptedOpenSSHMarshaler(privKeyBlock []byte) ([]byte, string, string, string, error) {
key := generateOpenSSHPadding(privKeyBlock, 8)
return key, "none", "none", "", nil
}
func passphraseProtectedOpenSSHMarshaler(passphrase []byte) openSSHEncryptFunc {
return func(privKeyBlock []byte) ([]byte, string, string, string, error) {
salt := make([]byte, 16)
if _, err := rand.Read(salt); err != nil {
return nil, "", "", "", err
}
opts := struct {
Salt []byte
Rounds uint32
}{salt, 16}
// Derive key to encrypt the private key block.
k, err := bcrypt_pbkdf.Key(passphrase, salt, int(opts.Rounds), 32+aes.BlockSize)
if err != nil {
return nil, "", "", "", err
}
// Add padding matching the block size of AES.
keyBlock := generateOpenSSHPadding(privKeyBlock, aes.BlockSize)
// Encrypt the private key using the derived secret.
dst := make([]byte, len(keyBlock))
key, iv := k[:32], k[32:]
block, err := aes.NewCipher(key)
if err != nil {
return nil, "", "", "", err
}
stream := cipher.NewCTR(block, iv)
stream.XORKeyStream(dst, keyBlock)
return dst, "aes256-ctr", "bcrypt", string(Marshal(opts)), nil
}
}
const privateKeyAuthMagic = "openssh-key-v1\x00"
type openSSHDecryptFunc func(CipherName, KdfName, KdfOpts string, PrivKeyBlock []byte) ([]byte, error)
type openSSHEncryptFunc func(PrivKeyBlock []byte) (ProtectedKeyBlock []byte, cipherName, kdfName, kdfOptions string, err error)
type openSSHEncryptedPrivateKey struct {
CipherName string
KdfName string
KdfOpts string
NumKeys uint32
PubKey []byte
PrivKeyBlock []byte
}
type openSSHPrivateKey struct {
Check1 uint32
Check2 uint32
Keytype string
Rest []byte `ssh:"rest"`
}
type openSSHRSAPrivateKey struct {
N *big.Int
E *big.Int
D *big.Int
Iqmp *big.Int
P *big.Int
Q *big.Int
Comment string
Pad []byte `ssh:"rest"`
}
type openSSHEd25519PrivateKey struct {
Pub []byte
Priv []byte
Comment string
Pad []byte `ssh:"rest"`
}
type openSSHECDSAPrivateKey struct {
Curve string
Pub []byte
D *big.Int
Comment string
Pad []byte `ssh:"rest"`
}
// parseOpenSSHPrivateKey parses an OpenSSH private key, using the decrypt
// function to unwrap the encrypted portion. unencryptedOpenSSHKey can be used
// as the decrypt function to parse an unencrypted private key. See
// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.key.
func parseOpenSSHPrivateKey(key []byte, decrypt openSSHDecryptFunc) (crypto.PrivateKey, error) {
if len(key) < len(privateKeyAuthMagic) || string(key[:len(privateKeyAuthMagic)]) != privateKeyAuthMagic {
return nil, errors.New("ssh: invalid openssh private key format")
}
remaining := key[len(privateKeyAuthMagic):]
var w openSSHEncryptedPrivateKey
if err := Unmarshal(remaining, &w); err != nil {
return nil, err
}
if w.NumKeys != 1 {
// We only support single key files, and so does OpenSSH.
// https://github.com/openssh/openssh-portable/blob/4103a3ec7/sshkey.c#L4171
return nil, errors.New("ssh: multi-key files are not supported")
}
privKeyBlock, err := decrypt(w.CipherName, w.KdfName, w.KdfOpts, w.PrivKeyBlock)
if err != nil {
if err, ok := err.(*PassphraseMissingError); ok {
pub, errPub := ParsePublicKey(w.PubKey)
if errPub != nil {
return nil, fmt.Errorf("ssh: failed to parse embedded public key: %v", errPub)
}
err.PublicKey = pub
}
return nil, err
}
var pk1 openSSHPrivateKey
if err := Unmarshal(privKeyBlock, &pk1); err != nil || pk1.Check1 != pk1.Check2 {
if w.CipherName != "none" {
return nil, x509.IncorrectPasswordError
}
return nil, errors.New("ssh: malformed OpenSSH key")
}
switch pk1.Keytype {
case KeyAlgoRSA:
var key openSSHRSAPrivateKey
if err := Unmarshal(pk1.Rest, &key); err != nil {
return nil, err
}
if err := checkOpenSSHKeyPadding(key.Pad); err != nil {
return nil, err
}
pk := &rsa.PrivateKey{
PublicKey: rsa.PublicKey{
N: key.N,
E: int(key.E.Int64()),
},
D: key.D,
Primes: []*big.Int{key.P, key.Q},
}
if err := pk.Validate(); err != nil {
return nil, err
}
pk.Precompute()
return pk, nil
case KeyAlgoED25519:
var key openSSHEd25519PrivateKey
if err := Unmarshal(pk1.Rest, &key); err != nil {
return nil, err
}
if len(key.Priv) != ed25519.PrivateKeySize {
return nil, errors.New("ssh: private key unexpected length")
}
if err := checkOpenSSHKeyPadding(key.Pad); err != nil {
return nil, err
}
pk := ed25519.PrivateKey(make([]byte, ed25519.PrivateKeySize))
copy(pk, key.Priv)
return &pk, nil
case KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521:
var key openSSHECDSAPrivateKey
if err := Unmarshal(pk1.Rest, &key); err != nil {
return nil, err
}
if err := checkOpenSSHKeyPadding(key.Pad); err != nil {
return nil, err
}
var curve elliptic.Curve
switch key.Curve {
case "nistp256":
curve = elliptic.P256()
case "nistp384":
curve = elliptic.P384()
case "nistp521":
curve = elliptic.P521()
default:
return nil, errors.New("ssh: unhandled elliptic curve: " + key.Curve)
}
X, Y := elliptic.Unmarshal(curve, key.Pub)
if X == nil || Y == nil {
return nil, errors.New("ssh: failed to unmarshal public key")
}
if key.D.Cmp(curve.Params().N) >= 0 {
return nil, errors.New("ssh: scalar is out of range")
}
x, y := curve.ScalarBaseMult(key.D.Bytes())
if x.Cmp(X) != 0 || y.Cmp(Y) != 0 {
return nil, errors.New("ssh: public key does not match private key")
}
return &ecdsa.PrivateKey{
PublicKey: ecdsa.PublicKey{
Curve: curve,
X: X,
Y: Y,
},
D: key.D,
}, nil
default:
return nil, errors.New("ssh: unhandled key type")
}
}
func marshalOpenSSHPrivateKey(key crypto.PrivateKey, comment string, encrypt openSSHEncryptFunc) (*pem.Block, error) {
var w openSSHEncryptedPrivateKey
var pk1 openSSHPrivateKey
// Random check bytes.
var check uint32
if err := binary.Read(rand.Reader, binary.BigEndian, &check); err != nil {
return nil, err
}
pk1.Check1 = check
pk1.Check2 = check
w.NumKeys = 1
// Use a []byte directly on ed25519 keys.
if k, ok := key.(*ed25519.PrivateKey); ok {
key = *k
}
switch k := key.(type) {
case *rsa.PrivateKey:
E := new(big.Int).SetInt64(int64(k.PublicKey.E))
// Marshal public key:
// E and N are in reversed order in the public and private key.
pubKey := struct {
KeyType string
E *big.Int
N *big.Int
}{
KeyAlgoRSA,
E, k.PublicKey.N,
}
w.PubKey = Marshal(pubKey)
// Marshal private key.
key := openSSHRSAPrivateKey{
N: k.PublicKey.N,
E: E,
D: k.D,
Iqmp: k.Precomputed.Qinv,
P: k.Primes[0],
Q: k.Primes[1],
Comment: comment,
}
pk1.Keytype = KeyAlgoRSA
pk1.Rest = Marshal(key)
case ed25519.PrivateKey:
pub := make([]byte, ed25519.PublicKeySize)
priv := make([]byte, ed25519.PrivateKeySize)
copy(pub, k[32:])
copy(priv, k)
// Marshal public key.
pubKey := struct {
KeyType string
Pub []byte
}{
KeyAlgoED25519, pub,
}
w.PubKey = Marshal(pubKey)
// Marshal private key.
key := openSSHEd25519PrivateKey{
Pub: pub,
Priv: priv,
Comment: comment,
}
pk1.Keytype = KeyAlgoED25519
pk1.Rest = Marshal(key)
case *ecdsa.PrivateKey:
var curve, keyType string
switch name := k.Curve.Params().Name; name {
case "P-256":
curve = "nistp256"
keyType = KeyAlgoECDSA256
case "P-384":
curve = "nistp384"
keyType = KeyAlgoECDSA384
case "P-521":
curve = "nistp521"
keyType = KeyAlgoECDSA521
default:
return nil, errors.New("ssh: unhandled elliptic curve " + name)
}
pub := elliptic.Marshal(k.Curve, k.PublicKey.X, k.PublicKey.Y)
// Marshal public key.
pubKey := struct {
KeyType string
Curve string
Pub []byte
}{
keyType, curve, pub,
}
w.PubKey = Marshal(pubKey)
// Marshal private key.
key := openSSHECDSAPrivateKey{
Curve: curve,
Pub: pub,
D: k.D,
Comment: comment,
}
pk1.Keytype = keyType
pk1.Rest = Marshal(key)
default:
return nil, fmt.Errorf("ssh: unsupported key type %T", k)
}
var err error
// Add padding and encrypt the key if necessary.
w.PrivKeyBlock, w.CipherName, w.KdfName, w.KdfOpts, err = encrypt(Marshal(pk1))
if err != nil {
return nil, err
}
b := Marshal(w)
block := &pem.Block{
Type: "OPENSSH PRIVATE KEY",
Bytes: append([]byte(privateKeyAuthMagic), b...),
}
return block, nil
}
func checkOpenSSHKeyPadding(pad []byte) error {
for i, b := range pad {
if int(b) != i+1 {
return errors.New("ssh: padding not as expected")
}
}
return nil
}
func generateOpenSSHPadding(block []byte, blockSize int) []byte {
for i, l := 0, len(block); (l+i)%blockSize != 0; i++ {
block = append(block, byte(i+1))
}
return block
}
// FingerprintLegacyMD5 returns the user presentation of the key's
// fingerprint as described by RFC 4716 section 4.
func FingerprintLegacyMD5(pubKey PublicKey) string {
md5sum := md5.Sum(pubKey.Marshal())
hexarray := make([]string, len(md5sum))
for i, c := range md5sum {
hexarray[i] = hex.EncodeToString([]byte{c})
}
return strings.Join(hexarray, ":")
}
// FingerprintSHA256 returns the user presentation of the key's
// fingerprint as unpadded base64 encoded sha256 hash.
// This format was introduced from OpenSSH 6.8.
// https://www.openssh.com/txt/release-6.8
// https://tools.ietf.org/html/rfc4648#section-3.2 (unpadded base64 encoding)
func FingerprintSHA256(pubKey PublicKey) string {
sha256sum := sha256.Sum256(pubKey.Marshal())
hash := base64.RawStdEncoding.EncodeToString(sha256sum[:])
return "SHA256:" + hash
}
================================================
FILE: gossh/crypto/ssh/knownhosts/knownhosts.go
================================================
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package knownhosts implements a parser for the OpenSSH known_hosts
// host key database, and provides utility functions for writing
// OpenSSH compliant known_hosts files.
package knownhosts
import (
"bufio"
"bytes"
"crypto/hmac"
"crypto/rand"
"crypto/sha1"
"encoding/base64"
"errors"
"fmt"
"io"
"net"
"os"
"strings"
"gossh/crypto/ssh"
)
// See the sshd manpage
// (http://man.openbsd.org/sshd#SSH_KNOWN_HOSTS_FILE_FORMAT) for
// background.
type addr struct{ host, port string }
func (a *addr) String() string {
h := a.host
if strings.Contains(h, ":") {
h = "[" + h + "]"
}
return h + ":" + a.port
}
type matcher interface {
match(addr) bool
}
type hostPattern struct {
negate bool
addr addr
}
func (p *hostPattern) String() string {
n := ""
if p.negate {
n = "!"
}
return n + p.addr.String()
}
type hostPatterns []hostPattern
func (ps hostPatterns) match(a addr) bool {
matched := false
for _, p := range ps {
if !p.match(a) {
continue
}
if p.negate {
return false
}
matched = true
}
return matched
}
// See
// https://android.googlesource.com/platform/external/openssh/+/ab28f5495c85297e7a597c1ba62e996416da7c7e/addrmatch.c
// The matching of * has no regard for separators, unlike filesystem globs
func wildcardMatch(pat []byte, str []byte) bool {
for {
if len(pat) == 0 {
return len(str) == 0
}
if len(str) == 0 {
return false
}
if pat[0] == '*' {
if len(pat) == 1 {
return true
}
for j := range str {
if wildcardMatch(pat[1:], str[j:]) {
return true
}
}
return false
}
if pat[0] == '?' || pat[0] == str[0] {
pat = pat[1:]
str = str[1:]
} else {
return false
}
}
}
func (p *hostPattern) match(a addr) bool {
return wildcardMatch([]byte(p.addr.host), []byte(a.host)) && p.addr.port == a.port
}
type keyDBLine struct {
cert bool
matcher matcher
knownKey KnownKey
}
func serialize(k ssh.PublicKey) string {
return k.Type() + " " + base64.StdEncoding.EncodeToString(k.Marshal())
}
func (l *keyDBLine) match(a addr) bool {
return l.matcher.match(a)
}
type hostKeyDB struct {
// Serialized version of revoked keys
revoked map[string]*KnownKey
lines []keyDBLine
}
func newHostKeyDB() *hostKeyDB {
db := &hostKeyDB{
revoked: make(map[string]*KnownKey),
}
return db
}
func keyEq(a, b ssh.PublicKey) bool {
return bytes.Equal(a.Marshal(), b.Marshal())
}
// IsHostAuthority can be used as a callback in ssh.CertChecker
func (db *hostKeyDB) IsHostAuthority(remote ssh.PublicKey, address string) bool {
h, p, err := net.SplitHostPort(address)
if err != nil {
return false
}
a := addr{host: h, port: p}
for _, l := range db.lines {
if l.cert && keyEq(l.knownKey.Key, remote) && l.match(a) {
return true
}
}
return false
}
// IsRevoked can be used as a callback in ssh.CertChecker
func (db *hostKeyDB) IsRevoked(key *ssh.Certificate) bool {
_, ok := db.revoked[string(key.Marshal())]
return ok
}
const markerCert = "@cert-authority"
const markerRevoked = "@revoked"
func nextWord(line []byte) (string, []byte) {
i := bytes.IndexAny(line, "\t ")
if i == -1 {
return string(line), nil
}
return string(line[:i]), bytes.TrimSpace(line[i:])
}
func parseLine(line []byte) (marker, host string, key ssh.PublicKey, err error) {
if w, next := nextWord(line); w == markerCert || w == markerRevoked {
marker = w
line = next
}
host, line = nextWord(line)
if len(line) == 0 {
return "", "", nil, errors.New("knownhosts: missing host pattern")
}
// ignore the keytype as it's in the key blob anyway.
_, line = nextWord(line)
if len(line) == 0 {
return "", "", nil, errors.New("knownhosts: missing key type pattern")
}
keyBlob, _ := nextWord(line)
keyBytes, err := base64.StdEncoding.DecodeString(keyBlob)
if err != nil {
return "", "", nil, err
}
key, err = ssh.ParsePublicKey(keyBytes)
if err != nil {
return "", "", nil, err
}
return marker, host, key, nil
}
func (db *hostKeyDB) parseLine(line []byte, filename string, linenum int) error {
marker, pattern, key, err := parseLine(line)
if err != nil {
return err
}
if marker == markerRevoked {
db.revoked[string(key.Marshal())] = &KnownKey{
Key: key,
Filename: filename,
Line: linenum,
}
return nil
}
entry := keyDBLine{
cert: marker == markerCert,
knownKey: KnownKey{
Filename: filename,
Line: linenum,
Key: key,
},
}
if pattern[0] == '|' {
entry.matcher, err = newHashedHost(pattern)
} else {
entry.matcher, err = newHostnameMatcher(pattern)
}
if err != nil {
return err
}
db.lines = append(db.lines, entry)
return nil
}
func newHostnameMatcher(pattern string) (matcher, error) {
var hps hostPatterns
for _, p := range strings.Split(pattern, ",") {
if len(p) == 0 {
continue
}
var a addr
var negate bool
if p[0] == '!' {
negate = true
p = p[1:]
}
if len(p) == 0 {
return nil, errors.New("knownhosts: negation without following hostname")
}
var err error
if p[0] == '[' {
a.host, a.port, err = net.SplitHostPort(p)
if err != nil {
return nil, err
}
} else {
a.host, a.port, err = net.SplitHostPort(p)
if err != nil {
a.host = p
a.port = "22"
}
}
hps = append(hps, hostPattern{
negate: negate,
addr: a,
})
}
return hps, nil
}
// KnownKey represents a key declared in a known_hosts file.
type KnownKey struct {
Key ssh.PublicKey
Filename string
Line int
}
func (k *KnownKey) String() string {
return fmt.Sprintf("%s:%d: %s", k.Filename, k.Line, serialize(k.Key))
}
// KeyError is returned if we did not find the key in the host key
// database, or there was a mismatch. Typically, in batch
// applications, this should be interpreted as failure. Interactive
// applications can offer an interactive prompt to the user.
type KeyError struct {
// Want holds the accepted host keys. For each key algorithm,
// there can be one hostkey. If Want is empty, the host is
// unknown. If Want is non-empty, there was a mismatch, which
// can signify a MITM attack.
Want []KnownKey
}
func (u *KeyError) Error() string {
if len(u.Want) == 0 {
return "knownhosts: key is unknown"
}
return "knownhosts: key mismatch"
}
// RevokedError is returned if we found a key that was revoked.
type RevokedError struct {
Revoked KnownKey
}
func (r *RevokedError) Error() string {
return "knownhosts: key is revoked"
}
// check checks a key against the host database. This should not be
// used for verifying certificates.
func (db *hostKeyDB) check(address string, remote net.Addr, remoteKey ssh.PublicKey) error {
if revoked := db.revoked[string(remoteKey.Marshal())]; revoked != nil {
return &RevokedError{Revoked: *revoked}
}
host, port, err := net.SplitHostPort(remote.String())
if err != nil {
return fmt.Errorf("knownhosts: SplitHostPort(%s): %v", remote, err)
}
hostToCheck := addr{host, port}
if address != "" {
// Give preference to the hostname if available.
host, port, err := net.SplitHostPort(address)
if err != nil {
return fmt.Errorf("knownhosts: SplitHostPort(%s): %v", address, err)
}
hostToCheck = addr{host, port}
}
return db.checkAddr(hostToCheck, remoteKey)
}
// checkAddr checks if we can find the given public key for the
// given address. If we only find an entry for the IP address,
// or only the hostname, then this still succeeds.
func (db *hostKeyDB) checkAddr(a addr, remoteKey ssh.PublicKey) error {
// TODO(hanwen): are these the right semantics? What if there
// is just a key for the IP address, but not for the
// hostname?
// Algorithm => key.
knownKeys := map[string]KnownKey{}
for _, l := range db.lines {
if l.match(a) {
typ := l.knownKey.Key.Type()
if _, ok := knownKeys[typ]; !ok {
knownKeys[typ] = l.knownKey
}
}
}
keyErr := &KeyError{}
for _, v := range knownKeys {
keyErr.Want = append(keyErr.Want, v)
}
// Unknown remote host.
if len(knownKeys) == 0 {
return keyErr
}
// If the remote host starts using a different, unknown key type, we
// also interpret that as a mismatch.
if known, ok := knownKeys[remoteKey.Type()]; !ok || !keyEq(known.Key, remoteKey) {
return keyErr
}
return nil
}
// The Read function parses file contents.
func (db *hostKeyDB) Read(r io.Reader, filename string) error {
scanner := bufio.NewScanner(r)
lineNum := 0
for scanner.Scan() {
lineNum++
line := scanner.Bytes()
line = bytes.TrimSpace(line)
if len(line) == 0 || line[0] == '#' {
continue
}
if err := db.parseLine(line, filename, lineNum); err != nil {
return fmt.Errorf("knownhosts: %s:%d: %v", filename, lineNum, err)
}
}
return scanner.Err()
}
// New creates a host key callback from the given OpenSSH host key
// files. The returned callback is for use in
// ssh.ClientConfig.HostKeyCallback. By preference, the key check
// operates on the hostname if available, i.e. if a server changes its
// IP address, the host key check will still succeed, even though a
// record of the new IP address is not available.
func New(files ...string) (ssh.HostKeyCallback, error) {
db := newHostKeyDB()
for _, fn := range files {
f, err := os.Open(fn)
if err != nil {
return nil, err
}
defer f.Close()
if err := db.Read(f, fn); err != nil {
return nil, err
}
}
var certChecker ssh.CertChecker
certChecker.IsHostAuthority = db.IsHostAuthority
certChecker.IsRevoked = db.IsRevoked
certChecker.HostKeyFallback = db.check
return certChecker.CheckHostKey, nil
}
// Normalize normalizes an address into the form used in known_hosts
func Normalize(address string) string {
host, port, err := net.SplitHostPort(address)
if err != nil {
host = address
port = "22"
}
entry := host
if port != "22" {
entry = "[" + entry + "]:" + port
} else if strings.Contains(host, ":") && !strings.HasPrefix(host, "[") {
entry = "[" + entry + "]"
}
return entry
}
// Line returns a line to add append to the known_hosts files.
func Line(addresses []string, key ssh.PublicKey) string {
var trimmed []string
for _, a := range addresses {
trimmed = append(trimmed, Normalize(a))
}
return strings.Join(trimmed, ",") + " " + serialize(key)
}
// HashHostname hashes the given hostname. The hostname is not
// normalized before hashing.
func HashHostname(hostname string) string {
// TODO(hanwen): check if we can safely normalize this always.
salt := make([]byte, sha1.Size)
_, err := rand.Read(salt)
if err != nil {
panic(fmt.Sprintf("crypto/rand failure %v", err))
}
hash := hashHost(hostname, salt)
return encodeHash(sha1HashType, salt, hash)
}
func decodeHash(encoded string) (hashType string, salt, hash []byte, err error) {
if len(encoded) == 0 || encoded[0] != '|' {
err = errors.New("knownhosts: hashed host must start with '|'")
return
}
components := strings.Split(encoded, "|")
if len(components) != 4 {
err = fmt.Errorf("knownhosts: got %d components, want 3", len(components))
return
}
hashType = components[1]
if salt, err = base64.StdEncoding.DecodeString(components[2]); err != nil {
return
}
if hash, err = base64.StdEncoding.DecodeString(components[3]); err != nil {
return
}
return
}
func encodeHash(typ string, salt []byte, hash []byte) string {
return strings.Join([]string{"",
typ,
base64.StdEncoding.EncodeToString(salt),
base64.StdEncoding.EncodeToString(hash),
}, "|")
}
// See https://android.googlesource.com/platform/external/openssh/+/ab28f5495c85297e7a597c1ba62e996416da7c7e/hostfile.c#120
func hashHost(hostname string, salt []byte) []byte {
mac := hmac.New(sha1.New, salt)
mac.Write([]byte(hostname))
return mac.Sum(nil)
}
type hashedHost struct {
salt []byte
hash []byte
}
const sha1HashType = "1"
func newHashedHost(encoded string) (*hashedHost, error) {
typ, salt, hash, err := decodeHash(encoded)
if err != nil {
return nil, err
}
// The type field seems for future algorithm agility, but it's
// actually hardcoded in openssh currently, see
// https://android.googlesource.com/platform/external/openssh/+/ab28f5495c85297e7a597c1ba62e996416da7c7e/hostfile.c#120
if typ != sha1HashType {
return nil, fmt.Errorf("knownhosts: got hash type %s, must be '1'", typ)
}
return &hashedHost{salt: salt, hash: hash}, nil
}
func (h *hashedHost) match(a addr) bool {
return bytes.Equal(hashHost(Normalize(a.String()), h.salt), h.hash)
}
================================================
FILE: gossh/crypto/ssh/mac.go
================================================
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
// Message authentication support
import (
"crypto/hmac"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"hash"
)
type macMode struct {
keySize int
etm bool
new func(key []byte) hash.Hash
}
// truncatingMAC wraps around a hash.Hash and truncates the output digest to
// a given size.
type truncatingMAC struct {
length int
hmac hash.Hash
}
func (t truncatingMAC) Write(data []byte) (int, error) {
return t.hmac.Write(data)
}
func (t truncatingMAC) Sum(in []byte) []byte {
out := t.hmac.Sum(in)
return out[:len(in)+t.length]
}
func (t truncatingMAC) Reset() {
t.hmac.Reset()
}
func (t truncatingMAC) Size() int {
return t.length
}
func (t truncatingMAC) BlockSize() int { return t.hmac.BlockSize() }
var macModes = map[string]*macMode{
"hmac-sha2-512-etm@openssh.com": {64, true, func(key []byte) hash.Hash {
return hmac.New(sha512.New, key)
}},
"hmac-sha2-256-etm@openssh.com": {32, true, func(key []byte) hash.Hash {
return hmac.New(sha256.New, key)
}},
"hmac-sha2-512": {64, false, func(key []byte) hash.Hash {
return hmac.New(sha512.New, key)
}},
"hmac-sha2-256": {32, false, func(key []byte) hash.Hash {
return hmac.New(sha256.New, key)
}},
"hmac-sha1": {20, false, func(key []byte) hash.Hash {
return hmac.New(sha1.New, key)
}},
"hmac-sha1-96": {20, false, func(key []byte) hash.Hash {
return truncatingMAC{12, hmac.New(sha1.New, key)}
}},
}
================================================
FILE: gossh/crypto/ssh/messages.go
================================================
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"math/big"
"reflect"
"strconv"
"strings"
)
// These are SSH message type numbers. They are scattered around several
// documents but many were taken from [SSH-PARAMETERS].
const (
msgIgnore = 2
msgUnimplemented = 3
msgDebug = 4
msgNewKeys = 21
)
// SSH messages:
//
// These structures mirror the wire format of the corresponding SSH messages.
// They are marshaled using reflection with the marshal and unmarshal functions
// in this file. The only wrinkle is that a final member of type []byte with a
// ssh tag of "rest" receives the remainder of a packet when unmarshaling.
// See RFC 4253, section 11.1.
const msgDisconnect = 1
// disconnectMsg is the message that signals a disconnect. It is also
// the error type returned from mux.Wait()
type disconnectMsg struct {
Reason uint32 `sshtype:"1"`
Message string
Language string
}
func (d *disconnectMsg) Error() string {
return fmt.Sprintf("ssh: disconnect, reason %d: %s", d.Reason, d.Message)
}
// See RFC 4253, section 7.1.
const msgKexInit = 20
type kexInitMsg struct {
Cookie [16]byte `sshtype:"20"`
KexAlgos []string
ServerHostKeyAlgos []string
CiphersClientServer []string
CiphersServerClient []string
MACsClientServer []string
MACsServerClient []string
CompressionClientServer []string
CompressionServerClient []string
LanguagesClientServer []string
LanguagesServerClient []string
FirstKexFollows bool
Reserved uint32
}
// See RFC 4253, section 8.
// Diffie-Hellman
const msgKexDHInit = 30
type kexDHInitMsg struct {
X *big.Int `sshtype:"30"`
}
const msgKexECDHInit = 30
type kexECDHInitMsg struct {
ClientPubKey []byte `sshtype:"30"`
}
const msgKexECDHReply = 31
type kexECDHReplyMsg struct {
HostKey []byte `sshtype:"31"`
EphemeralPubKey []byte
Signature []byte
}
const msgKexDHReply = 31
type kexDHReplyMsg struct {
HostKey []byte `sshtype:"31"`
Y *big.Int
Signature []byte
}
// See RFC 4419, section 5.
const msgKexDHGexGroup = 31
type kexDHGexGroupMsg struct {
P *big.Int `sshtype:"31"`
G *big.Int
}
const msgKexDHGexInit = 32
type kexDHGexInitMsg struct {
X *big.Int `sshtype:"32"`
}
const msgKexDHGexReply = 33
type kexDHGexReplyMsg struct {
HostKey []byte `sshtype:"33"`
Y *big.Int
Signature []byte
}
const msgKexDHGexRequest = 34
type kexDHGexRequestMsg struct {
MinBits uint32 `sshtype:"34"`
PreferedBits uint32
MaxBits uint32
}
// See RFC 4253, section 10.
const msgServiceRequest = 5
type serviceRequestMsg struct {
Service string `sshtype:"5"`
}
// See RFC 4253, section 10.
const msgServiceAccept = 6
type serviceAcceptMsg struct {
Service string `sshtype:"6"`
}
// See RFC 8308, section 2.3
const msgExtInfo = 7
type extInfoMsg struct {
NumExtensions uint32 `sshtype:"7"`
Payload []byte `ssh:"rest"`
}
// See RFC 4252, section 5.
const msgUserAuthRequest = 50
type userAuthRequestMsg struct {
User string `sshtype:"50"`
Service string
Method string
Payload []byte `ssh:"rest"`
}
// Used for debug printouts of packets.
type userAuthSuccessMsg struct {
}
// See RFC 4252, section 5.1
const msgUserAuthFailure = 51
type userAuthFailureMsg struct {
Methods []string `sshtype:"51"`
PartialSuccess bool
}
// See RFC 4252, section 5.1
const msgUserAuthSuccess = 52
// See RFC 4252, section 5.4
const msgUserAuthBanner = 53
type userAuthBannerMsg struct {
Message string `sshtype:"53"`
// unused, but required to allow message parsing
Language string
}
// See RFC 4256, section 3.2
const msgUserAuthInfoRequest = 60
const msgUserAuthInfoResponse = 61
type userAuthInfoRequestMsg struct {
Name string `sshtype:"60"`
Instruction string
Language string
NumPrompts uint32
Prompts []byte `ssh:"rest"`
}
// See RFC 4254, section 5.1.
const msgChannelOpen = 90
type channelOpenMsg struct {
ChanType string `sshtype:"90"`
PeersID uint32
PeersWindow uint32
MaxPacketSize uint32
TypeSpecificData []byte `ssh:"rest"`
}
const msgChannelExtendedData = 95
const msgChannelData = 94
// Used for debug print outs of packets.
type channelDataMsg struct {
PeersID uint32 `sshtype:"94"`
Length uint32
Rest []byte `ssh:"rest"`
}
// See RFC 4254, section 5.1.
const msgChannelOpenConfirm = 91
type channelOpenConfirmMsg struct {
PeersID uint32 `sshtype:"91"`
MyID uint32
MyWindow uint32
MaxPacketSize uint32
TypeSpecificData []byte `ssh:"rest"`
}
// See RFC 4254, section 5.1.
const msgChannelOpenFailure = 92
type channelOpenFailureMsg struct {
PeersID uint32 `sshtype:"92"`
Reason RejectionReason
Message string
Language string
}
const msgChannelRequest = 98
type channelRequestMsg struct {
PeersID uint32 `sshtype:"98"`
Request string
WantReply bool
RequestSpecificData []byte `ssh:"rest"`
}
// See RFC 4254, section 5.4.
const msgChannelSuccess = 99
type channelRequestSuccessMsg struct {
PeersID uint32 `sshtype:"99"`
}
// See RFC 4254, section 5.4.
const msgChannelFailure = 100
type channelRequestFailureMsg struct {
PeersID uint32 `sshtype:"100"`
}
// See RFC 4254, section 5.3
const msgChannelClose = 97
type channelCloseMsg struct {
PeersID uint32 `sshtype:"97"`
}
// See RFC 4254, section 5.3
const msgChannelEOF = 96
type channelEOFMsg struct {
PeersID uint32 `sshtype:"96"`
}
// See RFC 4254, section 4
const msgGlobalRequest = 80
type globalRequestMsg struct {
Type string `sshtype:"80"`
WantReply bool
Data []byte `ssh:"rest"`
}
// See RFC 4254, section 4
const msgRequestSuccess = 81
type globalRequestSuccessMsg struct {
Data []byte `ssh:"rest" sshtype:"81"`
}
// See RFC 4254, section 4
const msgRequestFailure = 82
type globalRequestFailureMsg struct {
Data []byte `ssh:"rest" sshtype:"82"`
}
// See RFC 4254, section 5.2
const msgChannelWindowAdjust = 93
type windowAdjustMsg struct {
PeersID uint32 `sshtype:"93"`
AdditionalBytes uint32
}
// See RFC 4252, section 7
const msgUserAuthPubKeyOk = 60
type userAuthPubKeyOkMsg struct {
Algo string `sshtype:"60"`
PubKey []byte
}
// See RFC 4462, section 3
const msgUserAuthGSSAPIResponse = 60
type userAuthGSSAPIResponse struct {
SupportMech []byte `sshtype:"60"`
}
const msgUserAuthGSSAPIToken = 61
type userAuthGSSAPIToken struct {
Token []byte `sshtype:"61"`
}
const msgUserAuthGSSAPIMIC = 66
type userAuthGSSAPIMIC struct {
MIC []byte `sshtype:"66"`
}
// See RFC 4462, section 3.9
const msgUserAuthGSSAPIErrTok = 64
type userAuthGSSAPIErrTok struct {
ErrorToken []byte `sshtype:"64"`
}
// See RFC 4462, section 3.8
const msgUserAuthGSSAPIError = 65
type userAuthGSSAPIError struct {
MajorStatus uint32 `sshtype:"65"`
MinorStatus uint32
Message string
LanguageTag string
}
// Transport layer OpenSSH extension. See [PROTOCOL], section 1.9
const msgPing = 192
type pingMsg struct {
Data string `sshtype:"192"`
}
// Transport layer OpenSSH extension. See [PROTOCOL], section 1.9
const msgPong = 193
type pongMsg struct {
Data string `sshtype:"193"`
}
// typeTags returns the possible type bytes for the given reflect.Type, which
// should be a struct. The possible values are separated by a '|' character.
func typeTags(structType reflect.Type) (tags []byte) {
tagStr := structType.Field(0).Tag.Get("sshtype")
for _, tag := range strings.Split(tagStr, "|") {
i, err := strconv.Atoi(tag)
if err == nil {
tags = append(tags, byte(i))
}
}
return tags
}
func fieldError(t reflect.Type, field int, problem string) error {
if problem != "" {
problem = ": " + problem
}
return fmt.Errorf("ssh: unmarshal error for field %s of type %s%s", t.Field(field).Name, t.Name(), problem)
}
var errShortRead = errors.New("ssh: short read")
// Unmarshal parses data in SSH wire format into a structure. The out
// argument should be a pointer to struct. If the first member of the
// struct has the "sshtype" tag set to a '|'-separated set of numbers
// in decimal, the packet must start with one of those numbers. In
// case of error, Unmarshal returns a ParseError or
// UnexpectedMessageError.
func Unmarshal(data []byte, out any) error {
v := reflect.ValueOf(out).Elem()
structType := v.Type()
expectedTypes := typeTags(structType)
var expectedType byte
if len(expectedTypes) > 0 {
expectedType = expectedTypes[0]
}
if len(data) == 0 {
return parseError(expectedType)
}
if len(expectedTypes) > 0 {
goodType := false
for _, e := range expectedTypes {
if e > 0 && data[0] == e {
goodType = true
break
}
}
if !goodType {
return fmt.Errorf("ssh: unexpected message type %d (expected one of %v)", data[0], expectedTypes)
}
data = data[1:]
}
var ok bool
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
t := field.Type()
switch t.Kind() {
case reflect.Bool:
if len(data) < 1 {
return errShortRead
}
field.SetBool(data[0] != 0)
data = data[1:]
case reflect.Array:
if t.Elem().Kind() != reflect.Uint8 {
return fieldError(structType, i, "array of unsupported type")
}
if len(data) < t.Len() {
return errShortRead
}
for j, n := 0, t.Len(); j < n; j++ {
field.Index(j).Set(reflect.ValueOf(data[j]))
}
data = data[t.Len():]
case reflect.Uint64:
var u64 uint64
if u64, data, ok = parseUint64(data); !ok {
return errShortRead
}
field.SetUint(u64)
case reflect.Uint32:
var u32 uint32
if u32, data, ok = parseUint32(data); !ok {
return errShortRead
}
field.SetUint(uint64(u32))
case reflect.Uint8:
if len(data) < 1 {
return errShortRead
}
field.SetUint(uint64(data[0]))
data = data[1:]
case reflect.String:
var s []byte
if s, data, ok = parseString(data); !ok {
return fieldError(structType, i, "")
}
field.SetString(string(s))
case reflect.Slice:
switch t.Elem().Kind() {
case reflect.Uint8:
if structType.Field(i).Tag.Get("ssh") == "rest" {
field.Set(reflect.ValueOf(data))
data = nil
} else {
var s []byte
if s, data, ok = parseString(data); !ok {
return errShortRead
}
field.Set(reflect.ValueOf(s))
}
case reflect.String:
var nl []string
if nl, data, ok = parseNameList(data); !ok {
return errShortRead
}
field.Set(reflect.ValueOf(nl))
default:
return fieldError(structType, i, "slice of unsupported type")
}
case reflect.Ptr:
if t == bigIntType {
var n *big.Int
if n, data, ok = parseInt(data); !ok {
return errShortRead
}
field.Set(reflect.ValueOf(n))
} else {
return fieldError(structType, i, "pointer to unsupported type")
}
default:
return fieldError(structType, i, fmt.Sprintf("unsupported type: %v", t))
}
}
if len(data) != 0 {
return parseError(expectedType)
}
return nil
}
// Marshal serializes the message in msg to SSH wire format. The msg
// argument should be a struct or pointer to struct. If the first
// member has the "sshtype" tag set to a number in decimal, that
// number is prepended to the result. If the last of member has the
// "ssh" tag set to "rest", its contents are appended to the output.
func Marshal(msg any) []byte {
out := make([]byte, 0, 64)
return marshalStruct(out, msg)
}
func marshalStruct(out []byte, msg any) []byte {
v := reflect.Indirect(reflect.ValueOf(msg))
msgTypes := typeTags(v.Type())
if len(msgTypes) > 0 {
out = append(out, msgTypes[0])
}
for i, n := 0, v.NumField(); i < n; i++ {
field := v.Field(i)
switch t := field.Type(); t.Kind() {
case reflect.Bool:
var v uint8
if field.Bool() {
v = 1
}
out = append(out, v)
case reflect.Array:
if t.Elem().Kind() != reflect.Uint8 {
panic(fmt.Sprintf("array of non-uint8 in field %d: %T", i, field.Interface()))
}
for j, l := 0, t.Len(); j < l; j++ {
out = append(out, uint8(field.Index(j).Uint()))
}
case reflect.Uint32:
out = appendU32(out, uint32(field.Uint()))
case reflect.Uint64:
out = appendU64(out, uint64(field.Uint()))
case reflect.Uint8:
out = append(out, uint8(field.Uint()))
case reflect.String:
s := field.String()
out = appendInt(out, len(s))
out = append(out, s...)
case reflect.Slice:
switch t.Elem().Kind() {
case reflect.Uint8:
if v.Type().Field(i).Tag.Get("ssh") != "rest" {
out = appendInt(out, field.Len())
}
out = append(out, field.Bytes()...)
case reflect.String:
offset := len(out)
out = appendU32(out, 0)
if n := field.Len(); n > 0 {
for j := 0; j < n; j++ {
f := field.Index(j)
if j != 0 {
out = append(out, ',')
}
out = append(out, f.String()...)
}
// overwrite length value
binary.BigEndian.PutUint32(out[offset:], uint32(len(out)-offset-4))
}
default:
panic(fmt.Sprintf("slice of unknown type in field %d: %T", i, field.Interface()))
}
case reflect.Ptr:
if t == bigIntType {
var n *big.Int
nValue := reflect.ValueOf(&n)
nValue.Elem().Set(field)
needed := intLength(n)
oldLength := len(out)
if cap(out)-len(out) < needed {
newOut := make([]byte, len(out), 2*(len(out)+needed))
copy(newOut, out)
out = newOut
}
out = out[:oldLength+needed]
marshalInt(out[oldLength:], n)
} else {
panic(fmt.Sprintf("pointer to unknown type in field %d: %T", i, field.Interface()))
}
}
}
return out
}
var bigOne = big.NewInt(1)
func parseString(in []byte) (out, rest []byte, ok bool) {
if len(in) < 4 {
return
}
length := binary.BigEndian.Uint32(in)
in = in[4:]
if uint32(len(in)) < length {
return
}
out = in[:length]
rest = in[length:]
ok = true
return
}
var (
comma = []byte{','}
emptyNameList = []string{}
)
func parseNameList(in []byte) (out []string, rest []byte, ok bool) {
contents, rest, ok := parseString(in)
if !ok {
return
}
if len(contents) == 0 {
out = emptyNameList
return
}
parts := bytes.Split(contents, comma)
out = make([]string, len(parts))
for i, part := range parts {
out[i] = string(part)
}
return
}
func parseInt(in []byte) (out *big.Int, rest []byte, ok bool) {
contents, rest, ok := parseString(in)
if !ok {
return
}
out = new(big.Int)
if len(contents) > 0 && contents[0]&0x80 == 0x80 {
// This is a negative number
notBytes := make([]byte, len(contents))
for i := range notBytes {
notBytes[i] = ^contents[i]
}
out.SetBytes(notBytes)
out.Add(out, bigOne)
out.Neg(out)
} else {
// Positive number
out.SetBytes(contents)
}
ok = true
return
}
func parseUint32(in []byte) (uint32, []byte, bool) {
if len(in) < 4 {
return 0, nil, false
}
return binary.BigEndian.Uint32(in), in[4:], true
}
func parseUint64(in []byte) (uint64, []byte, bool) {
if len(in) < 8 {
return 0, nil, false
}
return binary.BigEndian.Uint64(in), in[8:], true
}
func intLength(n *big.Int) int {
length := 4 /* length bytes */
if n.Sign() < 0 {
nMinus1 := new(big.Int).Neg(n)
nMinus1.Sub(nMinus1, bigOne)
bitLen := nMinus1.BitLen()
if bitLen%8 == 0 {
// The number will need 0xff padding
length++
}
length += (bitLen + 7) / 8
} else if n.Sign() == 0 {
// A zero is the zero length string
} else {
bitLen := n.BitLen()
if bitLen%8 == 0 {
// The number will need 0x00 padding
length++
}
length += (bitLen + 7) / 8
}
return length
}
func marshalUint32(to []byte, n uint32) []byte {
binary.BigEndian.PutUint32(to, n)
return to[4:]
}
func marshalUint64(to []byte, n uint64) []byte {
binary.BigEndian.PutUint64(to, n)
return to[8:]
}
func marshalInt(to []byte, n *big.Int) []byte {
lengthBytes := to
to = to[4:]
length := 0
if n.Sign() < 0 {
// A negative number has to be converted to two's-complement
// form. So we'll subtract 1 and invert. If the
// most-significant-bit isn't set then we'll need to pad the
// beginning with 0xff in order to keep the number negative.
nMinus1 := new(big.Int).Neg(n)
nMinus1.Sub(nMinus1, bigOne)
bytes := nMinus1.Bytes()
for i := range bytes {
bytes[i] ^= 0xff
}
if len(bytes) == 0 || bytes[0]&0x80 == 0 {
to[0] = 0xff
to = to[1:]
length++
}
nBytes := copy(to, bytes)
to = to[nBytes:]
length += nBytes
} else if n.Sign() == 0 {
// A zero is the zero length string
} else {
bytes := n.Bytes()
if len(bytes) > 0 && bytes[0]&0x80 != 0 {
// We'll have to pad this with a 0x00 in order to
// stop it looking like a negative number.
to[0] = 0
to = to[1:]
length++
}
nBytes := copy(to, bytes)
to = to[nBytes:]
length += nBytes
}
lengthBytes[0] = byte(length >> 24)
lengthBytes[1] = byte(length >> 16)
lengthBytes[2] = byte(length >> 8)
lengthBytes[3] = byte(length)
return to
}
func writeInt(w io.Writer, n *big.Int) {
length := intLength(n)
buf := make([]byte, length)
marshalInt(buf, n)
w.Write(buf)
}
func writeString(w io.Writer, s []byte) {
var lengthBytes [4]byte
lengthBytes[0] = byte(len(s) >> 24)
lengthBytes[1] = byte(len(s) >> 16)
lengthBytes[2] = byte(len(s) >> 8)
lengthBytes[3] = byte(len(s))
w.Write(lengthBytes[:])
w.Write(s)
}
func stringLength(n int) int {
return 4 + n
}
func marshalString(to []byte, s []byte) []byte {
to[0] = byte(len(s) >> 24)
to[1] = byte(len(s) >> 16)
to[2] = byte(len(s) >> 8)
to[3] = byte(len(s))
to = to[4:]
copy(to, s)
return to[len(s):]
}
var bigIntType = reflect.TypeOf((*big.Int)(nil))
// Decode a packet into its corresponding message.
func decode(packet []byte) (any, error) {
var msg any
switch packet[0] {
case msgDisconnect:
msg = new(disconnectMsg)
case msgServiceRequest:
msg = new(serviceRequestMsg)
case msgServiceAccept:
msg = new(serviceAcceptMsg)
case msgExtInfo:
msg = new(extInfoMsg)
case msgKexInit:
msg = new(kexInitMsg)
case msgKexDHInit:
msg = new(kexDHInitMsg)
case msgKexDHReply:
msg = new(kexDHReplyMsg)
case msgUserAuthRequest:
msg = new(userAuthRequestMsg)
case msgUserAuthSuccess:
return new(userAuthSuccessMsg), nil
case msgUserAuthFailure:
msg = new(userAuthFailureMsg)
case msgUserAuthPubKeyOk:
msg = new(userAuthPubKeyOkMsg)
case msgGlobalRequest:
msg = new(globalRequestMsg)
case msgRequestSuccess:
msg = new(globalRequestSuccessMsg)
case msgRequestFailure:
msg = new(globalRequestFailureMsg)
case msgChannelOpen:
msg = new(channelOpenMsg)
case msgChannelData:
msg = new(channelDataMsg)
case msgChannelOpenConfirm:
msg = new(channelOpenConfirmMsg)
case msgChannelOpenFailure:
msg = new(channelOpenFailureMsg)
case msgChannelWindowAdjust:
msg = new(windowAdjustMsg)
case msgChannelEOF:
msg = new(channelEOFMsg)
case msgChannelClose:
msg = new(channelCloseMsg)
case msgChannelRequest:
msg = new(channelRequestMsg)
case msgChannelSuccess:
msg = new(channelRequestSuccessMsg)
case msgChannelFailure:
msg = new(channelRequestFailureMsg)
case msgUserAuthGSSAPIToken:
msg = new(userAuthGSSAPIToken)
case msgUserAuthGSSAPIMIC:
msg = new(userAuthGSSAPIMIC)
case msgUserAuthGSSAPIErrTok:
msg = new(userAuthGSSAPIErrTok)
case msgUserAuthGSSAPIError:
msg = new(userAuthGSSAPIError)
default:
return nil, unexpectedMessageError(0, packet[0])
}
if err := Unmarshal(packet, msg); err != nil {
return nil, err
}
return msg, nil
}
var packetTypeNames = map[byte]string{
msgDisconnect: "disconnectMsg",
msgServiceRequest: "serviceRequestMsg",
msgServiceAccept: "serviceAcceptMsg",
msgExtInfo: "extInfoMsg",
msgKexInit: "kexInitMsg",
msgKexDHInit: "kexDHInitMsg",
msgKexDHReply: "kexDHReplyMsg",
msgUserAuthRequest: "userAuthRequestMsg",
msgUserAuthSuccess: "userAuthSuccessMsg",
msgUserAuthFailure: "userAuthFailureMsg",
msgUserAuthPubKeyOk: "userAuthPubKeyOkMsg",
msgGlobalRequest: "globalRequestMsg",
msgRequestSuccess: "globalRequestSuccessMsg",
msgRequestFailure: "globalRequestFailureMsg",
msgChannelOpen: "channelOpenMsg",
msgChannelData: "channelDataMsg",
msgChannelOpenConfirm: "channelOpenConfirmMsg",
msgChannelOpenFailure: "channelOpenFailureMsg",
msgChannelWindowAdjust: "windowAdjustMsg",
msgChannelEOF: "channelEOFMsg",
msgChannelClose: "channelCloseMsg",
msgChannelRequest: "channelRequestMsg",
msgChannelSuccess: "channelRequestSuccessMsg",
msgChannelFailure: "channelRequestFailureMsg",
}
================================================
FILE: gossh/crypto/ssh/mux.go
================================================
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"encoding/binary"
"fmt"
"io"
"log"
"sync"
"sync/atomic"
)
// debugMux, if set, causes messages in the connection protocol to be
// logged.
const debugMux = false
// chanList is a thread safe channel list.
type chanList struct {
// protects concurrent access to chans
sync.Mutex
// chans are indexed by the local id of the channel, which the
// other side should send in the PeersId field.
chans []*channel
// This is a debugging aid: it offsets all IDs by this
// amount. This helps distinguish otherwise identical
// server/client muxes
offset uint32
}
// Assigns a channel ID to the given channel.
func (c *chanList) add(ch *channel) uint32 {
c.Lock()
defer c.Unlock()
for i := range c.chans {
if c.chans[i] == nil {
c.chans[i] = ch
return uint32(i) + c.offset
}
}
c.chans = append(c.chans, ch)
return uint32(len(c.chans)-1) + c.offset
}
// getChan returns the channel for the given ID.
func (c *chanList) getChan(id uint32) *channel {
id -= c.offset
c.Lock()
defer c.Unlock()
if id < uint32(len(c.chans)) {
return c.chans[id]
}
return nil
}
func (c *chanList) remove(id uint32) {
id -= c.offset
c.Lock()
if id < uint32(len(c.chans)) {
c.chans[id] = nil
}
c.Unlock()
}
// dropAll forgets all channels it knows, returning them in a slice.
func (c *chanList) dropAll() []*channel {
c.Lock()
defer c.Unlock()
var r []*channel
for _, ch := range c.chans {
if ch == nil {
continue
}
r = append(r, ch)
}
c.chans = nil
return r
}
// mux represents the state for the SSH connection protocol, which
// multiplexes many channels onto a single packet transport.
type mux struct {
conn packetConn
chanList chanList
incomingChannels chan NewChannel
globalSentMu sync.Mutex
globalResponses chan any
incomingRequests chan *Request
errCond *sync.Cond
err error
}
// When debugging, each new chanList instantiation has a different
// offset.
var globalOff uint32
func (m *mux) Wait() error {
m.errCond.L.Lock()
defer m.errCond.L.Unlock()
for m.err == nil {
m.errCond.Wait()
}
return m.err
}
// newMux returns a mux that runs over the given connection.
func newMux(p packetConn) *mux {
m := &mux{
conn: p,
incomingChannels: make(chan NewChannel, chanSize),
globalResponses: make(chan any, 1),
incomingRequests: make(chan *Request, chanSize),
errCond: newCond(),
}
if debugMux {
m.chanList.offset = atomic.AddUint32(&globalOff, 1)
}
go m.loop()
return m
}
func (m *mux) sendMessage(msg any) error {
p := Marshal(msg)
if debugMux {
log.Printf("send global(%d): %#v", m.chanList.offset, msg)
}
return m.conn.writePacket(p)
}
func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) {
if wantReply {
m.globalSentMu.Lock()
defer m.globalSentMu.Unlock()
}
if err := m.sendMessage(globalRequestMsg{
Type: name,
WantReply: wantReply,
Data: payload,
}); err != nil {
return false, nil, err
}
if !wantReply {
return false, nil, nil
}
msg, ok := <-m.globalResponses
if !ok {
return false, nil, io.EOF
}
switch msg := msg.(type) {
case *globalRequestFailureMsg:
return false, msg.Data, nil
case *globalRequestSuccessMsg:
return true, msg.Data, nil
default:
return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg)
}
}
// ackRequest must be called after processing a global request that
// has WantReply set.
func (m *mux) ackRequest(ok bool, data []byte) error {
if ok {
return m.sendMessage(globalRequestSuccessMsg{Data: data})
}
return m.sendMessage(globalRequestFailureMsg{Data: data})
}
func (m *mux) Close() error {
return m.conn.Close()
}
// loop runs the connection machine. It will process packets until an
// error is encountered. To synchronize on loop exit, use mux.Wait.
func (m *mux) loop() {
var err error
for err == nil {
err = m.onePacket()
}
for _, ch := range m.chanList.dropAll() {
ch.close()
}
close(m.incomingChannels)
close(m.incomingRequests)
close(m.globalResponses)
m.conn.Close()
m.errCond.L.Lock()
m.err = err
m.errCond.Broadcast()
m.errCond.L.Unlock()
if debugMux {
log.Println("loop exit", err)
}
}
// onePacket reads and processes one packet.
func (m *mux) onePacket() error {
packet, err := m.conn.readPacket()
if err != nil {
return err
}
if debugMux {
if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData {
log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet))
} else {
p, _ := decode(packet)
log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet))
}
}
switch packet[0] {
case msgChannelOpen:
return m.handleChannelOpen(packet)
case msgGlobalRequest, msgRequestSuccess, msgRequestFailure:
return m.handleGlobalPacket(packet)
case msgPing:
var msg pingMsg
if err := Unmarshal(packet, &msg); err != nil {
return fmt.Errorf("failed to unmarshal ping@openssh.com message: %w", err)
}
return m.sendMessage(pongMsg(msg))
}
// assume a channel packet.
if len(packet) < 5 {
return parseError(packet[0])
}
id := binary.BigEndian.Uint32(packet[1:])
ch := m.chanList.getChan(id)
if ch == nil {
return m.handleUnknownChannelPacket(id, packet)
}
return ch.handlePacket(packet)
}
func (m *mux) handleGlobalPacket(packet []byte) error {
msg, err := decode(packet)
if err != nil {
return err
}
switch msg := msg.(type) {
case *globalRequestMsg:
m.incomingRequests <- &Request{
Type: msg.Type,
WantReply: msg.WantReply,
Payload: msg.Data,
mux: m,
}
case *globalRequestSuccessMsg, *globalRequestFailureMsg:
m.globalResponses <- msg
default:
panic(fmt.Sprintf("not a global message %#v", msg))
}
return nil
}
// handleChannelOpen schedules a channel to be Accept()ed.
func (m *mux) handleChannelOpen(packet []byte) error {
var msg channelOpenMsg
if err := Unmarshal(packet, &msg); err != nil {
return err
}
if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
failMsg := channelOpenFailureMsg{
PeersID: msg.PeersID,
Reason: ConnectionFailed,
Message: "invalid request",
Language: "en_US.UTF-8",
}
return m.sendMessage(failMsg)
}
c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData)
c.remoteId = msg.PeersID
c.maxRemotePayload = msg.MaxPacketSize
c.remoteWin.add(msg.PeersWindow)
m.incomingChannels <- c
return nil
}
func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) {
ch, err := m.openChannel(chanType, extra)
if err != nil {
return nil, nil, err
}
return ch, ch.incomingRequests, nil
}
func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) {
ch := m.newChannel(chanType, channelOutbound, extra)
ch.maxIncomingPayload = channelMaxPacket
open := channelOpenMsg{
ChanType: chanType,
PeersWindow: ch.myWindow,
MaxPacketSize: ch.maxIncomingPayload,
TypeSpecificData: extra,
PeersID: ch.localId,
}
if err := m.sendMessage(open); err != nil {
return nil, err
}
switch msg := (<-ch.msg).(type) {
case *channelOpenConfirmMsg:
return ch, nil
case *channelOpenFailureMsg:
return nil, &OpenChannelError{msg.Reason, msg.Message}
default:
return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg)
}
}
func (m *mux) handleUnknownChannelPacket(id uint32, packet []byte) error {
msg, err := decode(packet)
if err != nil {
return err
}
switch msg := msg.(type) {
// RFC 4254 section 5.4 says unrecognized channel requests should
// receive a failure response.
case *channelRequestMsg:
if msg.WantReply {
return m.sendMessage(channelRequestFailureMsg{
PeersID: msg.PeersID,
})
}
return nil
default:
return fmt.Errorf("ssh: invalid channel %d", id)
}
}
================================================
FILE: gossh/crypto/ssh/server.go
================================================
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"errors"
"fmt"
"io"
"net"
"strings"
)
// The Permissions type holds fine-grained permissions that are
// specific to a user or a specific authentication method for a user.
// The Permissions value for a successful authentication attempt is
// available in ServerConn, so it can be used to pass information from
// the user-authentication phase to the application layer.
type Permissions struct {
// CriticalOptions indicate restrictions to the default
// permissions, and are typically used in conjunction with
// user certificates. The standard for SSH certificates
// defines "force-command" (only allow the given command to
// execute) and "source-address" (only allow connections from
// the given address). The SSH package currently only enforces
// the "source-address" critical option. It is up to server
// implementations to enforce other critical options, such as
// "force-command", by checking them after the SSH handshake
// is successful. In general, SSH servers should reject
// connections that specify critical options that are unknown
// or not supported.
CriticalOptions map[string]string
// Extensions are extra functionality that the server may
// offer on authenticated connections. Lack of support for an
// extension does not preclude authenticating a user. Common
// extensions are "permit-agent-forwarding",
// "permit-X11-forwarding". The Go SSH library currently does
// not act on any extension, and it is up to server
// implementations to honor them. Extensions can be used to
// pass data from the authentication callbacks to the server
// application layer.
Extensions map[string]string
}
type GSSAPIWithMICConfig struct {
// AllowLogin, must be set, is called when gssapi-with-mic
// authentication is selected (RFC 4462 section 3). The srcName is from the
// results of the GSS-API authentication. The format is username@DOMAIN.
// GSSAPI just guarantees to the server who the user is, but not if they can log in, and with what permissions.
// This callback is called after the user identity is established with GSSAPI to decide if the user can login with
// which permissions. If the user is allowed to login, it should return a nil error.
AllowLogin func(conn ConnMetadata, srcName string) (*Permissions, error)
// Server must be set. It's the implementation
// of the GSSAPIServer interface. See GSSAPIServer interface for details.
Server GSSAPIServer
}
// ServerConfig holds server specific configuration data.
type ServerConfig struct {
// Config contains configuration shared between client and server.
Config
// PublicKeyAuthAlgorithms specifies the supported client public key
// authentication algorithms. Note that this should not include certificate
// types since those use the underlying algorithm. This list is sent to the
// client if it supports the server-sig-algs extension. Order is irrelevant.
// If unspecified then a default set of algorithms is used.
PublicKeyAuthAlgorithms []string
hostKeys []Signer
// NoClientAuth is true if clients are allowed to connect without
// authenticating.
// To determine NoClientAuth at runtime, set NoClientAuth to true
// and the optional NoClientAuthCallback to a non-nil value.
NoClientAuth bool
// NoClientAuthCallback, if non-nil, is called when a user
// attempts to authenticate with auth method "none".
// NoClientAuth must also be set to true for this be used, or
// this func is unused.
NoClientAuthCallback func(ConnMetadata) (*Permissions, error)
// MaxAuthTries specifies the maximum number of authentication attempts
// permitted per connection. If set to a negative number, the number of
// attempts are unlimited. If set to zero, the number of attempts are limited
// to 6.
MaxAuthTries int
// PasswordCallback, if non-nil, is called when a user
// attempts to authenticate using a password.
PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error)
// PublicKeyCallback, if non-nil, is called when a client
// offers a public key for authentication. It must return a nil error
// if the given public key can be used to authenticate the
// given user. For example, see CertChecker.Authenticate. A
// call to this function does not guarantee that the key
// offered is in fact used to authenticate. To record any data
// depending on the public key, store it inside a
// Permissions.Extensions entry.
PublicKeyCallback func(conn ConnMetadata, key PublicKey) (*Permissions, error)
// KeyboardInteractiveCallback, if non-nil, is called when
// keyboard-interactive authentication is selected (RFC
// 4256). The client object's Challenge function should be
// used to query the user. The callback may offer multiple
// Challenge rounds. To avoid information leaks, the client
// should be presented a challenge even if the user is
// unknown.
KeyboardInteractiveCallback func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error)
// AuthLogCallback, if non-nil, is called to log all authentication
// attempts.
AuthLogCallback func(conn ConnMetadata, method string, err error)
// ServerVersion is the version identification string to announce in
// the public handshake.
// If empty, a reasonable default is used.
// Note that RFC 4253 section 4.2 requires that this string start with
// "SSH-2.0-".
ServerVersion string
// BannerCallback, if present, is called and the return string is sent to
// the client after key exchange completed but before authentication.
BannerCallback func(conn ConnMetadata) string
// GSSAPIWithMICConfig includes gssapi server and callback, which if both non-nil, is used
// when gssapi-with-mic authentication is selected (RFC 4462 section 3).
GSSAPIWithMICConfig *GSSAPIWithMICConfig
}
// AddHostKey adds a private key as a host key. If an existing host
// key exists with the same public key format, it is replaced. Each server
// config must have at least one host key.
func (s *ServerConfig) AddHostKey(key Signer) {
for i, k := range s.hostKeys {
if k.PublicKey().Type() == key.PublicKey().Type() {
s.hostKeys[i] = key
return
}
}
s.hostKeys = append(s.hostKeys, key)
}
// cachedPubKey contains the results of querying whether a public key is
// acceptable for a user.
type cachedPubKey struct {
user string
pubKeyData []byte
result error
perms *Permissions
}
const maxCachedPubKeys = 16
// pubKeyCache caches tests for public keys. Since SSH clients
// will query whether a public key is acceptable before attempting to
// authenticate with it, we end up with duplicate queries for public
// key validity. The cache only applies to a single ServerConn.
type pubKeyCache struct {
keys []cachedPubKey
}
// get returns the result for a given user/algo/key tuple.
func (c *pubKeyCache) get(user string, pubKeyData []byte) (cachedPubKey, bool) {
for _, k := range c.keys {
if k.user == user && bytes.Equal(k.pubKeyData, pubKeyData) {
return k, true
}
}
return cachedPubKey{}, false
}
// add adds the given tuple to the cache.
func (c *pubKeyCache) add(candidate cachedPubKey) {
if len(c.keys) < maxCachedPubKeys {
c.keys = append(c.keys, candidate)
}
}
// ServerConn is an authenticated SSH connection, as seen from the
// server
type ServerConn struct {
Conn
// If the succeeding authentication callback returned a
// non-nil Permissions pointer, it is stored here.
Permissions *Permissions
}
// NewServerConn starts a new SSH server with c as the underlying
// transport. It starts with a handshake and, if the handshake is
// unsuccessful, it closes the connection and returns an error. The
// Request and NewChannel channels must be serviced, or the connection
// will hang.
//
// The returned error may be of type *ServerAuthError for
// authentication errors.
func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewChannel, <-chan *Request, error) {
fullConf := *config
fullConf.SetDefaults()
if fullConf.MaxAuthTries == 0 {
fullConf.MaxAuthTries = 6
}
if len(fullConf.PublicKeyAuthAlgorithms) == 0 {
fullConf.PublicKeyAuthAlgorithms = supportedPubKeyAuthAlgos
} else {
for _, algo := range fullConf.PublicKeyAuthAlgorithms {
if !contains(supportedPubKeyAuthAlgos, algo) {
c.Close()
return nil, nil, nil, fmt.Errorf("ssh: unsupported public key authentication algorithm %s", algo)
}
}
}
// Check if the config contains any unsupported key exchanges
for _, kex := range fullConf.KeyExchanges {
if _, ok := serverForbiddenKexAlgos[kex]; ok {
c.Close()
return nil, nil, nil, fmt.Errorf("ssh: unsupported key exchange %s for server", kex)
}
}
s := &connection{
sshConn: sshConn{conn: c},
}
perms, err := s.serverHandshake(&fullConf)
if err != nil {
c.Close()
return nil, nil, nil, err
}
return &ServerConn{s, perms}, s.mux.incomingChannels, s.mux.incomingRequests, nil
}
// signAndMarshal signs the data with the appropriate algorithm,
// and serializes the result in SSH wire format. algo is the negotiate
// algorithm and may be a certificate type.
func signAndMarshal(k AlgorithmSigner, rand io.Reader, data []byte, algo string) ([]byte, error) {
sig, err := k.SignWithAlgorithm(rand, data, underlyingAlgo(algo))
if err != nil {
return nil, err
}
return Marshal(sig), nil
}
// handshake performs key exchange and user authentication.
func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error) {
if len(config.hostKeys) == 0 {
return nil, errors.New("ssh: server has no host keys")
}
if !config.NoClientAuth && config.PasswordCallback == nil && config.PublicKeyCallback == nil &&
config.KeyboardInteractiveCallback == nil && (config.GSSAPIWithMICConfig == nil ||
config.GSSAPIWithMICConfig.AllowLogin == nil || config.GSSAPIWithMICConfig.Server == nil) {
return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false")
}
if config.ServerVersion != "" {
s.serverVersion = []byte(config.ServerVersion)
} else {
s.serverVersion = []byte(packageVersion)
}
var err error
s.clientVersion, err = exchangeVersions(s.sshConn.conn, s.serverVersion)
if err != nil {
return nil, err
}
tr := newTransport(s.sshConn.conn, config.Rand, false /* not client */)
s.transport = newServerTransport(tr, s.clientVersion, s.serverVersion, config)
if err := s.transport.waitSession(); err != nil {
return nil, err
}
// We just did the key change, so the session ID is established.
s.sessionID = s.transport.getSessionID()
var packet []byte
if packet, err = s.transport.readPacket(); err != nil {
return nil, err
}
var serviceRequest serviceRequestMsg
if err = Unmarshal(packet, &serviceRequest); err != nil {
return nil, err
}
if serviceRequest.Service != serviceUserAuth {
return nil, errors.New("ssh: requested service '" + serviceRequest.Service + "' before authenticating")
}
serviceAccept := serviceAcceptMsg{
Service: serviceUserAuth,
}
if err := s.transport.writePacket(Marshal(&serviceAccept)); err != nil {
return nil, err
}
perms, err := s.serverAuthenticate(config)
if err != nil {
return nil, err
}
s.mux = newMux(s.transport)
return perms, err
}
func checkSourceAddress(addr net.Addr, sourceAddrs string) error {
if addr == nil {
return errors.New("ssh: no address known for client, but source-address match required")
}
tcpAddr, ok := addr.(*net.TCPAddr)
if !ok {
return fmt.Errorf("ssh: remote address %v is not an TCP address when checking source-address match", addr)
}
for _, sourceAddr := range strings.Split(sourceAddrs, ",") {
if allowedIP := net.ParseIP(sourceAddr); allowedIP != nil {
if allowedIP.Equal(tcpAddr.IP) {
return nil
}
} else {
_, ipNet, err := net.ParseCIDR(sourceAddr)
if err != nil {
return fmt.Errorf("ssh: error parsing source-address restriction %q: %v", sourceAddr, err)
}
if ipNet.Contains(tcpAddr.IP) {
return nil
}
}
}
return fmt.Errorf("ssh: remote address %v is not allowed because of source-address restriction", addr)
}
func gssExchangeToken(gssapiConfig *GSSAPIWithMICConfig, token []byte, s *connection,
sessionID []byte, userAuthReq userAuthRequestMsg) (authErr error, perms *Permissions, err error) {
gssAPIServer := gssapiConfig.Server
defer gssAPIServer.DeleteSecContext()
var srcName string
for {
var (
outToken []byte
needContinue bool
)
outToken, srcName, needContinue, err = gssAPIServer.AcceptSecContext(token)
if err != nil {
return err, nil, nil
}
if len(outToken) != 0 {
if err := s.transport.writePacket(Marshal(&userAuthGSSAPIToken{
Token: outToken,
})); err != nil {
return nil, nil, err
}
}
if !needContinue {
break
}
packet, err := s.transport.readPacket()
if err != nil {
return nil, nil, err
}
userAuthGSSAPITokenReq := &userAuthGSSAPIToken{}
if err := Unmarshal(packet, userAuthGSSAPITokenReq); err != nil {
return nil, nil, err
}
token = userAuthGSSAPITokenReq.Token
}
packet, err := s.transport.readPacket()
if err != nil {
return nil, nil, err
}
userAuthGSSAPIMICReq := &userAuthGSSAPIMIC{}
if err := Unmarshal(packet, userAuthGSSAPIMICReq); err != nil {
return nil, nil, err
}
mic := buildMIC(string(sessionID), userAuthReq.User, userAuthReq.Service, userAuthReq.Method)
if err := gssAPIServer.VerifyMIC(mic, userAuthGSSAPIMICReq.MIC); err != nil {
return err, nil, nil
}
perms, authErr = gssapiConfig.AllowLogin(s, srcName)
return authErr, perms, nil
}
// isAlgoCompatible checks if the signature format is compatible with the
// selected algorithm taking into account edge cases that occur with old
// clients.
func isAlgoCompatible(algo, sigFormat string) bool {
// Compatibility for old clients.
//
// For certificate authentication with OpenSSH 7.2-7.7 signature format can
// be rsa-sha2-256 or rsa-sha2-512 for the algorithm
// ssh-rsa-cert-v01@openssh.com.
//
// With gpg-agent < 2.2.6 the algorithm can be rsa-sha2-256 or rsa-sha2-512
// for signature format ssh-rsa.
if isRSA(algo) && isRSA(sigFormat) {
return true
}
// Standard case: the underlying algorithm must match the signature format.
return underlyingAlgo(algo) == sigFormat
}
// ServerAuthError represents server authentication errors and is
// sometimes returned by NewServerConn. It appends any authentication
// errors that may occur, and is returned if all of the authentication
// methods provided by the user failed to authenticate.
type ServerAuthError struct {
// Errors contains authentication errors returned by the authentication
// callback methods. The first entry is typically ErrNoAuth.
Errors []error
}
func (l ServerAuthError) Error() string {
var errs []string
for _, err := range l.Errors {
errs = append(errs, err.Error())
}
return "[" + strings.Join(errs, ", ") + "]"
}
// ErrNoAuth is the error value returned if no
// authentication method has been passed yet. This happens as a normal
// part of the authentication loop, since the client first tries
// 'none' authentication to discover available methods.
// It is returned in ServerAuthError.Errors from NewServerConn.
var ErrNoAuth = errors.New("ssh: no auth passed yet")
func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) {
sessionID := s.transport.getSessionID()
var cache pubKeyCache
var perms *Permissions
authFailures := 0
var authErrs []error
var displayedBanner bool
userAuthLoop:
for {
if authFailures >= config.MaxAuthTries && config.MaxAuthTries > 0 {
discMsg := &disconnectMsg{
Reason: 2,
Message: "too many authentication failures",
}
if err := s.transport.writePacket(Marshal(discMsg)); err != nil {
return nil, err
}
return nil, discMsg
}
var userAuthReq userAuthRequestMsg
if packet, err := s.transport.readPacket(); err != nil {
if err == io.EOF {
return nil, &ServerAuthError{Errors: authErrs}
}
return nil, err
} else if err = Unmarshal(packet, &userAuthReq); err != nil {
return nil, err
}
if userAuthReq.Service != serviceSSH {
return nil, errors.New("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service)
}
s.user = userAuthReq.User
if !displayedBanner && config.BannerCallback != nil {
displayedBanner = true
msg := config.BannerCallback(s)
if msg != "" {
bannerMsg := &userAuthBannerMsg{
Message: msg,
}
if err := s.transport.writePacket(Marshal(bannerMsg)); err != nil {
return nil, err
}
}
}
perms = nil
authErr := ErrNoAuth
switch userAuthReq.Method {
case "none":
if config.NoClientAuth {
if config.NoClientAuthCallback != nil {
perms, authErr = config.NoClientAuthCallback(s)
} else {
authErr = nil
}
}
// allow initial attempt of 'none' without penalty
if authFailures == 0 {
authFailures--
}
case "password":
if config.PasswordCallback == nil {
authErr = errors.New("ssh: password auth not configured")
break
}
payload := userAuthReq.Payload
if len(payload) < 1 || payload[0] != 0 {
return nil, parseError(msgUserAuthRequest)
}
payload = payload[1:]
password, payload, ok := parseString(payload)
if !ok || len(payload) > 0 {
return nil, parseError(msgUserAuthRequest)
}
perms, authErr = config.PasswordCallback(s, password)
case "keyboard-interactive":
if config.KeyboardInteractiveCallback == nil {
authErr = errors.New("ssh: keyboard-interactive auth not configured")
break
}
prompter := &sshClientKeyboardInteractive{s}
perms, authErr = config.KeyboardInteractiveCallback(s, prompter.Challenge)
case "publickey":
if config.PublicKeyCallback == nil {
authErr = errors.New("ssh: publickey auth not configured")
break
}
payload := userAuthReq.Payload
if len(payload) < 1 {
return nil, parseError(msgUserAuthRequest)
}
isQuery := payload[0] == 0
payload = payload[1:]
algoBytes, payload, ok := parseString(payload)
if !ok {
return nil, parseError(msgUserAuthRequest)
}
algo := string(algoBytes)
if !contains(config.PublicKeyAuthAlgorithms, underlyingAlgo(algo)) {
authErr = fmt.Errorf("ssh: algorithm %q not accepted", algo)
break
}
pubKeyData, payload, ok := parseString(payload)
if !ok {
return nil, parseError(msgUserAuthRequest)
}
pubKey, err := ParsePublicKey(pubKeyData)
if err != nil {
return nil, err
}
candidate, ok := cache.get(s.user, pubKeyData)
if !ok {
candidate.user = s.user
candidate.pubKeyData = pubKeyData
candidate.perms, candidate.result = config.PublicKeyCallback(s, pubKey)
if candidate.result == nil && candidate.perms != nil && candidate.perms.CriticalOptions != nil && candidate.perms.CriticalOptions[sourceAddressCriticalOption] != "" {
candidate.result = checkSourceAddress(
s.RemoteAddr(),
candidate.perms.CriticalOptions[sourceAddressCriticalOption])
}
cache.add(candidate)
}
if isQuery {
// The client can query if the given public key
// would be okay.
if len(payload) > 0 {
return nil, parseError(msgUserAuthRequest)
}
if candidate.result == nil {
okMsg := userAuthPubKeyOkMsg{
Algo: algo,
PubKey: pubKeyData,
}
if err = s.transport.writePacket(Marshal(&okMsg)); err != nil {
return nil, err
}
continue userAuthLoop
}
authErr = candidate.result
} else {
sig, payload, ok := parseSignature(payload)
if !ok || len(payload) > 0 {
return nil, parseError(msgUserAuthRequest)
}
// Ensure the declared public key algo is compatible with the
// decoded one. This check will ensure we don't accept e.g.
// ssh-rsa-cert-v01@openssh.com algorithm with ssh-rsa public
// key type. The algorithm and public key type must be
// consistent: both must be certificate algorithms, or neither.
if !contains(algorithmsForKeyFormat(pubKey.Type()), algo) {
authErr = fmt.Errorf("ssh: public key type %q not compatible with selected algorithm %q",
pubKey.Type(), algo)
break
}
// Ensure the public key algo and signature algo
// are supported. Compare the private key
// algorithm name that corresponds to algo with
// sig.Format. This is usually the same, but
// for certs, the names differ.
if !contains(config.PublicKeyAuthAlgorithms, sig.Format) {
authErr = fmt.Errorf("ssh: algorithm %q not accepted", sig.Format)
break
}
if !isAlgoCompatible(algo, sig.Format) {
authErr = fmt.Errorf("ssh: signature %q not compatible with selected algorithm %q", sig.Format, algo)
break
}
signedData := buildDataSignedForAuth(sessionID, userAuthReq, algo, pubKeyData)
if err := pubKey.Verify(signedData, sig); err != nil {
return nil, err
}
authErr = candidate.result
perms = candidate.perms
}
case "gssapi-with-mic":
if config.GSSAPIWithMICConfig == nil {
authErr = errors.New("ssh: gssapi-with-mic auth not configured")
break
}
gssapiConfig := config.GSSAPIWithMICConfig
userAuthRequestGSSAPI, err := parseGSSAPIPayload(userAuthReq.Payload)
if err != nil {
return nil, parseError(msgUserAuthRequest)
}
// OpenSSH supports Kerberos V5 mechanism only for GSS-API authentication.
if userAuthRequestGSSAPI.N == 0 {
authErr = fmt.Errorf("ssh: Mechanism negotiation is not supported")
break
}
var i uint32
present := false
for i = 0; i < userAuthRequestGSSAPI.N; i++ {
if userAuthRequestGSSAPI.OIDS[i].Equal(krb5Mesh) {
present = true
break
}
}
if !present {
authErr = fmt.Errorf("ssh: GSSAPI authentication must use the Kerberos V5 mechanism")
break
}
// Initial server response, see RFC 4462 section 3.3.
if err := s.transport.writePacket(Marshal(&userAuthGSSAPIResponse{
SupportMech: krb5OID,
})); err != nil {
return nil, err
}
// Exchange token, see RFC 4462 section 3.4.
packet, err := s.transport.readPacket()
if err != nil {
return nil, err
}
userAuthGSSAPITokenReq := &userAuthGSSAPIToken{}
if err := Unmarshal(packet, userAuthGSSAPITokenReq); err != nil {
return nil, err
}
authErr, perms, err = gssExchangeToken(gssapiConfig, userAuthGSSAPITokenReq.Token, s, sessionID,
userAuthReq)
if err != nil {
return nil, err
}
default:
authErr = fmt.Errorf("ssh: unknown method %q", userAuthReq.Method)
}
authErrs = append(authErrs, authErr)
if config.AuthLogCallback != nil {
config.AuthLogCallback(s, userAuthReq.Method, authErr)
}
if authErr == nil {
break userAuthLoop
}
authFailures++
if config.MaxAuthTries > 0 && authFailures >= config.MaxAuthTries {
// If we have hit the max attempts, don't bother sending the
// final SSH_MSG_USERAUTH_FAILURE message, since there are
// no more authentication methods which can be attempted,
// and this message may cause the client to re-attempt
// authentication while we send the disconnect message.
// Continue, and trigger the disconnect at the start of
// the loop.
//
// The SSH specification is somewhat confusing about this,
// RFC 4252 Section 5.1 requires each authentication failure
// be responded to with a respective SSH_MSG_USERAUTH_FAILURE
// message, but Section 4 says the server should disconnect
// after some number of attempts, but it isn't explicit which
// message should take precedence (i.e. should there be a failure
// message than a disconnect message, or if we are going to
// disconnect, should we only send that message.)
//
// Either way, OpenSSH disconnects immediately after the last
// failed authnetication attempt, and given they are typically
// considered the golden implementation it seems reasonable
// to match that behavior.
continue
}
var failureMsg userAuthFailureMsg
if config.PasswordCallback != nil {
failureMsg.Methods = append(failureMsg.Methods, "password")
}
if config.PublicKeyCallback != nil {
failureMsg.Methods = append(failureMsg.Methods, "publickey")
}
if config.KeyboardInteractiveCallback != nil {
failureMsg.Methods = append(failureMsg.Methods, "keyboard-interactive")
}
if config.GSSAPIWithMICConfig != nil && config.GSSAPIWithMICConfig.Server != nil &&
config.GSSAPIWithMICConfig.AllowLogin != nil {
failureMsg.Methods = append(failureMsg.Methods, "gssapi-with-mic")
}
if len(failureMsg.Methods) == 0 {
return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false")
}
if err := s.transport.writePacket(Marshal(&failureMsg)); err != nil {
return nil, err
}
}
if err := s.transport.writePacket([]byte{msgUserAuthSuccess}); err != nil {
return nil, err
}
return perms, nil
}
// sshClientKeyboardInteractive implements a ClientKeyboardInteractive by
// asking the client on the other side of a ServerConn.
type sshClientKeyboardInteractive struct {
*connection
}
func (c *sshClientKeyboardInteractive) Challenge(name, instruction string, questions []string, echos []bool) (answers []string, err error) {
if len(questions) != len(echos) {
return nil, errors.New("ssh: echos and questions must have equal length")
}
var prompts []byte
for i := range questions {
prompts = appendString(prompts, questions[i])
prompts = appendBool(prompts, echos[i])
}
if err := c.transport.writePacket(Marshal(&userAuthInfoRequestMsg{
Name: name,
Instruction: instruction,
NumPrompts: uint32(len(questions)),
Prompts: prompts,
})); err != nil {
return nil, err
}
packet, err := c.transport.readPacket()
if err != nil {
return nil, err
}
if packet[0] != msgUserAuthInfoResponse {
return nil, unexpectedMessageError(msgUserAuthInfoResponse, packet[0])
}
packet = packet[1:]
n, packet, ok := parseUint32(packet)
if !ok || int(n) != len(questions) {
return nil, parseError(msgUserAuthInfoResponse)
}
for i := uint32(0); i < n; i++ {
ans, rest, ok := parseString(packet)
if !ok {
return nil, parseError(msgUserAuthInfoResponse)
}
answers = append(answers, string(ans))
packet = rest
}
if len(packet) != 0 {
return nil, errors.New("ssh: junk at end of message")
}
return answers, nil
}
================================================
FILE: gossh/crypto/ssh/session.go
================================================
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
// Session implements an interactive session described in
// "RFC 4254, section 6".
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"sync"
)
type Signal string
// POSIX signals as listed in RFC 4254 Section 6.10.
const (
SIGABRT Signal = "ABRT"
SIGALRM Signal = "ALRM"
SIGFPE Signal = "FPE"
SIGHUP Signal = "HUP"
SIGILL Signal = "ILL"
SIGINT Signal = "INT"
SIGKILL Signal = "KILL"
SIGPIPE Signal = "PIPE"
SIGQUIT Signal = "QUIT"
SIGSEGV Signal = "SEGV"
SIGTERM Signal = "TERM"
SIGUSR1 Signal = "USR1"
SIGUSR2 Signal = "USR2"
)
var signals = map[Signal]int{
SIGABRT: 6,
SIGALRM: 14,
SIGFPE: 8,
SIGHUP: 1,
SIGILL: 4,
SIGINT: 2,
SIGKILL: 9,
SIGPIPE: 13,
SIGQUIT: 3,
SIGSEGV: 11,
SIGTERM: 15,
}
type TerminalModes map[uint8]uint32
// POSIX terminal mode flags as listed in RFC 4254 Section 8.
const (
tty_OP_END = 0
VINTR = 1
VQUIT = 2
VERASE = 3
VKILL = 4
VEOF = 5
VEOL = 6
VEOL2 = 7
VSTART = 8
VSTOP = 9
VSUSP = 10
VDSUSP = 11
VREPRINT = 12
VWERASE = 13
VLNEXT = 14
VFLUSH = 15
VSWTCH = 16
VSTATUS = 17
VDISCARD = 18
IGNPAR = 30
PARMRK = 31
INPCK = 32
ISTRIP = 33
INLCR = 34
IGNCR = 35
ICRNL = 36
IUCLC = 37
IXON = 38
IXANY = 39
IXOFF = 40
IMAXBEL = 41
IUTF8 = 42 // RFC 8160
ISIG = 50
ICANON = 51
XCASE = 52
ECHO = 53
ECHOE = 54
ECHOK = 55
ECHONL = 56
NOFLSH = 57
TOSTOP = 58
IEXTEN = 59
ECHOCTL = 60
ECHOKE = 61
PENDIN = 62
OPOST = 70
OLCUC = 71
ONLCR = 72
OCRNL = 73
ONOCR = 74
ONLRET = 75
CS7 = 90
CS8 = 91
PARENB = 92
PARODD = 93
TTY_OP_ISPEED = 128
TTY_OP_OSPEED = 129
)
// A Session represents a connection to a remote command or shell.
type Session struct {
// Stdin specifies the remote process's standard input.
// If Stdin is nil, the remote process reads from an empty
// bytes.Buffer.
Stdin io.Reader
// Stdout and Stderr specify the remote process's standard
// output and error.
//
// If either is nil, Run connects the corresponding file
// descriptor to an instance of io.Discard. There is a
// fixed amount of buffering that is shared for the two streams.
// If either blocks it may eventually cause the remote
// command to block.
Stdout io.Writer
Stderr io.Writer
ch Channel // the channel backing this session
started bool // true once Start, Run or Shell is invoked.
copyFuncs []func() error
errors chan error // one send per copyFunc
// true if pipe method is active
stdinpipe, stdoutpipe, stderrpipe bool
// stdinPipeWriter is non-nil if StdinPipe has not been called
// and Stdin was specified by the user; it is the write end of
// a pipe connecting Session.Stdin to the stdin channel.
stdinPipeWriter io.WriteCloser
exitStatus chan error
}
// SendRequest sends an out-of-band channel request on the SSH channel
// underlying the session.
func (s *Session) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
return s.ch.SendRequest(name, wantReply, payload)
}
func (s *Session) Close() error {
return s.ch.Close()
}
// RFC 4254 Section 6.4.
type setenvRequest struct {
Name string
Value string
}
// Setenv sets an environment variable that will be applied to any
// command executed by Shell or Run.
func (s *Session) Setenv(name, value string) error {
msg := setenvRequest{
Name: name,
Value: value,
}
ok, err := s.ch.SendRequest("env", true, Marshal(&msg))
if err == nil && !ok {
err = errors.New("ssh: setenv failed")
}
return err
}
// RFC 4254 Section 6.2.
type ptyRequestMsg struct {
Term string
Columns uint32
Rows uint32
Width uint32
Height uint32
Modelist string
}
// RequestPty requests the association of a pty with the session on the remote host.
func (s *Session) RequestPty(term string, h, w int, termmodes TerminalModes) error {
var tm []byte
for k, v := range termmodes {
kv := struct {
Key byte
Val uint32
}{k, v}
tm = append(tm, Marshal(&kv)...)
}
tm = append(tm, tty_OP_END)
req := ptyRequestMsg{
Term: term,
Columns: uint32(w),
Rows: uint32(h),
Width: uint32(w * 8),
Height: uint32(h * 8),
Modelist: string(tm),
}
ok, err := s.ch.SendRequest("pty-req", true, Marshal(&req))
if err == nil && !ok {
err = errors.New("ssh: pty-req failed")
}
return err
}
// RFC 4254 Section 6.5.
type subsystemRequestMsg struct {
Subsystem string
}
// RequestSubsystem requests the association of a subsystem with the session on the remote host.
// A subsystem is a predefined command that runs in the background when the ssh session is initiated
func (s *Session) RequestSubsystem(subsystem string) error {
msg := subsystemRequestMsg{
Subsystem: subsystem,
}
ok, err := s.ch.SendRequest("subsystem", true, Marshal(&msg))
if err == nil && !ok {
err = errors.New("ssh: subsystem request failed")
}
return err
}
// RFC 4254 Section 6.7.
type ptyWindowChangeMsg struct {
Columns uint32
Rows uint32
Width uint32
Height uint32
}
// WindowChange informs the remote host about a terminal window dimension change to h rows and w columns.
func (s *Session) WindowChange(h, w int) error {
req := ptyWindowChangeMsg{
Columns: uint32(w),
Rows: uint32(h),
Width: uint32(w * 8),
Height: uint32(h * 8),
}
_, err := s.ch.SendRequest("window-change", false, Marshal(&req))
return err
}
// RFC 4254 Section 6.9.
type signalMsg struct {
Signal string
}
// Signal sends the given signal to the remote process.
// sig is one of the SIG* constants.
func (s *Session) Signal(sig Signal) error {
msg := signalMsg{
Signal: string(sig),
}
_, err := s.ch.SendRequest("signal", false, Marshal(&msg))
return err
}
// RFC 4254 Section 6.5.
type execMsg struct {
Command string
}
// Start runs cmd on the remote host. Typically, the remote
// server passes cmd to the shell for interpretation.
// A Session only accepts one call to Run, Start or Shell.
func (s *Session) Start(cmd string) error {
if s.started {
return errors.New("ssh: session already started")
}
req := execMsg{
Command: cmd,
}
ok, err := s.ch.SendRequest("exec", true, Marshal(&req))
if err == nil && !ok {
err = fmt.Errorf("ssh: command %v failed", cmd)
}
if err != nil {
return err
}
return s.start()
}
// Run runs cmd on the remote host. Typically, the remote
// server passes cmd to the shell for interpretation.
// A Session only accepts one call to Run, Start, Shell, Output,
// or CombinedOutput.
//
// The returned error is nil if the command runs, has no problems
// copying stdin, stdout, and stderr, and exits with a zero exit
// status.
//
// If the remote server does not send an exit status, an error of type
// *ExitMissingError is returned. If the command completes
// unsuccessfully or is interrupted by a signal, the error is of type
// *ExitError. Other error types may be returned for I/O problems.
func (s *Session) Run(cmd string) error {
err := s.Start(cmd)
if err != nil {
return err
}
return s.Wait()
}
// Output runs cmd on the remote host and returns its standard output.
func (s *Session) Output(cmd string) ([]byte, error) {
if s.Stdout != nil {
return nil, errors.New("ssh: Stdout already set")
}
var b bytes.Buffer
s.Stdout = &b
err := s.Run(cmd)
return b.Bytes(), err
}
type singleWriter struct {
b bytes.Buffer
mu sync.Mutex
}
func (w *singleWriter) Write(p []byte) (int, error) {
w.mu.Lock()
defer w.mu.Unlock()
return w.b.Write(p)
}
// CombinedOutput runs cmd on the remote host and returns its combined
// standard output and standard error.
func (s *Session) CombinedOutput(cmd string) ([]byte, error) {
if s.Stdout != nil {
return nil, errors.New("ssh: Stdout already set")
}
if s.Stderr != nil {
return nil, errors.New("ssh: Stderr already set")
}
var b singleWriter
s.Stdout = &b
s.Stderr = &b
err := s.Run(cmd)
return b.b.Bytes(), err
}
// Shell starts a login shell on the remote host. A Session only
// accepts one call to Run, Start, Shell, Output, or CombinedOutput.
func (s *Session) Shell() error {
if s.started {
return errors.New("ssh: session already started")
}
ok, err := s.ch.SendRequest("shell", true, nil)
if err == nil && !ok {
return errors.New("ssh: could not start shell")
}
if err != nil {
return err
}
return s.start()
}
func (s *Session) start() error {
s.started = true
type F func(*Session)
for _, setupFd := range []F{(*Session).stdin, (*Session).stdout, (*Session).stderr} {
setupFd(s)
}
s.errors = make(chan error, len(s.copyFuncs))
for _, fn := range s.copyFuncs {
go func(fn func() error) {
s.errors <- fn()
}(fn)
}
return nil
}
// Wait waits for the remote command to exit.
//
// The returned error is nil if the command runs, has no problems
// copying stdin, stdout, and stderr, and exits with a zero exit
// status.
//
// If the remote server does not send an exit status, an error of type
// *ExitMissingError is returned. If the command completes
// unsuccessfully or is interrupted by a signal, the error is of type
// *ExitError. Other error types may be returned for I/O problems.
func (s *Session) Wait() error {
if !s.started {
return errors.New("ssh: session not started")
}
waitErr := <-s.exitStatus
if s.stdinPipeWriter != nil {
s.stdinPipeWriter.Close()
}
var copyError error
for range s.copyFuncs {
if err := <-s.errors; err != nil && copyError == nil {
copyError = err
}
}
if waitErr != nil {
return waitErr
}
return copyError
}
func (s *Session) wait(reqs <-chan *Request) error {
wm := Waitmsg{status: -1}
// Wait for msg channel to be closed before returning.
for msg := range reqs {
switch msg.Type {
case "exit-status":
wm.status = int(binary.BigEndian.Uint32(msg.Payload))
case "exit-signal":
var sigval struct {
Signal string
CoreDumped bool
Error string
Lang string
}
if err := Unmarshal(msg.Payload, &sigval); err != nil {
return err
}
// Must sanitize strings?
wm.signal = sigval.Signal
wm.msg = sigval.Error
wm.lang = sigval.Lang
default:
// This handles keepalives and matches
// OpenSSH's behaviour.
if msg.WantReply {
msg.Reply(false, nil)
}
}
}
if wm.status == 0 {
return nil
}
if wm.status == -1 {
// exit-status was never sent from server
if wm.signal == "" {
// signal was not sent either. RFC 4254
// section 6.10 recommends against this
// behavior, but it is allowed, so we let
// clients handle it.
return &ExitMissingError{}
}
wm.status = 128
if _, ok := signals[Signal(wm.signal)]; ok {
wm.status += signals[Signal(wm.signal)]
}
}
return &ExitError{wm}
}
// ExitMissingError is returned if a session is torn down cleanly, but
// the server sends no confirmation of the exit status.
type ExitMissingError struct{}
func (e *ExitMissingError) Error() string {
return "wait: remote command exited without exit status or exit signal"
}
func (s *Session) stdin() {
if s.stdinpipe {
return
}
var stdin io.Reader
if s.Stdin == nil {
stdin = new(bytes.Buffer)
} else {
r, w := io.Pipe()
go func() {
_, err := io.Copy(w, s.Stdin)
w.CloseWithError(err)
}()
stdin, s.stdinPipeWriter = r, w
}
s.copyFuncs = append(s.copyFuncs, func() error {
_, err := io.Copy(s.ch, stdin)
if err1 := s.ch.CloseWrite(); err == nil && err1 != io.EOF {
err = err1
}
return err
})
}
func (s *Session) stdout() {
if s.stdoutpipe {
return
}
if s.Stdout == nil {
s.Stdout = io.Discard
}
s.copyFuncs = append(s.copyFuncs, func() error {
_, err := io.Copy(s.Stdout, s.ch)
return err
})
}
func (s *Session) stderr() {
if s.stderrpipe {
return
}
if s.Stderr == nil {
s.Stderr = io.Discard
}
s.copyFuncs = append(s.copyFuncs, func() error {
_, err := io.Copy(s.Stderr, s.ch.Stderr())
return err
})
}
// sessionStdin reroutes Close to CloseWrite.
type sessionStdin struct {
io.Writer
ch Channel
}
func (s *sessionStdin) Close() error {
return s.ch.CloseWrite()
}
// StdinPipe returns a pipe that will be connected to the
// remote command's standard input when the command starts.
func (s *Session) StdinPipe() (io.WriteCloser, error) {
if s.Stdin != nil {
return nil, errors.New("ssh: Stdin already set")
}
if s.started {
return nil, errors.New("ssh: StdinPipe after process started")
}
s.stdinpipe = true
return &sessionStdin{s.ch, s.ch}, nil
}
// StdoutPipe returns a pipe that will be connected to the
// remote command's standard output when the command starts.
// There is a fixed amount of buffering that is shared between
// stdout and stderr streams. If the StdoutPipe reader is
// not serviced fast enough it may eventually cause the
// remote command to block.
func (s *Session) StdoutPipe() (io.Reader, error) {
if s.Stdout != nil {
return nil, errors.New("ssh: Stdout already set")
}
if s.started {
return nil, errors.New("ssh: StdoutPipe after process started")
}
s.stdoutpipe = true
return s.ch, nil
}
// StderrPipe returns a pipe that will be connected to the
// remote command's standard error when the command starts.
// There is a fixed amount of buffering that is shared between
// stdout and stderr streams. If the StderrPipe reader is
// not serviced fast enough it may eventually cause the
// remote command to block.
func (s *Session) StderrPipe() (io.Reader, error) {
if s.Stderr != nil {
return nil, errors.New("ssh: Stderr already set")
}
if s.started {
return nil, errors.New("ssh: StderrPipe after process started")
}
s.stderrpipe = true
return s.ch.Stderr(), nil
}
// newSession returns a new interactive session on the remote host.
func newSession(ch Channel, reqs <-chan *Request) (*Session, error) {
s := &Session{
ch: ch,
}
s.exitStatus = make(chan error, 1)
go func() {
s.exitStatus <- s.wait(reqs)
}()
return s, nil
}
// An ExitError reports unsuccessful completion of a remote command.
type ExitError struct {
Waitmsg
}
func (e *ExitError) Error() string {
return e.Waitmsg.String()
}
// Waitmsg stores the information about an exited remote command
// as reported by Wait.
type Waitmsg struct {
status int
signal string
msg string
lang string
}
// ExitStatus returns the exit status of the remote command.
func (w Waitmsg) ExitStatus() int {
return w.status
}
// Signal returns the exit signal of the remote command if
// it was terminated violently.
func (w Waitmsg) Signal() string {
return w.signal
}
// Msg returns the exit message given by the remote command
func (w Waitmsg) Msg() string {
return w.msg
}
// Lang returns the language tag. See RFC 3066
func (w Waitmsg) Lang() string {
return w.lang
}
func (w Waitmsg) String() string {
str := fmt.Sprintf("Process exited with status %v", w.status)
if w.signal != "" {
str += fmt.Sprintf(" from signal %v", w.signal)
}
if w.msg != "" {
str += fmt.Sprintf(". Reason was: %v", w.msg)
}
return str
}
================================================
FILE: gossh/crypto/ssh/ssh_gss.go
================================================
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"encoding/asn1"
"errors"
)
var krb5OID []byte
func init() {
krb5OID, _ = asn1.Marshal(krb5Mesh)
}
// GSSAPIClient provides the API to plug-in GSSAPI authentication for client logins.
type GSSAPIClient interface {
// InitSecContext initiates the establishment of a security context for GSS-API between the
// ssh client and ssh server. Initially the token parameter should be specified as nil.
// The routine may return a outputToken which should be transferred to
// the ssh server, where the ssh server will present it to
// AcceptSecContext. If no token need be sent, InitSecContext will indicate this by setting
// needContinue to false. To complete the context
// establishment, one or more reply tokens may be required from the ssh
// server;if so, InitSecContext will return a needContinue which is true.
// In this case, InitSecContext should be called again when the
// reply token is received from the ssh server, passing the reply
// token to InitSecContext via the token parameters.
// See RFC 2743 section 2.2.1 and RFC 4462 section 3.4.
InitSecContext(target string, token []byte, isGSSDelegCreds bool) (outputToken []byte, needContinue bool, err error)
// GetMIC generates a cryptographic MIC for the SSH2 message, and places
// the MIC in a token for transfer to the ssh server.
// The contents of the MIC field are obtained by calling GSS_GetMIC()
// over the following, using the GSS-API context that was just
// established:
// string session identifier
// byte SSH_MSG_USERAUTH_REQUEST
// string user name
// string service
// string "gssapi-with-mic"
// See RFC 2743 section 2.3.1 and RFC 4462 3.5.
GetMIC(micFiled []byte) ([]byte, error)
// Whenever possible, it should be possible for
// DeleteSecContext() calls to be successfully processed even
// if other calls cannot succeed, thereby enabling context-related
// resources to be released.
// In addition to deleting established security contexts,
// gss_delete_sec_context must also be able to delete "half-built"
// security contexts resulting from an incomplete sequence of
// InitSecContext()/AcceptSecContext() calls.
// See RFC 2743 section 2.2.3.
DeleteSecContext() error
}
// GSSAPIServer provides the API to plug in GSSAPI authentication for server logins.
type GSSAPIServer interface {
// AcceptSecContext allows a remotely initiated security context between the application
// and a remote peer to be established by the ssh client. The routine may return a
// outputToken which should be transferred to the ssh client,
// where the ssh client will present it to InitSecContext.
// If no token need be sent, AcceptSecContext will indicate this
// by setting the needContinue to false. To
// complete the context establishment, one or more reply tokens may be
// required from the ssh client. if so, AcceptSecContext
// will return a needContinue which is true, in which case it
// should be called again when the reply token is received from the ssh
// client, passing the token to AcceptSecContext via the
// token parameters.
// The srcName return value is the authenticated username.
// See RFC 2743 section 2.2.2 and RFC 4462 section 3.4.
AcceptSecContext(token []byte) (outputToken []byte, srcName string, needContinue bool, err error)
// VerifyMIC verifies that a cryptographic MIC, contained in the token parameter,
// fits the supplied message is received from the ssh client.
// See RFC 2743 section 2.3.2.
VerifyMIC(micField []byte, micToken []byte) error
// Whenever possible, it should be possible for
// DeleteSecContext() calls to be successfully processed even
// if other calls cannot succeed, thereby enabling context-related
// resources to be released.
// In addition to deleting established security contexts,
// gss_delete_sec_context must also be able to delete "half-built"
// security contexts resulting from an incomplete sequence of
// InitSecContext()/AcceptSecContext() calls.
// See RFC 2743 section 2.2.3.
DeleteSecContext() error
}
var (
// OpenSSH supports Kerberos V5 mechanism only for GSS-API authentication,
// so we also support the krb5 mechanism only.
// See RFC 1964 section 1.
krb5Mesh = asn1.ObjectIdentifier{1, 2, 840, 113554, 1, 2, 2}
)
// The GSS-API authentication method is initiated when the client sends an SSH_MSG_USERAUTH_REQUEST
// See RFC 4462 section 3.2.
type userAuthRequestGSSAPI struct {
N uint32
OIDS []asn1.ObjectIdentifier
}
func parseGSSAPIPayload(payload []byte) (*userAuthRequestGSSAPI, error) {
n, rest, ok := parseUint32(payload)
if !ok {
return nil, errors.New("parse uint32 failed")
}
s := &userAuthRequestGSSAPI{
N: n,
OIDS: make([]asn1.ObjectIdentifier, n),
}
for i := 0; i < int(n); i++ {
var (
desiredMech []byte
err error
)
desiredMech, rest, ok = parseString(rest)
if !ok {
return nil, errors.New("parse string failed")
}
if rest, err = asn1.Unmarshal(desiredMech, &s.OIDS[i]); err != nil {
return nil, err
}
}
return s, nil
}
// See RFC 4462 section 3.6.
func buildMIC(sessionID string, username string, service string, authMethod string) []byte {
out := make([]byte, 0, 0)
out = appendString(out, sessionID)
out = append(out, msgUserAuthRequest)
out = appendString(out, username)
out = appendString(out, service)
out = appendString(out, authMethod)
return out
}
================================================
FILE: gossh/crypto/ssh/streamlocal.go
================================================
package ssh
import (
"errors"
"io"
"net"
)
// streamLocalChannelOpenDirectMsg is a struct used for SSH_MSG_CHANNEL_OPEN message
// with "direct-streamlocal@openssh.com" string.
//
// See openssh-portable/PROTOCOL, section 2.4. connection: Unix domain socket forwarding
// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL#L235
type streamLocalChannelOpenDirectMsg struct {
socketPath string
reserved0 string
reserved1 uint32
}
// forwardedStreamLocalPayload is a struct used for SSH_MSG_CHANNEL_OPEN message
// with "forwarded-streamlocal@openssh.com" string.
type forwardedStreamLocalPayload struct {
SocketPath string
Reserved0 string
}
// streamLocalChannelForwardMsg is a struct used for SSH2_MSG_GLOBAL_REQUEST message
// with "streamlocal-forward@openssh.com"/"cancel-streamlocal-forward@openssh.com" string.
type streamLocalChannelForwardMsg struct {
socketPath string
}
// ListenUnix is similar to ListenTCP but uses a Unix domain socket.
func (c *Client) ListenUnix(socketPath string) (net.Listener, error) {
c.handleForwardsOnce.Do(c.handleForwards)
m := streamLocalChannelForwardMsg{
socketPath,
}
// send message
ok, _, err := c.SendRequest("streamlocal-forward@openssh.com", true, Marshal(&m))
if err != nil {
return nil, err
}
if !ok {
return nil, errors.New("ssh: streamlocal-forward@openssh.com request denied by peer")
}
ch := c.forwards.add(&net.UnixAddr{Name: socketPath, Net: "unix"})
return &unixListener{socketPath, c, ch}, nil
}
func (c *Client) dialStreamLocal(socketPath string) (Channel, error) {
msg := streamLocalChannelOpenDirectMsg{
socketPath: socketPath,
}
ch, in, err := c.OpenChannel("direct-streamlocal@openssh.com", Marshal(&msg))
if err != nil {
return nil, err
}
go DiscardRequests(in)
return ch, err
}
type unixListener struct {
socketPath string
conn *Client
in <-chan forward
}
// Accept waits for and returns the next connection to the listener.
func (l *unixListener) Accept() (net.Conn, error) {
s, ok := <-l.in
if !ok {
return nil, io.EOF
}
ch, incoming, err := s.newCh.Accept()
if err != nil {
return nil, err
}
go DiscardRequests(incoming)
return &chanConn{
Channel: ch,
laddr: &net.UnixAddr{
Name: l.socketPath,
Net: "unix",
},
raddr: &net.UnixAddr{
Name: "@",
Net: "unix",
},
}, nil
}
// Close closes the listener.
func (l *unixListener) Close() error {
// this also closes the listener.
l.conn.forwards.remove(&net.UnixAddr{Name: l.socketPath, Net: "unix"})
m := streamLocalChannelForwardMsg{
l.socketPath,
}
ok, _, err := l.conn.SendRequest("cancel-streamlocal-forward@openssh.com", true, Marshal(&m))
if err == nil && !ok {
err = errors.New("ssh: cancel-streamlocal-forward@openssh.com failed")
}
return err
}
// Addr returns the listener's network address.
func (l *unixListener) Addr() net.Addr {
return &net.UnixAddr{
Name: l.socketPath,
Net: "unix",
}
}
================================================
FILE: gossh/crypto/ssh/tcpip.go
================================================
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"context"
"errors"
"fmt"
"io"
"math/rand"
"net"
"strconv"
"strings"
"sync"
"time"
)
// Listen requests the remote peer open a listening socket on
// addr. Incoming connections will be available by calling Accept on
// the returned net.Listener. The listener must be serviced, or the
// SSH connection may hang.
// N must be "tcp", "tcp4", "tcp6", or "unix".
func (c *Client) Listen(n, addr string) (net.Listener, error) {
switch n {
case "tcp", "tcp4", "tcp6":
laddr, err := net.ResolveTCPAddr(n, addr)
if err != nil {
return nil, err
}
return c.ListenTCP(laddr)
case "unix":
return c.ListenUnix(addr)
default:
return nil, fmt.Errorf("ssh: unsupported protocol: %s", n)
}
}
// Automatic port allocation is broken with OpenSSH before 6.0. See
// also https://bugzilla.mindrot.org/show_bug.cgi?id=2017. In
// particular, OpenSSH 5.9 sends a channelOpenMsg with port number 0,
// rather than the actual port number. This means you can never open
// two different listeners with auto allocated ports. We work around
// this by trying explicit ports until we succeed.
const openSSHPrefix = "OpenSSH_"
var portRandomizer = rand.New(rand.NewSource(time.Now().UnixNano()))
// isBrokenOpenSSHVersion returns true if the given version string
// specifies a version of OpenSSH that is known to have a bug in port
// forwarding.
func isBrokenOpenSSHVersion(versionStr string) bool {
i := strings.Index(versionStr, openSSHPrefix)
if i < 0 {
return false
}
i += len(openSSHPrefix)
j := i
for ; j < len(versionStr); j++ {
if versionStr[j] < '0' || versionStr[j] > '9' {
break
}
}
version, _ := strconv.Atoi(versionStr[i:j])
return version < 6
}
// autoPortListenWorkaround simulates automatic port allocation by
// trying random ports repeatedly.
func (c *Client) autoPortListenWorkaround(laddr *net.TCPAddr) (net.Listener, error) {
var sshListener net.Listener
var err error
const tries = 10
for i := 0; i < tries; i++ {
addr := *laddr
addr.Port = 1024 + portRandomizer.Intn(60000)
sshListener, err = c.ListenTCP(&addr)
if err == nil {
laddr.Port = addr.Port
return sshListener, err
}
}
return nil, fmt.Errorf("ssh: listen on random port failed after %d tries: %v", tries, err)
}
// RFC 4254 7.1
type channelForwardMsg struct {
addr string
rport uint32
}
// handleForwards starts goroutines handling forwarded connections.
// It's called on first use by (*Client).ListenTCP to not launch
// goroutines until needed.
func (c *Client) handleForwards() {
go c.forwards.handleChannels(c.HandleChannelOpen("forwarded-tcpip"))
go c.forwards.handleChannels(c.HandleChannelOpen("forwarded-streamlocal@openssh.com"))
}
// ListenTCP requests the remote peer open a listening socket
// on laddr. Incoming connections will be available by calling
// Accept on the returned net.Listener.
func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) {
c.handleForwardsOnce.Do(c.handleForwards)
if laddr.Port == 0 && isBrokenOpenSSHVersion(string(c.ServerVersion())) {
return c.autoPortListenWorkaround(laddr)
}
m := channelForwardMsg{
laddr.IP.String(),
uint32(laddr.Port),
}
// send message
ok, resp, err := c.SendRequest("tcpip-forward", true, Marshal(&m))
if err != nil {
return nil, err
}
if !ok {
return nil, errors.New("ssh: tcpip-forward request denied by peer")
}
// If the original port was 0, then the remote side will
// supply a real port number in the response.
if laddr.Port == 0 {
var p struct {
Port uint32
}
if err := Unmarshal(resp, &p); err != nil {
return nil, err
}
laddr.Port = int(p.Port)
}
// Register this forward, using the port number we obtained.
ch := c.forwards.add(laddr)
return &tcpListener{laddr, c, ch}, nil
}
// forwardList stores a mapping between remote
// forward requests and the tcpListeners.
type forwardList struct {
sync.Mutex
entries []forwardEntry
}
// forwardEntry represents an established mapping of a laddr on a
// remote ssh server to a channel connected to a tcpListener.
type forwardEntry struct {
laddr net.Addr
c chan forward
}
// forward represents an incoming forwarded tcpip connection. The
// arguments to add/remove/lookup should be address as specified in
// the original forward-request.
type forward struct {
newCh NewChannel // the ssh client channel underlying this forward
raddr net.Addr // the raddr of the incoming connection
}
func (l *forwardList) add(addr net.Addr) chan forward {
l.Lock()
defer l.Unlock()
f := forwardEntry{
laddr: addr,
c: make(chan forward, 1),
}
l.entries = append(l.entries, f)
return f.c
}
// See RFC 4254, section 7.2
type forwardedTCPPayload struct {
Addr string
Port uint32
OriginAddr string
OriginPort uint32
}
// parseTCPAddr parses the originating address from the remote into a *net.TCPAddr.
func parseTCPAddr(addr string, port uint32) (*net.TCPAddr, error) {
if port == 0 || port > 65535 {
return nil, fmt.Errorf("ssh: port number out of range: %d", port)
}
ip := net.ParseIP(string(addr))
if ip == nil {
return nil, fmt.Errorf("ssh: cannot parse IP address %q", addr)
}
return &net.TCPAddr{IP: ip, Port: int(port)}, nil
}
func (l *forwardList) handleChannels(in <-chan NewChannel) {
for ch := range in {
var (
laddr net.Addr
raddr net.Addr
err error
)
switch channelType := ch.ChannelType(); channelType {
case "forwarded-tcpip":
var payload forwardedTCPPayload
if err = Unmarshal(ch.ExtraData(), &payload); err != nil {
ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error())
continue
}
// RFC 4254 section 7.2 specifies that incoming
// addresses should list the address, in string
// format. It is implied that this should be an IP
// address, as it would be impossible to connect to it
// otherwise.
laddr, err = parseTCPAddr(payload.Addr, payload.Port)
if err != nil {
ch.Reject(ConnectionFailed, err.Error())
continue
}
raddr, err = parseTCPAddr(payload.OriginAddr, payload.OriginPort)
if err != nil {
ch.Reject(ConnectionFailed, err.Error())
continue
}
case "forwarded-streamlocal@openssh.com":
var payload forwardedStreamLocalPayload
if err = Unmarshal(ch.ExtraData(), &payload); err != nil {
ch.Reject(ConnectionFailed, "could not parse forwarded-streamlocal@openssh.com payload: "+err.Error())
continue
}
laddr = &net.UnixAddr{
Name: payload.SocketPath,
Net: "unix",
}
raddr = &net.UnixAddr{
Name: "@",
Net: "unix",
}
default:
panic(fmt.Errorf("ssh: unknown channel type %s", channelType))
}
if ok := l.forward(laddr, raddr, ch); !ok {
// Section 7.2, implementations MUST reject spurious incoming
// connections.
ch.Reject(Prohibited, "no forward for address")
continue
}
}
}
// remove removes the forward entry, and the channel feeding its
// listener.
func (l *forwardList) remove(addr net.Addr) {
l.Lock()
defer l.Unlock()
for i, f := range l.entries {
if addr.Network() == f.laddr.Network() && addr.String() == f.laddr.String() {
l.entries = append(l.entries[:i], l.entries[i+1:]...)
close(f.c)
return
}
}
}
// closeAll closes and clears all forwards.
func (l *forwardList) closeAll() {
l.Lock()
defer l.Unlock()
for _, f := range l.entries {
close(f.c)
}
l.entries = nil
}
func (l *forwardList) forward(laddr, raddr net.Addr, ch NewChannel) bool {
l.Lock()
defer l.Unlock()
for _, f := range l.entries {
if laddr.Network() == f.laddr.Network() && laddr.String() == f.laddr.String() {
f.c <- forward{newCh: ch, raddr: raddr}
return true
}
}
return false
}
type tcpListener struct {
laddr *net.TCPAddr
conn *Client
in <-chan forward
}
// Accept waits for and returns the next connection to the listener.
func (l *tcpListener) Accept() (net.Conn, error) {
s, ok := <-l.in
if !ok {
return nil, io.EOF
}
ch, incoming, err := s.newCh.Accept()
if err != nil {
return nil, err
}
go DiscardRequests(incoming)
return &chanConn{
Channel: ch,
laddr: l.laddr,
raddr: s.raddr,
}, nil
}
// Close closes the listener.
func (l *tcpListener) Close() error {
m := channelForwardMsg{
l.laddr.IP.String(),
uint32(l.laddr.Port),
}
// this also closes the listener.
l.conn.forwards.remove(l.laddr)
ok, _, err := l.conn.SendRequest("cancel-tcpip-forward", true, Marshal(&m))
if err == nil && !ok {
err = errors.New("ssh: cancel-tcpip-forward failed")
}
return err
}
// Addr returns the listener's network address.
func (l *tcpListener) Addr() net.Addr {
return l.laddr
}
// DialContext initiates a connection to the addr from the remote host.
//
// The provided Context must be non-nil. If the context expires before the
// connection is complete, an error is returned. Once successfully connected,
// any expiration of the context will not affect the connection.
//
// See func Dial for additional information.
func (c *Client) DialContext(ctx context.Context, n, addr string) (net.Conn, error) {
if err := ctx.Err(); err != nil {
return nil, err
}
type connErr struct {
conn net.Conn
err error
}
ch := make(chan connErr)
go func() {
conn, err := c.Dial(n, addr)
select {
case ch <- connErr{conn, err}:
case <-ctx.Done():
if conn != nil {
conn.Close()
}
}
}()
select {
case res := <-ch:
return res.conn, res.err
case <-ctx.Done():
return nil, ctx.Err()
}
}
// Dial initiates a connection to the addr from the remote host.
// The resulting connection has a zero LocalAddr() and RemoteAddr().
func (c *Client) Dial(n, addr string) (net.Conn, error) {
var ch Channel
switch n {
case "tcp", "tcp4", "tcp6":
// Parse the address into host and numeric port.
host, portString, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
port, err := strconv.ParseUint(portString, 10, 16)
if err != nil {
return nil, err
}
ch, err = c.dial(net.IPv4zero.String(), 0, host, int(port))
if err != nil {
return nil, err
}
// Use a zero address for local and remote address.
zeroAddr := &net.TCPAddr{
IP: net.IPv4zero,
Port: 0,
}
return &chanConn{
Channel: ch,
laddr: zeroAddr,
raddr: zeroAddr,
}, nil
case "unix":
var err error
ch, err = c.dialStreamLocal(addr)
if err != nil {
return nil, err
}
return &chanConn{
Channel: ch,
laddr: &net.UnixAddr{
Name: "@",
Net: "unix",
},
raddr: &net.UnixAddr{
Name: addr,
Net: "unix",
},
}, nil
default:
return nil, fmt.Errorf("ssh: unsupported protocol: %s", n)
}
}
// DialTCP connects to the remote address raddr on the network net,
// which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is used
// as the local address for the connection.
func (c *Client) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error) {
if laddr == nil {
laddr = &net.TCPAddr{
IP: net.IPv4zero,
Port: 0,
}
}
ch, err := c.dial(laddr.IP.String(), laddr.Port, raddr.IP.String(), raddr.Port)
if err != nil {
return nil, err
}
return &chanConn{
Channel: ch,
laddr: laddr,
raddr: raddr,
}, nil
}
// RFC 4254 7.2
type channelOpenDirectMsg struct {
raddr string
rport uint32
laddr string
lport uint32
}
func (c *Client) dial(laddr string, lport int, raddr string, rport int) (Channel, error) {
msg := channelOpenDirectMsg{
raddr: raddr,
rport: uint32(rport),
laddr: laddr,
lport: uint32(lport),
}
ch, in, err := c.OpenChannel("direct-tcpip", Marshal(&msg))
if err != nil {
return nil, err
}
go DiscardRequests(in)
return ch, err
}
type tcpChan struct {
Channel // the backing channel
}
// chanConn fulfills the net.Conn interface without
// the tcpChan having to hold laddr or raddr directly.
type chanConn struct {
Channel
laddr, raddr net.Addr
}
// LocalAddr returns the local network address.
func (t *chanConn) LocalAddr() net.Addr {
return t.laddr
}
// RemoteAddr returns the remote network address.
func (t *chanConn) RemoteAddr() net.Addr {
return t.raddr
}
// SetDeadline sets the read and write deadlines associated
// with the connection.
func (t *chanConn) SetDeadline(deadline time.Time) error {
if err := t.SetReadDeadline(deadline); err != nil {
return err
}
return t.SetWriteDeadline(deadline)
}
// SetReadDeadline sets the read deadline.
// A zero value for t means Read will not time out.
// After the deadline, the error from Read will implement net.Error
// with Timeout() == true.
func (t *chanConn) SetReadDeadline(deadline time.Time) error {
// for compatibility with previous version,
// the error message contains "tcpChan"
return errors.New("ssh: tcpChan: deadline not supported")
}
// SetWriteDeadline exists to satisfy the net.Conn interface
// but is not implemented by this type. It always returns an error.
func (t *chanConn) SetWriteDeadline(deadline time.Time) error {
return errors.New("ssh: tcpChan: deadline not supported")
}
================================================
FILE: gossh/crypto/ssh/terminal/terminal.go
================================================
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package terminal provides support functions for dealing with terminals, as
// commonly found on UNIX systems.
//
// Deprecated: this package moved to golang.org/x/term.
package terminal
import (
"io"
"gossh/term"
)
// EscapeCodes contains escape sequences that can be written to the terminal in
// order to achieve different styles of text.
type EscapeCodes = term.EscapeCodes
// Terminal contains the state for running a VT100 terminal that is capable of
// reading lines of input.
type Terminal = term.Terminal
// NewTerminal runs a VT100 terminal on the given ReadWriter. If the ReadWriter is
// a local terminal, that terminal must first have been put into raw mode.
// prompt is a string that is written at the start of each input line (i.e.
// "> ").
func NewTerminal(c io.ReadWriter, prompt string) *Terminal {
return term.NewTerminal(c, prompt)
}
// ErrPasteIndicator may be returned from ReadLine as the error, in addition
// to valid line data. It indicates that bracketed paste mode is enabled and
// that the returned line consists only of pasted data. Programs may wish to
// interpret pasted data more literally than typed data.
var ErrPasteIndicator = term.ErrPasteIndicator
// State contains the state of a terminal.
type State = term.State
// IsTerminal returns whether the given file descriptor is a terminal.
func IsTerminal(fd int) bool {
return term.IsTerminal(fd)
}
// ReadPassword reads a line of input from a terminal without local echo. This
// is commonly used for inputting passwords and other sensitive data. The slice
// returned does not include the \n.
func ReadPassword(fd int) ([]byte, error) {
return term.ReadPassword(fd)
}
// MakeRaw puts the terminal connected to the given file descriptor into raw
// mode and returns the previous state of the terminal so that it can be
// restored.
func MakeRaw(fd int) (*State, error) {
return term.MakeRaw(fd)
}
// Restore restores the terminal connected to the given file descriptor to a
// previous state.
func Restore(fd int, oldState *State) error {
return term.Restore(fd, oldState)
}
// GetState returns the current state of a terminal which may be useful to
// restore the terminal after a signal.
func GetState(fd int) (*State, error) {
return term.GetState(fd)
}
// GetSize returns the dimensions of the given terminal.
func GetSize(fd int) (width, height int, err error) {
return term.GetSize(fd)
}
================================================
FILE: gossh/crypto/ssh/transport.go
================================================
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bufio"
"bytes"
"errors"
"io"
"log"
)
// debugTransport if set, will print packet types as they go over the
// wire. No message decoding is done, to minimize the impact on timing.
const debugTransport = false
const (
gcm128CipherID = "aes128-gcm@openssh.com"
gcm256CipherID = "aes256-gcm@openssh.com"
aes128cbcID = "aes128-cbc"
tripledescbcID = "3des-cbc"
)
// packetConn represents a transport that implements packet based
// operations.
type packetConn interface {
// Encrypt and send a packet of data to the remote peer.
writePacket(packet []byte) error
// Read a packet from the connection. The read is blocking,
// i.e. if error is nil, then the returned byte slice is
// always non-empty.
readPacket() ([]byte, error)
// Close closes the write-side of the connection.
Close() error
}
// transport is the keyingTransport that implements the SSH packet
// protocol.
type transport struct {
reader connectionState
writer connectionState
bufReader *bufio.Reader
bufWriter *bufio.Writer
rand io.Reader
isClient bool
io.Closer
strictMode bool
initialKEXDone bool
}
// packetCipher represents a combination of SSH encryption/MAC
// protocol. A single instance should be used for one direction only.
type packetCipher interface {
// writeCipherPacket encrypts the packet and writes it to w. The
// contents of the packet are generally scrambled.
writeCipherPacket(seqnum uint32, w io.Writer, rand io.Reader, packet []byte) error
// readCipherPacket reads and decrypts a packet of data. The
// returned packet may be overwritten by future calls of
// readPacket.
readCipherPacket(seqnum uint32, r io.Reader) ([]byte, error)
}
// connectionState represents one side (read or write) of the
// connection. This is necessary because each direction has its own
// keys, and can even have its own algorithms
type connectionState struct {
packetCipher
seqNum uint32
dir direction
pendingKeyChange chan packetCipher
}
func (t *transport) setStrictMode() error {
if t.reader.seqNum != 1 {
return errors.New("ssh: sequence number != 1 when strict KEX mode requested")
}
t.strictMode = true
return nil
}
func (t *transport) setInitialKEXDone() {
t.initialKEXDone = true
}
// prepareKeyChange sets up key material for a keychange. The key changes in
// both directions are triggered by reading and writing a msgNewKey packet
// respectively.
func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) error {
ciph, err := newPacketCipher(t.reader.dir, algs.r, kexResult)
if err != nil {
return err
}
t.reader.pendingKeyChange <- ciph
ciph, err = newPacketCipher(t.writer.dir, algs.w, kexResult)
if err != nil {
return err
}
t.writer.pendingKeyChange <- ciph
return nil
}
func (t *transport) printPacket(p []byte, write bool) {
if len(p) == 0 {
return
}
who := "server"
if t.isClient {
who = "client"
}
what := "read"
if write {
what = "write"
}
log.Println(what, who, p[0])
}
// Read and decrypt next packet.
func (t *transport) readPacket() (p []byte, err error) {
for {
p, err = t.reader.readPacket(t.bufReader, t.strictMode)
if err != nil {
break
}
// in strict mode we pass through DEBUG and IGNORE packets only during the initial KEX
if len(p) == 0 || (t.strictMode && !t.initialKEXDone) || (p[0] != msgIgnore && p[0] != msgDebug) {
break
}
}
if debugTransport {
t.printPacket(p, false)
}
return p, err
}
func (s *connectionState) readPacket(r *bufio.Reader, strictMode bool) ([]byte, error) {
packet, err := s.packetCipher.readCipherPacket(s.seqNum, r)
s.seqNum++
if err == nil && len(packet) == 0 {
err = errors.New("ssh: zero length packet")
}
if len(packet) > 0 {
switch packet[0] {
case msgNewKeys:
select {
case cipher := <-s.pendingKeyChange:
s.packetCipher = cipher
if strictMode {
s.seqNum = 0
}
default:
return nil, errors.New("ssh: got bogus newkeys message")
}
case msgDisconnect:
// Transform a disconnect message into an
// error. Since this is lowest level at which
// we interpret message types, doing it here
// ensures that we don't have to handle it
// elsewhere.
var msg disconnectMsg
if err := Unmarshal(packet, &msg); err != nil {
return nil, err
}
return nil, &msg
}
}
// The packet may point to an internal buffer, so copy the
// packet out here.
fresh := make([]byte, len(packet))
copy(fresh, packet)
return fresh, err
}
func (t *transport) writePacket(packet []byte) error {
if debugTransport {
t.printPacket(packet, true)
}
return t.writer.writePacket(t.bufWriter, t.rand, packet, t.strictMode)
}
func (s *connectionState) writePacket(w *bufio.Writer, rand io.Reader, packet []byte, strictMode bool) error {
changeKeys := len(packet) > 0 && packet[0] == msgNewKeys
err := s.packetCipher.writeCipherPacket(s.seqNum, w, rand, packet)
if err != nil {
return err
}
if err = w.Flush(); err != nil {
return err
}
s.seqNum++
if changeKeys {
select {
case cipher := <-s.pendingKeyChange:
s.packetCipher = cipher
if strictMode {
s.seqNum = 0
}
default:
panic("ssh: no key material for msgNewKeys")
}
}
return err
}
func newTransport(rwc io.ReadWriteCloser, rand io.Reader, isClient bool) *transport {
t := &transport{
bufReader: bufio.NewReader(rwc),
bufWriter: bufio.NewWriter(rwc),
rand: rand,
reader: connectionState{
packetCipher: &streamPacketCipher{cipher: noneCipher{}},
pendingKeyChange: make(chan packetCipher, 1),
},
writer: connectionState{
packetCipher: &streamPacketCipher{cipher: noneCipher{}},
pendingKeyChange: make(chan packetCipher, 1),
},
Closer: rwc,
}
t.isClient = isClient
if isClient {
t.reader.dir = serverKeys
t.writer.dir = clientKeys
} else {
t.reader.dir = clientKeys
t.writer.dir = serverKeys
}
return t
}
type direction struct {
ivTag []byte
keyTag []byte
macKeyTag []byte
}
var (
serverKeys = direction{[]byte{'B'}, []byte{'D'}, []byte{'F'}}
clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}}
)
// setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as
// described in RFC 4253, section 6.4. direction should either be serverKeys
// (to setup server->client keys) or clientKeys (for client->server keys).
func newPacketCipher(d direction, algs directionAlgorithms, kex *kexResult) (packetCipher, error) {
cipherMode := cipherModes[algs.Cipher]
iv := make([]byte, cipherMode.ivSize)
key := make([]byte, cipherMode.keySize)
generateKeyMaterial(iv, d.ivTag, kex)
generateKeyMaterial(key, d.keyTag, kex)
var macKey []byte
if !aeadCiphers[algs.Cipher] {
macMode := macModes[algs.MAC]
macKey = make([]byte, macMode.keySize)
generateKeyMaterial(macKey, d.macKeyTag, kex)
}
return cipherModes[algs.Cipher].create(key, iv, macKey, algs)
}
// generateKeyMaterial fills out with key material generated from tag, K, H
// and sessionId, as specified in RFC 4253, section 7.2.
func generateKeyMaterial(out, tag []byte, r *kexResult) {
var digestsSoFar []byte
h := r.Hash.New()
for len(out) > 0 {
h.Reset()
h.Write(r.K)
h.Write(r.H)
if len(digestsSoFar) == 0 {
h.Write(tag)
h.Write(r.SessionID)
} else {
h.Write(digestsSoFar)
}
digest := h.Sum(nil)
n := copy(out, digest)
out = out[n:]
if len(out) > 0 {
digestsSoFar = append(digestsSoFar, digest...)
}
}
}
const packageVersion = "SSH-2.0-Go"
// Sends and receives a version line. The versionLine string should
// be US ASCII, start with "SSH-2.0-", and should not include a
// newline. exchangeVersions returns the other side's version line.
func exchangeVersions(rw io.ReadWriter, versionLine []byte) (them []byte, err error) {
// Contrary to the RFC, we do not ignore lines that don't
// start with "SSH-2.0-" to make the library usable with
// nonconforming servers.
for _, c := range versionLine {
// The spec disallows non US-ASCII chars, and
// specifically forbids null chars.
if c < 32 {
return nil, errors.New("ssh: junk character in version line")
}
}
if _, err = rw.Write(append(versionLine, '\r', '\n')); err != nil {
return
}
them, err = readVersion(rw)
return them, err
}
// maxVersionStringBytes is the maximum number of bytes that we'll
// accept as a version string. RFC 4253 section 4.2 limits this at 255
// chars
const maxVersionStringBytes = 255
// Read version string as specified by RFC 4253, section 4.2.
func readVersion(r io.Reader) ([]byte, error) {
versionString := make([]byte, 0, 64)
var ok bool
var buf [1]byte
for length := 0; length < maxVersionStringBytes; length++ {
_, err := io.ReadFull(r, buf[:])
if err != nil {
return nil, err
}
// The RFC says that the version should be terminated with \r\n
// but several SSH servers actually only send a \n.
if buf[0] == '\n' {
if !bytes.HasPrefix(versionString, []byte("SSH-")) {
// RFC 4253 says we need to ignore all version string lines
// except the one containing the SSH version (provided that
// all the lines do not exceed 255 bytes in total).
versionString = versionString[:0]
continue
}
ok = true
break
}
// non ASCII chars are disallowed, but we are lenient,
// since Go doesn't use null-terminated strings.
// The RFC allows a comment after a space, however,
// all of it (version and comments) goes into the
// session hash.
versionString = append(versionString, buf[0])
}
if !ok {
return nil, errors.New("ssh: overflow reading version string")
}
// There might be a '\r' on the end which we should remove.
if len(versionString) > 0 && versionString[len(versionString)-1] == '\r' {
versionString = versionString[:len(versionString)-1]
}
return versionString, nil
}
================================================
FILE: gossh/gin/auth.go
================================================
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package gin
import (
"crypto/subtle"
"encoding/base64"
"net/http"
"strconv"
"gossh/gin/internal/bytesconv"
)
// AuthUserKey is the cookie name for user credential in basic auth.
const AuthUserKey = "user"
// AuthProxyUserKey is the cookie name for proxy_user credential in basic auth for proxy.
const AuthProxyUserKey = "proxy_user"
// Accounts defines a key/value for user/pass list of authorized logins.
type Accounts map[string]string
type authPair struct {
value string
user string
}
type authPairs []authPair
func (a authPairs) searchCredential(authValue string) (string, bool) {
if authValue == "" {
return "", false
}
for _, pair := range a {
if subtle.ConstantTimeCompare(bytesconv.StringToBytes(pair.value), bytesconv.StringToBytes(authValue)) == 1 {
return pair.user, true
}
}
return "", false
}
// BasicAuthForRealm returns a Basic HTTP Authorization middleware. It takes as arguments a map[string]string where
// the key is the user name and the value is the password, as well as the name of the Realm.
// If the realm is empty, "Authorization Required" will be used by default.
// (see http://tools.ietf.org/html/rfc2617#section-1.2)
func BasicAuthForRealm(accounts Accounts, realm string) HandlerFunc {
if realm == "" {
realm = "Authorization Required"
}
realm = "Basic realm=" + strconv.Quote(realm)
pairs := processAccounts(accounts)
return func(c *Context) {
// Search user in the slice of allowed credentials
user, found := pairs.searchCredential(c.requestHeader("Authorization"))
if !found {
// Credentials doesn't match, we return 401 and abort handlers chain.
c.Header("WWW-Authenticate", realm)
c.AbortWithStatus(http.StatusUnauthorized)
return
}
// The user credentials was found, set user's id to key AuthUserKey in this context, the user's id can be read later using
// c.MustGet(gin.AuthUserKey).
c.Set(AuthUserKey, user)
}
}
// BasicAuth returns a Basic HTTP Authorization middleware. It takes as argument a map[string]string where
// the key is the user name and the value is the password.
func BasicAuth(accounts Accounts) HandlerFunc {
return BasicAuthForRealm(accounts, "")
}
func processAccounts(accounts Accounts) authPairs {
length := len(accounts)
assert1(length > 0, "Empty list of authorized credentials")
pairs := make(authPairs, 0, length)
for user, password := range accounts {
assert1(user != "", "User can not be empty")
value := authorizationHeader(user, password)
pairs = append(pairs, authPair{
value: value,
user: user,
})
}
return pairs
}
func authorizationHeader(user, password string) string {
base := user + ":" + password
return "Basic " + base64.StdEncoding.EncodeToString(bytesconv.StringToBytes(base))
}
// BasicAuthForProxy returns a Basic HTTP Proxy-Authorization middleware.
// If the realm is empty, "Proxy Authorization Required" will be used by default.
func BasicAuthForProxy(accounts Accounts, realm string) HandlerFunc {
if realm == "" {
realm = "Proxy Authorization Required"
}
realm = "Basic realm=" + strconv.Quote(realm)
pairs := processAccounts(accounts)
return func(c *Context) {
proxyUser, found := pairs.searchCredential(c.requestHeader("Proxy-Authorization"))
if !found {
// Credentials doesn't match, we return 407 and abort handlers chain.
c.Header("Proxy-Authenticate", realm)
c.AbortWithStatus(http.StatusProxyAuthRequired)
return
}
// The proxy_user credentials was found, set proxy_user's id to key AuthProxyUserKey in this context, the proxy_user's id can be read later using
// c.MustGet(gin.AuthProxyUserKey).
c.Set(AuthProxyUserKey, proxyUser)
}
}
================================================
FILE: gossh/gin/binding/binding.go
================================================
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package binding
import "net/http"
// Content-Type MIME of the most common data formats.
const (
MIMEJSON = "application/json"
MIMEHTML = "text/html"
MIMEXML = "application/xml"
MIMEXML2 = "text/xml"
MIMEPlain = "text/plain"
MIMEPOSTForm = "application/x-www-form-urlencoded"
MIMEMultipartPOSTForm = "multipart/form-data"
)
// Binding describes the interface which needs to be implemented for binding the
// data present in the request such as JSON request body, query parameters or
// the form POST.
type Binding interface {
Name() string
Bind(*http.Request, any) error
}
// BindingBody adds BindBody method to Binding. BindBody is similar with Bind,
// but it reads the body from supplied bytes instead of req.Body.
type BindingBody interface {
Binding
BindBody([]byte, any) error
}
// BindingUri adds BindUri method to Binding. BindUri is similar with Bind,
// but it reads the Params.
type BindingUri interface {
Name() string
BindUri(map[string][]string, any) error
}
// StructValidator is the minimal interface which needs to be implemented in
// order for it to be used as the validator engine for ensuring the correctness
// of the request. Gin provides a default implementation for this using
// https://github.com/go-playground/validator/tree/v10.6.1.
type StructValidator interface {
// ValidateStruct can receive any kind of type and it should never panic, even if the configuration is not right.
// If the received type is a slice|array, the validation should be performed travel on every element.
// If the received type is not a struct or slice|array, any validation should be skipped and nil must be returned.
// If the received type is a struct or pointer to a struct, the validation should be performed.
// If the struct is not valid or the validation itself fails, a descriptive error should be returned.
// Otherwise nil must be returned.
ValidateStruct(any) error
// Engine returns the underlying validator engine which powers the
// StructValidator implementation.
Engine() any
}
// Validator is the default validator which implements the StructValidator
// interface. It uses https://github.com/go-playground/validator/tree/v10.6.1
// under the hood.
var Validator StructValidator = &defaultValidator{}
// These implement the Binding interface and can be used to bind the data
// present in the request to struct instances.
var (
JSON BindingBody = jsonBinding{}
XML BindingBody = xmlBinding{}
Form Binding = formBinding{}
Query Binding = queryBinding{}
FormPost Binding = formPostBinding{}
FormMultipart Binding = formMultipartBinding{}
Uri BindingUri = uriBinding{}
Header Binding = headerBinding{}
)
// Default returns the appropriate Binding instance based on the HTTP method
// and the content type.
func Default(method, contentType string) Binding {
if method == http.MethodGet {
return Form
}
switch contentType {
case MIMEJSON:
return JSON
case MIMEXML, MIMEXML2:
return XML
case MIMEMultipartPOSTForm:
return FormMultipart
default: // case MIMEPOSTForm:
return Form
}
}
func validate(obj any) error {
if Validator == nil {
return nil
}
return Validator.ValidateStruct(obj)
}
================================================
FILE: gossh/gin/binding/default_validator.go
================================================
// Copyright 2017 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package binding
import (
"fmt"
"reflect"
"strings"
"sync"
"gossh/gin/validator"
)
type defaultValidator struct {
once sync.Once
validate *validator.Validate
}
type SliceValidationError []error
// Error concatenates all error elements in SliceValidationError into a single string separated by \n.
func (err SliceValidationError) Error() string {
n := len(err)
switch n {
case 0:
return ""
default:
var b strings.Builder
if err[0] != nil {
fmt.Fprintf(&b, "[%d]: %s", 0, err[0].Error())
}
if n > 1 {
for i := 1; i < n; i++ {
if err[i] != nil {
b.WriteString("\n")
fmt.Fprintf(&b, "[%d]: %s", i, err[i].Error())
}
}
}
return b.String()
}
}
var _ StructValidator = (*defaultValidator)(nil)
// ValidateStruct receives any kind of type, but only performed struct or pointer to struct type.
func (v *defaultValidator) ValidateStruct(obj any) error {
if obj == nil {
return nil
}
value := reflect.ValueOf(obj)
switch value.Kind() {
case reflect.Ptr:
if value.Elem().Kind() != reflect.Struct {
return v.ValidateStruct(value.Elem().Interface())
}
return v.validateStruct(obj)
case reflect.Struct:
return v.validateStruct(obj)
case reflect.Slice, reflect.Array:
count := value.Len()
validateRet := make(SliceValidationError, 0)
for i := 0; i < count; i++ {
if err := v.ValidateStruct(value.Index(i).Interface()); err != nil {
validateRet = append(validateRet, err)
}
}
if len(validateRet) == 0 {
return nil
}
return validateRet
default:
return nil
}
}
// validateStruct receives struct type
func (v *defaultValidator) validateStruct(obj any) error {
v.lazyinit()
return v.validate.Struct(obj)
}
// Engine returns the underlying validator engine which powers the default
// Validator instance. This is useful if you want to register custom validations
// or struct level validations. See validator GoDoc for more info -
// https://pkg.go.dev/github.com/go-playground/validator/v10
func (v *defaultValidator) Engine() any {
v.lazyinit()
return v.validate
}
func (v *defaultValidator) lazyinit() {
v.once.Do(func() {
v.validate = validator.New()
v.validate.SetTagName("binding")
})
}
================================================
FILE: gossh/gin/binding/form.go
================================================
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package binding
import (
"errors"
"net/http"
)
const defaultMemory = 32 << 20
type formBinding struct{}
type formPostBinding struct{}
type formMultipartBinding struct{}
func (formBinding) Name() string {
return "form"
}
func (formBinding) Bind(req *http.Request, obj any) error {
if err := req.ParseForm(); err != nil {
return err
}
if err := req.ParseMultipartForm(defaultMemory); err != nil && !errors.Is(err, http.ErrNotMultipart) {
return err
}
if err := mapForm(obj, req.Form); err != nil {
return err
}
return validate(obj)
}
func (formPostBinding) Name() string {
return "form-urlencoded"
}
func (formPostBinding) Bind(req *http.Request, obj any) error {
if err := req.ParseForm(); err != nil {
return err
}
if err := mapForm(obj, req.PostForm); err != nil {
return err
}
return validate(obj)
}
func (formMultipartBinding) Name() string {
return "multipart/form-data"
}
func (formMultipartBinding) Bind(req *http.Request, obj any) error {
if err := req.ParseMultipartForm(defaultMemory); err != nil {
return err
}
if err := mappingByPtr(obj, (*multipartRequest)(req), "form"); err != nil {
return err
}
return validate(obj)
}
================================================
FILE: gossh/gin/binding/form_mapping.go
================================================
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package binding
import (
"errors"
"fmt"
"mime/multipart"
"reflect"
"strconv"
"strings"
"time"
"gossh/gin/internal/bytesconv"
"gossh/gin/internal/json"
)
var (
errUnknownType = errors.New("unknown type")
// ErrConvertMapStringSlice can not convert to map[string][]string
ErrConvertMapStringSlice = errors.New("can not convert to map slices of strings")
// ErrConvertToMapString can not convert to map[string]string
ErrConvertToMapString = errors.New("can not convert to map of strings")
)
func mapURI(ptr any, m map[string][]string) error {
return mapFormByTag(ptr, m, "uri")
}
func mapForm(ptr any, form map[string][]string) error {
return mapFormByTag(ptr, form, "form")
}
func MapFormWithTag(ptr any, form map[string][]string, tag string) error {
return mapFormByTag(ptr, form, tag)
}
var emptyField = reflect.StructField{}
func mapFormByTag(ptr any, form map[string][]string, tag string) error {
// Check if ptr is a map
ptrVal := reflect.ValueOf(ptr)
var pointed any
if ptrVal.Kind() == reflect.Ptr {
ptrVal = ptrVal.Elem()
pointed = ptrVal.Interface()
}
if ptrVal.Kind() == reflect.Map &&
ptrVal.Type().Key().Kind() == reflect.String {
if pointed != nil {
ptr = pointed
}
return setFormMap(ptr, form)
}
return mappingByPtr(ptr, formSource(form), tag)
}
// setter tries to set value on a walking by fields of a struct
type setter interface {
TrySet(value reflect.Value, field reflect.StructField, key string, opt setOptions) (isSet bool, err error)
}
type formSource map[string][]string
var _ setter = formSource(nil)
// TrySet tries to set a value by request's form source (like map[string][]string)
func (form formSource) TrySet(value reflect.Value, field reflect.StructField, tagValue string, opt setOptions) (isSet bool, err error) {
return setByForm(value, field, form, tagValue, opt)
}
func mappingByPtr(ptr any, setter setter, tag string) error {
_, err := mapping(reflect.ValueOf(ptr), emptyField, setter, tag)
return err
}
func mapping(value reflect.Value, field reflect.StructField, setter setter, tag string) (bool, error) {
if field.Tag.Get(tag) == "-" { // just ignoring this field
return false, nil
}
vKind := value.Kind()
if vKind == reflect.Ptr {
var isNew bool
vPtr := value
if value.IsNil() {
isNew = true
vPtr = reflect.New(value.Type().Elem())
}
isSet, err := mapping(vPtr.Elem(), field, setter, tag)
if err != nil {
return false, err
}
if isNew && isSet {
value.Set(vPtr)
}
return isSet, nil
}
if vKind != reflect.Struct || !field.Anonymous {
ok, err := tryToSetValue(value, field, setter, tag)
if err != nil {
return false, err
}
if ok {
return true, nil
}
}
if vKind == reflect.Struct {
tValue := value.Type()
var isSet bool
for i := 0; i < value.NumField(); i++ {
sf := tValue.Field(i)
if sf.PkgPath != "" && !sf.Anonymous { // unexported
continue
}
ok, err := mapping(value.Field(i), sf, setter, tag)
if err != nil {
return false, err
}
isSet = isSet || ok
}
return isSet, nil
}
return false, nil
}
type setOptions struct {
isDefaultExists bool
defaultValue string
}
func tryToSetValue(value reflect.Value, field reflect.StructField, setter setter, tag string) (bool, error) {
var tagValue string
var setOpt setOptions
tagValue = field.Tag.Get(tag)
tagValue, opts := head(tagValue, ",")
if tagValue == "" { // default value is FieldName
tagValue = field.Name
}
if tagValue == "" { // when field is "emptyField" variable
return false, nil
}
var opt string
for len(opts) > 0 {
opt, opts = head(opts, ",")
if k, v := head(opt, "="); k == "default" {
setOpt.isDefaultExists = true
setOpt.defaultValue = v
}
}
return setter.TrySet(value, field, tagValue, setOpt)
}
func setByForm(value reflect.Value, field reflect.StructField, form map[string][]string, tagValue string, opt setOptions) (isSet bool, err error) {
vs, ok := form[tagValue]
if !ok && !opt.isDefaultExists {
return false, nil
}
switch value.Kind() {
case reflect.Slice:
if !ok {
vs = []string{opt.defaultValue}
}
return true, setSlice(vs, value, field)
case reflect.Array:
if !ok {
vs = []string{opt.defaultValue}
}
if len(vs) != value.Len() {
return false, fmt.Errorf("%q is not valid value for %s", vs, value.Type().String())
}
return true, setArray(vs, value, field)
default:
var val string
if !ok {
val = opt.defaultValue
}
if len(vs) > 0 {
val = vs[0]
}
return true, setWithProperType(val, value, field)
}
}
func setWithProperType(val string, value reflect.Value, field reflect.StructField) error {
switch value.Kind() {
case reflect.Int:
return setIntField(val, 0, value)
case reflect.Int8:
return setIntField(val, 8, value)
case reflect.Int16:
return setIntField(val, 16, value)
case reflect.Int32:
return setIntField(val, 32, value)
case reflect.Int64:
switch value.Interface().(type) {
case time.Duration:
return setTimeDuration(val, value)
}
return setIntField(val, 64, value)
case reflect.Uint:
return setUintField(val, 0, value)
case reflect.Uint8:
return setUintField(val, 8, value)
case reflect.Uint16:
return setUintField(val, 16, value)
case reflect.Uint32:
return setUintField(val, 32, value)
case reflect.Uint64:
return setUintField(val, 64, value)
case reflect.Bool:
return setBoolField(val, value)
case reflect.Float32:
return setFloatField(val, 32, value)
case reflect.Float64:
return setFloatField(val, 64, value)
case reflect.String:
value.SetString(val)
case reflect.Struct:
switch value.Interface().(type) {
case time.Time:
return setTimeField(val, field, value)
case multipart.FileHeader:
return nil
}
return json.Unmarshal(bytesconv.StringToBytes(val), value.Addr().Interface())
case reflect.Map:
return json.Unmarshal(bytesconv.StringToBytes(val), value.Addr().Interface())
case reflect.Ptr:
if !value.Elem().IsValid() {
value.Set(reflect.New(value.Type().Elem()))
}
return setWithProperType(val, value.Elem(), field)
default:
return errUnknownType
}
return nil
}
func setIntField(val string, bitSize int, field reflect.Value) error {
if val == "" {
val = "0"
}
intVal, err := strconv.ParseInt(val, 10, bitSize)
if err == nil {
field.SetInt(intVal)
}
return err
}
func setUintField(val string, bitSize int, field reflect.Value) error {
if val == "" {
val = "0"
}
uintVal, err := strconv.ParseUint(val, 10, bitSize)
if err == nil {
field.SetUint(uintVal)
}
return err
}
func setBoolField(val string, field reflect.Value) error {
if val == "" {
val = "false"
}
boolVal, err := strconv.ParseBool(val)
if err == nil {
field.SetBool(boolVal)
}
return err
}
func setFloatField(val string, bitSize int, field reflect.Value) error {
if val == "" {
val = "0.0"
}
floatVal, err := strconv.ParseFloat(val, bitSize)
if err == nil {
field.SetFloat(floatVal)
}
return err
}
func setTimeField(val string, structField reflect.StructField, value reflect.Value) error {
timeFormat := structField.Tag.Get("time_format")
if timeFormat == "" {
timeFormat = time.RFC3339
}
switch tf := strings.ToLower(timeFormat); tf {
case "unix", "unixnano":
tv, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return err
}
d := time.Duration(1)
if tf == "unixnano" {
d = time.Second
}
t := time.Unix(tv/int64(d), tv%int64(d))
value.Set(reflect.ValueOf(t))
return nil
}
if val == "" {
value.Set(reflect.ValueOf(time.Time{}))
return nil
}
l := time.Local
if isUTC, _ := strconv.ParseBool(structField.Tag.Get("time_utc")); isUTC {
l = time.UTC
}
if locTag := structField.Tag.Get("time_location"); locTag != "" {
loc, err := time.LoadLocation(locTag)
if err != nil {
return err
}
l = loc
}
t, err := time.ParseInLocation(timeFormat, val, l)
if err != nil {
return err
}
value.Set(reflect.ValueOf(t))
return nil
}
func setArray(vals []string, value reflect.Value, field reflect.StructField) error {
for i, s := range vals {
err := setWithProperType(s, value.Index(i), field)
if err != nil {
return err
}
}
return nil
}
func setSlice(vals []string, value reflect.Value, field reflect.StructField) error {
slice := reflect.MakeSlice(value.Type(), len(vals), len(vals))
err := setArray(vals, slice, field)
if err != nil {
return err
}
value.Set(slice)
return nil
}
func setTimeDuration(val string, value reflect.Value) error {
d, err := time.ParseDuration(val)
if err != nil {
return err
}
value.Set(reflect.ValueOf(d))
return nil
}
func head(str, sep string) (head string, tail string) {
idx := strings.Index(str, sep)
if idx < 0 {
return str, ""
}
return str[:idx], str[idx+len(sep):]
}
func setFormMap(ptr any, form map[string][]string) error {
el := reflect.TypeOf(ptr).Elem()
if el.Kind() == reflect.Slice {
ptrMap, ok := ptr.(map[string][]string)
if !ok {
return ErrConvertMapStringSlice
}
for k, v := range form {
ptrMap[k] = v
}
return nil
}
ptrMap, ok := ptr.(map[string]string)
if !ok {
return ErrConvertToMapString
}
for k, v := range form {
ptrMap[k] = v[len(v)-1] // pick last
}
return nil
}
================================================
FILE: gossh/gin/binding/header.go
================================================
// Copyright 2022 Gin Core Team. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package binding
import (
"net/http"
"net/textproto"
"reflect"
)
type headerBinding struct{}
func (headerBinding) Name() string {
return "header"
}
func (headerBinding) Bind(req *http.Request, obj any) error {
if err := mapHeader(obj, req.Header); err != nil {
return err
}
return validate(obj)
}
func mapHeader(ptr any, h map[string][]string) error {
return mappingByPtr(ptr, headerSource(h), "header")
}
type headerSource map[string][]string
var _ setter = headerSource(nil)
func (hs headerSource) TrySet(value reflect.Value, field reflect.StructField, tagValue string, opt setOptions) (bool, error) {
return setByForm(value, field, hs, textproto.CanonicalMIMEHeaderKey(tagValue), opt)
}
================================================
FILE: gossh/gin/binding/json.go
================================================
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package binding
import (
"bytes"
"errors"
"io"
"net/http"
"gossh/gin/internal/json"
)
// EnableDecoderUseNumber is used to call the UseNumber method on the JSON
// Decoder instance. UseNumber causes the Decoder to unmarshal a number into an
// any as a Number instead of as a float64.
var EnableDecoderUseNumber = false
// EnableDecoderDisallowUnknownFields is used to call the DisallowUnknownFields method
// on the JSON Decoder instance. DisallowUnknownFields causes the Decoder to
// return an error when the destination is a struct and the input contains object
// keys which do not match any non-ignored, exported fields in the destination.
var EnableDecoderDisallowUnknownFields = false
type jsonBinding struct{}
func (jsonBinding) Name() string {
return "json"
}
func (jsonBinding) Bind(req *http.Request, obj any) error {
if req == nil || req.Body == nil {
return errors.New("invalid request")
}
return decodeJSON(req.Body, obj)
}
func (jsonBinding) BindBody(body []byte, obj any) error {
return decodeJSON(bytes.NewReader(body), obj)
}
func decodeJSON(r io.Reader, obj any) error {
decoder := json.NewDecoder(r)
if EnableDecoderUseNumber {
decoder.UseNumber()
}
if EnableDecoderDisallowUnknownFields {
decoder.DisallowUnknownFields()
}
if err := decoder.Decode(obj); err != nil {
return err
}
return validate(obj)
}
================================================
FILE: gossh/gin/binding/multipart_form_mapping.go
================================================
// Copyright 2019 Gin Core Team. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package binding
import (
"errors"
"mime/multipart"
"net/http"
"reflect"
)
type multipartRequest http.Request
var _ setter = (*multipartRequest)(nil)
var (
// ErrMultiFileHeader multipart.FileHeader invalid
ErrMultiFileHeader = errors.New("unsupported field type for multipart.FileHeader")
// ErrMultiFileHeaderLenInvalid array for []*multipart.FileHeader len invalid
ErrMultiFileHeaderLenInvalid = errors.New("unsupported len of array for []*multipart.FileHeader")
)
// TrySet tries to set a value by the multipart request with the binding a form file
func (r *multipartRequest) TrySet(value reflect.Value, field reflect.StructField, key string, opt setOptions) (bool, error) {
if files := r.MultipartForm.File[key]; len(files) != 0 {
return setByMultipartFormFile(value, field, files)
}
return setByForm(value, field, r.MultipartForm.Value, key, opt)
}
func setByMultipartFormFile(value reflect.Value, field reflect.StructField, files []*multipart.FileHeader) (isSet bool, err error) {
switch value.Kind() {
case reflect.Ptr:
switch value.Interface().(type) {
case *multipart.FileHeader:
value.Set(reflect.ValueOf(files[0]))
return true, nil
}
case reflect.Struct:
switch value.Interface().(type) {
case multipart.FileHeader:
value.Set(reflect.ValueOf(*files[0]))
return true, nil
}
case reflect.Slice:
slice := reflect.MakeSlice(value.Type(), len(files), len(files))
isSet, err = setArrayOfMultipartFormFiles(slice, field, files)
if err != nil || !isSet {
return isSet, err
}
value.Set(slice)
return true, nil
case reflect.Array:
return setArrayOfMultipartFormFiles(value, field, files)
}
return false, ErrMultiFileHeader
}
func setArrayOfMultipartFormFiles(value reflect.Value, field reflect.StructField, files []*multipart.FileHeader) (isSet bool, err error) {
if value.Len() != len(files) {
return false, ErrMultiFileHeaderLenInvalid
}
for i := range files {
set, err := setByMultipartFormFile(value.Index(i), field, files[i:i+1])
if err != nil || !set {
return set, err
}
}
return true, nil
}
================================================
FILE: gossh/gin/binding/query.go
================================================
// Copyright 2017 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package binding
import "net/http"
type queryBinding struct{}
func (queryBinding) Name() string {
return "query"
}
func (queryBinding) Bind(req *http.Request, obj any) error {
values := req.URL.Query()
if err := mapForm(obj, values); err != nil {
return err
}
return validate(obj)
}
================================================
FILE: gossh/gin/binding/uri.go
================================================
// Copyright 2018 Gin Core Team. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package binding
type uriBinding struct{}
func (uriBinding) Name() string {
return "uri"
}
func (uriBinding) BindUri(m map[string][]string, obj any) error {
if err := mapURI(obj, m); err != nil {
return err
}
return validate(obj)
}
================================================
FILE: gossh/gin/binding/xml.go
================================================
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package binding
import (
"bytes"
"encoding/xml"
"io"
"net/http"
)
type xmlBinding struct{}
func (xmlBinding) Name() string {
return "xml"
}
func (xmlBinding) Bind(req *http.Request, obj any) error {
return decodeXML(req.Body, obj)
}
func (xmlBinding) BindBody(body []byte, obj any) error {
return decodeXML(bytes.NewReader(body), obj)
}
func decodeXML(r io.Reader, obj any) error {
decoder := xml.NewDecoder(r)
if err := decoder.Decode(obj); err != nil {
return err
}
return validate(obj)
}
================================================
FILE: gossh/gin/context.go
================================================
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package gin
import (
"errors"
"io"
"math"
"mime/multipart"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"sync"
"time"
"gossh/gin/binding"
"gossh/gin/render"
"gossh/gin/sse"
)
// Content-Type MIME of the most common data formats.
const (
MIMEJSON = binding.MIMEJSON
MIMEHTML = binding.MIMEHTML
MIMEXML = binding.MIMEXML
MIMEPlain = binding.MIMEPlain
MIMEPOSTForm = binding.MIMEPOSTForm
MIMEMultipartPOSTForm = binding.MIMEMultipartPOSTForm
)
// BodyBytesKey indicates a default body bytes key.
const BodyBytesKey = "_gin-gonic/gin/bodybyteskey"
// ContextKey is the key that a Context returns itself for.
const ContextKey = "_gin-gonic/gin/contextkey"
type ContextKeyType int
const ContextRequestKey ContextKeyType = 0
// abortIndex represents a typical value used in abort functions.
const abortIndex int8 = math.MaxInt8 >> 1
// Context is the most important part of gin. It allows us to pass variables between middleware,
// manage the flow, validate the JSON of a request and render a JSON response for example.
type Context struct {
writermem responseWriter
Request *http.Request
Writer ResponseWriter
Params Params
handlers HandlersChain
index int8
fullPath string
engine *Engine
params *Params
skippedNodes *[]skippedNode
// This mutex protects Keys map.
mu sync.RWMutex
// Keys is a key/value pair exclusively for the context of each request.
Keys map[string]any
// Errors is a list of errors attached to all the handlers/middlewares who used this context.
Errors errorMsgs
// Accepted defines a list of manually accepted formats for content negotiation.
Accepted []string
// queryCache caches the query result from c.Request.URL.Query().
queryCache url.Values
// formCache caches c.Request.PostForm, which contains the parsed form data from POST, PATCH,
// or PUT body parameters.
formCache url.Values
// SameSite allows a server to define a cookie attribute making it impossible for
// the browser to send this cookie along with cross-site requests.
sameSite http.SameSite
}
/************************************/
/********** CONTEXT CREATION ********/
/************************************/
func (c *Context) reset() {
c.Writer = &c.writermem
c.Params = c.Params[:0]
c.handlers = nil
c.index = -1
c.fullPath = ""
c.Keys = nil
c.Errors = c.Errors[:0]
c.Accepted = nil
c.queryCache = nil
c.formCache = nil
c.sameSite = 0
*c.params = (*c.params)[:0]
*c.skippedNodes = (*c.skippedNodes)[:0]
}
// Copy returns a copy of the current context that can be safely used outside the request's scope.
// This has to be used when the context has to be passed to a goroutine.
func (c *Context) Copy() *Context {
cp := Context{
writermem: c.writermem,
Request: c.Request,
engine: c.engine,
}
cp.writermem.ResponseWriter = nil
cp.Writer = &cp.writermem
cp.index = abortIndex
cp.handlers = nil
cp.fullPath = c.fullPath
cKeys := c.Keys
cp.Keys = make(map[string]any, len(cKeys))
c.mu.RLock()
for k, v := range cKeys {
cp.Keys[k] = v
}
c.mu.RUnlock()
cParams := c.Params
cp.Params = make([]Param, len(cParams))
copy(cp.Params, cParams)
return &cp
}
// HandlerName returns the main handler's name. For example if the handler is "handleGetUsers()",
// this function will return "main.handleGetUsers".
func (c *Context) HandlerName() string {
return nameOfFunction(c.handlers.Last())
}
// HandlerNames returns a list of all registered handlers for this context in descending order,
// following the semantics of HandlerName()
func (c *Context) HandlerNames() []string {
hn := make([]string, 0, len(c.handlers))
for _, val := range c.handlers {
hn = append(hn, nameOfFunction(val))
}
return hn
}
// Handler returns the main handler.
func (c *Context) Handler() HandlerFunc {
return c.handlers.Last()
}
// FullPath returns a matched route full path. For not found routes
// returns an empty string.
//
// router.GET("/user/:id", func(c *gin.Context) {
// c.FullPath() == "/user/:id" // true
// })
func (c *Context) FullPath() string {
return c.fullPath
}
/************************************/
/*********** FLOW CONTROL ***********/
/************************************/
// Next should be used only inside middleware.
// It executes the pending handlers in the chain inside the calling handler.
// See example in GitHub.
func (c *Context) Next() {
c.index++
for c.index < int8(len(c.handlers)) {
c.handlers[c.index](c)
c.index++
}
}
// IsAborted returns true if the current context was aborted.
func (c *Context) IsAborted() bool {
return c.index >= abortIndex
}
// Abort prevents pending handlers from being called. Note that this will not stop the current handler.
// Let's say you have an authorization middleware that validates that the current request is authorized.
// If the authorization fails (ex: the password does not match), call Abort to ensure the remaining handlers
// for this request are not called.
func (c *Context) Abort() {
c.index = abortIndex
}
// AbortWithStatus calls `Abort()` and writes the headers with the specified status code.
// For example, a failed attempt to authenticate a request could use: context.AbortWithStatus(401).
func (c *Context) AbortWithStatus(code int) {
c.Status(code)
c.Writer.WriteHeaderNow()
c.Abort()
}
// AbortWithStatusJSON calls `Abort()` and then `JSON` internally.
// This method stops the chain, writes the status code and return a JSON body.
// It also sets the Content-Type as "application/json".
func (c *Context) AbortWithStatusJSON(code int, jsonObj any) {
c.Abort()
c.JSON(code, jsonObj)
}
// AbortWithError calls `AbortWithStatus()` and `Error()` internally.
// This method stops the chain, writes the status code and pushes the specified error to `c.Errors`.
// See Context.Error() for more details.
func (c *Context) AbortWithError(code int, err error) *Error {
c.AbortWithStatus(code)
return c.Error(err)
}
/************************************/
/********* ERROR MANAGEMENT *********/
/************************************/
// Error attaches an error to the current context. The error is pushed to a list of errors.
// It's a good idea to call Error for each error that occurred during the resolution of a request.
// A middleware can be used to collect all the errors and push them to a database together,
// print a log, or append it in the HTTP response.
// Error will panic if err is nil.
func (c *Context) Error(err error) *Error {
if err == nil {
panic("err is nil")
}
var parsedError *Error
ok := errors.As(err, &parsedError)
if !ok {
parsedError = &Error{
Err: err,
Type: ErrorTypePrivate,
}
}
c.Errors = append(c.Errors, parsedError)
return parsedError
}
/************************************/
/******** METADATA MANAGEMENT********/
/************************************/
// Set is used to store a new key/value pair exclusively for this context.
// It also lazy initializes c.Keys if it was not used previously.
func (c *Context) Set(key string, value any) {
c.mu.Lock()
defer c.mu.Unlock()
if c.Keys == nil {
c.Keys = make(map[string]any)
}
c.Keys[key] = value
}
// Get returns the value for the given key, ie: (value, true).
// If the value does not exist it returns (nil, false)
func (c *Context) Get(key string) (value any, exists bool) {
c.mu.RLock()
defer c.mu.RUnlock()
value, exists = c.Keys[key]
return
}
// MustGet returns the value for the given key if it exists, otherwise it panics.
func (c *Context) MustGet(key string) any {
if value, exists := c.Get(key); exists {
return value
}
panic("Key \"" + key + "\" does not exist")
}
// GetString returns the value associated with the key as a string.
func (c *Context) GetString(key string) (s string) {
if val, ok := c.Get(key); ok && val != nil {
s, _ = val.(string)
}
return
}
// GetBool returns the value associated with the key as a boolean.
func (c *Context) GetBool(key string) (b bool) {
if val, ok := c.Get(key); ok && val != nil {
b, _ = val.(bool)
}
return
}
// GetInt returns the value associated with the key as an integer.
func (c *Context) GetInt(key string) (i int) {
if val, ok := c.Get(key); ok && val != nil {
i, _ = val.(int)
}
return
}
// GetInt64 returns the value associated with the key as an integer.
func (c *Context) GetInt64(key string) (i64 int64) {
if val, ok := c.Get(key); ok && val != nil {
i64, _ = val.(int64)
}
return
}
// GetUint returns the value associated with the key as an unsigned integer.
func (c *Context) GetUint(key string) (ui uint) {
if val, ok := c.Get(key); ok && val != nil {
ui, _ = val.(uint)
}
return
}
// GetUint64 returns the value associated with the key as an unsigned integer.
func (c *Context) GetUint64(key string) (ui64 uint64) {
if val, ok := c.Get(key); ok && val != nil {
ui64, _ = val.(uint64)
}
return
}
// GetFloat64 returns the value associated with the key as a float64.
func (c *Context) GetFloat64(key string) (f64 float64) {
if val, ok := c.Get(key); ok && val != nil {
f64, _ = val.(float64)
}
return
}
// GetTime returns the value associated with the key as time.
func (c *Context) GetTime(key string) (t time.Time) {
if val, ok := c.Get(key); ok && val != nil {
t, _ = val.(time.Time)
}
return
}
// GetDuration returns the value associated with the key as a duration.
func (c *Context) GetDuration(key string) (d time.Duration) {
if val, ok := c.Get(key); ok && val != nil {
d, _ = val.(time.Duration)
}
return
}
// GetStringSlice returns the value associated with the key as a slice of strings.
func (c *Context) GetStringSlice(key string) (ss []string) {
if val, ok := c.Get(key); ok && val != nil {
ss, _ = val.([]string)
}
return
}
// GetStringMap returns the value associated with the key as a map of interfaces.
func (c *Context) GetStringMap(key string) (sm map[string]any) {
if val, ok := c.Get(key); ok && val != nil {
sm, _ = val.(map[string]any)
}
return
}
// GetStringMapString returns the value associated with the key as a map of strings.
func (c *Context) GetStringMapString(key string) (sms map[string]string) {
if val, ok := c.Get(key); ok && val != nil {
sms, _ = val.(map[string]string)
}
return
}
// GetStringMapStringSlice returns the value associated with the key as a map to a slice of strings.
func (c *Context) GetStringMapStringSlice(key string) (smss map[string][]string) {
if val, ok := c.Get(key); ok && val != nil {
smss, _ = val.(map[string][]string)
}
return
}
/************************************/
/************ INPUT DATA ************/
/************************************/
// Param returns the value of the URL param.
// It is a shortcut for c.Params.ByName(key)
//
// router.GET("/user/:id", func(c *gin.Context) {
// // a GET request to /user/john
// id := c.Param("id") // id == "john"
// // a GET request to /user/john/
// id := c.Param("id") // id == "/john/"
// })
func (c *Context) Param(key string) string {
return c.Params.ByName(key)
}
// AddParam adds param to context and
// replaces path param key with given value for e2e testing purposes
// Example Route: "/user/:id"
// AddParam("id", 1)
// Result: "/user/1"
func (c *Context) AddParam(key, value string) {
c.Params = append(c.Params, Param{Key: key, Value: value})
}
// Query returns the keyed url query value if it exists,
// otherwise it returns an empty string `("")`.
// It is shortcut for `c.Request.URL.Query().Get(key)`
//
// GET /path?id=1234&name=Manu&value=
// c.Query("id") == "1234"
// c.Query("name") == "Manu"
// c.Query("value") == ""
// c.Query("wtf") == ""
func (c *Context) Query(key string) (value string) {
value, _ = c.GetQuery(key)
return
}
// DefaultQuery returns the keyed url query value if it exists,
// otherwise it returns the specified defaultValue string.
// See: Query() and GetQuery() for further information.
//
// GET /?name=Manu&lastname=
// c.DefaultQuery("name", "unknown") == "Manu"
// c.DefaultQuery("id", "none") == "none"
// c.DefaultQuery("lastname", "none") == ""
func (c *Context) DefaultQuery(key, defaultValue string) string {
if value, ok := c.GetQuery(key); ok {
return value
}
return defaultValue
}
// GetQuery is like Query(), it returns the keyed url query value
// if it exists `(value, true)` (even when the value is an empty string),
// otherwise it returns `("", false)`.
// It is shortcut for `c.Request.URL.Query().Get(key)`
//
// GET /?name=Manu&lastname=
// ("Manu", true) == c.GetQuery("name")
// ("", false) == c.GetQuery("id")
// ("", true) == c.GetQuery("lastname")
func (c *Context) GetQuery(key string) (string, bool) {
if values, ok := c.GetQueryArray(key); ok {
return values[0], ok
}
return "", false
}
// QueryArray returns a slice of strings for a given query key.
// The length of the slice depends on the number of params with the given key.
func (c *Context) QueryArray(key string) (values []string) {
values, _ = c.GetQueryArray(key)
return
}
func (c *Context) initQueryCache() {
if c.queryCache == nil {
if c.Request != nil {
c.queryCache = c.Request.URL.Query()
} else {
c.queryCache = url.Values{}
}
}
}
// GetQueryArray returns a slice of strings for a given query key, plus
// a boolean value whether at least one value exists for the given key.
func (c *Context) GetQueryArray(key string) (values []string, ok bool) {
c.initQueryCache()
values, ok = c.queryCache[key]
return
}
// QueryMap returns a map for a given query key.
func (c *Context) QueryMap(key string) (dicts map[string]string) {
dicts, _ = c.GetQueryMap(key)
return
}
// GetQueryMap returns a map for a given query key, plus a boolean value
// whether at least one value exists for the given key.
func (c *Context) GetQueryMap(key string) (map[string]string, bool) {
c.initQueryCache()
return c.get(c.queryCache, key)
}
// PostForm returns the specified key from a POST urlencoded form or multipart form
// when it exists, otherwise it returns an empty string `("")`.
func (c *Context) PostForm(key string) (value string) {
value, _ = c.GetPostForm(key)
return
}
// DefaultPostForm returns the specified key from a POST urlencoded form or multipart form
// when it exists, otherwise it returns the specified defaultValue string.
// See: PostForm() and GetPostForm() for further information.
func (c *Context) DefaultPostForm(key, defaultValue string) string {
if value, ok := c.GetPostForm(key); ok {
return value
}
return defaultValue
}
// GetPostForm is like PostForm(key). It returns the specified key from a POST urlencoded
// form or multipart form when it exists `(value, true)` (even when the value is an empty string),
// otherwise it returns ("", false).
// For example, during a PATCH request to update the user's email:
//
// email=mail@example.com --> ("mail@example.com", true) := GetPostForm("email") // set email to "mail@example.com"
// email= --> ("", true) := GetPostForm("email") // set email to ""
// --> ("", false) := GetPostForm("email") // do nothing with email
func (c *Context) GetPostForm(key string) (string, bool) {
if values, ok := c.GetPostFormArray(key); ok {
return values[0], ok
}
return "", false
}
// PostFormArray returns a slice of strings for a given form key.
// The length of the slice depends on the number of params with the given key.
func (c *Context) PostFormArray(key string) (values []string) {
values, _ = c.GetPostFormArray(key)
return
}
func (c *Context) initFormCache() {
if c.formCache == nil {
c.formCache = make(url.Values)
req := c.Request
if err := req.ParseMultipartForm(c.engine.MaxMultipartMemory); err != nil {
if !errors.Is(err, http.ErrNotMultipart) {
debugPrint("error on parse multipart form array: %v", err)
}
}
c.formCache = req.PostForm
}
}
// GetPostFormArray returns a slice of strings for a given form key, plus
// a boolean value whether at least one value exists for the given key.
func (c *Context) GetPostFormArray(key string) (values []string, ok bool) {
c.initFormCache()
values, ok = c.formCache[key]
return
}
// PostFormMap returns a map for a given form key.
func (c *Context) PostFormMap(key string) (dicts map[string]string) {
dicts, _ = c.GetPostFormMap(key)
return
}
// GetPostFormMap returns a map for a given form key, plus a boolean value
// whether at least one value exists for the given key.
func (c *Context) GetPostFormMap(key string) (map[string]string, bool) {
c.initFormCache()
return c.get(c.formCache, key)
}
// get is an internal method and returns a map which satisfies conditions.
func (c *Context) get(m map[string][]string, key string) (map[string]string, bool) {
dicts := make(map[string]string)
exist := false
for k, v := range m {
if i := strings.IndexByte(k, '['); i >= 1 && k[0:i] == key {
if j := strings.IndexByte(k[i+1:], ']'); j >= 1 {
exist = true
dicts[k[i+1:][:j]] = v[0]
}
}
}
return dicts, exist
}
// FormFile returns the first file for the provided form key.
func (c *Context) FormFile(name string) (*multipart.FileHeader, error) {
if c.Request.MultipartForm == nil {
if err := c.Request.ParseMultipartForm(c.engine.MaxMultipartMemory); err != nil {
return nil, err
}
}
f, fh, err := c.Request.FormFile(name)
if err != nil {
return nil, err
}
f.Close()
return fh, err
}
// MultipartForm is the parsed multipart form, including file uploads.
func (c *Context) MultipartForm() (*multipart.Form, error) {
err := c.Request.ParseMultipartForm(c.engine.MaxMultipartMemory)
return c.Request.MultipartForm, err
}
// SaveUploadedFile uploads the form file to specific dst.
func (c *Context) SaveUploadedFile(file *multipart.FileHeader, dst string) error {
src, err := file.Open()
if err != nil {
return err
}
defer src.Close()
if err = os.MkdirAll(filepath.Dir(dst), 0750); err != nil {
return err
}
out, err := os.Create(dst)
if err != nil {
return err
}
defer out.Close()
_, err = io.Copy(out, src)
return err
}
// Bind checks the Method and Content-Type to select a binding engine automatically,
// Depending on the "Content-Type" header different bindings are used, for example:
//
// "application/json" --> JSON binding
// "application/xml" --> XML binding
//
// It parses the request's body as JSON if Content-Type == "application/json" using JSON or XML as a JSON input.
// It decodes the json payload into the struct specified as a pointer.
// It writes a 400 error and sets Content-Type header "text/plain" in the response if input is not valid.
func (c *Context) Bind(obj any) error {
b := binding.Default(c.Request.Method, c.ContentType())
return c.MustBindWith(obj, b)
}
// BindJSON is a shortcut for c.MustBindWith(obj, binding.JSON).
func (c *Context) BindJSON(obj any) error {
return c.MustBindWith(obj, binding.JSON)
}
// BindXML is a shortcut for c.MustBindWith(obj, binding.BindXML).
func (c *Context) BindXML(obj any) error {
return c.MustBindWith(obj, binding.XML)
}
// BindQuery is a shortcut for c.MustBindWith(obj, binding.Query).
func (c *Context) BindQuery(obj any) error {
return c.MustBindWith(obj, binding.Query)
}
// BindHeader is a shortcut for c.MustBindWith(obj, binding.Header).
func (c *Context) BindHeader(obj any) error {
return c.MustBindWith(obj, binding.Header)
}
// BindUri binds the passed struct pointer using binding.Uri.
// It will abort the request with HTTP 400 if any error occurs.
func (c *Context) BindUri(obj any) error {
if err := c.ShouldBindUri(obj); err != nil {
c.AbortWithError(http.StatusBadRequest, err).SetType(ErrorTypeBind) //nolint: errcheck
return err
}
return nil
}
// MustBindWith binds the passed struct pointer using the specified binding engine.
// It will abort the request with HTTP 400 if any error occurs.
// See the binding package.
func (c *Context) MustBindWith(obj any, b binding.Binding) error {
if err := c.ShouldBindWith(obj, b); err != nil {
c.AbortWithError(http.StatusBadRequest, err).SetType(ErrorTypeBind) //nolint: errcheck
return err
}
return nil
}
// ShouldBind checks the Method and Content-Type to select a binding engine automatically,
// Depending on the "Content-Type" header different bindings are used, for example:
//
// "application/json" --> JSON binding
// "application/xml" --> XML binding
//
// It parses the request's body as JSON if Content-Type == "application/json" using JSON or XML as a JSON input.
// It decodes the json payload into the struct specified as a pointer.
// Like c.Bind() but this method does not set the response status code to 400 or abort if input is not valid.
func (c *Context) ShouldBind(obj any) error {
b := binding.Default(c.Request.Method, c.ContentType())
return c.ShouldBindWith(obj, b)
}
// ShouldBindJSON is a shortcut for c.ShouldBindWith(obj, binding.JSON).
func (c *Context) ShouldBindJSON(obj any) error {
return c.ShouldBindWith(obj, binding.JSON)
}
// ShouldBindXML is a shortcut for c.ShouldBindWith(obj, binding.XML).
func (c *Context) ShouldBindXML(obj any) error {
return c.ShouldBindWith(obj, binding.XML)
}
// ShouldBindQuery is a shortcut for c.ShouldBindWith(obj, binding.Query).
func (c *Context) ShouldBindQuery(obj any) error {
return c.ShouldBindWith(obj, binding.Query)
}
// ShouldBindHeader is a shortcut for c.ShouldBindWith(obj, binding.Header).
func (c *Context) ShouldBindHeader(obj any) error {
return c.ShouldBindWith(obj, binding.Header)
}
// ShouldBindUri binds the passed struct pointer using the specified binding engine.
func (c *Context) ShouldBindUri(obj any) error {
m := make(map[string][]string)
for _, v := range c.Params {
m[v.Key] = []string{v.Value}
}
return binding.Uri.BindUri(m, obj)
}
// ShouldBindWith binds the passed struct pointer using the specified binding engine.
// See the binding package.
func (c *Context) ShouldBindWith(obj any, b binding.Binding) error {
return b.Bind(c.Request, obj)
}
// ShouldBindBodyWith is similar with ShouldBindWith, but it stores the request
// body into the context, and reuse when it is called again.
//
// NOTE: This method reads the body before binding. So you should use
// ShouldBindWith for better performance if you need to call only once.
func (c *Context) ShouldBindBodyWith(obj any, bb binding.BindingBody) (err error) {
var body []byte
if cb, ok := c.Get(BodyBytesKey); ok {
if cbb, ok := cb.([]byte); ok {
body = cbb
}
}
if body == nil {
body, err = io.ReadAll(c.Request.Body)
if err != nil {
return err
}
c.Set(BodyBytesKey, body)
}
return bb.BindBody(body, obj)
}
// ShouldBindBodyWithJSON is a shortcut for c.ShouldBindBodyWith(obj, binding.JSON).
func (c *Context) ShouldBindBodyWithJSON(obj any) error {
return c.ShouldBindBodyWith(obj, binding.JSON)
}
// ShouldBindBodyWithXML is a shortcut for c.ShouldBindBodyWith(obj, binding.XML).
func (c *Context) ShouldBindBodyWithXML(obj any) error {
return c.ShouldBindBodyWith(obj, binding.XML)
}
// ClientIP implements one best effort algorithm to return the real client IP.
// It calls c.RemoteIP() under the hood, to check if the remote IP is a trusted proxy or not.
// If it is it will then try to parse the headers defined in Engine.RemoteIPHeaders (defaulting to [X-Forwarded-For, X-Real-Ip]).
// If the headers are not syntactically valid OR the remote IP does not correspond to a trusted proxy,
// the remote IP (coming from Request.RemoteAddr) is returned.
func (c *Context) ClientIP() string {
// Check if we're running on a trusted platform, continue running backwards if error
if c.engine.TrustedPlatform != "" {
// Developers can define their own header of Trusted Platform or use predefined constants
if addr := c.requestHeader(c.engine.TrustedPlatform); addr != "" {
return addr
}
}
// It also checks if the remoteIP is a trusted proxy or not.
// In order to perform this validation, it will see if the IP is contained within at least one of the CIDR blocks
// defined by Engine.SetTrustedProxies()
remoteIP := net.ParseIP(c.RemoteIP())
if remoteIP == nil {
return ""
}
trusted := c.engine.isTrustedProxy(remoteIP)
if trusted && c.engine.ForwardedByClientIP && c.engine.RemoteIPHeaders != nil {
for _, headerName := range c.engine.RemoteIPHeaders {
ip, valid := c.engine.validateHeader(c.requestHeader(headerName))
if valid {
return ip
}
}
}
return remoteIP.String()
}
// RemoteIP parses the IP from Request.RemoteAddr, normalizes and returns the IP (without the port).
func (c *Context) RemoteIP() string {
ip, _, err := net.SplitHostPort(strings.TrimSpace(c.Request.RemoteAddr))
if err != nil {
return ""
}
return ip
}
// ContentType returns the Content-Type header of the request.
func (c *Context) ContentType() string {
return filterFlags(c.requestHeader("Content-Type"))
}
// IsWebsocket returns true if the request headers indicate that a websocket
// handshake is being initiated by the client.
func (c *Context) IsWebsocket() bool {
if strings.Contains(strings.ToLower(c.requestHeader("Connection")), "upgrade") &&
strings.EqualFold(c.requestHeader("Upgrade"), "websocket") {
return true
}
return false
}
func (c *Context) requestHeader(key string) string {
return c.Request.Header.Get(key)
}
/************************************/
/******** RESPONSE RENDERING ********/
/************************************/
// bodyAllowedForStatus is a copy of http.bodyAllowedForStatus non-exported function.
func bodyAllowedForStatus(status int) bool {
switch {
case status >= 100 && status <= 199:
return false
case status == http.StatusNoContent:
return false
case status == http.StatusNotModified:
return false
}
return true
}
// Status sets the HTTP response code.
func (c *Context) Status(code int) {
c.Writer.WriteHeader(code)
}
// Header is an intelligent shortcut for c.Writer.Header().Set(key, value).
// It writes a header in the response.
// If value == "", this method removes the header `c.Writer.Header().Del(key)`
func (c *Context) Header(key, value string) {
if value == "" {
c.Writer.Header().Del(key)
return
}
c.Writer.Header().Set(key, value)
}
// GetHeader returns value from request headers.
func (c *Context) GetHeader(key string) string {
return c.requestHeader(key)
}
// GetRawData returns stream data.
func (c *Context) GetRawData() ([]byte, error) {
if c.Request.Body == nil {
return nil, errors.New("cannot read nil body")
}
return io.ReadAll(c.Request.Body)
}
// SetSameSite with cookie
func (c *Context) SetSameSite(samesite http.SameSite) {
c.sameSite = samesite
}
// SetCookie adds a Set-Cookie header to the ResponseWriter's headers.
// The provided cookie must have a valid Name. Invalid cookies may be
// silently dropped.
func (c *Context) SetCookie(name, value string, maxAge int, path, domain string, secure, httpOnly bool) {
if path == "" {
path = "/"
}
http.SetCookie(c.Writer, &http.Cookie{
Name: name,
Value: url.QueryEscape(value),
MaxAge: maxAge,
Path: path,
Domain: domain,
SameSite: c.sameSite,
Secure: secure,
HttpOnly: httpOnly,
})
}
// Cookie returns the named cookie provided in the request or
// ErrNoCookie if not found. And return the named cookie is unescaped.
// If multiple cookies match the given name, only one cookie will
// be returned.
func (c *Context) Cookie(name string) (string, error) {
cookie, err := c.Request.Cookie(name)
if err != nil {
return "", err
}
val, _ := url.QueryUnescape(cookie.Value)
return val, nil
}
// Render writes the response headers and calls render.Render to render data.
func (c *Context) Render(code int, r render.Render) {
c.Status(code)
if !bodyAllowedForStatus(code) {
r.WriteContentType(c.Writer)
c.Writer.WriteHeaderNow()
return
}
if err := r.Render(c.Writer); err != nil {
// Pushing error to c.Errors
_ = c.Error(err)
c.Abort()
}
}
// HTML renders the HTTP template specified by its file name.
// It also updates the HTTP code and sets the Content-Type as "text/html".
// See http://golang.org/doc/articles/wiki/
func (c *Context) HTML(code int, name string, obj any) {
instance := c.engine.HTMLRender.Instance(name, obj)
c.Render(code, instance)
}
// IndentedJSON serializes the given struct as pretty JSON (indented + endlines) into the response body.
// It also sets the Content-Type as "application/json".
// WARNING: we recommend using this only for development purposes since printing pretty JSON is
// more CPU and bandwidth consuming. Use Context.JSON() instead.
func (c *Context) IndentedJSON(code int, obj any) {
c.Render(code, render.IndentedJSON{Data: obj})
}
// SecureJSON serializes the given struct as Secure JSON into the response body.
// Default prepends "while(1)," to response body if the given struct is array values.
// It also sets the Content-Type as "application/json".
func (c *Context) SecureJSON(code int, obj any) {
c.Render(code, render.SecureJSON{Prefix: c.engine.secureJSONPrefix, Data: obj})
}
// JSONP serializes the given struct as JSON into the response body.
// It adds padding to response body to request data from a server residing in a different domain than the client.
// It also sets the Content-Type as "application/javascript".
func (c *Context) JSONP(code int, obj any) {
callback := c.DefaultQuery("callback", "")
if callback == "" {
c.Render(code, render.JSON{Data: obj})
return
}
c.Render(code, render.JsonpJSON{Callback: callback, Data: obj})
}
// JSON serializes the given struct as JSON into the response body.
// It also sets the Content-Type as "application/json".
func (c *Context) JSON(code int, obj any) {
c.Render(code, render.JSON{Data: obj})
}
// AsciiJSON serializes the given struct as JSON into the response body with unicode to ASCII string.
// It also sets the Content-Type as "application/json".
func (c *Context) AsciiJSON(code int, obj any) {
c.Render(code, render.AsciiJSON{Data: obj})
}
// PureJSON serializes the given struct as JSON into the response body.
// PureJSON, unlike JSON, does not replace special html characters with their unicode entities.
func (c *Context) PureJSON(code int, obj any) {
c.Render(code, render.PureJSON{Data: obj})
}
// XML serializes the given struct as XML into the response body.
// It also sets the Content-Type as "application/xml".
func (c *Context) XML(code int, obj any) {
c.Render(code, render.XML{Data: obj})
}
// String writes the given string into the response body.
func (c *Context) String(code int, format string, values ...any) {
c.Render(code, render.String{Format: format, Data: values})
}
// Redirect returns an HTTP redirect to the specific location.
func (c *Context) Redirect(code int, location string) {
c.Render(-1, render.Redirect{
Code: code,
Location: location,
Request: c.Request,
})
}
// Data writes some data into the body stream and updates the HTTP code.
func (c *Context) Data(code int, contentType string, data []byte) {
c.Render(code, render.Data{
ContentType: contentType,
Data: data,
})
}
// DataFromReader writes the specified reader into the body stream and updates the HTTP code.
func (c *Context) DataFromReader(code int, contentLength int64, contentType string, reader io.Reader, extraHeaders map[string]string) {
c.Render(code, render.Reader{
Headers: extraHeaders,
ContentType: contentType,
ContentLength: contentLength,
Reader: reader,
})
}
// File writes the specified file into the body stream in an efficient way.
func (c *Context) File(filepath string) {
http.ServeFile(c.Writer, c.Request, filepath)
}
// FileFromFS writes the specified file from http.FileSystem into the body stream in an efficient way.
func (c *Context) FileFromFS(filepath string, fs http.FileSystem) {
defer func(old string) {
c.Request.URL.Path = old
}(c.Request.URL.Path)
c.Request.URL.Path = filepath
http.FileServer(fs).ServeHTTP(c.Writer, c.Request)
}
var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"")
func escapeQuotes(s string) string {
return quoteEscaper.Replace(s)
}
// FileAttachment writes the specified file into the body stream in an efficient way
// On the client side, the file will typically be downloaded with the given filename
func (c *Context) FileAttachment(filepath, filename string) {
if isASCII(filename) {
c.Writer.Header().Set("Content-Disposition", `attachment; filename="`+escapeQuotes(filename)+`"`)
} else {
c.Writer.Header().Set("Content-Disposition", `attachment; filename*=UTF-8''`+url.QueryEscape(filename))
}
http.ServeFile(c.Writer, c.Request, filepath)
}
// SSEvent writes a Server-Sent Event into the body stream.
func (c *Context) SSEvent(name string, message any) {
c.Render(-1, sse.Event{
Event: name,
Data: message,
})
}
// Stream sends a streaming response and returns a boolean
// indicates "Is client disconnected in middle of stream"
func (c *Context) Stream(step func(w io.Writer) bool) bool {
w := c.Writer
clientGone := w.CloseNotify()
for {
select {
case <-clientGone:
return true
default:
keepOpen := step(w)
w.Flush()
if !keepOpen {
return false
}
}
}
}
/************************************/
/******** CONTENT NEGOTIATION *******/
/************************************/
// Negotiate contains all negotiations data.
type Negotiate struct {
Offered []string
HTMLName string
HTMLData any
JSONData any
XMLData any
Data any
}
// Negotiate calls different Render according to acceptable Accept format.
func (c *Context) Negotiate(code int, config Negotiate) {
switch c.NegotiateFormat(config.Offered...) {
case binding.MIMEJSON:
data := chooseData(config.JSONData, config.Data)
c.JSON(code, data)
case binding.MIMEHTML:
data := chooseData(config.HTMLData, config.Data)
c.HTML(code, config.HTMLName, data)
case binding.MIMEXML:
data := chooseData(config.XMLData, config.Data)
c.XML(code, data)
default:
c.AbortWithError(http.StatusNotAcceptable, errors.New("the accepted formats are not offered by the server")) //nolint: errcheck
}
}
// NegotiateFormat returns an acceptable Accept format.
func (c *Context) NegotiateFormat(offered ...string) string {
assert1(len(offered) > 0, "you must provide at least one offer")
if c.Accepted == nil {
c.Accepted = parseAccept(c.requestHeader("Accept"))
}
if len(c.Accepted) == 0 {
return offered[0]
}
for _, accepted := range c.Accepted {
for _, offer := range offered {
// According to RFC 2616 and RFC 2396, non-ASCII characters are not allowed in headers,
// therefore we can just iterate over the string without casting it into []rune
i := 0
for ; i < len(accepted) && i < len(offer); i++ {
if accepted[i] == '*' || offer[i] == '*' {
return offer
}
if accepted[i] != offer[i] {
break
}
}
if i == len(accepted) {
return offer
}
}
}
return ""
}
// SetAccepted sets Accept header data.
func (c *Context) SetAccepted(formats ...string) {
c.Accepted = formats
}
/************************************/
/***** GOLANG.ORG/X/NET/CONTEXT *****/
/************************************/
// hasRequestContext returns whether c.Request has Context and fallback.
func (c *Context) hasRequestContext() bool {
hasFallback := c.engine != nil && c.engine.ContextWithFallback
hasRequestContext := c.Request != nil && c.Request.Context() != nil
return hasFallback && hasRequestContext
}
// Deadline returns that there is no deadline (ok==false) when c.Request has no Context.
func (c *Context) Deadline() (deadline time.Time, ok bool) {
if !c.hasRequestContext() {
return
}
return c.Request.Context().Deadline()
}
// Done returns nil (chan which will wait forever) when c.Request has no Context.
func (c *Context) Done() <-chan struct{} {
if !c.hasRequestContext() {
return nil
}
return c.Request.Context().Done()
}
// Err returns nil when c.Request has no Context.
func (c *Context) Err() error {
if !c.hasRequestContext() {
return nil
}
return c.Request.Context().Err()
}
// Value returns the value associated with this context for key, or nil
// if no value is associated with key. Successive calls to Value with
// the same key returns the same result.
func (c *Context) Value(key any) any {
if key == ContextRequestKey {
return c.Request
}
if key == ContextKey {
return c
}
if keyAsString, ok := key.(string); ok {
if val, exists := c.Get(keyAsString); exists {
return val
}
}
if !c.hasRequestContext() {
return nil
}
return c.Request.Context().Value(key)
}
================================================
FILE: gossh/gin/debug.go
================================================
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package gin
import (
"fmt"
"html/template"
"runtime"
"strconv"
"strings"
)
const ginSupportMinGoVer = 18
// IsDebugging returns true if the framework is running in debug mode.
// Use SetMode(gin.ReleaseMode) to disable debug mode.
func IsDebugging() bool {
return ginMode == debugCode
}
// DebugPrintRouteFunc indicates debug log output format.
var DebugPrintRouteFunc func(httpMethod, absolutePath, handlerName string, nuHandlers int)
// DebugPrintFunc indicates debug log output format.
var DebugPrintFunc func(format string, values ...any)
func debugPrintRoute(httpMethod, absolutePath string, handlers HandlersChain) {
if IsDebugging() {
nuHandlers := len(handlers)
handlerName := nameOfFunction(handlers.Last())
if DebugPrintRouteFunc == nil {
debugPrint("%-6s %-25s --> %s (%d handlers)\n", httpMethod, absolutePath, handlerName, nuHandlers)
} else {
DebugPrintRouteFunc(httpMethod, absolutePath, handlerName, nuHandlers)
}
}
}
func debugPrintLoadTemplate(tmpl *template.Template) {
if IsDebugging() {
var buf strings.Builder
for _, tmpl := range tmpl.Templates() {
buf.WriteString("\t- ")
buf.WriteString(tmpl.Name())
buf.WriteString("\n")
}
debugPrint("Loaded HTML Templates (%d): \n%s\n", len(tmpl.Templates()), buf.String())
}
}
func debugPrint(format string, values ...any) {
if !IsDebugging() {
return
}
if DebugPrintFunc != nil {
DebugPrintFunc(format, values...)
return
}
if !strings.HasSuffix(format, "\n") {
format += "\n"
}
fmt.Fprintf(DefaultWriter, "[GIN-debug] "+format, values...)
}
func getMinVer(v string) (uint64, error) {
first := strings.IndexByte(v, '.')
last := strings.LastIndexByte(v, '.')
if first == last {
return strconv.ParseUint(v[first+1:], 10, 64)
}
return strconv.ParseUint(v[first+1:last], 10, 64)
}
func debugPrintWARNINGDefault() {
if v, e := getMinVer(runtime.Version()); e == nil && v < ginSupportMinGoVer {
debugPrint(`[WARNING] Now Gin requires Go 1.18+.
`)
}
debugPrint(`[WARNING] Creating an Engine instance with the Logger and Recovery middleware already attached.
`)
}
func debugPrintWARNINGNew() {
debugPrint(`[WARNING] Running in "debug" mode. Switch to "release" mode in production.
- using env: export GIN_MODE=release
- using code: gin.SetMode(gin.ReleaseMode)
`)
}
func debugPrintWARNINGSetHTMLTemplate() {
debugPrint(`[WARNING] Since SetHTMLTemplate() is NOT thread-safe. It should only be called
at initialization. ie. before any route is registered or the router is listening in a socket:
router := gin.Default()
router.SetHTMLTemplate(template) // << good place
`)
}
func debugPrintError(err error) {
if err != nil && IsDebugging() {
fmt.Fprintf(DefaultErrorWriter, "[GIN-debug] [ERROR] %v\n", err)
}
}
================================================
FILE: gossh/gin/errors.go
================================================
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package gin
import (
"fmt"
"reflect"
"strings"
"gossh/gin/internal/json"
)
// ErrorType is an unsigned 64-bit error code as defined in the gin spec.
type ErrorType uint64
const (
// ErrorTypeBind is used when Context.Bind() fails.
ErrorTypeBind ErrorType = 1 << 63
// ErrorTypeRender is used when Context.Render() fails.
ErrorTypeRender ErrorType = 1 << 62
// ErrorTypePrivate indicates a private error.
ErrorTypePrivate ErrorType = 1 << 0
// ErrorTypePublic indicates a public error.
ErrorTypePublic ErrorType = 1 << 1
// ErrorTypeAny indicates any other error.
ErrorTypeAny ErrorType = 1<<64 - 1
// ErrorTypeNu indicates any other error.
ErrorTypeNu = 2
)
// Error represents a error's specification.
type Error struct {
Err error
Type ErrorType
Meta any
}
type errorMsgs []*Error
var _ error = (*Error)(nil)
// SetType sets the error's type.
func (msg *Error) SetType(flags ErrorType) *Error {
msg.Type = flags
return msg
}
// SetMeta sets the error's meta data.
func (msg *Error) SetMeta(data any) *Error {
msg.Meta = data
return msg
}
// JSON creates a properly formatted JSON
func (msg *Error) JSON() any {
jsonData := H{}
if msg.Meta != nil {
value := reflect.ValueOf(msg.Meta)
switch value.Kind() {
case reflect.Struct:
return msg.Meta
case reflect.Map:
for _, key := range value.MapKeys() {
jsonData[key.String()] = value.MapIndex(key).Interface()
}
default:
jsonData["meta"] = msg.Meta
}
}
if _, ok := jsonData["error"]; !ok {
jsonData["error"] = msg.Error()
}
return jsonData
}
// MarshalJSON implements the json.Marshaller interface.
func (msg *Error) MarshalJSON() ([]byte, error) {
return json.Marshal(msg.JSON())
}
// Error implements the error interface.
func (msg Error) Error() string {
return msg.Err.Error()
}
// IsType judges one error.
func (msg *Error) IsType(flags ErrorType) bool {
return (msg.Type & flags) > 0
}
// Unwrap returns the wrapped error, to allow interoperability with errors.Is(), errors.As() and errors.Unwrap()
func (msg *Error) Unwrap() error {
return msg.Err
}
// ByType returns a readonly copy filtered the byte.
// ie ByType(gin.ErrorTypePublic) returns a slice of errors with type=ErrorTypePublic.
func (a errorMsgs) ByType(typ ErrorType) errorMsgs {
if len(a) == 0 {
return nil
}
if typ == ErrorTypeAny {
return a
}
var result errorMsgs
for _, msg := range a {
if msg.IsType(typ) {
result = append(result, msg)
}
}
return result
}
// Last returns the last error in the slice. It returns nil if the array is empty.
// Shortcut for errors[len(errors)-1].
func (a errorMsgs) Last() *Error {
if length := len(a); length > 0 {
return a[length-1]
}
return nil
}
// Errors returns an array with all the error messages.
// Example:
//
// c.Error(errors.New("first"))
// c.Error(errors.New("second"))
// c.Error(errors.New("third"))
// c.Errors.Errors() // == []string{"first", "second", "third"}
func (a errorMsgs) Errors() []string {
if len(a) == 0 {
return nil
}
errorStrings := make([]string, len(a))
for i, err := range a {
errorStrings[i] = err.Error()
}
return errorStrings
}
func (a errorMsgs) JSON() any {
switch length := len(a); length {
case 0:
return nil
case 1:
return a.Last().JSON()
default:
jsonData := make([]any, length)
for i, err := range a {
jsonData[i] = err.JSON()
}
return jsonData
}
}
// MarshalJSON implements the json.Marshaller interface.
func (a errorMsgs) MarshalJSON() ([]byte, error) {
return json.Marshal(a.JSON())
}
func (a errorMsgs) String() string {
if len(a) == 0 {
return ""
}
var buffer strings.Builder
for i, msg := range a {
fmt.Fprintf(&buffer, "Error #%02d: %s\n", i+1, msg.Err)
if msg.Meta != nil {
fmt.Fprintf(&buffer, " Meta: %v\n", msg.Meta)
}
}
return buffer.String()
}
================================================
FILE: gossh/gin/fs.go
================================================
// Copyright 2017 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package gin
import (
"net/http"
"os"
)
type onlyFilesFS struct {
fs http.FileSystem
}
type neuteredReaddirFile struct {
http.File
}
// Dir returns a http.FileSystem that can be used by http.FileServer(). It is used internally
// in router.Static().
// if listDirectory == true, then it works the same as http.Dir() otherwise it returns
// a filesystem that prevents http.FileServer() to list the directory files.
func Dir(root string, listDirectory bool) http.FileSystem {
fs := http.Dir(root)
if listDirectory {
return fs
}
return &onlyFilesFS{fs}
}
// Open conforms to http.Filesystem.
func (fs onlyFilesFS) Open(name string) (http.File, error) {
f, err := fs.fs.Open(name)
if err != nil {
return nil, err
}
return neuteredReaddirFile{f}, nil
}
// Readdir overrides the http.File default implementation.
func (f neuteredReaddirFile) Readdir(_ int) ([]os.FileInfo, error) {
// this disables directory listing
return nil, nil
}
================================================
FILE: gossh/gin/gin.go
================================================
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package gin
import (
"fmt"
"html/template"
"net"
"net/http"
"os"
"path"
"regexp"
"strings"
"sync"
"gossh/gin/internal/bytesconv"
"gossh/gin/render"
)
const defaultMultipartMemory = 32 << 20 // 32 MB
var (
default404Body = []byte("404 page not found")
default405Body = []byte("405 method not allowed")
)
var defaultPlatform string
var defaultTrustedCIDRs = []*net.IPNet{
{ // 0.0.0.0/0 (IPv4)
IP: net.IP{0x0, 0x0, 0x0, 0x0},
Mask: net.IPMask{0x0, 0x0, 0x0, 0x0},
},
{ // ::/0 (IPv6)
IP: net.IP{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
Mask: net.IPMask{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
},
}
var regSafePrefix = regexp.MustCompile("[^a-zA-Z0-9/-]+")
var regRemoveRepeatedChar = regexp.MustCompile("/{2,}")
// HandlerFunc defines the handler used by gin middleware as return value.
type HandlerFunc func(*Context)
// OptionFunc defines the function to change the default configuration
type OptionFunc func(*Engine)
// HandlersChain defines a HandlerFunc slice.
type HandlersChain []HandlerFunc
// Last returns the last handler in the chain. i.e. the last handler is the main one.
func (c HandlersChain) Last() HandlerFunc {
if length := len(c); length > 0 {
return c[length-1]
}
return nil
}
// RouteInfo represents a request route's specification which contains method and path and its handler.
type RouteInfo struct {
Method string
Path string
Handler string
HandlerFunc HandlerFunc
}
// RoutesInfo defines a RouteInfo slice.
type RoutesInfo []RouteInfo
// Engine is the framework's instance, it contains the muxer, middleware and configuration settings.
// Create an instance of Engine, by using New() or Default()
type Engine struct {
RouterGroup
// RedirectTrailingSlash enables automatic redirection if the current route can't be matched but a
// handler for the path with (without) the trailing slash exists.
// For example if /foo/ is requested but a route only exists for /foo, the
// client is redirected to /foo with http status code 301 for GET requests
// and 307 for all other request methods.
RedirectTrailingSlash bool
// RedirectFixedPath if enabled, the router tries to fix the current request path, if no
// handle is registered for it.
// First superfluous path elements like ../ or // are removed.
// Afterwards the router does a case-insensitive lookup of the cleaned path.
// If a handle can be found for this route, the router makes a redirection
// to the corrected path with status code 301 for GET requests and 307 for
// all other request methods.
// For example /FOO and /..//Foo could be redirected to /foo.
// RedirectTrailingSlash is independent of this option.
RedirectFixedPath bool
// HandleMethodNotAllowed if enabled, the router checks if another method is allowed for the
// current route, if the current request can not be routed.
// If this is the case, the request is answered with 'Method Not Allowed'
// and HTTP status code 405.
// If no other Method is allowed, the request is delegated to the NotFound
// handler.
HandleMethodNotAllowed bool
// ForwardedByClientIP if enabled, client IP will be parsed from the request's headers that
// match those stored at `(*gin.Engine).RemoteIPHeaders`. If no IP was
// fetched, it falls back to the IP obtained from
// `(*gin.Context).Request.RemoteAddr`.
ForwardedByClientIP bool
// UseRawPath if enabled, the url.RawPath will be used to find parameters.
UseRawPath bool
// UnescapePathValues if true, the path value will be unescaped.
// If UseRawPath is false (by default), the UnescapePathValues effectively is true,
// as url.Path gonna be used, which is already unescaped.
UnescapePathValues bool
// RemoveExtraSlash a parameter can be parsed from the URL even with extra slashes.
// See the PR #1817 and issue #1644
RemoveExtraSlash bool
// RemoteIPHeaders list of headers used to obtain the client IP when
// `(*gin.Engine).ForwardedByClientIP` is `true` and
// `(*gin.Context).Request.RemoteAddr` is matched by at least one of the
// network origins of list defined by `(*gin.Engine).SetTrustedProxies()`.
RemoteIPHeaders []string
// TrustedPlatform if set to a constant of value gin.Platform*, trusts the headers set by
// that platform, for example to determine the client IP
TrustedPlatform string
// MaxMultipartMemory value of 'maxMemory' param that is given to http.Request's ParseMultipartForm
// method call.
MaxMultipartMemory int64
// ContextWithFallback enable fallback Context.Deadline(), Context.Done(), Context.Err() and Context.Value() when Context.Request.Context() is not nil.
ContextWithFallback bool
delims render.Delims
secureJSONPrefix string
HTMLRender render.HTMLRender
FuncMap template.FuncMap
allNoRoute HandlersChain
allNoMethod HandlersChain
noRoute HandlersChain
noMethod HandlersChain
pool sync.Pool
trees methodTrees
maxParams uint16
maxSections uint16
trustedProxies []string
trustedCIDRs []*net.IPNet
}
var _ IRouter = (*Engine)(nil)
// New returns a new blank Engine instance without any middleware attached.
// By default, the configuration is:
// - RedirectTrailingSlash: true
// - RedirectFixedPath: false
// - HandleMethodNotAllowed: false
// - ForwardedByClientIP: true
// - UseRawPath: false
// - UnescapePathValues: true
func New(opts ...OptionFunc) *Engine {
debugPrintWARNINGNew()
engine := &Engine{
RouterGroup: RouterGroup{
Handlers: nil,
basePath: "/",
root: true,
},
FuncMap: template.FuncMap{},
RedirectTrailingSlash: true,
RedirectFixedPath: false,
HandleMethodNotAllowed: false,
ForwardedByClientIP: true,
RemoteIPHeaders: []string{"X-Forwarded-For", "X-Real-IP"},
TrustedPlatform: defaultPlatform,
UseRawPath: false,
RemoveExtraSlash: false,
UnescapePathValues: true,
MaxMultipartMemory: defaultMultipartMemory,
trees: make(methodTrees, 0, 9),
delims: render.Delims{Left: "{{", Right: "}}"},
secureJSONPrefix: "while(1);",
trustedProxies: []string{"0.0.0.0/0", "::/0"},
trustedCIDRs: defaultTrustedCIDRs,
}
engine.RouterGroup.engine = engine
engine.pool.New = func() any {
return engine.allocateContext(engine.maxParams)
}
return engine.With(opts...)
}
// Default returns an Engine instance with the Logger and Recovery middleware already attached.
func Default(opts ...OptionFunc) *Engine {
debugPrintWARNINGDefault()
engine := New()
engine.Use(Logger(), Recovery())
return engine.With(opts...)
}
func (engine *Engine) Handler() http.Handler {
return engine
}
func (engine *Engine) allocateContext(maxParams uint16) *Context {
v := make(Params, 0, maxParams)
skippedNodes := make([]skippedNode, 0, engine.maxSections)
return &Context{engine: engine, params: &v, skippedNodes: &skippedNodes}
}
// Delims sets template left and right delims and returns an Engine instance.
func (engine *Engine) Delims(left, right string) *Engine {
engine.delims = render.Delims{Left: left, Right: right}
return engine
}
// SecureJsonPrefix sets the secureJSONPrefix used in Context.SecureJSON.
func (engine *Engine) SecureJsonPrefix(prefix string) *Engine {
engine.secureJSONPrefix = prefix
return engine
}
// LoadHTMLGlob loads HTML files identified by glob pattern
// and associates the result with HTML renderer.
func (engine *Engine) LoadHTMLGlob(pattern string) {
left := engine.delims.Left
right := engine.delims.Right
templ := template.Must(template.New("").Delims(left, right).Funcs(engine.FuncMap).ParseGlob(pattern))
if IsDebugging() {
debugPrintLoadTemplate(templ)
engine.HTMLRender = render.HTMLDebug{Glob: pattern, FuncMap: engine.FuncMap, Delims: engine.delims}
return
}
engine.SetHTMLTemplate(templ)
}
// LoadHTMLFiles loads a slice of HTML files
// and associates the result with HTML renderer.
func (engine *Engine) LoadHTMLFiles(files ...string) {
if IsDebugging() {
engine.HTMLRender = render.HTMLDebug{Files: files, FuncMap: engine.FuncMap, Delims: engine.delims}
return
}
templ := template.Must(template.New("").Delims(engine.delims.Left, engine.delims.Right).Funcs(engine.FuncMap).ParseFiles(files...))
engine.SetHTMLTemplate(templ)
}
// SetHTMLTemplate associate a template with HTML renderer.
func (engine *Engine) SetHTMLTemplate(templ *template.Template) {
if len(engine.trees) > 0 {
debugPrintWARNINGSetHTMLTemplate()
}
engine.HTMLRender = render.HTMLProduction{Template: templ.Funcs(engine.FuncMap)}
}
// SetFuncMap sets the FuncMap used for template.FuncMap.
func (engine *Engine) SetFuncMap(funcMap template.FuncMap) {
engine.FuncMap = funcMap
}
// NoRoute adds handlers for NoRoute. It returns a 404 code by default.
func (engine *Engine) NoRoute(handlers ...HandlerFunc) {
engine.noRoute = handlers
engine.rebuild404Handlers()
}
// NoMethod sets the handlers called when Engine.HandleMethodNotAllowed = true.
func (engine *Engine) NoMethod(handlers ...HandlerFunc) {
engine.noMethod = handlers
engine.rebuild405Handlers()
}
// Use attaches a global middleware to the router. i.e. the middleware attached through Use() will be
// included in the handlers chain for every single request. Even 404, 405, static files...
// For example, this is the right place for a logger or error management middleware.
func (engine *Engine) Use(middleware ...HandlerFunc) IRoutes {
engine.RouterGroup.Use(middleware...)
engine.rebuild404Handlers()
engine.rebuild405Handlers()
return engine
}
// With returns a new Engine instance with the provided options.
func (engine *Engine) With(opts ...OptionFunc) *Engine {
for _, opt := range opts {
opt(engine)
}
return engine
}
func (engine *Engine) rebuild404Handlers() {
engine.allNoRoute = engine.combineHandlers(engine.noRoute)
}
func (engine *Engine) rebuild405Handlers() {
engine.allNoMethod = engine.combineHandlers(engine.noMethod)
}
func (engine *Engine) addRoute(method, path string, handlers HandlersChain) {
assert1(path[0] == '/', "path must begin with '/'")
assert1(method != "", "HTTP method can not be empty")
assert1(len(handlers) > 0, "there must be at least one handler")
debugPrintRoute(method, path, handlers)
root := engine.trees.get(method)
if root == nil {
root = new(node)
root.fullPath = "/"
engine.trees = append(engine.trees, methodTree{method: method, root: root})
}
root.addRoute(path, handlers)
if paramsCount := countParams(path); paramsCount > engine.maxParams {
engine.maxParams = paramsCount
}
if sectionsCount := countSections(path); sectionsCount > engine.maxSections {
engine.maxSections = sectionsCount
}
}
// Routes returns a slice of registered routes, including some useful information, such as:
// the http method, path and the handler name.
func (engine *Engine) Routes() (routes RoutesInfo) {
for _, tree := range engine.trees {
routes = iterate("", tree.method, routes, tree.root)
}
return routes
}
func iterate(path, method string, routes RoutesInfo, root *node) RoutesInfo {
path += root.path
if len(root.handlers) > 0 {
handlerFunc := root.handlers.Last()
routes = append(routes, RouteInfo{
Method: method,
Path: path,
Handler: nameOfFunction(handlerFunc),
HandlerFunc: handlerFunc,
})
}
for _, child := range root.children {
routes = iterate(path, method, routes, child)
}
return routes
}
// Run attaches the router to a http.Server and starts listening and serving HTTP requests.
// It is a shortcut for http.ListenAndServe(addr, router)
// Note: this method will block the calling goroutine indefinitely unless an error happens.
func (engine *Engine) Run(addr ...string) (err error) {
defer func() { debugPrintError(err) }()
if engine.isUnsafeTrustedProxies() {
debugPrint("[WARNING] You trusted all proxies, this is NOT safe. We recommend you to set a value.\n" +
"Please check https://pkg.go.dev/gossh/gin#readme-don-t-trust-all-proxies for details.")
}
address := resolveAddress(addr)
debugPrint("Listening and serving HTTP on %s\n", address)
err = http.ListenAndServe(address, engine.Handler())
return
}
func (engine *Engine) prepareTrustedCIDRs() ([]*net.IPNet, error) {
if engine.trustedProxies == nil {
return nil, nil
}
cidr := make([]*net.IPNet, 0, len(engine.trustedProxies))
for _, trustedProxy := range engine.trustedProxies {
if !strings.Contains(trustedProxy, "/") {
ip := parseIP(trustedProxy)
if ip == nil {
return cidr, &net.ParseError{Type: "IP address", Text: trustedProxy}
}
switch len(ip) {
case net.IPv4len:
trustedProxy += "/32"
case net.IPv6len:
trustedProxy += "/128"
}
}
_, cidrNet, err := net.ParseCIDR(trustedProxy)
if err != nil {
return cidr, err
}
cidr = append(cidr, cidrNet)
}
return cidr, nil
}
// SetTrustedProxies set a list of network origins (IPv4 addresses,
// IPv4 CIDRs, IPv6 addresses or IPv6 CIDRs) from which to trust
// request's headers that contain alternative client IP when
// `(*gin.Engine).ForwardedByClientIP` is `true`. `TrustedProxies`
// feature is enabled by default, and it also trusts all proxies
// by default. If you want to disable this feature, use
// Engine.SetTrustedProxies(nil), then Context.ClientIP() will
// return the remote address directly.
func (engine *Engine) SetTrustedProxies(trustedProxies []string) error {
engine.trustedProxies = trustedProxies
return engine.parseTrustedProxies()
}
// isUnsafeTrustedProxies checks if Engine.trustedCIDRs contains all IPs, it's not safe if it has (returns true)
func (engine *Engine) isUnsafeTrustedProxies() bool {
return engine.isTrustedProxy(net.ParseIP("0.0.0.0")) || engine.isTrustedProxy(net.ParseIP("::"))
}
// parseTrustedProxies parse Engine.trustedProxies to Engine.trustedCIDRs
func (engine *Engine) parseTrustedProxies() error {
trustedCIDRs, err := engine.prepareTrustedCIDRs()
engine.trustedCIDRs = trustedCIDRs
return err
}
// isTrustedProxy will check whether the IP address is included in the trusted list according to Engine.trustedCIDRs
func (engine *Engine) isTrustedProxy(ip net.IP) bool {
if engine.trustedCIDRs == nil {
return false
}
for _, cidr := range engine.trustedCIDRs {
if cidr.Contains(ip) {
return true
}
}
return false
}
// validateHeader will parse X-Forwarded-For header and return the trusted client IP address
func (engine *Engine) validateHeader(header string) (clientIP string, valid bool) {
if header == "" {
return "", false
}
items := strings.Split(header, ",")
for i := len(items) - 1; i >= 0; i-- {
ipStr := strings.TrimSpace(items[i])
ip := net.ParseIP(ipStr)
if ip == nil {
break
}
// X-Forwarded-For is appended by proxy
// Check IPs in reverse order and stop when find untrusted proxy
if (i == 0) || (!engine.isTrustedProxy(ip)) {
return ipStr, true
}
}
return "", false
}
// parseIP parse a string representation of an IP and returns a net.IP with the
// minimum byte representation or nil if input is invalid.
func parseIP(ip string) net.IP {
parsedIP := net.ParseIP(ip)
if ipv4 := parsedIP.To4(); ipv4 != nil {
// return ip in a 4-byte representation
return ipv4
}
// return ip in a 16-byte representation or nil
return parsedIP
}
// RunTLS attaches the router to a http.Server and starts listening and serving HTTPS (secure) requests.
// It is a shortcut for http.ListenAndServeTLS(addr, certFile, keyFile, router)
// Note: this method will block the calling goroutine indefinitely unless an error happens.
func (engine *Engine) RunTLS(addr, certFile, keyFile string) (err error) {
debugPrint("Listening and serving HTTPS on %s\n", addr)
defer func() { debugPrintError(err) }()
if engine.isUnsafeTrustedProxies() {
debugPrint("[WARNING] You trusted all proxies, this is NOT safe. We recommend you to set a value.\n" +
"Please check https://pkg.go.dev/gossh/gin#readme-don-t-trust-all-proxies for details.")
}
err = http.ListenAndServeTLS(addr, certFile, keyFile, engine.Handler())
return
}
// RunUnix attaches the router to a http.Server and starts listening and serving HTTP requests
// through the specified unix socket (i.e. a file).
// Note: this method will block the calling goroutine indefinitely unless an error happens.
func (engine *Engine) RunUnix(file string) (err error) {
debugPrint("Listening and serving HTTP on unix:/%s", file)
defer func() { debugPrintError(err) }()
if engine.isUnsafeTrustedProxies() {
debugPrint("[WARNING] You trusted all proxies, this is NOT safe. We recommend you to set a value.\n" +
"Please check https://gossh/gin/blob/master/docs/doc.md#dont-trust-all-proxies for details.")
}
listener, err := net.Listen("unix", file)
if err != nil {
return
}
defer listener.Close()
defer os.Remove(file)
err = http.Serve(listener, engine.Handler())
return
}
// RunFd attaches the router to a http.Server and starts listening and serving HTTP requests
// through the specified file descriptor.
// Note: this method will block the calling goroutine indefinitely unless an error happens.
func (engine *Engine) RunFd(fd int) (err error) {
debugPrint("Listening and serving HTTP on fd@%d", fd)
defer func() { debugPrintError(err) }()
if engine.isUnsafeTrustedProxies() {
debugPrint("[WARNING] You trusted all proxies, this is NOT safe. We recommend you to set a value.\n" +
"Please check https://gossh/gin/blob/master/docs/doc.md#dont-trust-all-proxies for details.")
}
f := os.NewFile(uintptr(fd), fmt.Sprintf("fd@%d", fd))
listener, err := net.FileListener(f)
if err != nil {
return
}
defer listener.Close()
err = engine.RunListener(listener)
return
}
// RunListener attaches the router to a http.Server and starts listening and serving HTTP requests
// through the specified net.Listener
func (engine *Engine) RunListener(listener net.Listener) (err error) {
debugPrint("Listening and serving HTTP on listener what's bind with address@%s", listener.Addr())
defer func() { debugPrintError(err) }()
if engine.isUnsafeTrustedProxies() {
debugPrint("[WARNING] You trusted all proxies, this is NOT safe. We recommend you to set a value.\n" +
"Please check https://gossh/gin/blob/master/docs/doc.md#dont-trust-all-proxies for details.")
}
err = http.Serve(listener, engine.Handler())
return
}
// ServeHTTP conforms to the http.Handler interface.
func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) {
c := engine.pool.Get().(*Context)
c.writermem.reset(w)
c.Request = req
c.reset()
engine.handleHTTPRequest(c)
engine.pool.Put(c)
}
// HandleContext re-enters a context that has been rewritten.
// This can be done by setting c.Request.URL.Path to your new target.
// Disclaimer: You can loop yourself to deal with this, use wisely.
func (engine *Engine) HandleContext(c *Context) {
oldIndexValue := c.index
c.reset()
engine.handleHTTPRequest(c)
c.index = oldIndexValue
}
func (engine *Engine) handleHTTPRequest(c *Context) {
httpMethod := c.Request.Method
rPath := c.Request.URL.Path
unescape := false
if engine.UseRawPath && len(c.Request.URL.RawPath) > 0 {
rPath = c.Request.URL.RawPath
unescape = engine.UnescapePathValues
}
if engine.RemoveExtraSlash {
rPath = cleanPath(rPath)
}
// Find root of the tree for the given HTTP method
t := engine.trees
for i, tl := 0, len(t); i < tl; i++ {
if t[i].method != httpMethod {
continue
}
root := t[i].root
// Find route in tree
value := root.getValue(rPath, c.params, c.skippedNodes, unescape)
if value.params != nil {
c.Params = *value.params
}
if value.handlers != nil {
c.handlers = value.handlers
c.fullPath = value.fullPath
c.Next()
c.writermem.WriteHeaderNow()
return
}
if httpMethod != http.MethodConnect && rPath != "/" {
if value.tsr && engine.RedirectTrailingSlash {
redirectTrailingSlash(c)
return
}
if engine.RedirectFixedPath && redirectFixedPath(c, root, engine.RedirectFixedPath) {
return
}
}
break
}
if engine.HandleMethodNotAllowed {
// According to RFC 7231 section 6.5.5, MUST generate an Allow header field in response
// containing a list of the target resource's currently supported methods.
allowed := make([]string, 0, len(t)-1)
for _, tree := range engine.trees {
if tree.method == httpMethod {
continue
}
if value := tree.root.getValue(rPath, nil, c.skippedNodes, unescape); value.handlers != nil {
allowed = append(allowed, tree.method)
}
}
if len(allowed) > 0 {
c.handlers = engine.allNoMethod
c.writermem.Header().Set("Allow", strings.Join(allowed, ", "))
serveError(c, http.StatusMethodNotAllowed, default405Body)
return
}
}
c.handlers = engine.allNoRoute
serveError(c, http.StatusNotFound, default404Body)
}
var mimePlain = []string{MIMEPlain}
func serveError(c *Context, code int, defaultMessage []byte) {
c.writermem.status = code
c.Next()
if c.writermem.Written() {
return
}
if c.writermem.Status() == code {
c.writermem.Header()["Content-Type"] = mimePlain
_, err := c.Writer.Write(defaultMessage)
if err != nil {
debugPrint("cannot write message to writer during serve error: %v", err)
}
return
}
c.writermem.WriteHeaderNow()
}
func redirectTrailingSlash(c *Context) {
req := c.Request
p := req.URL.Path
if prefix := path.Clean(c.Request.Header.Get("X-Forwarded-Prefix")); prefix != "." {
prefix = regSafePrefix.ReplaceAllString(prefix, "")
prefix = regRemoveRepeatedChar.ReplaceAllString(prefix, "/")
p = prefix + "/" + req.URL.Path
}
req.URL.Path = p + "/"
if length := len(p); length > 1 && p[length-1] == '/' {
req.URL.Path = p[:length-1]
}
redirectRequest(c)
}
func redirectFixedPath(c *Context, root *node, trailingSlash bool) bool {
req := c.Request
rPath := req.URL.Path
if fixedPath, ok := root.findCaseInsensitivePath(cleanPath(rPath), trailingSlash); ok {
req.URL.Path = bytesconv.BytesToString(fixedPath)
redirectRequest(c)
return true
}
return false
}
func redirectRequest(c *Context) {
req := c.Request
rPath := req.URL.Path
rURL := req.URL.String()
code := http.StatusMovedPermanently // Permanent redirect, request with GET method
if req.Method != http.MethodGet {
code = http.StatusTemporaryRedirect
}
debugPrint("redirecting request %d: %s --> %s", code, rPath, rURL)
http.Redirect(c.Writer, req, rURL, code)
c.writermem.WriteHeaderNow()
}
================================================
FILE: gossh/gin/ginS/README.md
================================================
# Gin Default Server
This is API experiment for Gin.
```go
package main
import (
"gossh/gin"
"gossh/gin/ginS"
)
func main() {
ginS.GET("/", func(c *gin.Context) { c.String(200, "Hello World") })
ginS.Run()
}
```
================================================
FILE: gossh/gin/ginS/gins.go
================================================
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package ginS
import (
"html/template"
"net/http"
"sync"
"gossh/gin"
)
var once sync.Once
var internalEngine *gin.Engine
func engine() *gin.Engine {
once.Do(func() {
internalEngine = gin.Default()
})
return internalEngine
}
// LoadHTMLGlob is a wrapper for Engine.LoadHTMLGlob.
func LoadHTMLGlob(pattern string) {
engine().LoadHTMLGlob(pattern)
}
// LoadHTMLFiles is a wrapper for Engine.LoadHTMLFiles.
func LoadHTMLFiles(files ...string) {
engine().LoadHTMLFiles(files...)
}
// SetHTMLTemplate is a wrapper for Engine.SetHTMLTemplate.
func SetHTMLTemplate(templ *template.Template) {
engine().SetHTMLTemplate(templ)
}
// NoRoute adds handlers for NoRoute. It returns a 404 code by default.
func NoRoute(handlers ...gin.HandlerFunc) {
engine().NoRoute(handlers...)
}
// NoMethod is a wrapper for Engine.NoMethod.
func NoMethod(handlers ...gin.HandlerFunc) {
engine().NoMethod(handlers...)
}
// Group creates a new router group. You should add all the routes that have common middlewares or the same path prefix.
// For example, all the routes that use a common middleware for authorization could be grouped.
func Group(relativePath string, handlers ...gin.HandlerFunc) *gin.RouterGroup {
return engine().Group(relativePath, handlers...)
}
// Handle is a wrapper for Engine.Handle.
func Handle(httpMethod, relativePath string, handlers ...gin.HandlerFunc) gin.IRoutes {
return engine().Handle(httpMethod, relativePath, handlers...)
}
// POST is a shortcut for router.Handle("POST", path, handle)
func POST(relativePath string, handlers ...gin.HandlerFunc) gin.IRoutes {
return engine().POST(relativePath, handlers...)
}
// GET is a shortcut for router.Handle("GET", path, handle)
func GET(relativePath string, handlers ...gin.HandlerFunc) gin.IRoutes {
return engine().GET(relativePath, handlers...)
}
// DELETE is a shortcut for router.Handle("DELETE", path, handle)
func DELETE(relativePath string, handlers ...gin.HandlerFunc) gin.IRoutes {
return engine().DELETE(relativePath, handlers...)
}
// PATCH is a shortcut for router.Handle("PATCH", path, handle)
func PATCH(relativePath string, handlers ...gin.HandlerFunc) gin.IRoutes {
return engine().PATCH(relativePath, handlers...)
}
// PUT is a shortcut for router.Handle("PUT", path, handle)
func PUT(relativePath string, handlers ...gin.HandlerFunc) gin.IRoutes {
return engine().PUT(relativePath, handlers...)
}
// OPTIONS is a shortcut for router.Handle("OPTIONS", path, handle)
func OPTIONS(relativePath string, handlers ...gin.HandlerFunc) gin.IRoutes {
return engine().OPTIONS(relativePath, handlers...)
}
// HEAD is a shortcut for router.Handle("HEAD", path, handle)
func HEAD(relativePath string, handlers ...gin.HandlerFunc) gin.IRoutes {
return engine().HEAD(relativePath, handlers...)
}
// Any is a wrapper for Engine.Any.
func Any(relativePath string, handlers ...gin.HandlerFunc) gin.IRoutes {
return engine().Any(relativePath, handlers...)
}
// StaticFile is a wrapper for Engine.StaticFile.
func StaticFile(relativePath, filepath string) gin.IRoutes {
return engine().StaticFile(relativePath, filepath)
}
// Static serves files from the given file system root.
// Internally a http.FileServer is used, therefore http.NotFound is used instead
// of the Router's NotFound handler.
// To use the operating system's file system implementation,
// use :
//
// router.Static("/static", "/var/www")
func Static(relativePath, root string) gin.IRoutes {
return engine().Static(relativePath, root)
}
// StaticFS is a wrapper for Engine.StaticFS.
func StaticFS(relativePath string, fs http.FileSystem) gin.IRoutes {
return engine().StaticFS(relativePath, fs)
}
// Use attaches a global middleware to the router. i.e. the middlewares attached through Use() will be
// included in the handlers chain for every single request. Even 404, 405, static files...
// For example, this is the right place for a logger or error management middleware.
func Use(middlewares ...gin.HandlerFunc) gin.IRoutes {
return engine().Use(middlewares...)
}
// Routes returns a slice of registered routes.
func Routes() gin.RoutesInfo {
return engine().Routes()
}
// Run attaches to a http.Server and starts listening and serving HTTP requests.
// It is a shortcut for http.ListenAndServe(addr, router)
// Note: this method will block the calling goroutine indefinitely unless an error happens.
func Run(addr ...string) (err error) {
return engine().Run(addr...)
}
// RunTLS attaches to a http.Server and starts listening and serving HTTPS requests.
// It is a shortcut for http.ListenAndServeTLS(addr, certFile, keyFile, router)
// Note: this method will block the calling goroutine indefinitely unless an error happens.
func RunTLS(addr, certFile, keyFile string) (err error) {
return engine().RunTLS(addr, certFile, keyFile)
}
// RunUnix attaches to a http.Server and starts listening and serving HTTP requests
// through the specified unix socket (i.e. a file)
// Note: this method will block the calling goroutine indefinitely unless an error happens.
func RunUnix(file string) (err error) {
return engine().RunUnix(file)
}
// RunFd attaches the router to a http.Server and starts listening and serving HTTP requests
// through the specified file descriptor.
// Note: the method will block the calling goroutine indefinitely unless on error happens.
func RunFd(fd int) (err error) {
return engine().RunFd(fd)
}
================================================
FILE: gossh/gin/internal/bytesconv/bytesconv.go
================================================
// Copyright 2023 Gin Core Team. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package bytesconv
import (
"unsafe"
)
// StringToBytes converts string to byte slice without a memory allocation.
// For more details, see https://github.com/golang/go/issues/53003#issuecomment-1140276077.
func StringToBytes(s string) []byte {
return unsafe.Slice(unsafe.StringData(s), len(s))
}
// BytesToString converts byte slice to string without a memory allocation.
// For more details, see https://github.com/golang/go/issues/53003#issuecomment-1140276077.
func BytesToString(b []byte) string {
return unsafe.String(unsafe.SliceData(b), len(b))
}
================================================
FILE: gossh/gin/internal/json/json.go
================================================
// Copyright 2017 Bo-Yi Wu. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package json
import "encoding/json"
var (
// Marshal is exported by gin/json package.
Marshal = json.Marshal
// Unmarshal is exported by gin/json package.
Unmarshal = json.Unmarshal
// MarshalIndent is exported by gin/json package.
MarshalIndent = json.MarshalIndent
// NewDecoder is exported by gin/json package.
NewDecoder = json.NewDecoder
// NewEncoder is exported by gin/json package.
NewEncoder = json.NewEncoder
)
================================================
FILE: gossh/gin/jwt/README.md
================================================
# jwt-go
[](https://github.com/golang-jwt/jwt/actions/workflows/build.yml)
[](https://pkg.go.dev/github.com/golang-jwt/jwt/v5)
[](https://coveralls.io/github/golang-jwt/jwt?branch=main)
A [go](http://www.golang.org) (or 'golang' for search engine friendliness)
implementation of [JSON Web
Tokens](https://datatracker.ietf.org/doc/html/rfc7519).
Starting with [v4.0.0](https://github.com/golang-jwt/jwt/releases/tag/v4.0.0)
this project adds Go module support, but maintains backwards compatibility with
older `v3.x.y` tags and upstream `github.com/dgrijalva/jwt-go`. See the
[`MIGRATION_GUIDE.md`](./MIGRATION_GUIDE.md) for more information. Version
v5.0.0 introduces major improvements to the validation of tokens, but is not
entirely backwards compatible.
> After the original author of the library suggested migrating the maintenance
> of `jwt-go`, a dedicated team of open source maintainers decided to clone the
> existing library into this repository. See
> [dgrijalva/jwt-go#462](https://github.com/dgrijalva/jwt-go/issues/462) for a
> detailed discussion on this topic.
**SECURITY NOTICE:** Some older versions of Go have a security issue in the
crypto/elliptic. Recommendation is to upgrade to at least 1.15 See issue
[dgrijalva/jwt-go#216](https://github.com/dgrijalva/jwt-go/issues/216) for more
detail.
**SECURITY NOTICE:** It's important that you [validate the `alg` presented is
what you
expect](https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/).
This library attempts to make it easy to do the right thing by requiring key
types match the expected alg, but you should take the extra step to verify it in
your usage. See the examples provided.
### Supported Go versions
Our support of Go versions is aligned with Go's [version release
policy](https://golang.org/doc/devel/release#policy). So we will support a major
version of Go until there are two newer major releases. We no longer support
building jwt-go with unsupported Go versions, as these contain security
vulnerabilities which will not be fixed.
## What the heck is a JWT?
JWT.io has [a great introduction](https://jwt.io/introduction) to JSON Web
Tokens.
In short, it's a signed JSON object that does something useful (for example,
authentication). It's commonly used for `Bearer` tokens in Oauth 2. A token is
made of three parts, separated by `.`'s. The first two parts are JSON objects,
that have been [base64url](https://datatracker.ietf.org/doc/html/rfc4648)
encoded. The last part is the signature, encoded the same way.
The first part is called the header. It contains the necessary information for
verifying the last part, the signature. For example, which encryption method
was used for signing and what key was used.
The part in the middle is the interesting bit. It's called the Claims and
contains the actual stuff you care about. Refer to [RFC
7519](https://datatracker.ietf.org/doc/html/rfc7519) for information about
reserved keys and the proper way to add your own.
## What's in the box?
This library supports the parsing and verification as well as the generation and
signing of JWTs. Current supported signing algorithms are HMAC SHA, RSA,
RSA-PSS, and ECDSA, though hooks are present for adding your own.
## Installation Guidelines
1. To install the jwt package, you first need to have
[Go](https://go.dev/doc/install) installed, then you can use the command
below to add `jwt-go` as a dependency in your Go program.
```sh
go get -u github.com/golang-jwt/jwt/v5
```
2. Import it in your code:
```go
import "github.com/golang-jwt/jwt/v5"
```
## Usage
A detailed usage guide, including how to sign and verify tokens can be found on
our [documentation website](https://golang-jwt.github.io/jwt/usage/create/).
## Examples
See [the project documentation](https://pkg.go.dev/github.com/golang-jwt/jwt/v5)
for examples of usage:
* [Simple example of parsing and validating a
token](https://pkg.go.dev/github.com/golang-jwt/jwt/v5#example-Parse-Hmac)
* [Simple example of building and signing a
token](https://pkg.go.dev/github.com/golang-jwt/jwt/v5#example-New-Hmac)
* [Directory of
Examples](https://pkg.go.dev/github.com/golang-jwt/jwt/v5#pkg-examples)
## Compliance
This library was last reviewed to comply with [RFC
7519](https://datatracker.ietf.org/doc/html/rfc7519) dated May 2015 with a few
notable differences:
* In order to protect against accidental use of [Unsecured
JWTs](https://datatracker.ietf.org/doc/html/rfc7519#section-6), tokens using
`alg=none` will only be accepted if the constant
`jwt.UnsafeAllowNoneSignatureType` is provided as the key.
## Project Status & Versioning
This library is considered production ready. Feedback and feature requests are
appreciated. The API should be considered stable. There should be very few
backwards-incompatible changes outside of major version updates (and only with
good reason).
This project uses [Semantic Versioning 2.0.0](http://semver.org). Accepted pull
requests will land on `main`. Periodically, versions will be tagged from
`main`. You can find all the releases on [the project releases
page](https://github.com/golang-jwt/jwt/releases).
**BREAKING CHANGES:*** A full list of breaking changes is available in
`VERSION_HISTORY.md`. See `MIGRATION_GUIDE.md` for more information on updating
your code.
## Extensions
This library publishes all the necessary components for adding your own signing
methods or key functions. Simply implement the `SigningMethod` interface and
register a factory method using `RegisterSigningMethod` or provide a
`jwt.Keyfunc`.
A common use case would be integrating with different 3rd party signature
providers, like key management services from various cloud providers or Hardware
Security Modules (HSMs) or to implement additional standards.
| Extension | Purpose | Repo |
| --------- | -------------------------------------------------------------------------------------------------------- | ------------------------------------------ |
| GCP | Integrates with multiple Google Cloud Platform signing tools (AppEngine, IAM API, Cloud KMS) | https://github.com/someone1/gcp-jwt-go |
| AWS | Integrates with AWS Key Management Service, KMS | https://github.com/matelang/jwt-go-aws-kms |
| JWKS | Provides support for JWKS ([RFC 7517](https://datatracker.ietf.org/doc/html/rfc7517)) as a `jwt.Keyfunc` | https://github.com/MicahParks/keyfunc |
*Disclaimer*: Unless otherwise specified, these integrations are maintained by
third parties and should not be considered as a primary offer by any of the
mentioned cloud providers
## More
Go package documentation can be found [on
pkg.go.dev](https://pkg.go.dev/github.com/golang-jwt/jwt/v5). Additional
documentation can be found on [our project
page](https://golang-jwt.github.io/jwt/).
The command line utility included in this project (cmd/jwt) provides a
straightforward example of token creation and parsing as well as a useful tool
for debugging your own integration. You'll also find several implementation
examples in the documentation.
[golang-jwt](https://github.com/orgs/golang-jwt) incorporates a modified version
of the JWT logo, which is distributed under the terms of the [MIT
License](https://github.com/jsonwebtoken/jsonwebtoken.github.io/blob/master/LICENSE.txt).
================================================
FILE: gossh/gin/jwt/claims.go
================================================
package jwt
// Claims represent any form of a JWT Claims Set according to
// https://datatracker.ietf.org/doc/html/rfc7519#section-4. In order to have a
// common basis for validation, it is required that an implementation is able to
// supply at least the claim names provided in
// https://datatracker.ietf.org/doc/html/rfc7519#section-4.1 namely `exp`,
// `iat`, `nbf`, `iss`, `sub` and `aud`.
type Claims interface {
GetExpirationTime() (*NumericDate, error)
GetIssuedAt() (*NumericDate, error)
GetNotBefore() (*NumericDate, error)
GetIssuer() (string, error)
GetSubject() (string, error)
GetAudience() (ClaimStrings, error)
}
================================================
FILE: gossh/gin/jwt/cmd/jwt/README.md
================================================
`jwt` command-line tool
=======================
This is a simple tool to sign, verify and show JSON Web Tokens from
the command line.
The following will create and sign a token, then verify it and output the original claims:
echo {\"foo\":\"bar\"} | ./jwt -key ../../test/sample_key -alg RS256 -sign - | ./jwt -key ../../test/sample_key.pub -alg RS256 -verify -
Key files should be in PEM format. Other formats are not supported by this tool.
To simply display a token, use:
echo $JWT | ./jwt -show -
You can install this tool with the following command:
go install github.com/golang-jwt/jwt/v5/cmd/jwt
================================================
FILE: gossh/gin/jwt/cmd/jwt/main.go
================================================
// A useful example app. You can use this to debug your tokens on the command line.
// This is also a great place to look at how you might use this library.
//
// Example usage:
// The following will create and sign a token, then verify it and output the original claims.
//
// echo {\"foo\":\"bar\"} | bin/jwt -key test/sample_key -alg RS256 -sign - | bin/jwt -key test/sample_key.pub -verify -
package main
import (
"encoding/json"
"flag"
"fmt"
"io"
"os"
"regexp"
"sort"
"strings"
"gossh/gin/jwt"
)
var (
// Options
flagAlg = flag.String("alg", "", algHelp())
flagKey = flag.String("key", "", "path to key file or '-' to read from stdin")
flagCompact = flag.Bool("compact", false, "output compact JSON")
flagDebug = flag.Bool("debug", false, "print out all kinds of debug data")
flagClaims = make(ArgList)
flagHead = make(ArgList)
// Modes - exactly one of these is required
flagSign = flag.String("sign", "", "path to claims object to sign, '-' to read from stdin, or '+' to use only -claim args")
flagVerify = flag.String("verify", "", "path to JWT token to verify or '-' to read from stdin")
flagShow = flag.String("show", "", "path to JWT file or '-' to read from stdin")
)
func main() {
// Plug in Var flags
flag.Var(flagClaims, "claim", "add additional claims. may be used more than once")
flag.Var(flagHead, "header", "add additional header params. may be used more than once")
// Usage message if you ask for -help or if you mess up inputs.
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0])
fmt.Fprintf(os.Stderr, " One of the following flags is required: sign, verify\n")
flag.PrintDefaults()
}
// Parse command line options
flag.Parse()
// Do the thing. If something goes wrong, print error to stderr
// and exit with a non-zero status code
if err := start(); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)
}
}
// Figure out which thing to do and then do that
func start() error {
switch {
case *flagSign != "":
return signToken()
case *flagVerify != "":
return verifyToken()
case *flagShow != "":
return showToken()
default:
flag.Usage()
return fmt.Errorf("none of the required flags are present. What do you want me to do?")
}
}
// Helper func: Read input from specified file or stdin
func loadData(p string) ([]byte, error) {
if p == "" {
return nil, fmt.Errorf("no path specified")
}
var rdr io.Reader
switch p {
case "-":
rdr = os.Stdin
case "+":
return []byte("{}"), nil
default:
f, err := os.Open(p)
if err != nil {
return nil, err
}
rdr = f
defer f.Close()
}
return io.ReadAll(rdr)
}
// Print a json object in accordance with the prophecy (or the command line options)
func printJSON(j any) error {
var out []byte
var err error
if !*flagCompact {
out, err = json.MarshalIndent(j, "", " ")
} else {
out, err = json.Marshal(j)
}
if err == nil {
fmt.Println(string(out))
}
return err
}
// Verify a token and output the claims. This is a great example
// of how to verify and view a token.
func verifyToken() error {
// get the token
tokData, err := loadData(*flagVerify)
if err != nil {
return fmt.Errorf("couldn't read token: %w", err)
}
// trim possible whitespace from token
tokData = regexp.MustCompile(`\s*$`).ReplaceAll(tokData, []byte{})
if *flagDebug {
fmt.Fprintf(os.Stderr, "Token len: %v bytes\n", len(tokData))
}
// Parse the token. Load the key from command line option
token, err := jwt.Parse(string(tokData), func(t *jwt.Token) (any, error) {
if isNone() {
return jwt.UnsafeAllowNoneSignatureType, nil
}
data, err := loadData(*flagKey)
if err != nil {
return nil, err
}
switch {
case isEs():
return jwt.ParseECPublicKeyFromPEM(data)
case isRs():
return jwt.ParseRSAPublicKeyFromPEM(data)
case isEd():
return jwt.ParseEdPublicKeyFromPEM(data)
default:
return data, nil
}
})
// Print an error if we can't parse for some reason
if err != nil {
return fmt.Errorf("couldn't parse token: %w", err)
}
// Print some debug data
if *flagDebug {
fmt.Fprintf(os.Stderr, "Header:\n%v\n", token.Header)
fmt.Fprintf(os.Stderr, "Claims:\n%v\n", token.Claims)
}
// Print the token details
if err := printJSON(token.Claims); err != nil {
return fmt.Errorf("failed to output claims: %w", err)
}
return nil
}
// Create, sign, and output a token. This is a great, simple example of
// how to use this library to create and sign a token.
func signToken() error {
// get the token data from command line arguments
tokData, err := loadData(*flagSign)
if err != nil {
return fmt.Errorf("couldn't read token: %w", err)
} else if *flagDebug {
fmt.Fprintf(os.Stderr, "Token: %v bytes", len(tokData))
}
// parse the JSON of the claims
var claims jwt.MapClaims
if err := json.Unmarshal(tokData, &claims); err != nil {
return fmt.Errorf("couldn't parse claims JSON: %w", err)
}
// add command line claims
if len(flagClaims) > 0 {
for k, v := range flagClaims {
claims[k] = v
}
}
// get the key
var key any
if isNone() {
key = jwt.UnsafeAllowNoneSignatureType
} else {
key, err = loadData(*flagKey)
if err != nil {
return fmt.Errorf("couldn't read key: %w", err)
}
}
// get the signing alg
alg := jwt.GetSigningMethod(*flagAlg)
if alg == nil {
return fmt.Errorf("couldn't find signing method: %v", *flagAlg)
}
// create a new token
token := jwt.NewWithClaims(alg, claims)
// add command line headers
if len(flagHead) > 0 {
for k, v := range flagHead {
token.Header[k] = v
}
}
switch {
case isEs():
k, ok := key.([]byte)
if !ok {
return fmt.Errorf("couldn't convert key data to key")
}
key, err = jwt.ParseECPrivateKeyFromPEM(k)
if err != nil {
return err
}
case isRs():
k, ok := key.([]byte)
if !ok {
return fmt.Errorf("couldn't convert key data to key")
}
key, err = jwt.ParseRSAPrivateKeyFromPEM(k)
if err != nil {
return err
}
case isEd():
k, ok := key.([]byte)
if !ok {
return fmt.Errorf("couldn't convert key data to key")
}
key, err = jwt.ParseEdPrivateKeyFromPEM(k)
if err != nil {
return err
}
}
out, err := token.SignedString(key)
if err != nil {
return fmt.Errorf("error signing token: %w", err)
}
fmt.Println(out)
return nil
}
// showToken pretty-prints the token on the command line.
func showToken() error {
// get the token
tokData, err := loadData(*flagShow)
if err != nil {
return fmt.Errorf("couldn't read token: %w", err)
}
// trim possible whitespace from token
tokData = regexp.MustCompile(`\s*$`).ReplaceAll(tokData, []byte{})
if *flagDebug {
fmt.Fprintf(os.Stderr, "Token len: %v bytes\n", len(tokData))
}
token, err := jwt.Parse(string(tokData), nil)
if err != nil {
return fmt.Errorf("malformed token: %w", err)
}
// Print the token details
fmt.Println("Header:")
if err := printJSON(token.Header); err != nil {
return fmt.Errorf("failed to output header: %w", err)
}
fmt.Println("Claims:")
if err := printJSON(token.Claims); err != nil {
return fmt.Errorf("failed to output claims: %w", err)
}
return nil
}
func isEs() bool {
return strings.HasPrefix(*flagAlg, "ES")
}
func isRs() bool {
return strings.HasPrefix(*flagAlg, "RS") || strings.HasPrefix(*flagAlg, "PS")
}
func isEd() bool {
return *flagAlg == "EdDSA"
}
func isNone() bool {
return *flagAlg == "none"
}
func algHelp() string {
algs := jwt.GetAlgorithms()
sort.Strings(algs)
var b strings.Builder
b.WriteString("signing algorithm identifier, one of\n")
for i, alg := range algs {
if i > 0 {
if i%7 == 0 {
b.WriteString(",\n")
} else {
b.WriteString(", ")
}
}
b.WriteString(alg)
}
return b.String()
}
type ArgList map[string]string
func (l ArgList) String() string {
data, _ := json.Marshal(l)
return string(data)
}
func (l ArgList) Set(arg string) error {
parts := strings.SplitN(arg, "=", 2)
if len(parts) != 2 {
return fmt.Errorf("invalid argument '%v'. Must use format 'key=value'. %v", arg, parts)
}
l[parts[0]] = parts[1]
return nil
}
================================================
FILE: gossh/gin/jwt/ecdsa.go
================================================
package jwt
import (
"crypto"
"crypto/ecdsa"
"crypto/rand"
"errors"
"math/big"
)
var (
// Sadly this is missing from crypto/ecdsa compared to crypto/rsa
ErrECDSAVerification = errors.New("crypto/ecdsa: verification error")
)
// SigningMethodECDSA implements the ECDSA family of signing methods.
// Expects *ecdsa.PrivateKey for signing and *ecdsa.PublicKey for verification
type SigningMethodECDSA struct {
Name string
Hash crypto.Hash
KeySize int
CurveBits int
}
// Specific instances for EC256 and company
var (
SigningMethodES256 *SigningMethodECDSA
SigningMethodES384 *SigningMethodECDSA
SigningMethodES512 *SigningMethodECDSA
)
func init() {
// ES256
SigningMethodES256 = &SigningMethodECDSA{"ES256", crypto.SHA256, 32, 256}
RegisterSigningMethod(SigningMethodES256.Alg(), func() SigningMethod {
return SigningMethodES256
})
// ES384
SigningMethodES384 = &SigningMethodECDSA{"ES384", crypto.SHA384, 48, 384}
RegisterSigningMethod(SigningMethodES384.Alg(), func() SigningMethod {
return SigningMethodES384
})
// ES512
SigningMethodES512 = &SigningMethodECDSA{"ES512", crypto.SHA512, 66, 521}
RegisterSigningMethod(SigningMethodES512.Alg(), func() SigningMethod {
return SigningMethodES512
})
}
func (m *SigningMethodECDSA) Alg() string {
return m.Name
}
// Verify implements token verification for the SigningMethod.
// For this verify method, key must be an ecdsa.PublicKey struct
func (m *SigningMethodECDSA) Verify(signingString string, sig []byte, key any) error {
// Get the key
var ecdsaKey *ecdsa.PublicKey
switch k := key.(type) {
case *ecdsa.PublicKey:
ecdsaKey = k
default:
return newError("ECDSA verify expects *ecdsa.PublicKey", ErrInvalidKeyType)
}
if len(sig) != 2*m.KeySize {
return ErrECDSAVerification
}
r := big.NewInt(0).SetBytes(sig[:m.KeySize])
s := big.NewInt(0).SetBytes(sig[m.KeySize:])
// Create hasher
if !m.Hash.Available() {
return ErrHashUnavailable
}
hasher := m.Hash.New()
hasher.Write([]byte(signingString))
// Verify the signature
if verifystatus := ecdsa.Verify(ecdsaKey, hasher.Sum(nil), r, s); verifystatus {
return nil
}
return ErrECDSAVerification
}
// Sign implements token signing for the SigningMethod.
// For this signing method, key must be an ecdsa.PrivateKey struct
func (m *SigningMethodECDSA) Sign(signingString string, key any) ([]byte, error) {
// Get the key
var ecdsaKey *ecdsa.PrivateKey
switch k := key.(type) {
case *ecdsa.PrivateKey:
ecdsaKey = k
default:
return nil, newError("ECDSA sign expects *ecdsa.PrivateKey", ErrInvalidKeyType)
}
// Create the hasher
if !m.Hash.Available() {
return nil, ErrHashUnavailable
}
hasher := m.Hash.New()
hasher.Write([]byte(signingString))
// Sign the string and return r, s
if r, s, err := ecdsa.Sign(rand.Reader, ecdsaKey, hasher.Sum(nil)); err == nil {
curveBits := ecdsaKey.Curve.Params().BitSize
if m.CurveBits != curveBits {
return nil, ErrInvalidKey
}
keyBytes := curveBits / 8
if curveBits%8 > 0 {
keyBytes += 1
}
// We serialize the outputs (r and s) into big-endian byte arrays
// padded with zeros on the left to make sure the sizes work out.
// Output must be 2*keyBytes long.
out := make([]byte, 2*keyBytes)
r.FillBytes(out[0:keyBytes]) // r is assigned to the first half of output.
s.FillBytes(out[keyBytes:]) // s is assigned to the second half of output.
return out, nil
} else {
return nil, err
}
}
================================================
FILE: gossh/gin/jwt/ecdsa_utils.go
================================================
package jwt
import (
"crypto/ecdsa"
"crypto/x509"
"encoding/pem"
"errors"
)
var (
ErrNotECPublicKey = errors.New("key is not a valid ECDSA public key")
ErrNotECPrivateKey = errors.New("key is not a valid ECDSA private key")
)
// ParseECPrivateKeyFromPEM parses a PEM encoded Elliptic Curve Private Key Structure
func ParseECPrivateKeyFromPEM(key []byte) (*ecdsa.PrivateKey, error) {
var err error
// Parse PEM block
var block *pem.Block
if block, _ = pem.Decode(key); block == nil {
return nil, ErrKeyMustBePEMEncoded
}
// Parse the key
var parsedKey any
if parsedKey, err = x509.ParseECPrivateKey(block.Bytes); err != nil {
if parsedKey, err = x509.ParsePKCS8PrivateKey(block.Bytes); err != nil {
return nil, err
}
}
var pkey *ecdsa.PrivateKey
var ok bool
if pkey, ok = parsedKey.(*ecdsa.PrivateKey); !ok {
return nil, ErrNotECPrivateKey
}
return pkey, nil
}
// ParseECPublicKeyFromPEM parses a PEM encoded PKCS1 or PKCS8 public key
func ParseECPublicKeyFromPEM(key []byte) (*ecdsa.PublicKey, error) {
var err error
// Parse PEM block
var block *pem.Block
if block, _ = pem.Decode(key); block == nil {
return nil, ErrKeyMustBePEMEncoded
}
// Parse the key
var parsedKey any
if parsedKey, err = x509.ParsePKIXPublicKey(block.Bytes); err != nil {
if cert, err := x509.ParseCertificate(block.Bytes); err == nil {
parsedKey = cert.PublicKey
} else {
return nil, err
}
}
var pkey *ecdsa.PublicKey
var ok bool
if pkey, ok = parsedKey.(*ecdsa.PublicKey); !ok {
return nil, ErrNotECPublicKey
}
return pkey, nil
}
================================================
FILE: gossh/gin/jwt/ed25519.go
================================================
package jwt
import (
"crypto"
"crypto/ed25519"
"crypto/rand"
"errors"
)
var (
ErrEd25519Verification = errors.New("ed25519: verification error")
)
// SigningMethodEd25519 implements the EdDSA family.
// Expects ed25519.PrivateKey for signing and ed25519.PublicKey for verification
type SigningMethodEd25519 struct{}
// Specific instance for EdDSA
var (
SigningMethodEdDSA *SigningMethodEd25519
)
func init() {
SigningMethodEdDSA = &SigningMethodEd25519{}
RegisterSigningMethod(SigningMethodEdDSA.Alg(), func() SigningMethod {
return SigningMethodEdDSA
})
}
func (m *SigningMethodEd25519) Alg() string {
return "EdDSA"
}
// Verify implements token verification for the SigningMethod.
// For this verify method, key must be an ed25519.PublicKey
func (m *SigningMethodEd25519) Verify(signingString string, sig []byte, key any) error {
var ed25519Key ed25519.PublicKey
var ok bool
if ed25519Key, ok = key.(ed25519.PublicKey); !ok {
return newError("Ed25519 verify expects ed25519.PublicKey", ErrInvalidKeyType)
}
if len(ed25519Key) != ed25519.PublicKeySize {
return ErrInvalidKey
}
// Verify the signature
if !ed25519.Verify(ed25519Key, []byte(signingString), sig) {
return ErrEd25519Verification
}
return nil
}
// Sign implements token signing for the SigningMethod.
// For this signing method, key must be an ed25519.PrivateKey
func (m *SigningMethodEd25519) Sign(signingString string, key any) ([]byte, error) {
var ed25519Key crypto.Signer
var ok bool
if ed25519Key, ok = key.(crypto.Signer); !ok {
return nil, newError("Ed25519 sign expects crypto.Signer", ErrInvalidKeyType)
}
if _, ok := ed25519Key.Public().(ed25519.PublicKey); !ok {
return nil, ErrInvalidKey
}
// Sign the string and return the result. ed25519 performs a two-pass hash
// as part of its algorithm. Therefore, we need to pass a non-prehashed
// message into the Sign function, as indicated by crypto.Hash(0)
sig, err := ed25519Key.Sign(rand.Reader, []byte(signingString), crypto.Hash(0))
if err != nil {
return nil, err
}
return sig, nil
}
================================================
FILE: gossh/gin/jwt/ed25519_utils.go
================================================
package jwt
import (
"crypto"
"crypto/ed25519"
"crypto/x509"
"encoding/pem"
"errors"
)
var (
ErrNotEdPrivateKey = errors.New("key is not a valid Ed25519 private key")
ErrNotEdPublicKey = errors.New("key is not a valid Ed25519 public key")
)
// ParseEdPrivateKeyFromPEM parses a PEM-encoded Edwards curve private key
func ParseEdPrivateKeyFromPEM(key []byte) (crypto.PrivateKey, error) {
var err error
// Parse PEM block
var block *pem.Block
if block, _ = pem.Decode(key); block == nil {
return nil, ErrKeyMustBePEMEncoded
}
// Parse the key
var parsedKey any
if parsedKey, err = x509.ParsePKCS8PrivateKey(block.Bytes); err != nil {
return nil, err
}
var pkey ed25519.PrivateKey
var ok bool
if pkey, ok = parsedKey.(ed25519.PrivateKey); !ok {
return nil, ErrNotEdPrivateKey
}
return pkey, nil
}
// ParseEdPublicKeyFromPEM parses a PEM-encoded Edwards curve public key
func ParseEdPublicKeyFromPEM(key []byte) (crypto.PublicKey, error) {
var err error
// Parse PEM block
var block *pem.Block
if block, _ = pem.Decode(key); block == nil {
return nil, ErrKeyMustBePEMEncoded
}
// Parse the key
var parsedKey any
if parsedKey, err = x509.ParsePKIXPublicKey(block.Bytes); err != nil {
return nil, err
}
var pkey ed25519.PublicKey
var ok bool
if pkey, ok = parsedKey.(ed25519.PublicKey); !ok {
return nil, ErrNotEdPublicKey
}
return pkey, nil
}
================================================
FILE: gossh/gin/jwt/errors.go
================================================
package jwt
import (
"errors"
"strings"
)
var (
ErrInvalidKey = errors.New("key is invalid")
ErrInvalidKeyType = errors.New("key is of invalid type")
ErrHashUnavailable = errors.New("the requested hash function is unavailable")
ErrTokenMalformed = errors.New("token is malformed")
ErrTokenUnverifiable = errors.New("token is unverifiable")
ErrTokenSignatureInvalid = errors.New("token signature is invalid")
ErrTokenRequiredClaimMissing = errors.New("token is missing required claim")
ErrTokenInvalidAudience = errors.New("token has invalid audience")
ErrTokenExpired = errors.New("token is expired")
ErrTokenUsedBeforeIssued = errors.New("token used before issued")
ErrTokenInvalidIssuer = errors.New("token has invalid issuer")
ErrTokenInvalidSubject = errors.New("token has invalid subject")
ErrTokenNotValidYet = errors.New("token is not valid yet")
ErrTokenInvalidId = errors.New("token has invalid id")
ErrTokenInvalidClaims = errors.New("token has invalid claims")
ErrInvalidType = errors.New("invalid type for claim")
)
// joinedError is an error type that works similar to what [errors.Join]
// produces, with the exception that it has a nice error string; mainly its
// error messages are concatenated using a comma, rather than a newline.
type joinedError struct {
errs []error
}
func (je joinedError) Error() string {
msg := []string{}
for _, err := range je.errs {
msg = append(msg, err.Error())
}
return strings.Join(msg, ", ")
}
// joinErrors joins together multiple errors. Useful for scenarios where
// multiple errors next to each other occur, e.g., in claims validation.
func joinErrors(errs ...error) error {
return &joinedError{
errs: errs,
}
}
================================================
FILE: gossh/gin/jwt/errors_go1_20.go
================================================
//go:build go1.20
// +build go1.20
package jwt
import (
"fmt"
)
// Unwrap implements the multiple error unwrapping for this error type, which is
// possible in Go 1.20.
func (je joinedError) Unwrap() []error {
return je.errs
}
// newError creates a new error message with a detailed error message. The
// message will be prefixed with the contents of the supplied error type.
// Additionally, more errors, that provide more context can be supplied which
// will be appended to the message. This makes use of Go 1.20's possibility to
// include more than one %w formatting directive in [fmt.Errorf].
//
// For example,
//
// newError("no keyfunc was provided", ErrTokenUnverifiable)
//
// will produce the error string
//
// "token is unverifiable: no keyfunc was provided"
func newError(message string, err error, more ...error) error {
var format string
var args []any
if message != "" {
format = "%w: %s"
args = []any{err, message}
} else {
format = "%w"
args = []any{err}
}
for _, e := range more {
format += ": %w"
args = append(args, e)
}
err = fmt.Errorf(format, args...)
return err
}
================================================
FILE: gossh/gin/jwt/errors_go_other.go
================================================
//go:build !go1.20
// +build !go1.20
package jwt
import (
"errors"
"fmt"
)
// Is implements checking for multiple errors using [errors.Is], since multiple
// error unwrapping is not possible in versions less than Go 1.20.
func (je joinedError) Is(err error) bool {
for _, e := range je.errs {
if errors.Is(e, err) {
return true
}
}
return false
}
// wrappedErrors is a workaround for wrapping multiple errors in environments
// where Go 1.20 is not available. It basically uses the already implemented
// functionality of joinedError to handle multiple errors with supplies a
// custom error message that is identical to the one we produce in Go 1.20 using
// multiple %w directives.
type wrappedErrors struct {
msg string
joinedError
}
// Error returns the stored error string
func (we wrappedErrors) Error() string {
return we.msg
}
// newError creates a new error message with a detailed error message. The
// message will be prefixed with the contents of the supplied error type.
// Additionally, more errors, that provide more context can be supplied which
// will be appended to the message. Since we cannot use of Go 1.20's possibility
// to include more than one %w formatting directive in [fmt.Errorf], we have to
// emulate that.
//
// For example,
//
// newError("no keyfunc was provided", ErrTokenUnverifiable)
//
// will produce the error string
//
// "token is unverifiable: no keyfunc was provided"
func newError(message string, err error, more ...error) error {
// We cannot wrap multiple errors here with %w, so we have to be a little
// bit creative. Basically, we are using %s instead of %w to produce the
// same error message and then throw the result into a custom error struct.
var format string
var args []any
if message != "" {
format = "%s: %s"
args = []any{err, message}
} else {
format = "%s"
args = []any{err}
}
errs := []error{err}
for _, e := range more {
format += ": %s"
args = append(args, e)
errs = append(errs, e)
}
err = &wrappedErrors{
msg: fmt.Sprintf(format, args...),
joinedError: joinedError{errs: errs},
}
return err
}
================================================
FILE: gossh/gin/jwt/hmac.go
================================================
package jwt
import (
"crypto"
"crypto/hmac"
"errors"
)
// SigningMethodHMAC implements the HMAC-SHA family of signing methods.
// Expects key type of []byte for both signing and validation
type SigningMethodHMAC struct {
Name string
Hash crypto.Hash
}
// Specific instances for HS256 and company
var (
SigningMethodHS256 *SigningMethodHMAC
SigningMethodHS384 *SigningMethodHMAC
SigningMethodHS512 *SigningMethodHMAC
ErrSignatureInvalid = errors.New("signature is invalid")
)
func init() {
// HS256
SigningMethodHS256 = &SigningMethodHMAC{"HS256", crypto.SHA256}
RegisterSigningMethod(SigningMethodHS256.Alg(), func() SigningMethod {
return SigningMethodHS256
})
// HS384
SigningMethodHS384 = &SigningMethodHMAC{"HS384", crypto.SHA384}
RegisterSigningMethod(SigningMethodHS384.Alg(), func() SigningMethod {
return SigningMethodHS384
})
// HS512
SigningMethodHS512 = &SigningMethodHMAC{"HS512", crypto.SHA512}
RegisterSigningMethod(SigningMethodHS512.Alg(), func() SigningMethod {
return SigningMethodHS512
})
}
func (m *SigningMethodHMAC) Alg() string {
return m.Name
}
// Verify implements token verification for the SigningMethod. Returns nil if
// the signature is valid. Key must be []byte.
//
// Note it is not advised to provide a []byte which was converted from a 'human
// readable' string using a subset of ASCII characters. To maximize entropy, you
// should ideally be providing a []byte key which was produced from a
// cryptographically random source, e.g. crypto/rand. Additional information
// about this, and why we intentionally are not supporting string as a key can
// be found on our usage guide
// https://golang-jwt.github.io/jwt/usage/signing_methods/#signing-methods-and-key-types.
func (m *SigningMethodHMAC) Verify(signingString string, sig []byte, key any) error {
// Verify the key is the right type
keyBytes, ok := key.([]byte)
if !ok {
return newError("HMAC verify expects []byte", ErrInvalidKeyType)
}
// Can we use the specified hashing method?
if !m.Hash.Available() {
return ErrHashUnavailable
}
// This signing method is symmetric, so we validate the signature
// by reproducing the signature from the signing string and key, then
// comparing that against the provided signature.
hasher := hmac.New(m.Hash.New, keyBytes)
hasher.Write([]byte(signingString))
if !hmac.Equal(sig, hasher.Sum(nil)) {
return ErrSignatureInvalid
}
// No validation errors. Signature is good.
return nil
}
// Sign implements token signing for the SigningMethod. Key must be []byte.
//
// Note it is not advised to provide a []byte which was converted from a 'human
// readable' string using a subset of ASCII characters. To maximize entropy, you
// should ideally be providing a []byte key which was produced from a
// cryptographically random source, e.g. crypto/rand. Additional information
// about this, and why we intentionally are not supporting string as a key can
// be found on our usage guide https://golang-jwt.github.io/jwt/usage/signing_methods/.
func (m *SigningMethodHMAC) Sign(signingString string, key any) ([]byte, error) {
if keyBytes, ok := key.([]byte); ok {
if !m.Hash.Available() {
return nil, ErrHashUnavailable
}
hasher := hmac.New(m.Hash.New, keyBytes)
hasher.Write([]byte(signingString))
return hasher.Sum(nil), nil
}
return nil, newError("HMAC sign expects []byte", ErrInvalidKeyType)
}
================================================
FILE: gossh/gin/jwt/map_claims.go
================================================
package jwt
import (
"encoding/json"
"fmt"
)
// MapClaims is a claims type that uses the map[string]interface{} for JSON
// decoding. This is the default claims type if you don't supply one
type MapClaims map[string]any
// GetExpirationTime implements the Claims interface.
func (m MapClaims) GetExpirationTime() (*NumericDate, error) {
return m.parseNumericDate("exp")
}
// GetNotBefore implements the Claims interface.
func (m MapClaims) GetNotBefore() (*NumericDate, error) {
return m.parseNumericDate("nbf")
}
// GetIssuedAt implements the Claims interface.
func (m MapClaims) GetIssuedAt() (*NumericDate, error) {
return m.parseNumericDate("iat")
}
// GetAudience implements the Claims interface.
func (m MapClaims) GetAudience() (ClaimStrings, error) {
return m.parseClaimsString("aud")
}
// GetIssuer implements the Claims interface.
func (m MapClaims) GetIssuer() (string, error) {
return m.parseString("iss")
}
// GetSubject implements the Claims interface.
func (m MapClaims) GetSubject() (string, error) {
return m.parseString("sub")
}
// parseNumericDate tries to parse a key in the map claims type as a number
// date. This will succeed, if the underlying type is either a [float64] or a
// [json.Number]. Otherwise, nil will be returned.
func (m MapClaims) parseNumericDate(key string) (*NumericDate, error) {
v, ok := m[key]
if !ok {
return nil, nil
}
switch exp := v.(type) {
case float64:
if exp == 0 {
return nil, nil
}
return newNumericDateFromSeconds(exp), nil
case json.Number:
v, _ := exp.Float64()
return newNumericDateFromSeconds(v), nil
}
return nil, newError(fmt.Sprintf("%s is invalid", key), ErrInvalidType)
}
// parseClaimsString tries to parse a key in the map claims type as a
// [ClaimsStrings] type, which can either be a string or an array of string.
func (m MapClaims) parseClaimsString(key string) (ClaimStrings, error) {
var cs []string
switch v := m[key].(type) {
case string:
cs = append(cs, v)
case []string:
cs = v
case []any:
for _, a := range v {
vs, ok := a.(string)
if !ok {
return nil, newError(fmt.Sprintf("%s is invalid", key), ErrInvalidType)
}
cs = append(cs, vs)
}
}
return cs, nil
}
// parseString tries to parse a key in the map claims type as a [string] type.
// If the key does not exist, an empty string is returned. If the key has the
// wrong type, an error is returned.
func (m MapClaims) parseString(key string) (string, error) {
var (
ok bool
raw any
iss string
)
raw, ok = m[key]
if !ok {
return "", nil
}
iss, ok = raw.(string)
if !ok {
return "", newError(fmt.Sprintf("%s is invalid", key), ErrInvalidType)
}
return iss, nil
}
================================================
FILE: gossh/gin/jwt/none.go
================================================
package jwt
// SigningMethodNone implements the none signing method. This is required by the spec
// but you probably should never use it.
var SigningMethodNone *signingMethodNone
const UnsafeAllowNoneSignatureType unsafeNoneMagicConstant = "none signing method allowed"
var NoneSignatureTypeDisallowedError error
type signingMethodNone struct{}
type unsafeNoneMagicConstant string
func init() {
SigningMethodNone = &signingMethodNone{}
NoneSignatureTypeDisallowedError = newError("'none' signature type is not allowed", ErrTokenUnverifiable)
RegisterSigningMethod(SigningMethodNone.Alg(), func() SigningMethod {
return SigningMethodNone
})
}
func (m *signingMethodNone) Alg() string {
return "none"
}
// Only allow 'none' alg type if UnsafeAllowNoneSignatureType is specified as the key
func (m *signingMethodNone) Verify(signingString string, sig []byte, key any) (err error) {
// Key must be UnsafeAllowNoneSignatureType to prevent accidentally
// accepting 'none' signing method
if _, ok := key.(unsafeNoneMagicConstant); !ok {
return NoneSignatureTypeDisallowedError
}
// If signing method is none, signature must be an empty string
if len(sig) != 0 {
return newError("'none' signing method with non-empty signature", ErrTokenUnverifiable)
}
// Accept 'none' signing method.
return nil
}
// Only allow 'none' signing if UnsafeAllowNoneSignatureType is specified as the key
func (m *signingMethodNone) Sign(signingString string, key any) ([]byte, error) {
if _, ok := key.(unsafeNoneMagicConstant); ok {
return []byte{}, nil
}
return nil, NoneSignatureTypeDisallowedError
}
================================================
FILE: gossh/gin/jwt/parser.go
================================================
package jwt
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"strings"
)
type Parser struct {
// If populated, only these methods will be considered valid.
validMethods []string
// Use JSON Number format in JSON decoder.
useJSONNumber bool
// Skip claims validation during token parsing.
skipClaimsValidation bool
validator *Validator
decodeStrict bool
decodePaddingAllowed bool
}
// NewParser creates a new Parser with the specified options
func NewParser(options ...ParserOption) *Parser {
p := &Parser{
validator: &Validator{},
}
// Loop through our parsing options and apply them
for _, option := range options {
option(p)
}
return p
}
// Parse parses, validates, verifies the signature and returns the parsed token.
// keyFunc will receive the parsed token and should return the key for validating.
func (p *Parser) Parse(tokenString string, keyFunc Keyfunc) (*Token, error) {
return p.ParseWithClaims(tokenString, MapClaims{}, keyFunc)
}
// ParseWithClaims parses, validates, and verifies like Parse, but supplies a default object implementing the Claims
// interface. This provides default values which can be overridden and allows a caller to use their own type, rather
// than the default MapClaims implementation of Claims.
//
// Note: If you provide a custom claim implementation that embeds one of the standard claims (such as RegisteredClaims),
// make sure that a) you either embed a non-pointer version of the claims or b) if you are using a pointer, allocate the
// proper memory for it before passing in the overall claims, otherwise you might run into a panic.
func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyfunc) (*Token, error) {
token, parts, err := p.ParseUnverified(tokenString, claims)
if err != nil {
return token, err
}
// Verify signing method is in the required set
if p.validMethods != nil {
var signingMethodValid = false
var alg = token.Method.Alg()
for _, m := range p.validMethods {
if m == alg {
signingMethodValid = true
break
}
}
if !signingMethodValid {
// signing method is not in the listed set
return token, newError(fmt.Sprintf("signing method %v is invalid", alg), ErrTokenSignatureInvalid)
}
}
// Decode signature
token.Signature, err = p.DecodeSegment(parts[2])
if err != nil {
return token, newError("could not base64 decode signature", ErrTokenMalformed, err)
}
text := strings.Join(parts[0:2], ".")
// Lookup key(s)
if keyFunc == nil {
// keyFunc was not provided. short circuiting validation
return token, newError("no keyfunc was provided", ErrTokenUnverifiable)
}
got, err := keyFunc(token)
if err != nil {
return token, newError("error while executing keyfunc", ErrTokenUnverifiable, err)
}
switch have := got.(type) {
case VerificationKeySet:
if len(have.Keys) == 0 {
return token, newError("keyfunc returned empty verification key set", ErrTokenUnverifiable)
}
// Iterate through keys and verify signature, skipping the rest when a match is found.
// Return the last error if no match is found.
for _, key := range have.Keys {
if err = token.Method.Verify(text, token.Signature, key); err == nil {
break
}
}
default:
err = token.Method.Verify(text, token.Signature, have)
}
if err != nil {
return token, newError("", ErrTokenSignatureInvalid, err)
}
// Validate Claims
if !p.skipClaimsValidation {
// Make sure we have at least a default validator
if p.validator == nil {
p.validator = NewValidator()
}
if err := p.validator.Validate(claims); err != nil {
return token, newError("", ErrTokenInvalidClaims, err)
}
}
// No errors so far, token is valid.
token.Valid = true
return token, nil
}
// ParseUnverified parses the token but doesn't validate the signature.
//
// WARNING: Don't use this method unless you know what you're doing.
//
// It's only ever useful in cases where you know the signature is valid (since it has already
// been or will be checked elsewhere in the stack) and you want to extract values from it.
func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Token, parts []string, err error) {
parts = strings.Split(tokenString, ".")
if len(parts) != 3 {
return nil, parts, newError("token contains an invalid number of segments", ErrTokenMalformed)
}
token = &Token{Raw: tokenString}
// parse Header
var headerBytes []byte
if headerBytes, err = p.DecodeSegment(parts[0]); err != nil {
return token, parts, newError("could not base64 decode header", ErrTokenMalformed, err)
}
if err = json.Unmarshal(headerBytes, &token.Header); err != nil {
return token, parts, newError("could not JSON decode header", ErrTokenMalformed, err)
}
// parse Claims
token.Claims = claims
claimBytes, err := p.DecodeSegment(parts[1])
if err != nil {
return token, parts, newError("could not base64 decode claim", ErrTokenMalformed, err)
}
// If `useJSONNumber` is enabled then we must use *json.Decoder to decode
// the claims. However, this comes with a performance penalty so only use
// it if we must and, otherwise, simple use json.Unmarshal.
if !p.useJSONNumber {
// JSON Unmarshal. Special case for map type to avoid weird pointer behavior.
if c, ok := token.Claims.(MapClaims); ok {
err = json.Unmarshal(claimBytes, &c)
} else {
err = json.Unmarshal(claimBytes, &claims)
}
} else {
dec := json.NewDecoder(bytes.NewBuffer(claimBytes))
dec.UseNumber()
// JSON Decode. Special case for map type to avoid weird pointer behavior.
if c, ok := token.Claims.(MapClaims); ok {
err = dec.Decode(&c)
} else {
err = dec.Decode(&claims)
}
}
if err != nil {
return token, parts, newError("could not JSON decode claim", ErrTokenMalformed, err)
}
// Lookup signature method
if method, ok := token.Header["alg"].(string); ok {
if token.Method = GetSigningMethod(method); token.Method == nil {
return token, parts, newError("signing method (alg) is unavailable", ErrTokenUnverifiable)
}
} else {
return token, parts, newError("signing method (alg) is unspecified", ErrTokenUnverifiable)
}
return token, parts, nil
}
// DecodeSegment decodes a JWT specific base64url encoding. This function will
// take into account whether the [Parser] is configured with additional options,
// such as [WithStrictDecoding] or [WithPaddingAllowed].
func (p *Parser) DecodeSegment(seg string) ([]byte, error) {
encoding := base64.RawURLEncoding
if p.decodePaddingAllowed {
if l := len(seg) % 4; l > 0 {
seg += strings.Repeat("=", 4-l)
}
encoding = base64.URLEncoding
}
if p.decodeStrict {
encoding = encoding.Strict()
}
return encoding.DecodeString(seg)
}
// Parse parses, validates, verifies the signature and returns the parsed token.
// keyFunc will receive the parsed token and should return the cryptographic key
// for verifying the signature. The caller is strongly encouraged to set the
// WithValidMethods option to validate the 'alg' claim in the token matches the
// expected algorithm. For more details about the importance of validating the
// 'alg' claim, see
// https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/
func Parse(tokenString string, keyFunc Keyfunc, options ...ParserOption) (*Token, error) {
return NewParser(options...).Parse(tokenString, keyFunc)
}
// ParseWithClaims is a shortcut for NewParser().ParseWithClaims().
//
// Note: If you provide a custom claim implementation that embeds one of the
// standard claims (such as RegisteredClaims), make sure that a) you either
// embed a non-pointer version of the claims or b) if you are using a pointer,
// allocate the proper memory for it before passing in the overall claims,
// otherwise you might run into a panic.
func ParseWithClaims(tokenString string, claims Claims, keyFunc Keyfunc, options ...ParserOption) (*Token, error) {
return NewParser(options...).ParseWithClaims(tokenString, claims, keyFunc)
}
================================================
FILE: gossh/gin/jwt/parser_option.go
================================================
package jwt
import "time"
// ParserOption is used to implement functional-style options that modify the
// behavior of the parser. To add new options, just create a function (ideally
// beginning with With or Without) that returns an anonymous function that takes
// a *Parser type as input and manipulates its configuration accordingly.
type ParserOption func(*Parser)
// WithValidMethods is an option to supply algorithm methods that the parser
// will check. Only those methods will be considered valid. It is heavily
// encouraged to use this option in order to prevent attacks such as
// https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/.
func WithValidMethods(methods []string) ParserOption {
return func(p *Parser) {
p.validMethods = methods
}
}
// WithJSONNumber is an option to configure the underlying JSON parser with
// UseNumber.
func WithJSONNumber() ParserOption {
return func(p *Parser) {
p.useJSONNumber = true
}
}
// WithoutClaimsValidation is an option to disable claims validation. This
// option should only be used if you exactly know what you are doing.
func WithoutClaimsValidation() ParserOption {
return func(p *Parser) {
p.skipClaimsValidation = true
}
}
// WithLeeway returns the ParserOption for specifying the leeway window.
func WithLeeway(leeway time.Duration) ParserOption {
return func(p *Parser) {
p.validator.leeway = leeway
}
}
// WithTimeFunc returns the ParserOption for specifying the time func. The
// primary use-case for this is testing. If you are looking for a way to account
// for clock-skew, WithLeeway should be used instead.
func WithTimeFunc(f func() time.Time) ParserOption {
return func(p *Parser) {
p.validator.timeFunc = f
}
}
// WithIssuedAt returns the ParserOption to enable verification
// of issued-at.
func WithIssuedAt() ParserOption {
return func(p *Parser) {
p.validator.verifyIat = true
}
}
// WithExpirationRequired returns the ParserOption to make exp claim required.
// By default exp claim is optional.
func WithExpirationRequired() ParserOption {
return func(p *Parser) {
p.validator.requireExp = true
}
}
// WithAudience configures the validator to require the specified audience in
// the `aud` claim. Validation will fail if the audience is not listed in the
// token or the `aud` claim is missing.
//
// NOTE: While the `aud` claim is OPTIONAL in a JWT, the handling of it is
// application-specific. Since this validation API is helping developers in
// writing secure application, we decided to REQUIRE the existence of the claim,
// if an audience is expected.
func WithAudience(aud string) ParserOption {
return func(p *Parser) {
p.validator.expectedAud = aud
}
}
// WithIssuer configures the validator to require the specified issuer in the
// `iss` claim. Validation will fail if a different issuer is specified in the
// token or the `iss` claim is missing.
//
// NOTE: While the `iss` claim is OPTIONAL in a JWT, the handling of it is
// application-specific. Since this validation API is helping developers in
// writing secure application, we decided to REQUIRE the existence of the claim,
// if an issuer is expected.
func WithIssuer(iss string) ParserOption {
return func(p *Parser) {
p.validator.expectedIss = iss
}
}
// WithSubject configures the validator to require the specified subject in the
// `sub` claim. Validation will fail if a different subject is specified in the
// token or the `sub` claim is missing.
//
// NOTE: While the `sub` claim is OPTIONAL in a JWT, the handling of it is
// application-specific. Since this validation API is helping developers in
// writing secure application, we decided to REQUIRE the existence of the claim,
// if a subject is expected.
func WithSubject(sub string) ParserOption {
return func(p *Parser) {
p.validator.expectedSub = sub
}
}
// WithPaddingAllowed will enable the codec used for decoding JWTs to allow
// padding. Note that the JWS RFC7515 states that the tokens will utilize a
// Base64url encoding with no padding. Unfortunately, some implementations of
// JWT are producing non-standard tokens, and thus require support for decoding.
func WithPaddingAllowed() ParserOption {
return func(p *Parser) {
p.decodePaddingAllowed = true
}
}
// WithStrictDecoding will switch the codec used for decoding JWTs into strict
// mode. In this mode, the decoder requires that trailing padding bits are zero,
// as described in RFC 4648 section 3.5.
func WithStrictDecoding() ParserOption {
return func(p *Parser) {
p.decodeStrict = true
}
}
================================================
FILE: gossh/gin/jwt/registered_claims.go
================================================
package jwt
// RegisteredClaims are a structured version of the JWT Claims Set,
// restricted to Registered Claim Names, as referenced at
// https://datatracker.ietf.org/doc/html/rfc7519#section-4.1
//
// This type can be used on its own, but then additional private and
// public claims embedded in the JWT will not be parsed. The typical use-case
// therefore is to embedded this in a user-defined claim type.
//
// See examples for how to use this with your own claim types.
type RegisteredClaims struct {
// the `iss` (Issuer) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.1
Issuer string `json:"iss,omitempty"`
// the `sub` (Subject) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.2
Subject string `json:"sub,omitempty"`
// the `aud` (Audience) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.3
Audience ClaimStrings `json:"aud,omitempty"`
// the `exp` (Expiration Time) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.4
ExpiresAt *NumericDate `json:"exp,omitempty"`
// the `nbf` (Not Before) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.5
NotBefore *NumericDate `json:"nbf,omitempty"`
// the `iat` (Issued At) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.6
IssuedAt *NumericDate `json:"iat,omitempty"`
// the `jti` (JWT ID) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.7
ID string `json:"jti,omitempty"`
}
// GetExpirationTime implements the Claims interface.
func (c RegisteredClaims) GetExpirationTime() (*NumericDate, error) {
return c.ExpiresAt, nil
}
// GetNotBefore implements the Claims interface.
func (c RegisteredClaims) GetNotBefore() (*NumericDate, error) {
return c.NotBefore, nil
}
// GetIssuedAt implements the Claims interface.
func (c RegisteredClaims) GetIssuedAt() (*NumericDate, error) {
return c.IssuedAt, nil
}
// GetAudience implements the Claims interface.
func (c RegisteredClaims) GetAudience() (ClaimStrings, error) {
return c.Audience, nil
}
// GetIssuer implements the Claims interface.
func (c RegisteredClaims) GetIssuer() (string, error) {
return c.Issuer, nil
}
// GetSubject implements the Claims interface.
func (c RegisteredClaims) GetSubject() (string, error) {
return c.Subject, nil
}
================================================
FILE: gossh/gin/jwt/request/doc.go
================================================
// Utility package for extracting JWT tokens from
// HTTP requests.
//
// The main function is ParseFromRequest and it's WithClaims variant.
// See examples for how to use the various Extractor implementations
// or roll your own.
package request
================================================
FILE: gossh/gin/jwt/request/extractor.go
================================================
package request
import (
"errors"
"net/http"
"strings"
)
// Errors
var (
ErrNoTokenInRequest = errors.New("no token present in request")
)
// Extractor is an interface for extracting a token from an HTTP request.
// The ExtractToken method should return a token string or an error.
// If no token is present, you must return ErrNoTokenInRequest.
type Extractor interface {
ExtractToken(*http.Request) (string, error)
}
// HeaderExtractor is an extractor for finding a token in a header.
// Looks at each specified header in order until there's a match
type HeaderExtractor []string
func (e HeaderExtractor) ExtractToken(req *http.Request) (string, error) {
// loop over header names and return the first one that contains data
for _, header := range e {
if ah := req.Header.Get(header); ah != "" {
return ah, nil
}
}
return "", ErrNoTokenInRequest
}
// ArgumentExtractor extracts a token from request arguments. This includes a POSTed form or
// GET URL arguments. Argument names are tried in order until there's a match.
// This extractor calls `ParseMultipartForm` on the request
type ArgumentExtractor []string
func (e ArgumentExtractor) ExtractToken(req *http.Request) (string, error) {
// Make sure form is parsed. We are explicitly ignoring errors at this point
_ = req.ParseMultipartForm(10e6)
// loop over arg names and return the first one that contains data
for _, arg := range e {
if ah := req.Form.Get(arg); ah != "" {
return ah, nil
}
}
return "", ErrNoTokenInRequest
}
// MultiExtractor tries Extractors in order until one returns a token string or an error occurs
type MultiExtractor []Extractor
func (e MultiExtractor) ExtractToken(req *http.Request) (string, error) {
// loop over header names and return the first one that contains data
for _, extractor := range e {
if tok, err := extractor.ExtractToken(req); tok != "" {
return tok, nil
} else if !errors.Is(err, ErrNoTokenInRequest) {
return "", err
}
}
return "", ErrNoTokenInRequest
}
// PostExtractionFilter wraps an Extractor in this to post-process the value before it's handed off.
// See AuthorizationHeaderExtractor for an example
type PostExtractionFilter struct {
Extractor
Filter func(string) (string, error)
}
func (e *PostExtractionFilter) ExtractToken(req *http.Request) (string, error) {
if tok, err := e.Extractor.ExtractToken(req); tok != "" {
return e.Filter(tok)
} else {
return "", err
}
}
// BearerExtractor extracts a token from the Authorization header.
// The header is expected to match the format "Bearer XX", where "XX" is the
// JWT token.
type BearerExtractor struct{}
func (e BearerExtractor) ExtractToken(req *http.Request) (string, error) {
tokenHeader := req.Header.Get("Authorization")
// The usual convention is for "Bearer" to be title-cased. However, there's no
// strict rule around this, and it's best to follow the robustness principle here.
if len(tokenHeader) < 7 || !strings.EqualFold(tokenHeader[:7], "bearer ") {
return "", ErrNoTokenInRequest
}
return tokenHeader[7:], nil
}
================================================
FILE: gossh/gin/jwt/request/oauth2.go
================================================
package request
import (
"strings"
)
// Strips 'Bearer ' prefix from bearer token string
func stripBearerPrefixFromTokenString(tok string) (string, error) {
// Should be a bearer token
if len(tok) > 6 && strings.EqualFold(tok[:7], "bearer ") {
return tok[7:], nil
}
return tok, nil
}
// AuthorizationHeaderExtractor extracts a bearer token from Authorization header
// Uses PostExtractionFilter to strip "Bearer " prefix from header
var AuthorizationHeaderExtractor = &PostExtractionFilter{
HeaderExtractor{"Authorization"},
stripBearerPrefixFromTokenString,
}
// OAuth2Extractor is an Extractor for OAuth2 access tokens. Looks in 'Authorization'
// header then 'access_token' argument for a token.
var OAuth2Extractor = &MultiExtractor{
AuthorizationHeaderExtractor,
ArgumentExtractor{"access_token"},
}
================================================
FILE: gossh/gin/jwt/request/request.go
================================================
package request
import (
"net/http"
"gossh/gin/jwt"
)
// ParseFromRequest extracts and parses a JWT token from an HTTP request.
// This behaves the same as Parse, but accepts a request and an extractor
// instead of a token string. The Extractor interface allows you to define
// the logic for extracting a token. Several useful implementations are provided.
//
// You can provide options to modify parsing behavior
func ParseFromRequest(req *http.Request, extractor Extractor, keyFunc jwt.Keyfunc, options ...ParseFromRequestOption) (token *jwt.Token, err error) {
// Create basic parser struct
p := &fromRequestParser{req, extractor, nil, nil}
// Handle options
for _, option := range options {
option(p)
}
// Set defaults
if p.claims == nil {
p.claims = jwt.MapClaims{}
}
if p.parser == nil {
p.parser = &jwt.Parser{}
}
// perform extract
tokenString, err := p.extractor.ExtractToken(req)
if err != nil {
return nil, err
}
// perform parse
return p.parser.ParseWithClaims(tokenString, p.claims, keyFunc)
}
// ParseFromRequestWithClaims is an alias for ParseFromRequest but with custom Claims type.
//
// Deprecated: use ParseFromRequest and the WithClaims option
func ParseFromRequestWithClaims(req *http.Request, extractor Extractor, claims jwt.Claims, keyFunc jwt.Keyfunc) (token *jwt.Token, err error) {
return ParseFromRequest(req, extractor, keyFunc, WithClaims(claims))
}
type fromRequestParser struct {
req *http.Request
extractor Extractor
claims jwt.Claims
parser *jwt.Parser
}
type ParseFromRequestOption func(*fromRequestParser)
// WithClaims parses with custom claims
func WithClaims(claims jwt.Claims) ParseFromRequestOption {
return func(p *fromRequestParser) {
p.claims = claims
}
}
// WithParser parses using a custom parser
func WithParser(parser *jwt.Parser) ParseFromRequestOption {
return func(p *fromRequestParser) {
p.parser = parser
}
}
================================================
FILE: gossh/gin/jwt/rsa.go
================================================
package jwt
import (
"crypto"
"crypto/rand"
"crypto/rsa"
)
// SigningMethodRSA implements the RSA family of signing methods.
// Expects *rsa.PrivateKey for signing and *rsa.PublicKey for validation
type SigningMethodRSA struct {
Name string
Hash crypto.Hash
}
// Specific instances for RS256 and company
var (
SigningMethodRS256 *SigningMethodRSA
SigningMethodRS384 *SigningMethodRSA
SigningMethodRS512 *SigningMethodRSA
)
func init() {
// RS256
SigningMethodRS256 = &SigningMethodRSA{"RS256", crypto.SHA256}
RegisterSigningMethod(SigningMethodRS256.Alg(), func() SigningMethod {
return SigningMethodRS256
})
// RS384
SigningMethodRS384 = &SigningMethodRSA{"RS384", crypto.SHA384}
RegisterSigningMethod(SigningMethodRS384.Alg(), func() SigningMethod {
return SigningMethodRS384
})
// RS512
SigningMethodRS512 = &SigningMethodRSA{"RS512", crypto.SHA512}
RegisterSigningMethod(SigningMethodRS512.Alg(), func() SigningMethod {
return SigningMethodRS512
})
}
func (m *SigningMethodRSA) Alg() string {
return m.Name
}
// Verify implements token verification for the SigningMethod
// For this signing method, must be an *rsa.PublicKey structure.
func (m *SigningMethodRSA) Verify(signingString string, sig []byte, key any) error {
var rsaKey *rsa.PublicKey
var ok bool
if rsaKey, ok = key.(*rsa.PublicKey); !ok {
return newError("RSA verify expects *rsa.PublicKey", ErrInvalidKeyType)
}
// Create hasher
if !m.Hash.Available() {
return ErrHashUnavailable
}
hasher := m.Hash.New()
hasher.Write([]byte(signingString))
// Verify the signature
return rsa.VerifyPKCS1v15(rsaKey, m.Hash, hasher.Sum(nil), sig)
}
// Sign implements token signing for the SigningMethod
// For this signing method, must be an *rsa.PrivateKey structure.
func (m *SigningMethodRSA) Sign(signingString string, key any) ([]byte, error) {
var rsaKey *rsa.PrivateKey
var ok bool
// Validate type of key
if rsaKey, ok = key.(*rsa.PrivateKey); !ok {
return nil, newError("RSA sign expects *rsa.PrivateKey", ErrInvalidKeyType)
}
// Create the hasher
if !m.Hash.Available() {
return nil, ErrHashUnavailable
}
hasher := m.Hash.New()
hasher.Write([]byte(signingString))
// Sign the string and return the encoded bytes
if sigBytes, err := rsa.SignPKCS1v15(rand.Reader, rsaKey, m.Hash, hasher.Sum(nil)); err == nil {
return sigBytes, nil
} else {
return nil, err
}
}
================================================
FILE: gossh/gin/jwt/rsa_pss.go
================================================
package jwt
import (
"crypto"
"crypto/rand"
"crypto/rsa"
)
// SigningMethodRSAPSS implements the RSAPSS family of signing methods signing methods
type SigningMethodRSAPSS struct {
*SigningMethodRSA
Options *rsa.PSSOptions
// VerifyOptions is optional. If set overrides Options for rsa.VerifyPPS.
// Used to accept tokens signed with rsa.PSSSaltLengthAuto, what doesn't follow
// https://tools.ietf.org/html/rfc7518#section-3.5 but was used previously.
// See https://github.com/dgrijalva/jwt-go/issues/285#issuecomment-437451244 for details.
VerifyOptions *rsa.PSSOptions
}
// Specific instances for RS/PS and company.
var (
SigningMethodPS256 *SigningMethodRSAPSS
SigningMethodPS384 *SigningMethodRSAPSS
SigningMethodPS512 *SigningMethodRSAPSS
)
func init() {
// PS256
SigningMethodPS256 = &SigningMethodRSAPSS{
SigningMethodRSA: &SigningMethodRSA{
Name: "PS256",
Hash: crypto.SHA256,
},
Options: &rsa.PSSOptions{
SaltLength: rsa.PSSSaltLengthEqualsHash,
},
VerifyOptions: &rsa.PSSOptions{
SaltLength: rsa.PSSSaltLengthAuto,
},
}
RegisterSigningMethod(SigningMethodPS256.Alg(), func() SigningMethod {
return SigningMethodPS256
})
// PS384
SigningMethodPS384 = &SigningMethodRSAPSS{
SigningMethodRSA: &SigningMethodRSA{
Name: "PS384",
Hash: crypto.SHA384,
},
Options: &rsa.PSSOptions{
SaltLength: rsa.PSSSaltLengthEqualsHash,
},
VerifyOptions: &rsa.PSSOptions{
SaltLength: rsa.PSSSaltLengthAuto,
},
}
RegisterSigningMethod(SigningMethodPS384.Alg(), func() SigningMethod {
return SigningMethodPS384
})
// PS512
SigningMethodPS512 = &SigningMethodRSAPSS{
SigningMethodRSA: &SigningMethodRSA{
Name: "PS512",
Hash: crypto.SHA512,
},
Options: &rsa.PSSOptions{
SaltLength: rsa.PSSSaltLengthEqualsHash,
},
VerifyOptions: &rsa.PSSOptions{
SaltLength: rsa.PSSSaltLengthAuto,
},
}
RegisterSigningMethod(SigningMethodPS512.Alg(), func() SigningMethod {
return SigningMethodPS512
})
}
// Verify implements token verification for the SigningMethod.
// For this verify method, key must be an rsa.PublicKey struct
func (m *SigningMethodRSAPSS) Verify(signingString string, sig []byte, key any) error {
var rsaKey *rsa.PublicKey
switch k := key.(type) {
case *rsa.PublicKey:
rsaKey = k
default:
return newError("RSA-PSS verify expects *rsa.PublicKey", ErrInvalidKeyType)
}
// Create hasher
if !m.Hash.Available() {
return ErrHashUnavailable
}
hasher := m.Hash.New()
hasher.Write([]byte(signingString))
opts := m.Options
if m.VerifyOptions != nil {
opts = m.VerifyOptions
}
return rsa.VerifyPSS(rsaKey, m.Hash, hasher.Sum(nil), sig, opts)
}
// Sign implements token signing for the SigningMethod.
// For this signing method, key must be an rsa.PrivateKey struct
func (m *SigningMethodRSAPSS) Sign(signingString string, key any) ([]byte, error) {
var rsaKey *rsa.PrivateKey
switch k := key.(type) {
case *rsa.PrivateKey:
rsaKey = k
default:
return nil, newError("RSA-PSS sign expects *rsa.PrivateKey", ErrInvalidKeyType)
}
// Create the hasher
if !m.Hash.Available() {
return nil, ErrHashUnavailable
}
hasher := m.Hash.New()
hasher.Write([]byte(signingString))
// Sign the string and return the encoded bytes
if sigBytes, err := rsa.SignPSS(rand.Reader, rsaKey, m.Hash, hasher.Sum(nil), m.Options); err == nil {
return sigBytes, nil
} else {
return nil, err
}
}
================================================
FILE: gossh/gin/jwt/rsa_utils.go
================================================
package jwt
import (
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
)
var (
ErrKeyMustBePEMEncoded = errors.New("invalid key: Key must be a PEM encoded PKCS1 or PKCS8 key")
ErrNotRSAPrivateKey = errors.New("key is not a valid RSA private key")
ErrNotRSAPublicKey = errors.New("key is not a valid RSA public key")
)
// ParseRSAPrivateKeyFromPEM parses a PEM encoded PKCS1 or PKCS8 private key
func ParseRSAPrivateKeyFromPEM(key []byte) (*rsa.PrivateKey, error) {
var err error
// Parse PEM block
var block *pem.Block
if block, _ = pem.Decode(key); block == nil {
return nil, ErrKeyMustBePEMEncoded
}
var parsedKey any
if parsedKey, err = x509.ParsePKCS1PrivateKey(block.Bytes); err != nil {
if parsedKey, err = x509.ParsePKCS8PrivateKey(block.Bytes); err != nil {
return nil, err
}
}
var pkey *rsa.PrivateKey
var ok bool
if pkey, ok = parsedKey.(*rsa.PrivateKey); !ok {
return nil, ErrNotRSAPrivateKey
}
return pkey, nil
}
// ParseRSAPrivateKeyFromPEMWithPassword parses a PEM encoded PKCS1 or PKCS8 private key protected with password
//
// Deprecated: This function is deprecated and should not be used anymore. It uses the deprecated x509.DecryptPEMBlock
// function, which was deprecated since RFC 1423 is regarded insecure by design. Unfortunately, there is no alternative
// in the Go standard library for now. See https://github.com/golang/go/issues/8860.
func ParseRSAPrivateKeyFromPEMWithPassword(key []byte, password string) (*rsa.PrivateKey, error) {
var err error
// Parse PEM block
var block *pem.Block
if block, _ = pem.Decode(key); block == nil {
return nil, ErrKeyMustBePEMEncoded
}
var parsedKey any
var blockDecrypted []byte
if blockDecrypted, err = x509.DecryptPEMBlock(block, []byte(password)); err != nil {
return nil, err
}
if parsedKey, err = x509.ParsePKCS1PrivateKey(blockDecrypted); err != nil {
if parsedKey, err = x509.ParsePKCS8PrivateKey(blockDecrypted); err != nil {
return nil, err
}
}
var pkey *rsa.PrivateKey
var ok bool
if pkey, ok = parsedKey.(*rsa.PrivateKey); !ok {
return nil, ErrNotRSAPrivateKey
}
return pkey, nil
}
// ParseRSAPublicKeyFromPEM parses a certificate or a PEM encoded PKCS1 or PKIX public key
func ParseRSAPublicKeyFromPEM(key []byte) (*rsa.PublicKey, error) {
var err error
// Parse PEM block
var block *pem.Block
if block, _ = pem.Decode(key); block == nil {
return nil, ErrKeyMustBePEMEncoded
}
// Parse the key
var parsedKey any
if parsedKey, err = x509.ParsePKIXPublicKey(block.Bytes); err != nil {
if cert, err := x509.ParseCertificate(block.Bytes); err == nil {
parsedKey = cert.PublicKey
} else {
if parsedKey, err = x509.ParsePKCS1PublicKey(block.Bytes); err != nil {
return nil, err
}
}
}
var pkey *rsa.PublicKey
var ok bool
if pkey, ok = parsedKey.(*rsa.PublicKey); !ok {
return nil, ErrNotRSAPublicKey
}
return pkey, nil
}
================================================
FILE: gossh/gin/jwt/signing_method.go
================================================
package jwt
import (
"sync"
)
var signingMethods = map[string]func() SigningMethod{}
var signingMethodLock = new(sync.RWMutex)
// SigningMethod can be used add new methods for signing or verifying tokens. It
// takes a decoded signature as an input in the Verify function and produces a
// signature in Sign. The signature is then usually base64 encoded as part of a
// JWT.
type SigningMethod interface {
Verify(signingString string, sig []byte, key any) error // Returns nil if signature is valid
Sign(signingString string, key any) ([]byte, error) // Returns signature or error
Alg() string // returns the alg identifier for this method (example: 'HS256')
}
// RegisterSigningMethod registers the "alg" name and a factory function for signing method.
// This is typically done during init() in the method's implementation
func RegisterSigningMethod(alg string, f func() SigningMethod) {
signingMethodLock.Lock()
defer signingMethodLock.Unlock()
signingMethods[alg] = f
}
// GetSigningMethod retrieves a signing method from an "alg" string
func GetSigningMethod(alg string) (method SigningMethod) {
signingMethodLock.RLock()
defer signingMethodLock.RUnlock()
if methodF, ok := signingMethods[alg]; ok {
method = methodF()
}
return
}
// GetAlgorithms returns a list of registered "alg" names
func GetAlgorithms() (algs []string) {
signingMethodLock.RLock()
defer signingMethodLock.RUnlock()
for alg := range signingMethods {
algs = append(algs, alg)
}
return
}
================================================
FILE: gossh/gin/jwt/token.go
================================================
package jwt
import (
"crypto"
"encoding/base64"
"encoding/json"
)
// Keyfunc will be used by the Parse methods as a callback function to supply
// the key for verification. The function receives the parsed, but unverified
// Token. This allows you to use properties in the Header of the token (such as
// `kid`) to identify which key to use.
//
// The returned interface{} may be a single key or a VerificationKeySet containing
// multiple keys.
type Keyfunc func(*Token) (any, error)
// VerificationKey represents a public or secret key for verifying a token's signature.
type VerificationKey interface {
crypto.PublicKey | []uint8
}
// VerificationKeySet is a set of public or secret keys. It is used by the parser to verify a token.
type VerificationKeySet struct {
Keys []VerificationKey
}
// Token represents a JWT Token. Different fields will be used depending on
// whether you're creating or parsing/verifying a token.
type Token struct {
Raw string // Raw contains the raw token. Populated when you [Parse] a token
Method SigningMethod // Method is the signing method used or to be used
Header map[string]any // Header is the first segment of the token in decoded form
Claims Claims // Claims is the second segment of the token in decoded form
Signature []byte // Signature is the third segment of the token in decoded form. Populated when you Parse a token
Valid bool // Valid specifies if the token is valid. Populated when you Parse/Verify a token
}
// New creates a new [Token] with the specified signing method and an empty map
// of claims. Additional options can be specified, but are currently unused.
func New(method SigningMethod, opts ...TokenOption) *Token {
return NewWithClaims(method, MapClaims{}, opts...)
}
// NewWithClaims creates a new [Token] with the specified signing method and
// claims. Additional options can be specified, but are currently unused.
func NewWithClaims(method SigningMethod, claims Claims, opts ...TokenOption) *Token {
return &Token{
Header: map[string]any{
"typ": "JWT",
"alg": method.Alg(),
},
Claims: claims,
Method: method,
}
}
// SignedString creates and returns a complete, signed JWT. The token is signed
// using the SigningMethod specified in the token. Please refer to
// https://golang-jwt.github.io/jwt/usage/signing_methods/#signing-methods-and-key-types
// for an overview of the different signing methods and their respective key
// types.
func (t *Token) SignedString(key any) (string, error) {
sstr, err := t.SigningString()
if err != nil {
return "", err
}
sig, err := t.Method.Sign(sstr, key)
if err != nil {
return "", err
}
return sstr + "." + t.EncodeSegment(sig), nil
}
// SigningString generates the signing string. This is the most expensive part
// of the whole deal. Unless you need this for something special, just go
// straight for the SignedString.
func (t *Token) SigningString() (string, error) {
h, err := json.Marshal(t.Header)
if err != nil {
return "", err
}
c, err := json.Marshal(t.Claims)
if err != nil {
return "", err
}
return t.EncodeSegment(h) + "." + t.EncodeSegment(c), nil
}
// EncodeSegment encodes a JWT specific base64url encoding with padding
// stripped. In the future, this function might take into account a
// [TokenOption]. Therefore, this function exists as a method of [Token], rather
// than a global function.
func (*Token) EncodeSegment(seg []byte) string {
return base64.RawURLEncoding.EncodeToString(seg)
}
================================================
FILE: gossh/gin/jwt/token_option.go
================================================
package jwt
// TokenOption is a reserved type, which provides some forward compatibility,
// if we ever want to introduce token creation-related options.
type TokenOption func(*Token)
================================================
FILE: gossh/gin/jwt/types.go
================================================
package jwt
import (
"encoding/json"
"fmt"
"math"
"strconv"
"time"
)
// TimePrecision sets the precision of times and dates within this library. This
// has an influence on the precision of times when comparing expiry or other
// related time fields. Furthermore, it is also the precision of times when
// serializing.
//
// For backwards compatibility the default precision is set to seconds, so that
// no fractional timestamps are generated.
var TimePrecision = time.Second
// MarshalSingleStringAsArray modifies the behavior of the ClaimStrings type,
// especially its MarshalJSON function.
//
// If it is set to true (the default), it will always serialize the type as an
// array of strings, even if it just contains one element, defaulting to the
// behavior of the underlying []string. If it is set to false, it will serialize
// to a single string, if it contains one element. Otherwise, it will serialize
// to an array of strings.
var MarshalSingleStringAsArray = true
// NumericDate represents a JSON numeric date value, as referenced at
// https://datatracker.ietf.org/doc/html/rfc7519#section-2.
type NumericDate struct {
time.Time
}
// NewNumericDate constructs a new *NumericDate from a standard library time.Time struct.
// It will truncate the timestamp according to the precision specified in TimePrecision.
func NewNumericDate(t time.Time) *NumericDate {
return &NumericDate{t.Truncate(TimePrecision)}
}
// newNumericDateFromSeconds creates a new *NumericDate out of a float64 representing a
// UNIX epoch with the float fraction representing non-integer seconds.
func newNumericDateFromSeconds(f float64) *NumericDate {
round, frac := math.Modf(f)
return NewNumericDate(time.Unix(int64(round), int64(frac*1e9)))
}
// MarshalJSON is an implementation of the json.RawMessage interface and serializes the UNIX epoch
// represented in NumericDate to a byte array, using the precision specified in TimePrecision.
func (date NumericDate) MarshalJSON() (b []byte, err error) {
var prec int
if TimePrecision < time.Second {
prec = int(math.Log10(float64(time.Second) / float64(TimePrecision)))
}
truncatedDate := date.Truncate(TimePrecision)
// For very large timestamps, UnixNano would overflow an int64, but this
// function requires nanosecond level precision, so we have to use the
// following technique to get round the issue:
//
// 1. Take the normal unix timestamp to form the whole number part of the
// output,
// 2. Take the result of the Nanosecond function, which returns the offset
// within the second of the particular unix time instance, to form the
// decimal part of the output
// 3. Concatenate them to produce the final result
seconds := strconv.FormatInt(truncatedDate.Unix(), 10)
nanosecondsOffset := strconv.FormatFloat(float64(truncatedDate.Nanosecond())/float64(time.Second), 'f', prec, 64)
output := append([]byte(seconds), []byte(nanosecondsOffset)[1:]...)
return output, nil
}
// UnmarshalJSON is an implementation of the json.RawMessage interface and
// deserializes a [NumericDate] from a JSON representation, i.e. a
// [json.Number]. This number represents an UNIX epoch with either integer or
// non-integer seconds.
func (date *NumericDate) UnmarshalJSON(b []byte) (err error) {
var (
number json.Number
f float64
)
if err = json.Unmarshal(b, &number); err != nil {
return fmt.Errorf("could not parse NumericData: %w", err)
}
if f, err = number.Float64(); err != nil {
return fmt.Errorf("could not convert json number value to float: %w", err)
}
n := newNumericDateFromSeconds(f)
*date = *n
return nil
}
// ClaimStrings is basically just a slice of strings, but it can be either
// serialized from a string array or just a string. This type is necessary,
// since the "aud" claim can either be a single string or an array.
type ClaimStrings []string
func (s *ClaimStrings) UnmarshalJSON(data []byte) (err error) {
var value any
if err = json.Unmarshal(data, &value); err != nil {
return err
}
var aud []string
switch v := value.(type) {
case string:
aud = append(aud, v)
case []string:
aud = ClaimStrings(v)
case []any:
for _, vv := range v {
vs, ok := vv.(string)
if !ok {
return ErrInvalidType
}
aud = append(aud, vs)
}
case nil:
return nil
default:
return ErrInvalidType
}
*s = aud
return
}
func (s ClaimStrings) MarshalJSON() (b []byte, err error) {
// This handles a special case in the JWT RFC. If the string array, e.g.
// used by the "aud" field, only contains one element, it MAY be serialized
// as a single string. This may or may not be desired based on the ecosystem
// of other JWT library used, so we make it configurable by the variable
// MarshalSingleStringAsArray.
if len(s) == 1 && !MarshalSingleStringAsArray {
return json.Marshal(s[0])
}
return json.Marshal([]string(s))
}
================================================
FILE: gossh/gin/jwt/validator.go
================================================
package jwt
import (
"crypto/subtle"
"fmt"
"time"
)
// ClaimsValidator is an interface that can be implemented by custom claims who
// wish to execute any additional claims validation based on
// application-specific logic. The Validate function is then executed in
// addition to the regular claims validation and any error returned is appended
// to the final validation result.
//
// type MyCustomClaims struct {
// Foo string `json:"foo"`
// jwt.RegisteredClaims
// }
//
// func (m MyCustomClaims) Validate() error {
// if m.Foo != "bar" {
// return errors.New("must be foobar")
// }
// return nil
// }
type ClaimsValidator interface {
Claims
Validate() error
}
// Validator is the core of the new Validation API. It is automatically used by
// a [Parser] during parsing and can be modified with various parser options.
//
// The [NewValidator] function should be used to create an instance of this
// struct.
type Validator struct {
// leeway is an optional leeway that can be provided to account for clock skew.
leeway time.Duration
// timeFunc is used to supply the current time that is needed for
// validation. If unspecified, this defaults to time.Now.
timeFunc func() time.Time
// requireExp specifies whether the exp claim is required
requireExp bool
// verifyIat specifies whether the iat (Issued At) claim will be verified.
// According to https://www.rfc-editor.org/rfc/rfc7519#section-4.1.6 this
// only specifies the age of the token, but no validation check is
// necessary. However, if wanted, it can be checked if the iat is
// unrealistic, i.e., in the future.
verifyIat bool
// expectedAud contains the audience this token expects. Supplying an empty
// string will disable aud checking.
expectedAud string
// expectedIss contains the issuer this token expects. Supplying an empty
// string will disable iss checking.
expectedIss string
// expectedSub contains the subject this token expects. Supplying an empty
// string will disable sub checking.
expectedSub string
}
// NewValidator can be used to create a stand-alone validator with the supplied
// options. This validator can then be used to validate already parsed claims.
//
// Note: Under normal circumstances, explicitly creating a validator is not
// needed and can potentially be dangerous; instead functions of the [Parser]
// class should be used.
//
// The [Validator] is only checking the *validity* of the claims, such as its
// expiration time, but it does NOT perform *signature verification* of the
// token.
func NewValidator(opts ...ParserOption) *Validator {
p := NewParser(opts...)
return p.validator
}
// Validate validates the given claims. It will also perform any custom
// validation if claims implements the [ClaimsValidator] interface.
//
// Note: It will NOT perform any *signature verification* on the token that
// contains the claims and expects that the [Claim] was already successfully
// verified.
func (v *Validator) Validate(claims Claims) error {
var (
now time.Time
errs []error = make([]error, 0, 6)
err error
)
// Check, if we have a time func
if v.timeFunc != nil {
now = v.timeFunc()
} else {
now = time.Now()
}
// We always need to check the expiration time, but usage of the claim
// itself is OPTIONAL by default. requireExp overrides this behavior
// and makes the exp claim mandatory.
if err = v.verifyExpiresAt(claims, now, v.requireExp); err != nil {
errs = append(errs, err)
}
// We always need to check not-before, but usage of the claim itself is
// OPTIONAL.
if err = v.verifyNotBefore(claims, now, false); err != nil {
errs = append(errs, err)
}
// Check issued-at if the option is enabled
if v.verifyIat {
if err = v.verifyIssuedAt(claims, now, false); err != nil {
errs = append(errs, err)
}
}
// If we have an expected audience, we also require the audience claim
if v.expectedAud != "" {
if err = v.verifyAudience(claims, v.expectedAud, true); err != nil {
errs = append(errs, err)
}
}
// If we have an expected issuer, we also require the issuer claim
if v.expectedIss != "" {
if err = v.verifyIssuer(claims, v.expectedIss, true); err != nil {
errs = append(errs, err)
}
}
// If we have an expected subject, we also require the subject claim
if v.expectedSub != "" {
if err = v.verifySubject(claims, v.expectedSub, true); err != nil {
errs = append(errs, err)
}
}
// Finally, we want to give the claim itself some possibility to do some
// additional custom validation based on a custom Validate function.
cvt, ok := claims.(ClaimsValidator)
if ok {
if err := cvt.Validate(); err != nil {
errs = append(errs, err)
}
}
if len(errs) == 0 {
return nil
}
return joinErrors(errs...)
}
// verifyExpiresAt compares the exp claim in claims against cmp. This function
// will succeed if cmp < exp. Additional leeway is taken into account.
//
// If exp is not set, it will succeed if the claim is not required,
// otherwise ErrTokenRequiredClaimMissing will be returned.
//
// Additionally, if any error occurs while retrieving the claim, e.g., when its
// the wrong type, an ErrTokenUnverifiable error will be returned.
func (v *Validator) verifyExpiresAt(claims Claims, cmp time.Time, required bool) error {
exp, err := claims.GetExpirationTime()
if err != nil {
return err
}
if exp == nil {
return errorIfRequired(required, "exp")
}
return errorIfFalse(cmp.Before((exp.Time).Add(+v.leeway)), ErrTokenExpired)
}
// verifyIssuedAt compares the iat claim in claims against cmp. This function
// will succeed if cmp >= iat. Additional leeway is taken into account.
//
// If iat is not set, it will succeed if the claim is not required,
// otherwise ErrTokenRequiredClaimMissing will be returned.
//
// Additionally, if any error occurs while retrieving the claim, e.g., when its
// the wrong type, an ErrTokenUnverifiable error will be returned.
func (v *Validator) verifyIssuedAt(claims Claims, cmp time.Time, required bool) error {
iat, err := claims.GetIssuedAt()
if err != nil {
return err
}
if iat == nil {
return errorIfRequired(required, "iat")
}
return errorIfFalse(!cmp.Before(iat.Add(-v.leeway)), ErrTokenUsedBeforeIssued)
}
// verifyNotBefore compares the nbf claim in claims against cmp. This function
// will return true if cmp >= nbf. Additional leeway is taken into account.
//
// If nbf is not set, it will succeed if the claim is not required,
// otherwise ErrTokenRequiredClaimMissing will be returned.
//
// Additionally, if any error occurs while retrieving the claim, e.g., when its
// the wrong type, an ErrTokenUnverifiable error will be returned.
func (v *Validator) verifyNotBefore(claims Claims, cmp time.Time, required bool) error {
nbf, err := claims.GetNotBefore()
if err != nil {
return err
}
if nbf == nil {
return errorIfRequired(required, "nbf")
}
return errorIfFalse(!cmp.Before(nbf.Add(-v.leeway)), ErrTokenNotValidYet)
}
// verifyAudience compares the aud claim against cmp.
//
// If aud is not set or an empty list, it will succeed if the claim is not required,
// otherwise ErrTokenRequiredClaimMissing will be returned.
//
// Additionally, if any error occurs while retrieving the claim, e.g., when its
// the wrong type, an ErrTokenUnverifiable error will be returned.
func (v *Validator) verifyAudience(claims Claims, cmp string, required bool) error {
aud, err := claims.GetAudience()
if err != nil {
return err
}
if len(aud) == 0 {
return errorIfRequired(required, "aud")
}
// use a var here to keep constant time compare when looping over a number of claims
result := false
var stringClaims string
for _, a := range aud {
if subtle.ConstantTimeCompare([]byte(a), []byte(cmp)) != 0 {
result = true
}
stringClaims = stringClaims + a
}
// case where "" is sent in one or many aud claims
if stringClaims == "" {
return errorIfRequired(required, "aud")
}
return errorIfFalse(result, ErrTokenInvalidAudience)
}
// verifyIssuer compares the iss claim in claims against cmp.
//
// If iss is not set, it will succeed if the claim is not required,
// otherwise ErrTokenRequiredClaimMissing will be returned.
//
// Additionally, if any error occurs while retrieving the claim, e.g., when its
// the wrong type, an ErrTokenUnverifiable error will be returned.
func (v *Validator) verifyIssuer(claims Claims, cmp string, required bool) error {
iss, err := claims.GetIssuer()
if err != nil {
return err
}
if iss == "" {
return errorIfRequired(required, "iss")
}
return errorIfFalse(iss == cmp, ErrTokenInvalidIssuer)
}
// verifySubject compares the sub claim against cmp.
//
// If sub is not set, it will succeed if the claim is not required,
// otherwise ErrTokenRequiredClaimMissing will be returned.
//
// Additionally, if any error occurs while retrieving the claim, e.g., when its
// the wrong type, an ErrTokenUnverifiable error will be returned.
func (v *Validator) verifySubject(claims Claims, cmp string, required bool) error {
sub, err := claims.GetSubject()
if err != nil {
return err
}
if sub == "" {
return errorIfRequired(required, "sub")
}
return errorIfFalse(sub == cmp, ErrTokenInvalidSubject)
}
// errorIfFalse returns the error specified in err, if the value is true.
// Otherwise, nil is returned.
func errorIfFalse(value bool, err error) error {
if value {
return nil
} else {
return err
}
}
// errorIfRequired returns an ErrTokenRequiredClaimMissing error if required is
// true. Otherwise, nil is returned.
func errorIfRequired(required bool, claim string) error {
if required {
return newError(fmt.Sprintf("%s claim is required", claim), ErrTokenRequiredClaimMissing)
} else {
return nil
}
}
================================================
FILE: gossh/gin/logger.go
================================================
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package gin
import (
"fmt"
"io"
"net/http"
"os"
"time"
)
type consoleColorModeValue int
const (
autoColor consoleColorModeValue = iota
disableColor
forceColor
)
const (
green = "\033[97;42m"
white = "\033[90;47m"
yellow = "\033[90;43m"
red = "\033[97;41m"
blue = "\033[97;44m"
magenta = "\033[97;45m"
cyan = "\033[97;46m"
reset = "\033[0m"
)
var consoleColorMode = autoColor
// LoggerConfig defines the config for Logger middleware.
type LoggerConfig struct {
// Optional. Default value is gin.defaultLogFormatter
Formatter LogFormatter
// Output is a writer where logs are written.
// Optional. Default value is gin.DefaultWriter.
Output io.Writer
// SkipPaths is an url path array which logs are not written.
// Optional.
SkipPaths []string
// Skip is a Skipper that indicates which logs should not be written.
// Optional.
Skip Skipper
}
// Skipper is a function to skip logs based on provided Context
type Skipper func(c *Context) bool
// LogFormatter gives the signature of the formatter function passed to LoggerWithFormatter
type LogFormatter func(params LogFormatterParams) string
// LogFormatterParams is the structure any formatter will be handed when time to log comes
type LogFormatterParams struct {
Request *http.Request
// TimeStamp shows the time after the server returns a response.
TimeStamp time.Time
// StatusCode is HTTP response code.
StatusCode int
// Latency is how much time the server cost to process a certain request.
Latency time.Duration
// ClientIP equals Context's ClientIP method.
ClientIP string
// Method is the HTTP method given to the request.
Method string
// Path is a path the client requests.
Path string
// ErrorMessage is set if error has occurred in processing the request.
ErrorMessage string
// isTerm shows whether gin's output descriptor refers to a terminal.
isTerm bool
// BodySize is the size of the Response Body
BodySize int
// Keys are the keys set on the request's context.
Keys map[string]any
}
// StatusCodeColor is the ANSI color for appropriately logging http status code to a terminal.
func (p *LogFormatterParams) StatusCodeColor() string {
code := p.StatusCode
switch {
case code >= http.StatusContinue && code < http.StatusOK:
return white
case code >= http.StatusOK && code < http.StatusMultipleChoices:
return green
case code >= http.StatusMultipleChoices && code < http.StatusBadRequest:
return white
case code >= http.StatusBadRequest && code < http.StatusInternalServerError:
return yellow
default:
return red
}
}
// MethodColor is the ANSI color for appropriately logging http method to a terminal.
func (p *LogFormatterParams) MethodColor() string {
method := p.Method
switch method {
case http.MethodGet:
return blue
case http.MethodPost:
return cyan
case http.MethodPut:
return yellow
case http.MethodDelete:
return red
case http.MethodPatch:
return green
case http.MethodHead:
return magenta
case http.MethodOptions:
return white
default:
return reset
}
}
// ResetColor resets all escape attributes.
func (p *LogFormatterParams) ResetColor() string {
return reset
}
// IsOutputColor indicates whether can colors be outputted to the log.
func (p *LogFormatterParams) IsOutputColor() bool {
return consoleColorMode == forceColor || (consoleColorMode == autoColor && p.isTerm)
}
// defaultLogFormatter is the default log format function Logger middleware uses.
var defaultLogFormatter = func(param LogFormatterParams) string {
var statusColor, methodColor, resetColor string
if param.IsOutputColor() {
statusColor = param.StatusCodeColor()
methodColor = param.MethodColor()
resetColor = param.ResetColor()
}
if param.Latency > time.Minute {
param.Latency = param.Latency.Truncate(time.Second)
}
return fmt.Sprintf("[GIN] %v |%s %3d %s| %13v | %15s |%s %-7s %s %#v\n%s",
param.TimeStamp.Format("2006/01/02 - 15:04:05"),
statusColor, param.StatusCode, resetColor,
param.Latency,
param.ClientIP,
methodColor, param.Method, resetColor,
param.Path,
param.ErrorMessage,
)
}
// DisableConsoleColor disables color output in the console.
func DisableConsoleColor() {
consoleColorMode = disableColor
}
// ForceConsoleColor force color output in the console.
func ForceConsoleColor() {
consoleColorMode = forceColor
}
// ErrorLogger returns a HandlerFunc for any error type.
func ErrorLogger() HandlerFunc {
return ErrorLoggerT(ErrorTypeAny)
}
// ErrorLoggerT returns a HandlerFunc for a given error type.
func ErrorLoggerT(typ ErrorType) HandlerFunc {
return func(c *Context) {
c.Next()
errors := c.Errors.ByType(typ)
if len(errors) > 0 {
c.JSON(-1, errors)
}
}
}
// Logger instances a Logger middleware that will write the logs to gin.DefaultWriter.
// By default, gin.DefaultWriter = os.Stdout.
func Logger() HandlerFunc {
return LoggerWithConfig(LoggerConfig{})
}
// LoggerWithFormatter instance a Logger middleware with the specified log format function.
func LoggerWithFormatter(f LogFormatter) HandlerFunc {
return LoggerWithConfig(LoggerConfig{
Formatter: f,
})
}
// LoggerWithWriter instance a Logger middleware with the specified writer buffer.
// Example: os.Stdout, a file opened in write mode, a socket...
func LoggerWithWriter(out io.Writer, notlogged ...string) HandlerFunc {
return LoggerWithConfig(LoggerConfig{
Output: out,
SkipPaths: notlogged,
})
}
// LoggerWithConfig instance a Logger middleware with config.
func LoggerWithConfig(conf LoggerConfig) HandlerFunc {
formatter := conf.Formatter
if formatter == nil {
formatter = defaultLogFormatter
}
out := conf.Output
if out == nil {
out = DefaultWriter
}
notlogged := conf.SkipPaths
isTerm := true
if _, ok := out.(*os.File); !ok || os.Getenv("TERM") == "dumb" {
isTerm = false
}
var skip map[string]struct{}
if length := len(notlogged); length > 0 {
skip = make(map[string]struct{}, length)
for _, path := range notlogged {
skip[path] = struct{}{}
}
}
return func(c *Context) {
// Start timer
start := time.Now()
path := c.Request.URL.Path
raw := c.Request.URL.RawQuery
// Process request
c.Next()
// Log only when it is not being skipped
if _, ok := skip[path]; ok || (conf.Skip != nil && conf.Skip(c)) {
return
}
param := LogFormatterParams{
Request: c.Request,
isTerm: isTerm,
Keys: c.Keys,
}
// Stop timer
param.TimeStamp = time.Now()
param.Latency = param.TimeStamp.Sub(start)
param.ClientIP = c.ClientIP()
param.Method = c.Request.Method
param.StatusCode = c.Writer.Status()
param.ErrorMessage = c.Errors.ByType(ErrorTypePrivate).String()
param.BodySize = c.Writer.Size()
if raw != "" {
path = path + "?" + raw
}
param.Path = path
fmt.Fprint(out, formatter(param))
}
}
================================================
FILE: gossh/gin/mode.go
================================================
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package gin
import (
"flag"
"io"
"os"
"gossh/gin/binding"
)
// EnvGinMode indicates environment name for gin mode.
const EnvGinMode = "GIN_MODE"
const (
// DebugMode indicates gin mode is debug.
DebugMode = "debug"
// ReleaseMode indicates gin mode is release.
ReleaseMode = "release"
// TestMode indicates gin mode is test.
TestMode = "test"
)
const (
debugCode = iota
releaseCode
testCode
)
// DefaultWriter is the default io.Writer used by Gin for debug output and
// middleware output like Logger() or Recovery().
// Note that both Logger and Recovery provides custom ways to configure their
// output io.Writer.
// To support coloring in Windows use:
//
// import "github.com/mattn/go-colorable"
// gin.DefaultWriter = colorable.NewColorableStdout()
var DefaultWriter io.Writer = os.Stdout
// DefaultErrorWriter is the default io.Writer used by Gin to debug errors
var DefaultErrorWriter io.Writer = os.Stderr
var (
ginMode = debugCode
modeName = DebugMode
)
func init() {
mode := os.Getenv(EnvGinMode)
SetMode(mode)
}
// SetMode sets gin mode according to input string.
func SetMode(value string) {
if value == "" {
if flag.Lookup("test.v") != nil {
value = TestMode
} else {
value = DebugMode
}
}
switch value {
case DebugMode:
ginMode = debugCode
case ReleaseMode:
ginMode = releaseCode
case TestMode:
ginMode = testCode
default:
panic("gin mode unknown: " + value + " (available mode: debug release test)")
}
modeName = value
}
// DisableBindValidation closes the default validator.
func DisableBindValidation() {
binding.Validator = nil
}
// EnableJsonDecoderUseNumber sets true for binding.EnableDecoderUseNumber to
// call the UseNumber method on the JSON Decoder instance.
func EnableJsonDecoderUseNumber() {
binding.EnableDecoderUseNumber = true
}
// EnableJsonDecoderDisallowUnknownFields sets true for binding.EnableDecoderDisallowUnknownFields to
// call the DisallowUnknownFields method on the JSON Decoder instance.
func EnableJsonDecoderDisallowUnknownFields() {
binding.EnableDecoderDisallowUnknownFields = true
}
// Mode returns current gin mode.
func Mode() string {
return modeName
}
================================================
FILE: gossh/gin/path.go
================================================
// Copyright 2013 Julien Schmidt. All rights reserved.
// Based on the path package, Copyright 2009 The Go Authors.
// Use of this source code is governed by a BSD-style license that can be found
// at https://github.com/julienschmidt/httprouter/blob/master/LICENSE.
package gin
// cleanPath is the URL version of path.Clean, it returns a canonical URL path
// for p, eliminating . and .. elements.
//
// The following rules are applied iteratively until no further processing can
// be done:
// 1. Replace multiple slashes with a single slash.
// 2. Eliminate each . path name element (the current directory).
// 3. Eliminate each inner .. path name element (the parent directory)
// along with the non-.. element that precedes it.
// 4. Eliminate .. elements that begin a rooted path:
// that is, replace "/.." by "/" at the beginning of a path.
//
// If the result of this process is an empty string, "/" is returned.
func cleanPath(p string) string {
const stackBufSize = 128
// Turn empty string into "/"
if p == "" {
return "/"
}
// Reasonably sized buffer on stack to avoid allocations in the common case.
// If a larger buffer is required, it gets allocated dynamically.
buf := make([]byte, 0, stackBufSize)
n := len(p)
// Invariants:
// reading from path; r is index of next byte to process.
// writing to buf; w is index of next byte to write.
// path must start with '/'
r := 1
w := 1
if p[0] != '/' {
r = 0
if n+1 > stackBufSize {
buf = make([]byte, n+1)
} else {
buf = buf[:n+1]
}
buf[0] = '/'
}
trailing := n > 1 && p[n-1] == '/'
// A bit more clunky without a 'lazybuf' like the path package, but the loop
// gets completely inlined (bufApp calls).
// loop has no expensive function calls (except 1x make) // So in contrast to the path package this loop has no expensive function
// calls (except make, if needed).
for r < n {
switch {
case p[r] == '/':
// empty path element, trailing slash is added after the end
r++
case p[r] == '.' && r+1 == n:
trailing = true
r++
case p[r] == '.' && p[r+1] == '/':
// . element
r += 2
case p[r] == '.' && p[r+1] == '.' && (r+2 == n || p[r+2] == '/'):
// .. element: remove to last /
r += 3
if w > 1 {
// can backtrack
w--
if len(buf) == 0 {
for w > 1 && p[w] != '/' {
w--
}
} else {
for w > 1 && buf[w] != '/' {
w--
}
}
}
default:
// Real path element.
// Add slash if needed
if w > 1 {
bufApp(&buf, p, w, '/')
w++
}
// Copy element
for r < n && p[r] != '/' {
bufApp(&buf, p, w, p[r])
w++
r++
}
}
}
// Re-append trailing slash
if trailing && w > 1 {
bufApp(&buf, p, w, '/')
w++
}
// If the original string was not modified (or only shortened at the end),
// return the respective substring of the original string.
// Otherwise return a new string from the buffer.
if len(buf) == 0 {
return p[:w]
}
return string(buf[:w])
}
// Internal helper to lazily create a buffer if necessary.
// Calls to this function get inlined.
func bufApp(buf *[]byte, s string, w int, c byte) {
b := *buf
if len(b) == 0 {
// No modification of the original string so far.
// If the next character is the same as in the original string, we do
// not yet have to allocate a buffer.
if s[w] == c {
return
}
// Otherwise use either the stack buffer, if it is large enough, or
// allocate a new buffer on the heap, and copy all previous characters.
length := len(s)
if length > cap(b) {
*buf = make([]byte, length)
} else {
*buf = (*buf)[:length]
}
b = *buf
copy(b, s[:w])
}
b[w] = c
}
================================================
FILE: gossh/gin/recovery.go
================================================
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package gin
import (
"bytes"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"net/http/httputil"
"os"
"runtime"
"strings"
"time"
)
var (
dunno = []byte("???")
centerDot = []byte("·")
dot = []byte(".")
slash = []byte("/")
)
// RecoveryFunc defines the function passable to CustomRecovery.
type RecoveryFunc func(c *Context, err any)
// Recovery returns a middleware that recovers from any panics and writes a 500 if there was one.
func Recovery() HandlerFunc {
return RecoveryWithWriter(DefaultErrorWriter)
}
// CustomRecovery returns a middleware that recovers from any panics and calls the provided handle func to handle it.
func CustomRecovery(handle RecoveryFunc) HandlerFunc {
return RecoveryWithWriter(DefaultErrorWriter, handle)
}
// RecoveryWithWriter returns a middleware for a given writer that recovers from any panics and writes a 500 if there was one.
func RecoveryWithWriter(out io.Writer, recovery ...RecoveryFunc) HandlerFunc {
if len(recovery) > 0 {
return CustomRecoveryWithWriter(out, recovery[0])
}
return CustomRecoveryWithWriter(out, defaultHandleRecovery)
}
// CustomRecoveryWithWriter returns a middleware for a given writer that recovers from any panics and calls the provided handle func to handle it.
func CustomRecoveryWithWriter(out io.Writer, handle RecoveryFunc) HandlerFunc {
var logger *log.Logger
if out != nil {
logger = log.New(out, "\n\n\x1b[31m", log.LstdFlags)
}
return func(c *Context) {
defer func() {
if err := recover(); err != nil {
// Check for a broken connection, as it is not really a
// condition that warrants a panic stack trace.
var brokenPipe bool
if ne, ok := err.(*net.OpError); ok {
var se *os.SyscallError
if errors.As(ne, &se) {
seStr := strings.ToLower(se.Error())
if strings.Contains(seStr, "broken pipe") ||
strings.Contains(seStr, "connection reset by peer") {
brokenPipe = true
}
}
}
if logger != nil {
stack := stack(3)
httpRequest, _ := httputil.DumpRequest(c.Request, false)
headers := strings.Split(string(httpRequest), "\r\n")
for idx, header := range headers {
current := strings.Split(header, ":")
if current[0] == "Authorization" {
headers[idx] = current[0] + ": *"
}
}
headersToStr := strings.Join(headers, "\r\n")
if brokenPipe {
logger.Printf("%s\n%s%s", err, headersToStr, reset)
} else if IsDebugging() {
logger.Printf("[Recovery] %s panic recovered:\n%s\n%s\n%s%s",
timeFormat(time.Now()), headersToStr, err, stack, reset)
} else {
logger.Printf("[Recovery] %s panic recovered:\n%s\n%s%s",
timeFormat(time.Now()), err, stack, reset)
}
}
if brokenPipe {
// If the connection is dead, we can't write a status to it.
c.Error(err.(error)) //nolint: errcheck
c.Abort()
} else {
handle(c, err)
}
}
}()
c.Next()
}
}
func defaultHandleRecovery(c *Context, _ any) {
c.AbortWithStatus(http.StatusInternalServerError)
}
// stack returns a nicely formatted stack frame, skipping skip frames.
func stack(skip int) []byte {
buf := new(bytes.Buffer) // the returned data
// As we loop, we open files and read them. These variables record the currently
// loaded file.
var lines [][]byte
var lastFile string
for i := skip; ; i++ { // Skip the expected number of frames
pc, file, line, ok := runtime.Caller(i)
if !ok {
break
}
// Print this much at least. If we can't find the source, it won't show.
fmt.Fprintf(buf, "%s:%d (0x%x)\n", file, line, pc)
if file != lastFile {
data, err := os.ReadFile(file)
if err != nil {
continue
}
lines = bytes.Split(data, []byte{'\n'})
lastFile = file
}
fmt.Fprintf(buf, "\t%s: %s\n", function(pc), source(lines, line))
}
return buf.Bytes()
}
// source returns a space-trimmed slice of the n'th line.
func source(lines [][]byte, n int) []byte {
n-- // in stack trace, lines are 1-indexed but our array is 0-indexed
if n < 0 || n >= len(lines) {
return dunno
}
return bytes.TrimSpace(lines[n])
}
// function returns, if possible, the name of the function containing the PC.
func function(pc uintptr) []byte {
fn := runtime.FuncForPC(pc)
if fn == nil {
return dunno
}
name := []byte(fn.Name())
// The name includes the path name to the package, which is unnecessary
// since the file name is already included. Plus, it has center dots.
// That is, we see
// runtime/debug.*T·ptrmethod
// and want
// *T.ptrmethod
// Also the package path might contain dot (e.g. code.google.com/...),
// so first eliminate the path prefix
if lastSlash := bytes.LastIndex(name, slash); lastSlash >= 0 {
name = name[lastSlash+1:]
}
if period := bytes.Index(name, dot); period >= 0 {
name = name[period+1:]
}
name = bytes.ReplaceAll(name, centerDot, dot)
return name
}
// timeFormat returns a customized time string for logger.
func timeFormat(t time.Time) string {
return t.Format("2006/01/02 - 15:04:05")
}
================================================
FILE: gossh/gin/render/data.go
================================================
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package render
import "net/http"
// Data contains ContentType and bytes data.
type Data struct {
ContentType string
Data []byte
}
// Render (Data) writes data with custom ContentType.
func (r Data) Render(w http.ResponseWriter) (err error) {
r.WriteContentType(w)
_, err = w.Write(r.Data)
return
}
// WriteContentType (Data) writes custom ContentType.
func (r Data) WriteContentType(w http.ResponseWriter) {
writeContentType(w, []string{r.ContentType})
}
================================================
FILE: gossh/gin/render/html.go
================================================
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package render
import (
"html/template"
"net/http"
)
// Delims represents a set of Left and Right delimiters for HTML template rendering.
type Delims struct {
// Left delimiter, defaults to {{.
Left string
// Right delimiter, defaults to }}.
Right string
}
// HTMLRender interface is to be implemented by HTMLProduction and HTMLDebug.
type HTMLRender interface {
// Instance returns an HTML instance.
Instance(string, any) Render
}
// HTMLProduction contains template reference and its delims.
type HTMLProduction struct {
Template *template.Template
Delims Delims
}
// HTMLDebug contains template delims and pattern and function with file list.
type HTMLDebug struct {
Files []string
Glob string
Delims Delims
FuncMap template.FuncMap
}
// HTML contains template reference and its name with given interface object.
type HTML struct {
Template *template.Template
Name string
Data any
}
var htmlContentType = []string{"text/html; charset=utf-8"}
// Instance (HTMLProduction) returns an HTML instance which it realizes Render interface.
func (r HTMLProduction) Instance(name string, data any) Render {
return HTML{
Template: r.Template,
Name: name,
Data: data,
}
}
// Instance (HTMLDebug) returns an HTML instance which it realizes Render interface.
func (r HTMLDebug) Instance(name string, data any) Render {
return HTML{
Template: r.loadTemplate(),
Name: name,
Data: data,
}
}
func (r HTMLDebug) loadTemplate() *template.Template {
if r.FuncMap == nil {
r.FuncMap = template.FuncMap{}
}
if len(r.Files) > 0 {
return template.Must(template.New("").Delims(r.Delims.Left, r.Delims.Right).Funcs(r.FuncMap).ParseFiles(r.Files...))
}
if r.Glob != "" {
return template.Must(template.New("").Delims(r.Delims.Left, r.Delims.Right).Funcs(r.FuncMap).ParseGlob(r.Glob))
}
panic("the HTML debug render was created without files or glob pattern")
}
// Render (HTML) executes template and writes its result with custom ContentType for response.
func (r HTML) Render(w http.ResponseWriter) error {
r.WriteContentType(w)
if r.Name == "" {
return r.Template.Execute(w, r.Data)
}
return r.Template.ExecuteTemplate(w, r.Name, r.Data)
}
// WriteContentType (HTML) writes HTML ContentType.
func (r HTML) WriteContentType(w http.ResponseWriter) {
writeContentType(w, htmlContentType)
}
================================================
FILE: gossh/gin/render/json.go
================================================
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package render
import (
"bytes"
"fmt"
"html/template"
"net/http"
"gossh/gin/internal/bytesconv"
"gossh/gin/internal/json"
)
// JSON contains the given interface object.
type JSON struct {
Data any
}
// IndentedJSON contains the given interface object.
type IndentedJSON struct {
Data any
}
// SecureJSON contains the given interface object and its prefix.
type SecureJSON struct {
Prefix string
Data any
}
// JsonpJSON contains the given interface object its callback.
type JsonpJSON struct {
Callback string
Data any
}
// AsciiJSON contains the given interface object.
type AsciiJSON struct {
Data any
}
// PureJSON contains the given interface object.
type PureJSON struct {
Data any
}
var (
jsonContentType = []string{"application/json; charset=utf-8"}
jsonpContentType = []string{"application/javascript; charset=utf-8"}
jsonASCIIContentType = []string{"application/json"}
)
// Render (JSON) writes data with custom ContentType.
func (r JSON) Render(w http.ResponseWriter) error {
return WriteJSON(w, r.Data)
}
// WriteContentType (JSON) writes JSON ContentType.
func (r JSON) WriteContentType(w http.ResponseWriter) {
writeContentType(w, jsonContentType)
}
// WriteJSON marshals the given interface object and writes it with custom ContentType.
func WriteJSON(w http.ResponseWriter, obj any) error {
writeContentType(w, jsonContentType)
jsonBytes, err := json.Marshal(obj)
if err != nil {
return err
}
_, err = w.Write(jsonBytes)
return err
}
// Render (IndentedJSON) marshals the given interface object and writes it with custom ContentType.
func (r IndentedJSON) Render(w http.ResponseWriter) error {
r.WriteContentType(w)
jsonBytes, err := json.MarshalIndent(r.Data, "", " ")
if err != nil {
return err
}
_, err = w.Write(jsonBytes)
return err
}
// WriteContentType (IndentedJSON) writes JSON ContentType.
func (r IndentedJSON) WriteContentType(w http.ResponseWriter) {
writeContentType(w, jsonContentType)
}
// Render (SecureJSON) marshals the given interface object and writes it with custom ContentType.
func (r SecureJSON) Render(w http.ResponseWriter) error {
r.WriteContentType(w)
jsonBytes, err := json.Marshal(r.Data)
if err != nil {
return err
}
// if the jsonBytes is array values
if bytes.HasPrefix(jsonBytes, bytesconv.StringToBytes("[")) && bytes.HasSuffix(jsonBytes,
bytesconv.StringToBytes("]")) {
if _, err = w.Write(bytesconv.StringToBytes(r.Prefix)); err != nil {
return err
}
}
_, err = w.Write(jsonBytes)
return err
}
// WriteContentType (SecureJSON) writes JSON ContentType.
func (r SecureJSON) WriteContentType(w http.ResponseWriter) {
writeContentType(w, jsonContentType)
}
// Render (JsonpJSON) marshals the given interface object and writes it and its callback with custom ContentType.
func (r JsonpJSON) Render(w http.ResponseWriter) (err error) {
r.WriteContentType(w)
ret, err := json.Marshal(r.Data)
if err != nil {
return err
}
if r.Callback == "" {
_, err = w.Write(ret)
return err
}
callback := template.JSEscapeString(r.Callback)
if _, err = w.Write(bytesconv.StringToBytes(callback)); err != nil {
return err
}
if _, err = w.Write(bytesconv.StringToBytes("(")); err != nil {
return err
}
if _, err = w.Write(ret); err != nil {
return err
}
if _, err = w.Write(bytesconv.StringToBytes(");")); err != nil {
return err
}
return nil
}
// WriteContentType (JsonpJSON) writes Javascript ContentType.
func (r JsonpJSON) WriteContentType(w http.ResponseWriter) {
writeContentType(w, jsonpContentType)
}
// Render (AsciiJSON) marshals the given interface object and writes it with custom ContentType.
func (r AsciiJSON) Render(w http.ResponseWriter) (err error) {
r.WriteContentType(w)
ret, err := json.Marshal(r.Data)
if err != nil {
return err
}
var buffer bytes.Buffer
for _, r := range bytesconv.BytesToString(ret) {
cvt := string(r)
if r >= 128 {
cvt = fmt.Sprintf("\\u%04x", int64(r))
}
buffer.WriteString(cvt)
}
_, err = w.Write(buffer.Bytes())
return err
}
// WriteContentType (AsciiJSON) writes JSON ContentType.
func (r AsciiJSON) WriteContentType(w http.ResponseWriter) {
writeContentType(w, jsonASCIIContentType)
}
// Render (PureJSON) writes custom ContentType and encodes the given interface object.
func (r PureJSON) Render(w http.ResponseWriter) error {
r.WriteContentType(w)
encoder := json.NewEncoder(w)
encoder.SetEscapeHTML(false)
return encoder.Encode(r.Data)
}
// WriteContentType (PureJSON) writes custom ContentType.
func (r PureJSON) WriteContentType(w http.ResponseWriter) {
writeContentType(w, jsonContentType)
}
================================================
FILE: gossh/gin/render/reader.go
================================================
// Copyright 2018 Gin Core Team. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package render
import (
"io"
"net/http"
"strconv"
)
// Reader contains the IO reader and its length, and custom ContentType and other headers.
type Reader struct {
ContentType string
ContentLength int64
Reader io.Reader
Headers map[string]string
}
// Render (Reader) writes data with custom ContentType and headers.
func (r Reader) Render(w http.ResponseWriter) (err error) {
r.WriteContentType(w)
if r.ContentLength >= 0 {
if r.Headers == nil {
r.Headers = map[string]string{}
}
r.Headers["Content-Length"] = strconv.FormatInt(r.ContentLength, 10)
}
r.writeHeaders(w, r.Headers)
_, err = io.Copy(w, r.Reader)
return
}
// WriteContentType (Reader) writes custom ContentType.
func (r Reader) WriteContentType(w http.ResponseWriter) {
writeContentType(w, []string{r.ContentType})
}
// writeHeaders writes custom Header.
func (r Reader) writeHeaders(w http.ResponseWriter, headers map[string]string) {
header := w.Header()
for k, v := range headers {
if header.Get(k) == "" {
header.Set(k, v)
}
}
}
================================================
FILE: gossh/gin/render/redirect.go
================================================
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package render
import (
"fmt"
"net/http"
)
// Redirect contains the http request reference and redirects status code and location.
type Redirect struct {
Code int
Request *http.Request
Location string
}
// Render (Redirect) redirects the http request to new location and writes redirect response.
func (r Redirect) Render(w http.ResponseWriter) error {
if (r.Code < http.StatusMultipleChoices || r.Code > http.StatusPermanentRedirect) && r.Code != http.StatusCreated {
panic(fmt.Sprintf("Cannot redirect with status code %d", r.Code))
}
http.Redirect(w, r.Request, r.Location, r.Code)
return nil
}
// WriteContentType (Redirect) don't write any ContentType.
func (r Redirect) WriteContentType(http.ResponseWriter) {}
================================================
FILE: gossh/gin/render/render.go
================================================
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package render
import "net/http"
// Render interface is to be implemented by JSON, XML, HTML, YAML and so on.
type Render interface {
// Render writes data with custom ContentType.
Render(http.ResponseWriter) error
// WriteContentType writes custom ContentType.
WriteContentType(w http.ResponseWriter)
}
var (
_ Render = (*JSON)(nil)
_ Render = (*IndentedJSON)(nil)
_ Render = (*SecureJSON)(nil)
_ Render = (*JsonpJSON)(nil)
_ Render = (*XML)(nil)
_ Render = (*String)(nil)
_ Render = (*Redirect)(nil)
_ Render = (*Data)(nil)
_ Render = (*HTML)(nil)
_ HTMLRender = (*HTMLDebug)(nil)
_ HTMLRender = (*HTMLProduction)(nil)
_ Render = (*Reader)(nil)
_ Render = (*AsciiJSON)(nil)
)
func writeContentType(w http.ResponseWriter, value []string) {
header := w.Header()
if val := header["Content-Type"]; len(val) == 0 {
header["Content-Type"] = value
}
}
================================================
FILE: gossh/gin/render/text.go
================================================
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package render
import (
"fmt"
"net/http"
"gossh/gin/internal/bytesconv"
)
// String contains the given interface object slice and its format.
type String struct {
Format string
Data []any
}
var plainContentType = []string{"text/plain; charset=utf-8"}
// Render (String) writes data with custom ContentType.
func (r String) Render(w http.ResponseWriter) error {
return WriteString(w, r.Format, r.Data)
}
// WriteContentType (String) writes Plain ContentType.
func (r String) WriteContentType(w http.ResponseWriter) {
writeContentType(w, plainContentType)
}
// WriteString writes data according to its format and write custom ContentType.
func WriteString(w http.ResponseWriter, format string, data []any) (err error) {
writeContentType(w, plainContentType)
if len(data) > 0 {
_, err = fmt.Fprintf(w, format, data...)
return
}
_, err = w.Write(bytesconv.StringToBytes(format))
return
}
================================================
FILE: gossh/gin/render/xml.go
================================================
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package render
import (
"encoding/xml"
"net/http"
)
// XML contains the given interface object.
type XML struct {
Data any
}
var xmlContentType = []string{"application/xml; charset=utf-8"}
// Render (XML) encodes the given interface object and writes data with custom ContentType.
func (r XML) Render(w http.ResponseWriter) error {
r.WriteContentType(w)
return xml.NewEncoder(w).Encode(r.Data)
}
// WriteContentType (XML) writes XML ContentType for response.
func (r XML) WriteContentType(w http.ResponseWriter) {
writeContentType(w, xmlContentType)
}
================================================
FILE: gossh/gin/response_writer.go
================================================
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package gin
import (
"bufio"
"io"
"net"
"net/http"
)
const (
noWritten = -1
defaultStatus = http.StatusOK
)
// ResponseWriter ...
type ResponseWriter interface {
http.ResponseWriter
http.Hijacker
http.Flusher
http.CloseNotifier
// Status returns the HTTP response status code of the current request.
Status() int
// Size returns the number of bytes already written into the response http body.
// See Written()
Size() int
// WriteString writes the string into the response body.
WriteString(string) (int, error)
// Written returns true if the response body was already written.
Written() bool
// WriteHeaderNow forces to write the http header (status code + headers).
WriteHeaderNow()
// Pusher get the http.Pusher for server push
Pusher() http.Pusher
}
type responseWriter struct {
http.ResponseWriter
size int
status int
}
var _ ResponseWriter = (*responseWriter)(nil)
func (w *responseWriter) Unwrap() http.ResponseWriter {
return w.ResponseWriter
}
func (w *responseWriter) reset(writer http.ResponseWriter) {
w.ResponseWriter = writer
w.size = noWritten
w.status = defaultStatus
}
func (w *responseWriter) WriteHeader(code int) {
if code > 0 && w.status != code {
if w.Written() {
debugPrint("[WARNING] Headers were already written. Wanted to override status code %d with %d", w.status, code)
return
}
w.status = code
}
}
func (w *responseWriter) WriteHeaderNow() {
if !w.Written() {
w.size = 0
w.ResponseWriter.WriteHeader(w.status)
}
}
func (w *responseWriter) Write(data []byte) (n int, err error) {
w.WriteHeaderNow()
n, err = w.ResponseWriter.Write(data)
w.size += n
return
}
func (w *responseWriter) WriteString(s string) (n int, err error) {
w.WriteHeaderNow()
n, err = io.WriteString(w.ResponseWriter, s)
w.size += n
return
}
func (w *responseWriter) Status() int {
return w.status
}
func (w *responseWriter) Size() int {
return w.size
}
func (w *responseWriter) Written() bool {
return w.size != noWritten
}
// Hijack implements the http.Hijacker interface.
func (w *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if w.size < 0 {
w.size = 0
}
return w.ResponseWriter.(http.Hijacker).Hijack()
}
// CloseNotify implements the http.CloseNotifier interface.
func (w *responseWriter) CloseNotify() <-chan bool {
return w.ResponseWriter.(http.CloseNotifier).CloseNotify()
}
// Flush implements the http.Flusher interface.
func (w *responseWriter) Flush() {
w.WriteHeaderNow()
w.ResponseWriter.(http.Flusher).Flush()
}
func (w *responseWriter) Pusher() (pusher http.Pusher) {
if pusher, ok := w.ResponseWriter.(http.Pusher); ok {
return pusher
}
return nil
}
================================================
FILE: gossh/gin/routergroup.go
================================================
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package gin
import (
"net/http"
"path"
"regexp"
"strings"
)
var (
// regEnLetter matches english letters for http method name
regEnLetter = regexp.MustCompile("^[A-Z]+$")
// anyMethods for RouterGroup Any method
anyMethods = []string{
http.MethodGet, http.MethodPost, http.MethodPut, http.MethodPatch,
http.MethodHead, http.MethodOptions, http.MethodDelete, http.MethodConnect,
http.MethodTrace,
}
)
// IRouter defines all router handle interface includes single and group router.
type IRouter interface {
IRoutes
Group(string, ...HandlerFunc) *RouterGroup
}
// IRoutes defines all router handle interface.
type IRoutes interface {
Use(...HandlerFunc) IRoutes
Handle(string, string, ...HandlerFunc) IRoutes
Any(string, ...HandlerFunc) IRoutes
GET(string, ...HandlerFunc) IRoutes
POST(string, ...HandlerFunc) IRoutes
DELETE(string, ...HandlerFunc) IRoutes
PATCH(string, ...HandlerFunc) IRoutes
PUT(string, ...HandlerFunc) IRoutes
OPTIONS(string, ...HandlerFunc) IRoutes
HEAD(string, ...HandlerFunc) IRoutes
Match([]string, string, ...HandlerFunc) IRoutes
StaticFile(string, string) IRoutes
StaticFileFS(string, string, http.FileSystem) IRoutes
Static(string, string) IRoutes
StaticFS(string, http.FileSystem) IRoutes
}
// RouterGroup is used internally to configure router, a RouterGroup is associated with
// a prefix and an array of handlers (middleware).
type RouterGroup struct {
Handlers HandlersChain
basePath string
engine *Engine
root bool
}
var _ IRouter = (*RouterGroup)(nil)
// Use adds middleware to the group, see example code in GitHub.
func (group *RouterGroup) Use(middleware ...HandlerFunc) IRoutes {
group.Handlers = append(group.Handlers, middleware...)
return group.returnObj()
}
// Group creates a new router group. You should add all the routes that have common middlewares or the same path prefix.
// For example, all the routes that use a common middleware for authorization could be grouped.
func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) *RouterGroup {
return &RouterGroup{
Handlers: group.combineHandlers(handlers),
basePath: group.calculateAbsolutePath(relativePath),
engine: group.engine,
}
}
// BasePath returns the base path of router group.
// For example, if v := router.Group("/rest/n/v1/api"), v.BasePath() is "/rest/n/v1/api".
func (group *RouterGroup) BasePath() string {
return group.basePath
}
func (group *RouterGroup) handle(httpMethod, relativePath string, handlers HandlersChain) IRoutes {
absolutePath := group.calculateAbsolutePath(relativePath)
handlers = group.combineHandlers(handlers)
group.engine.addRoute(httpMethod, absolutePath, handlers)
return group.returnObj()
}
// Handle registers a new request handle and middleware with the given path and method.
// The last handler should be the real handler, the other ones should be middleware that can and should be shared among different routes.
// See the example code in GitHub.
//
// For GET, POST, PUT, PATCH and DELETE requests the respective shortcut
// functions can be used.
//
// This function is intended for bulk loading and to allow the usage of less
// frequently used, non-standardized or custom methods (e.g. for internal
// communication with a proxy).
func (group *RouterGroup) Handle(httpMethod, relativePath string, handlers ...HandlerFunc) IRoutes {
if matched := regEnLetter.MatchString(httpMethod); !matched {
panic("http method " + httpMethod + " is not valid")
}
return group.handle(httpMethod, relativePath, handlers)
}
// POST is a shortcut for router.Handle("POST", path, handlers).
func (group *RouterGroup) POST(relativePath string, handlers ...HandlerFunc) IRoutes {
return group.handle(http.MethodPost, relativePath, handlers)
}
// GET is a shortcut for router.Handle("GET", path, handlers).
func (group *RouterGroup) GET(relativePath string, handlers ...HandlerFunc) IRoutes {
return group.handle(http.MethodGet, relativePath, handlers)
}
// DELETE is a shortcut for router.Handle("DELETE", path, handlers).
func (group *RouterGroup) DELETE(relativePath string, handlers ...HandlerFunc) IRoutes {
return group.handle(http.MethodDelete, relativePath, handlers)
}
// PATCH is a shortcut for router.Handle("PATCH", path, handlers).
func (group *RouterGroup) PATCH(relativePath string, handlers ...HandlerFunc) IRoutes {
return group.handle(http.MethodPatch, relativePath, handlers)
}
// PUT is a shortcut for router.Handle("PUT", path, handlers).
func (group *RouterGroup) PUT(relativePath string, handlers ...HandlerFunc) IRoutes {
return group.handle(http.MethodPut, relativePath, handlers)
}
// OPTIONS is a shortcut for router.Handle("OPTIONS", path, handlers).
func (group *RouterGroup) OPTIONS(relativePath string, handlers ...HandlerFunc) IRoutes {
return group.handle(http.MethodOptions, relativePath, handlers)
}
// HEAD is a shortcut for router.Handle("HEAD", path, handlers).
func (group *RouterGroup) HEAD(relativePath string, handlers ...HandlerFunc) IRoutes {
return group.handle(http.MethodHead, relativePath, handlers)
}
// Any registers a route that matches all the HTTP methods.
// GET, POST, PUT, PATCH, HEAD, OPTIONS, DELETE, CONNECT, TRACE.
func (group *RouterGroup) Any(relativePath string, handlers ...HandlerFunc) IRoutes {
for _, method := range anyMethods {
group.handle(method, relativePath, handlers)
}
return group.returnObj()
}
// Match registers a route that matches the specified methods that you declared.
func (group *RouterGroup) Match(methods []string, relativePath string, handlers ...HandlerFunc) IRoutes {
for _, method := range methods {
group.handle(method, relativePath, handlers)
}
return group.returnObj()
}
// StaticFile registers a single route in order to serve a single file of the local filesystem.
// router.StaticFile("favicon.ico", "./resources/favicon.ico")
func (group *RouterGroup) StaticFile(relativePath, filepath string) IRoutes {
return group.staticFileHandler(relativePath, func(c *Context) {
c.File(filepath)
})
}
// StaticFileFS works just like `StaticFile` but a custom `http.FileSystem` can be used instead..
// router.StaticFileFS("favicon.ico", "./resources/favicon.ico", Dir{".", false})
// Gin by default uses: gin.Dir()
func (group *RouterGroup) StaticFileFS(relativePath, filepath string, fs http.FileSystem) IRoutes {
return group.staticFileHandler(relativePath, func(c *Context) {
c.FileFromFS(filepath, fs)
})
}
func (group *RouterGroup) staticFileHandler(relativePath string, handler HandlerFunc) IRoutes {
if strings.Contains(relativePath, ":") || strings.Contains(relativePath, "*") {
panic("URL parameters can not be used when serving a static file")
}
group.GET(relativePath, handler)
group.HEAD(relativePath, handler)
return group.returnObj()
}
// Static serves files from the given file system root.
// Internally a http.FileServer is used, therefore http.NotFound is used instead
// of the Router's NotFound handler.
// To use the operating system's file system implementation,
// use :
//
// router.Static("/static", "/var/www")
func (group *RouterGroup) Static(relativePath, root string) IRoutes {
return group.StaticFS(relativePath, Dir(root, false))
}
// StaticFS works just like `Static()` but a custom `http.FileSystem` can be used instead.
// Gin by default uses: gin.Dir()
func (group *RouterGroup) StaticFS(relativePath string, fs http.FileSystem) IRoutes {
if strings.Contains(relativePath, ":") || strings.Contains(relativePath, "*") {
panic("URL parameters can not be used when serving a static folder")
}
handler := group.createStaticHandler(relativePath, fs)
urlPattern := path.Join(relativePath, "/*filepath")
// Register GET and HEAD handlers
group.GET(urlPattern, handler)
group.HEAD(urlPattern, handler)
return group.returnObj()
}
func (group *RouterGroup) createStaticHandler(relativePath string, fs http.FileSystem) HandlerFunc {
absolutePath := group.calculateAbsolutePath(relativePath)
fileServer := http.StripPrefix(absolutePath, http.FileServer(fs))
return func(c *Context) {
if _, noListing := fs.(*onlyFilesFS); noListing {
c.Writer.WriteHeader(http.StatusNotFound)
}
file := c.Param("filepath")
// Check if file exists and/or if we have permission to access it
f, err := fs.Open(file)
if err != nil {
c.Writer.WriteHeader(http.StatusNotFound)
c.handlers = group.engine.noRoute
// Reset index
c.index = -1
return
}
f.Close()
fileServer.ServeHTTP(c.Writer, c.Request)
}
}
func (group *RouterGroup) combineHandlers(handlers HandlersChain) HandlersChain {
finalSize := len(group.Handlers) + len(handlers)
assert1(finalSize < int(abortIndex), "too many handlers")
mergedHandlers := make(HandlersChain, finalSize)
copy(mergedHandlers, group.Handlers)
copy(mergedHandlers[len(group.Handlers):], handlers)
return mergedHandlers
}
func (group *RouterGroup) calculateAbsolutePath(relativePath string) string {
return joinPaths(group.basePath, relativePath)
}
func (group *RouterGroup) returnObj() IRoutes {
if group.root {
return group.engine
}
return group
}
================================================
FILE: gossh/gin/sessions/context/context.go
================================================
package context
import (
"net/http"
"sync"
"time"
)
var (
mutex sync.RWMutex
data = make(map[*http.Request]map[any]any)
datat = make(map[*http.Request]int64)
)
// Set stores a value for a given key in a given request.
func Set(r *http.Request, key, val any) {
mutex.Lock()
if data[r] == nil {
data[r] = make(map[any]any)
datat[r] = time.Now().Unix()
}
data[r][key] = val
mutex.Unlock()
}
// Get returns a value stored for a given key in a given request.
func Get(r *http.Request, key any) any {
mutex.RLock()
if ctx := data[r]; ctx != nil {
value := ctx[key]
mutex.RUnlock()
return value
}
mutex.RUnlock()
return nil
}
// GetOk returns stored value and presence state like multi-value return of map access.
func GetOk(r *http.Request, key any) (any, bool) {
mutex.RLock()
if _, ok := data[r]; ok {
value, ok := data[r][key]
mutex.RUnlock()
return value, ok
}
mutex.RUnlock()
return nil, false
}
// GetAll returns all stored values for the request as a map. Nil is returned for invalid requests.
func GetAll(r *http.Request) map[any]any {
mutex.RLock()
if context, ok := data[r]; ok {
result := make(map[any]any, len(context))
for k, v := range context {
result[k] = v
}
mutex.RUnlock()
return result
}
mutex.RUnlock()
return nil
}
// GetAllOk returns all stored values for the request as a map and a boolean value that indicates if
// the request was registered.
func GetAllOk(r *http.Request) (map[any]any, bool) {
mutex.RLock()
context, ok := data[r]
result := make(map[any]any, len(context))
for k, v := range context {
result[k] = v
}
mutex.RUnlock()
return result, ok
}
// Delete removes a value stored for a given key in a given request.
func Delete(r *http.Request, key any) {
mutex.Lock()
if data[r] != nil {
delete(data[r], key)
}
mutex.Unlock()
}
// Clear removes all values stored for a given request.
//
// This is usually called by a handler wrapper to clean up request
// variables at the end of a request lifetime. See ClearHandler().
func Clear(r *http.Request) {
mutex.Lock()
clear(r)
mutex.Unlock()
}
// clear is Clear without the lock.
func clear(r *http.Request) {
delete(data, r)
delete(datat, r)
}
// Purge removes request data stored for longer than maxAge, in seconds.
// It returns the amount of requests removed.
//
// If maxAge <= 0, all request data is removed.
//
// This is only used for sanity check: in case context cleaning was not
// properly set some request data can be kept forever, consuming an increasing
// amount of memory. In case this is detected, Purge() must be called
// periodically until the problem is fixed.
func Purge(maxAge int) int {
mutex.Lock()
count := 0
if maxAge <= 0 {
count = len(data)
data = make(map[*http.Request]map[any]any)
datat = make(map[*http.Request]int64)
} else {
min := time.Now().Unix() - int64(maxAge)
for r := range data {
if datat[r] < min {
clear(r)
count++
}
}
}
mutex.Unlock()
return count
}
// ClearHandler wraps an http.Handler and clears request values at the end
// of a request lifetime.
func ClearHandler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer Clear(r)
h.ServeHTTP(w, r)
})
}
================================================
FILE: gossh/gin/sessions/cookie/cookie.go
================================================
package cookie
import (
"gossh/gin/sessions"
gsessions "gossh/gin/sessions/sessions"
)
type Store interface {
sessions.Store
}
// Keys are defined in pairs to allow key rotation, but the common case is to set a single
// authentication key and optionally an encryption key.
//
// The first key in a pair is used for authentication and the second for encryption. The
// encryption key can be set to nil or omitted in the last pair, but the authentication key
// is required in all pairs.
//
// It is recommended to use an authentication key with 32 or 64 bytes. The encryption key,
// if set, must be either 16, 24, or 32 bytes to select AES-128, AES-192, or AES-256 modes.
func NewStore(keyPairs ...[]byte) Store {
return &store{gsessions.NewCookieStore(keyPairs...)}
}
type store struct {
*gsessions.CookieStore
}
func (c *store) Options(options sessions.Options) {
c.CookieStore.Options = options.ToGorillaOptions()
}
================================================
FILE: gossh/gin/sessions/memstore/cache.go
================================================
package memstore
import (
"sync"
)
type cache struct {
data map[string]valueType
mutex sync.RWMutex
}
func newCache() *cache {
return &cache{
data: make(map[string]valueType),
}
}
func (c *cache) value(name string) (valueType, bool) {
c.mutex.RLock()
defer c.mutex.RUnlock()
v, ok := c.data[name]
return v, ok
}
func (c *cache) setValue(name string, value valueType) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.data[name] = value
}
func (c *cache) delete(name string) {
c.mutex.Lock()
defer c.mutex.Unlock()
if _, ok := c.data[name]; ok {
delete(c.data, name)
}
}
================================================
FILE: gossh/gin/sessions/memstore/memstore.go
================================================
package memstore
import (
"bytes"
"encoding/base32"
"encoding/gob"
"fmt"
"net/http"
"strings"
"gossh/gin/sessions/securecookie"
"gossh/gin/sessions/sessions"
)
// MemStore is an in-memory implementation of gorilla/sessions, suitable
// for use in tests and development environments. Do not use in production.
// Values are cached in a map. The cache is protected and can be used by
// multiple goroutines.
type MemStore struct {
Codecs []securecookie.Codec
Options *sessions.Options
cache *cache
}
type valueType map[any]any
// NewMemStore returns a new MemStore.
//
// Keys are defined in pairs to allow key rotation, but the common case is
// to set a single authentication key and optionally an encryption key.
//
// The first key in a pair is used for authentication and the second for
// encryption. The encryption key can be set to nil or omitted in the last
// pair, but the authentication key is required in all pairs.
//
// It is recommended to use an authentication key with 32 or 64 bytes.
// The encryption key, if set, must be either 16, 24, or 32 bytes to select
// AES-128, AES-192, or AES-256 modes.
//
// Use the convenience function securecookie.GenerateRandomKey() to create
// strong keys.
func NewMemStore(keyPairs ...[]byte) *MemStore {
store := MemStore{
Codecs: securecookie.CodecsFromPairs(keyPairs...),
Options: &sessions.Options{
Path: "/",
MaxAge: 86400 * 30,
},
cache: newCache(),
}
store.MaxAge(store.Options.MaxAge)
return &store
}
// Get returns a session for the given name after adding it to the registry.
//
// It returns a new session if the sessions doesn't exist. Access IsNew on
// the session to check if it is an existing session or a new one.
//
// It returns a new session and an error if the session exists but could
// not be decoded.
func (m *MemStore) Get(r *http.Request, name string) (*sessions.Session, error) {
return sessions.GetRegistry(r).Get(m, name)
}
// New returns a session for the given name without adding it to the registry.
//
// The difference between New() and Get() is that calling New() twice will
// decode the session data twice, while Get() registers and reuses the same
// decoded session after the first call.
func (m *MemStore) New(r *http.Request, name string) (*sessions.Session, error) {
session := sessions.NewSession(m, name)
options := *m.Options
session.Options = &options
session.IsNew = true
c, err := r.Cookie(name)
if err != nil {
// Cookie not found, this is a new session
return session, nil
}
err = securecookie.DecodeMulti(name, c.Value, &session.ID, m.Codecs...)
if err != nil {
// Value could not be decrypted, consider this is a new session
return session, err
}
v, ok := m.cache.value(session.ID)
if !ok {
// No value found in cache, don't set any values in session object,
// consider a new session
return session, nil
}
// Values found in session, this is not a new session
session.Values = m.copy(v)
session.IsNew = false
return session, nil
}
// Save adds a single session to the response.
// Set Options.MaxAge to -1 or call MaxAge(-1) before saving the session to delete all values in it.
func (m *MemStore) Save(r *http.Request, w http.ResponseWriter, s *sessions.Session) error {
var cookieValue string
if s.Options.MaxAge < 0 {
cookieValue = ""
m.cache.delete(s.ID)
for k := range s.Values {
delete(s.Values, k)
}
} else {
if s.ID == "" {
s.ID = strings.TrimRight(base32.StdEncoding.EncodeToString(securecookie.GenerateRandomKey(32)), "=")
}
encrypted, err := securecookie.EncodeMulti(s.Name(), s.ID, m.Codecs...)
if err != nil {
return err
}
cookieValue = encrypted
m.cache.setValue(s.ID, m.copy(s.Values))
}
http.SetCookie(w, sessions.NewCookie(s.Name(), cookieValue, s.Options))
return nil
}
// MaxAge sets the maximum age for the store and the underlying cookie
// implementation. Individual sessions can be deleted by setting Options.MaxAge
// = -1 for that session.
func (m *MemStore) MaxAge(age int) {
m.Options.MaxAge = age
// Set the maxAge for each securecookie instance.
for _, codec := range m.Codecs {
if sc, ok := codec.(*securecookie.SecureCookie); ok {
sc.MaxAge(age)
}
}
}
func (m *MemStore) copy(v valueType) valueType {
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
dec := gob.NewDecoder(&buf)
err := enc.Encode(v)
if err != nil {
panic(fmt.Errorf("could not copy memstore value. Encoding to gob failed: %v", err))
}
var value valueType
err = dec.Decode(&value)
if err != nil {
panic(fmt.Errorf("could not copy memstore value. Decoding from gob failed: %v", err))
}
return value
}
================================================
FILE: gossh/gin/sessions/memstore/store.go
================================================
package memstore
import (
"gossh/gin/sessions"
)
type Store interface {
sessions.Store
}
// Keys are defined in pairs to allow key rotation, but the common case is to set a single
// authentication key and optionally an encryption key.
//
// The first key in a pair is used for authentication and the second for encryption. The
// encryption key can be set to nil or omitted in the last pair, but the authentication key
// is required in all pairs.
//
// It is recommended to use an authentication key with 32 or 64 bytes. The encryption key,
// if set, must be either 16, 24, or 32 bytes to select AES-128, AES-192, or AES-256 modes.
func NewStore(keyPairs ...[]byte) Store {
return &store{NewMemStore(keyPairs...)}
}
type store struct {
*MemStore
}
func (c *store) Options(options sessions.Options) {
c.MemStore.Options = options.ToGorillaOptions()
}
================================================
FILE: gossh/gin/sessions/securecookie/securecookie.go
================================================
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package securecookie
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"encoding/gob"
"encoding/json"
"fmt"
"hash"
"io"
"strconv"
"strings"
"time"
)
// Error is the interface of all errors returned by functions in this library.
type Error interface {
error
// IsUsage returns true for errors indicating the client code probably
// uses this library incorrectly. For example, the client may have
// failed to provide a valid hash key, or may have failed to configure
// the Serializer adequately for encoding value.
IsUsage() bool
// IsDecode returns true for errors indicating that a cookie could not
// be decoded and validated. Since cookies are usually untrusted
// user-provided input, errors of this type should be expected.
// Usually, the proper action is simply to reject the request.
IsDecode() bool
// IsInternal returns true for unexpected errors occurring in the
// securecookie implementation.
IsInternal() bool
// Cause, if it returns a non-nil value, indicates that this error was
// propagated from some underlying library. If this method returns nil,
// this error was raised directly by this library.
//
// Cause is provided principally for debugging/logging purposes; it is
// rare that application logic should perform meaningfully different
// logic based on Cause. See, for example, the caveats described on
// (MultiError).Cause().
Cause() error
}
// errorType is a bitmask giving the error type(s) of an cookieError value.
type errorType int
const (
usageError = errorType(1 << iota)
decodeError
internalError
)
type cookieError struct {
typ errorType
msg string
cause error
}
func (e cookieError) IsUsage() bool { return (e.typ & usageError) != 0 }
func (e cookieError) IsDecode() bool { return (e.typ & decodeError) != 0 }
func (e cookieError) IsInternal() bool { return (e.typ & internalError) != 0 }
func (e cookieError) Cause() error { return e.cause }
func (e cookieError) Error() string {
parts := []string{"securecookie: "}
if e.msg == "" {
parts = append(parts, "error")
} else {
parts = append(parts, e.msg)
}
if c := e.Cause(); c != nil {
parts = append(parts, " - caused by: ", c.Error())
}
return strings.Join(parts, "")
}
var (
errGeneratingIV = cookieError{typ: internalError, msg: "failed to generate random iv"}
errNoCodecs = cookieError{typ: usageError, msg: "no codecs provided"}
errHashKeyNotSet = cookieError{typ: usageError, msg: "hash key is not set"}
errBlockKeyNotSet = cookieError{typ: usageError, msg: "block key is not set"}
errEncodedValueTooLong = cookieError{typ: usageError, msg: "the value is too long"}
errValueToDecodeTooLong = cookieError{typ: decodeError, msg: "the value is too long"}
errTimestampInvalid = cookieError{typ: decodeError, msg: "invalid timestamp"}
errTimestampTooNew = cookieError{typ: decodeError, msg: "timestamp is too new"}
errTimestampExpired = cookieError{typ: decodeError, msg: "expired timestamp"}
errDecryptionFailed = cookieError{typ: decodeError, msg: "the value could not be decrypted"}
errValueNotByte = cookieError{typ: decodeError, msg: "value not a []byte."}
errValueNotBytePtr = cookieError{typ: decodeError, msg: "value not a pointer to []byte."}
// ErrMacInvalid indicates that cookie decoding failed because the HMAC
// could not be extracted and verified. Direct use of this error
// variable is deprecated; it is public only for legacy compatibility,
// and may be privatized in the future, as it is rarely useful to
// distinguish between this error and other Error implementations.
ErrMacInvalid = cookieError{typ: decodeError, msg: "the value is not valid"}
)
// Codec defines an interface to encode and decode cookie values.
type Codec interface {
Encode(name string, value any) (string, error)
Decode(name, value string, dst any) error
}
// New returns a new SecureCookie.
//
// hashKey is required, used to authenticate values using HMAC. Create it using
// GenerateRandomKey(). It is recommended to use a key with 32 or 64 bytes.
//
// blockKey is optional, used to encrypt values. Create it using
// GenerateRandomKey(). The key length must correspond to the key size
// of the encryption algorithm. For AES, used by default, valid lengths are
// 16, 24, or 32 bytes to select AES-128, AES-192, or AES-256.
// The default encoder used for cookie serialization is encoding/gob.
//
// Note that keys created using GenerateRandomKey() are not automatically
// persisted. New keys will be created when the application is restarted, and
// previously issued cookies will not be able to be decoded.
func New(hashKey, blockKey []byte) *SecureCookie {
s := &SecureCookie{
hashKey: hashKey,
blockKey: blockKey,
hashFunc: sha256.New,
maxAge: 86400 * 30,
maxLength: 4096,
sz: GobEncoder{},
}
if len(hashKey) == 0 {
s.err = errHashKeyNotSet
}
if blockKey != nil {
s.BlockFunc(aes.NewCipher)
}
return s
}
// SecureCookie encodes and decodes authenticated and optionally encrypted
// cookie values.
type SecureCookie struct {
hashKey []byte
hashFunc func() hash.Hash
blockKey []byte
block cipher.Block
maxLength int
maxAge int64
minAge int64
err error
sz Serializer
// For testing purposes, the function that returns the current timestamp.
// If not set, it will use time.Now().UTC().Unix().
timeFunc func() int64
}
// Serializer provides an interface for providing custom serializers for cookie
// values.
type Serializer interface {
Serialize(src any) ([]byte, error)
Deserialize(src []byte, dst any) error
}
// GobEncoder encodes cookie values using encoding/gob. This is the simplest
// encoder and can handle complex types via gob.Register.
type GobEncoder struct{}
// JSONEncoder encodes cookie values using encoding/json. Users who wish to
// encode complex types need to satisfy the json.Marshaller and
// json.Unmarshaller interfaces.
type JSONEncoder struct{}
// NopEncoder does not encode cookie values, and instead simply accepts a []byte
// (as an interface{}) and returns a []byte. This is particularly useful when
// you encoding an object upstream and do not wish to re-encode it.
type NopEncoder struct{}
// MaxLength restricts the maximum length, in bytes, for the cookie value.
//
// Default is 4096, which is the maximum value accepted by Internet Explorer.
func (s *SecureCookie) MaxLength(value int) *SecureCookie {
s.maxLength = value
return s
}
// MaxAge restricts the maximum age, in seconds, for the cookie value.
//
// Default is 86400 * 30. Set it to 0 for no restriction.
func (s *SecureCookie) MaxAge(value int) *SecureCookie {
s.maxAge = int64(value)
return s
}
// MinAge restricts the minimum age, in seconds, for the cookie value.
//
// Default is 0 (no restriction).
func (s *SecureCookie) MinAge(value int) *SecureCookie {
s.minAge = int64(value)
return s
}
// HashFunc sets the hash function used to create HMAC.
//
// Default is crypto/sha256.New.
func (s *SecureCookie) HashFunc(f func() hash.Hash) *SecureCookie {
s.hashFunc = f
return s
}
// BlockFunc sets the encryption function used to create a cipher.Block.
//
// Default is crypto/aes.New.
func (s *SecureCookie) BlockFunc(f func([]byte) (cipher.Block, error)) *SecureCookie {
if s.blockKey == nil {
s.err = errBlockKeyNotSet
} else if block, err := f(s.blockKey); err == nil {
s.block = block
} else {
s.err = cookieError{cause: err, typ: usageError}
}
return s
}
// Encoding sets the encoding/serialization method for cookies.
//
// Default is encoding/gob. To encode special structures using encoding/gob,
// they must be registered first using gob.Register().
func (s *SecureCookie) SetSerializer(sz Serializer) *SecureCookie {
s.sz = sz
return s
}
// Encode encodes a cookie value.
//
// It serializes, optionally encrypts, signs with a message authentication code,
// and finally encodes the value.
//
// The name argument is the cookie name. It is stored with the encoded value.
// The value argument is the value to be encoded. It can be any value that can
// be encoded using the currently selected serializer; see SetSerializer().
//
// It is the client's responsibility to ensure that value, when encoded using
// the current serialization/encryption settings on s and then base64-encoded,
// is shorter than the maximum permissible length.
func (s *SecureCookie) Encode(name string, value any) (string, error) {
if s.err != nil {
return "", s.err
}
if s.hashKey == nil {
s.err = errHashKeyNotSet
return "", s.err
}
var err error
var b []byte
// 1. Serialize.
if b, err = s.sz.Serialize(value); err != nil {
return "", cookieError{cause: err, typ: usageError}
}
// 2. Encrypt (optional).
if s.block != nil {
if b, err = encrypt(s.block, b); err != nil {
return "", cookieError{cause: err, typ: usageError}
}
}
b = encode(b)
// 3. Create MAC for "name|date|value". Extra pipe to be used later.
b = []byte(fmt.Sprintf("%s|%d|%s|", name, s.timestamp(), b))
mac := createMac(hmac.New(s.hashFunc, s.hashKey), b[:len(b)-1])
// Append mac, remove name.
b = append(b, mac...)[len(name)+1:]
// 4. Encode to base64.
b = encode(b)
// 5. Check length.
if s.maxLength != 0 && len(b) > s.maxLength {
return "", fmt.Errorf("%s: %d", errEncodedValueTooLong, len(b))
}
// Done.
return string(b), nil
}
// Decode decodes a cookie value.
//
// It decodes, verifies a message authentication code, optionally decrypts and
// finally deserializes the value.
//
// The name argument is the cookie name. It must be the same name used when
// it was stored. The value argument is the encoded cookie value. The dst
// argument is where the cookie will be decoded. It must be a pointer.
func (s *SecureCookie) Decode(name, value string, dst any) error {
if s.err != nil {
return s.err
}
if s.hashKey == nil {
s.err = errHashKeyNotSet
return s.err
}
// 1. Check length.
if s.maxLength != 0 && len(value) > s.maxLength {
return fmt.Errorf("%s: %d", errValueToDecodeTooLong, len(value))
}
// 2. Decode from base64.
b, err := decode([]byte(value))
if err != nil {
return err
}
// 3. Verify MAC. Value is "date|value|mac".
parts := bytes.SplitN(b, []byte("|"), 3)
if len(parts) != 3 {
return ErrMacInvalid
}
h := hmac.New(s.hashFunc, s.hashKey)
b = append([]byte(name+"|"), b[:len(b)-len(parts[2])-1]...)
if err = verifyMac(h, b, parts[2]); err != nil {
return err
}
// 4. Verify date ranges.
var t1 int64
if t1, err = strconv.ParseInt(string(parts[0]), 10, 64); err != nil {
return errTimestampInvalid
}
t2 := s.timestamp()
if s.minAge != 0 && t1 > t2-s.minAge {
return errTimestampTooNew
}
if s.maxAge != 0 && t1 < t2-s.maxAge {
return errTimestampExpired
}
// 5. Decrypt (optional).
b, err = decode(parts[1])
if err != nil {
return err
}
if s.block != nil {
if b, err = decrypt(s.block, b); err != nil {
return err
}
}
// 6. Deserialize.
if err = s.sz.Deserialize(b, dst); err != nil {
return cookieError{cause: err, typ: decodeError}
}
// Done.
return nil
}
// timestamp returns the current timestamp, in seconds.
//
// For testing purposes, the function that generates the timestamp can be
// overridden. If not set, it will return time.Now().UTC().Unix().
func (s *SecureCookie) timestamp() int64 {
if s.timeFunc == nil {
return time.Now().UTC().Unix()
}
return s.timeFunc()
}
// Authentication -------------------------------------------------------------
// createMac creates a message authentication code (MAC).
func createMac(h hash.Hash, value []byte) []byte {
h.Write(value)
return h.Sum(nil)
}
// verifyMac verifies that a message authentication code (MAC) is valid.
func verifyMac(h hash.Hash, value []byte, mac []byte) error {
mac2 := createMac(h, value)
// Check that both MACs are of equal length, as subtle.ConstantTimeCompare
// does not do this prior to Go 1.4.
if len(mac) == len(mac2) && subtle.ConstantTimeCompare(mac, mac2) == 1 {
return nil
}
return ErrMacInvalid
}
// Encryption -----------------------------------------------------------------
// encrypt encrypts a value using the given block in counter mode.
//
// A random initialization vector ( https://en.wikipedia.org/wiki/Block_cipher_mode_of_operation#Initialization_vector_(IV) ) with the length of the
// block size is prepended to the resulting ciphertext.
func encrypt(block cipher.Block, value []byte) ([]byte, error) {
iv := GenerateRandomKey(block.BlockSize())
if iv == nil {
return nil, errGeneratingIV
}
// Encrypt it.
stream := cipher.NewCTR(block, iv)
stream.XORKeyStream(value, value)
// Return iv + ciphertext.
return append(iv, value...), nil
}
// decrypt decrypts a value using the given block in counter mode.
//
// The value to be decrypted must be prepended by a initialization vector
// ( https://en.wikipedia.org/wiki/Block_cipher_mode_of_operation#Initialization_vector_(IV) ) with the length of the block size.
func decrypt(block cipher.Block, value []byte) ([]byte, error) {
size := block.BlockSize()
if len(value) > size {
// Extract iv.
iv := value[:size]
// Extract ciphertext.
value = value[size:]
// Decrypt it.
stream := cipher.NewCTR(block, iv)
stream.XORKeyStream(value, value)
return value, nil
}
return nil, errDecryptionFailed
}
// Serialization --------------------------------------------------------------
// Serialize encodes a value using gob.
func (e GobEncoder) Serialize(src any) ([]byte, error) {
buf := new(bytes.Buffer)
enc := gob.NewEncoder(buf)
if err := enc.Encode(src); err != nil {
return nil, cookieError{cause: err, typ: usageError}
}
return buf.Bytes(), nil
}
// Deserialize decodes a value using gob.
func (e GobEncoder) Deserialize(src []byte, dst any) error {
dec := gob.NewDecoder(bytes.NewBuffer(src))
if err := dec.Decode(dst); err != nil {
return cookieError{cause: err, typ: decodeError}
}
return nil
}
// Serialize encodes a value using encoding/json.
func (e JSONEncoder) Serialize(src any) ([]byte, error) {
buf := new(bytes.Buffer)
enc := json.NewEncoder(buf)
if err := enc.Encode(src); err != nil {
return nil, cookieError{cause: err, typ: usageError}
}
return buf.Bytes(), nil
}
// Deserialize decodes a value using encoding/json.
func (e JSONEncoder) Deserialize(src []byte, dst any) error {
dec := json.NewDecoder(bytes.NewReader(src))
if err := dec.Decode(dst); err != nil {
return cookieError{cause: err, typ: decodeError}
}
return nil
}
// Serialize passes a []byte through as-is.
func (e NopEncoder) Serialize(src any) ([]byte, error) {
if b, ok := src.([]byte); ok {
return b, nil
}
return nil, errValueNotByte
}
// Deserialize passes a []byte through as-is.
func (e NopEncoder) Deserialize(src []byte, dst any) error {
if dat, ok := dst.(*[]byte); ok {
*dat = src
return nil
}
return errValueNotBytePtr
}
// Encoding -------------------------------------------------------------------
// encode encodes a value using base64.
func encode(value []byte) []byte {
encoded := make([]byte, base64.URLEncoding.EncodedLen(len(value)))
base64.URLEncoding.Encode(encoded, value)
return encoded
}
// decode decodes a cookie using base64.
func decode(value []byte) ([]byte, error) {
decoded := make([]byte, base64.URLEncoding.DecodedLen(len(value)))
b, err := base64.URLEncoding.Decode(decoded, value)
if err != nil {
return nil, cookieError{cause: err, typ: decodeError, msg: "base64 decode failed"}
}
return decoded[:b], nil
}
// Helpers --------------------------------------------------------------------
// GenerateRandomKey creates a random key with the given length in bytes.
// On failure, returns nil.
//
// Note that keys created using `GenerateRandomKey()` are not automatically
// persisted. New keys will be created when the application is restarted, and
// previously issued cookies will not be able to be decoded.
//
// Callers should explicitly check for the possibility of a nil return, treat
// it as a failure of the system random number generator, and not continue.
func GenerateRandomKey(length int) []byte {
k := make([]byte, length)
if _, err := io.ReadFull(rand.Reader, k); err != nil {
return nil
}
return k
}
// CodecsFromPairs returns a slice of SecureCookie instances.
//
// It is a convenience function to create a list of codecs for key rotation. Note
// that the generated Codecs will have the default options applied: callers
// should iterate over each Codec and type-assert the underlying *SecureCookie to
// change these.
//
// Example:
//
// codecs := securecookie.CodecsFromPairs(
// []byte("new-hash-key"),
// []byte("new-block-key"),
// []byte("old-hash-key"),
// []byte("old-block-key"),
// )
//
// // Modify each instance.
// for _, s := range codecs {
// if cookie, ok := s.(*securecookie.SecureCookie); ok {
// cookie.MaxAge(86400 * 7)
// cookie.SetSerializer(securecookie.JSONEncoder{})
// cookie.HashFunc(sha512.New512_256)
// }
// }
func CodecsFromPairs(keyPairs ...[]byte) []Codec {
codecs := make([]Codec, len(keyPairs)/2+len(keyPairs)%2)
for i := 0; i < len(keyPairs); i += 2 {
var blockKey []byte
if i+1 < len(keyPairs) {
blockKey = keyPairs[i+1]
}
codecs[i/2] = New(keyPairs[i], blockKey)
}
return codecs
}
// EncodeMulti encodes a cookie value using a group of codecs.
//
// The codecs are tried in order. Multiple codecs are accepted to allow
// key rotation.
//
// On error, may return a MultiError.
func EncodeMulti(name string, value any, codecs ...Codec) (string, error) {
if len(codecs) == 0 {
return "", errNoCodecs
}
var errors MultiError
for _, codec := range codecs {
encoded, err := codec.Encode(name, value)
if err == nil {
return encoded, nil
}
errors = append(errors, err)
}
return "", errors
}
// DecodeMulti decodes a cookie value using a group of codecs.
//
// The codecs are tried in order. Multiple codecs are accepted to allow
// key rotation.
//
// On error, may return a MultiError.
func DecodeMulti(name string, value string, dst any, codecs ...Codec) error {
if len(codecs) == 0 {
return errNoCodecs
}
var errors MultiError
for _, codec := range codecs {
err := codec.Decode(name, value, dst)
if err == nil {
return nil
}
errors = append(errors, err)
}
return errors
}
// MultiError groups multiple errors.
type MultiError []error
func (m MultiError) IsUsage() bool { return m.any(func(e Error) bool { return e.IsUsage() }) }
func (m MultiError) IsDecode() bool { return m.any(func(e Error) bool { return e.IsDecode() }) }
func (m MultiError) IsInternal() bool { return m.any(func(e Error) bool { return e.IsInternal() }) }
// Cause returns nil for MultiError; there is no unique underlying cause in the
// general case.
//
// Note: we could conceivably return a non-nil Cause only when there is exactly
// one child error with a Cause. However, it would be brittle for client code
// to rely on the arity of causes inside a MultiError, so we have opted not to
// provide this functionality. Clients which really wish to access the Causes
// of the underlying errors are free to iterate through the errors themselves.
func (m MultiError) Cause() error { return nil }
func (m MultiError) Error() string {
s, n := "", 0
for _, e := range m {
if e != nil {
if n == 0 {
s = e.Error()
}
n++
}
}
switch n {
case 0:
return "(0 errors)"
case 1:
return s
case 2:
return s + " (and 1 other error)"
}
return fmt.Sprintf("%s (and %d other errors)", s, n-1)
}
// any returns true if any element of m is an Error for which pred returns true.
func (m MultiError) any(pred func(Error) bool) bool {
for _, e := range m {
if ourErr, ok := e.(Error); ok && pred(ourErr) {
return true
}
}
return false
}
================================================
FILE: gossh/gin/sessions/session_options_go1.10.go
================================================
//go:build !go1.11
// +build !go1.11
package sessions
import (
gsessions "gossh/gin/sessions/sessions"
)
// Options stores configuration for a session or session store.
// Fields are a subset of http.Cookie fields.
type Options struct {
Path string
Domain string
// MaxAge=0 means no 'Max-Age' attribute specified.
// MaxAge<0 means delete cookie now, equivalently 'Max-Age: 0'.
// MaxAge>0 means Max-Age attribute present and given in seconds.
MaxAge int
Secure bool
HttpOnly bool
}
func (options Options) ToGorillaOptions() *gsessions.Options {
return &gsessions.Options{
Path: options.Path,
Domain: options.Domain,
MaxAge: options.MaxAge,
Secure: options.Secure,
HttpOnly: options.HttpOnly,
}
}
================================================
FILE: gossh/gin/sessions/session_options_go1.11.go
================================================
//go:build go1.11
// +build go1.11
package sessions
import (
"net/http"
gsessions "gossh/gin/sessions/sessions"
)
// Options stores configuration for a session or session store.
// Fields are a subset of http.Cookie fields.
type Options struct {
Path string
Domain string
// MaxAge=0 means no 'Max-Age' attribute specified.
// MaxAge<0 means delete cookie now, equivalently 'Max-Age: 0'.
// MaxAge>0 means Max-Age attribute present and given in seconds.
MaxAge int
Secure bool
HttpOnly bool
// rfc-draft to preventing CSRF: https://tools.ietf.org/html/draft-west-first-party-cookies-07
// refer: https://godoc.org/net/http
// https://www.sjoerdlangkemper.nl/2016/04/14/preventing-csrf-with-samesite-cookie-attribute/
SameSite http.SameSite
}
func (options Options) ToGorillaOptions() *gsessions.Options {
return &gsessions.Options{
Path: options.Path,
Domain: options.Domain,
MaxAge: options.MaxAge,
Secure: options.Secure,
HttpOnly: options.HttpOnly,
SameSite: options.SameSite,
}
}
================================================
FILE: gossh/gin/sessions/sessions/cookie.go
================================================
//go:build !go1.11
// +build !go1.11
package sessions
import "net/http"
// newCookieFromOptions returns an http.Cookie with the options set.
func newCookieFromOptions(name, value string, options *Options) *http.Cookie {
return &http.Cookie{
Name: name,
Value: value,
Path: options.Path,
Domain: options.Domain,
MaxAge: options.MaxAge,
Secure: options.Secure,
HttpOnly: options.HttpOnly,
}
}
================================================
FILE: gossh/gin/sessions/sessions/cookie_go111.go
================================================
//go:build go1.11
// +build go1.11
package sessions
import "net/http"
// newCookieFromOptions returns an http.Cookie with the options set.
func newCookieFromOptions(name, value string, options *Options) *http.Cookie {
return &http.Cookie{
Name: name,
Value: value,
Path: options.Path,
Domain: options.Domain,
MaxAge: options.MaxAge,
Secure: options.Secure,
HttpOnly: options.HttpOnly,
SameSite: options.SameSite,
}
}
================================================
FILE: gossh/gin/sessions/sessions/doc.go
================================================
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package sessions provides cookie and filesystem sessions and
infrastructure for custom session backends.
The key features are:
- Simple API: use it as an easy way to set signed (and optionally
encrypted) cookies.
- Built-in backends to store sessions in cookies or the filesystem.
- Flash messages: session values that last until read.
- Convenient way to switch session persistency (aka "remember me") and set
other attributes.
- Mechanism to rotate authentication and encryption keys.
- Multiple sessions per request, even using different backends.
- Interfaces and infrastructure for custom session backends: sessions from
different stores can be retrieved and batch-saved using a common API.
Let's start with an example that shows the sessions API in a nutshell:
import (
"net/http"
"github.com/gorilla/sessions"
)
// Note: Don't store your key in your source code. Pass it via an
// environmental variable, or flag (or both), and don't accidentally commit it
// alongside your code. Ensure your key is sufficiently random - i.e. use Go's
// crypto/rand or securecookie.GenerateRandomKey(32) and persist the result.
// Ensure SESSION_KEY exists in the environment, or sessions will fail.
var store = sessions.NewCookieStore([]byte(os.Getenv("SESSION_KEY")))
func MyHandler(w http.ResponseWriter, r *http.Request) {
// Get a session. Get() always returns a session, even if empty.
session, err := store.Get(r, "session-name")
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Set some session values.
session.Values["foo"] = "bar"
session.Values[42] = 43
// Save it before we write to the response/return from the handler.
err = session.Save(r, w)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
First we initialize a session store calling NewCookieStore() and passing a
secret key used to authenticate the session. Inside the handler, we call
store.Get() to retrieve an existing session or a new one. Then we set some
session values in session.Values, which is a map[interface{}]interface{}.
And finally we call session.Save() to save the session in the response.
Note that in production code, we should check for errors when calling
session.Save(r, w), and either display an error message or otherwise handle it.
Save must be called before writing to the response, otherwise the session
cookie will not be sent to the client.
That's all you need to know for the basic usage. Let's take a look at other
options, starting with flash messages.
Flash messages are session values that last until read. The term appeared with
Ruby On Rails a few years back. When we request a flash message, it is removed
from the session. To add a flash, call session.AddFlash(), and to get all
flashes, call session.Flashes(). Here is an example:
func MyHandler(w http.ResponseWriter, r *http.Request) {
// Get a session.
session, err := store.Get(r, "session-name")
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Get the previous flashes, if any.
if flashes := session.Flashes(); len(flashes) > 0 {
// Use the flash values.
} else {
// Set a new flash.
session.AddFlash("Hello, flash messages world!")
}
err = session.Save(r, w)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
Flash messages are useful to set information to be read after a redirection,
like after form submissions.
There may also be cases where you want to store a complex datatype within a
session, such as a struct. Sessions are serialised using the encoding/gob package,
so it is easy to register new datatypes for storage in sessions:
import(
"encoding/gob"
"github.com/gorilla/sessions"
)
type Person struct {
FirstName string
LastName string
Email string
Age int
}
type M map[string]interface{}
func init() {
gob.Register(&Person{})
gob.Register(&M{})
}
As it's not possible to pass a raw type as a parameter to a function, gob.Register()
relies on us passing it a value of the desired type. In the example above we've passed
it a pointer to a struct and a pointer to a custom type representing a
map[string]interface. (We could have passed non-pointer values if we wished.) This will
then allow us to serialise/deserialise values of those types to and from our sessions.
Note that because session values are stored in a map[string]interface{}, there's
a need to type-assert data when retrieving it. We'll use the Person struct we registered above:
func MyHandler(w http.ResponseWriter, r *http.Request) {
session, err := store.Get(r, "session-name")
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Retrieve our struct and type-assert it
val := session.Values["person"]
var person = &Person{}
if person, ok := val.(*Person); !ok {
// Handle the case that it's not an expected type
}
// Now we can use our person object
}
By default, session cookies last for a month. This is probably too long for
some cases, but it is easy to change this and other attributes during
runtime. Sessions can be configured individually or the store can be
configured and then all sessions saved using it will use that configuration.
We access session.Options or store.Options to set a new configuration. The
fields are basically a subset of http.Cookie fields. Let's change the
maximum age of a session to one week:
session.Options = &sessions.Options{
Path: "/",
MaxAge: 86400 * 7,
HttpOnly: true,
}
Sometimes we may want to change authentication and/or encryption keys without
breaking existing sessions. The CookieStore supports key rotation, and to use
it you just need to set multiple authentication and encryption keys, in pairs,
to be tested in order:
var store = sessions.NewCookieStore(
[]byte("new-authentication-key"),
[]byte("new-encryption-key"),
[]byte("old-authentication-key"),
[]byte("old-encryption-key"),
)
New sessions will be saved using the first pair. Old sessions can still be
read because the first pair will fail, and the second will be tested. This
makes it easy to "rotate" secret keys and still be able to validate existing
sessions. Note: for all pairs the encryption key is optional; set it to nil
or omit it and and encryption won't be used.
Multiple sessions can be used in the same request, even with different
session backends. When this happens, calling Save() on each session
individually would be cumbersome, so we have a way to save all sessions
at once: it's sessions.Save(). Here's an example:
var store = sessions.NewCookieStore([]byte("something-very-secret"))
func MyHandler(w http.ResponseWriter, r *http.Request) {
// Get a session and set a value.
session1, _ := store.Get(r, "session-one")
session1.Values["foo"] = "bar"
// Get another session and set another value.
session2, _ := store.Get(r, "session-two")
session2.Values[42] = 43
// Save all sessions.
err = sessions.Save(r, w)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
This is possible because when we call Get() from a session store, it adds the
session to a common registry. Save() uses it to save all registered sessions.
*/
package sessions
================================================
FILE: gossh/gin/sessions/sessions/lex.go
================================================
// This file contains code adapted from the Go standard library
// https://github.com/golang/go/blob/39ad0fd0789872f9469167be7fe9578625ff246e/src/net/http/lex.go
package sessions
import "strings"
var isTokenTable = [127]bool{
'!': true,
'#': true,
'$': true,
'%': true,
'&': true,
'\'': true,
'*': true,
'+': true,
'-': true,
'.': true,
'0': true,
'1': true,
'2': true,
'3': true,
'4': true,
'5': true,
'6': true,
'7': true,
'8': true,
'9': true,
'A': true,
'B': true,
'C': true,
'D': true,
'E': true,
'F': true,
'G': true,
'H': true,
'I': true,
'J': true,
'K': true,
'L': true,
'M': true,
'N': true,
'O': true,
'P': true,
'Q': true,
'R': true,
'S': true,
'T': true,
'U': true,
'W': true,
'V': true,
'X': true,
'Y': true,
'Z': true,
'^': true,
'_': true,
'`': true,
'a': true,
'b': true,
'c': true,
'd': true,
'e': true,
'f': true,
'g': true,
'h': true,
'i': true,
'j': true,
'k': true,
'l': true,
'm': true,
'n': true,
'o': true,
'p': true,
'q': true,
'r': true,
's': true,
't': true,
'u': true,
'v': true,
'w': true,
'x': true,
'y': true,
'z': true,
'|': true,
'~': true,
}
func isToken(r rune) bool {
i := int(r)
return i < len(isTokenTable) && isTokenTable[i]
}
func isNotToken(r rune) bool {
return !isToken(r)
}
func isCookieNameValid(raw string) bool {
if raw == "" {
return false
}
return strings.IndexFunc(raw, isNotToken) < 0
}
================================================
FILE: gossh/gin/sessions/sessions/options.go
================================================
//go:build !go1.11
// +build !go1.11
package sessions
// Options stores configuration for a session or session store.
//
// Fields are a subset of http.Cookie fields.
type Options struct {
Path string
Domain string
// MaxAge=0 means no Max-Age attribute specified and the cookie will be
// deleted after the browser session ends.
// MaxAge<0 means delete cookie immediately.
// MaxAge>0 means Max-Age attribute present and given in seconds.
MaxAge int
Secure bool
HttpOnly bool
}
================================================
FILE: gossh/gin/sessions/sessions/options_go111.go
================================================
//go:build go1.11
// +build go1.11
package sessions
import "net/http"
// Options stores configuration for a session or session store.
//
// Fields are a subset of http.Cookie fields.
type Options struct {
Path string
Domain string
// MaxAge=0 means no Max-Age attribute specified and the cookie will be
// deleted after the browser session ends.
// MaxAge<0 means delete cookie immediately.
// MaxAge>0 means Max-Age attribute present and given in seconds.
MaxAge int
Secure bool
HttpOnly bool
// Defaults to http.SameSiteDefaultMode
SameSite http.SameSite
}
================================================
FILE: gossh/gin/sessions/sessions/sessions.go
================================================
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sessions
import (
"context"
"encoding/gob"
"fmt"
"net/http"
"time"
)
// Default flashes key.
const flashesKey = "_flash"
// Session --------------------------------------------------------------------
// NewSession is called by session stores to create a new session instance.
func NewSession(store Store, name string) *Session {
return &Session{
Values: make(map[any]any),
store: store,
name: name,
Options: new(Options),
}
}
// Session stores the values and optional configuration for a session.
type Session struct {
// The ID of the session, generated by stores. It should not be used for
// user data.
ID string
// Values contains the user-data for the session.
Values map[any]any
Options *Options
IsNew bool
store Store
name string
}
// Flashes returns a slice of flash messages from the session.
//
// A single variadic argument is accepted, and it is optional: it defines
// the flash key. If not defined "_flash" is used by default.
func (s *Session) Flashes(vars ...string) []any {
var flashes []any
key := flashesKey
if len(vars) > 0 {
key = vars[0]
}
if v, ok := s.Values[key]; ok {
// Drop the flashes and return it.
delete(s.Values, key)
flashes = v.([]any)
}
return flashes
}
// AddFlash adds a flash message to the session.
//
// A single variadic argument is accepted, and it is optional: it defines
// the flash key. If not defined "_flash" is used by default.
func (s *Session) AddFlash(value any, vars ...string) {
key := flashesKey
if len(vars) > 0 {
key = vars[0]
}
var flashes []any
if v, ok := s.Values[key]; ok {
flashes = v.([]any)
}
s.Values[key] = append(flashes, value)
}
// Save is a convenience method to save this session. It is the same as calling
// store.Save(request, response, session). You should call Save before writing to
// the response or returning from the handler.
func (s *Session) Save(r *http.Request, w http.ResponseWriter) error {
return s.store.Save(r, w, s)
}
// Name returns the name used to register the session.
func (s *Session) Name() string {
return s.name
}
// Store returns the session store used to register the session.
func (s *Session) Store() Store {
return s.store
}
// Registry -------------------------------------------------------------------
// sessionInfo stores a session tracked by the registry.
type sessionInfo struct {
s *Session
e error
}
// contextKey is the type used to store the registry in the context.
type contextKey int
// registryKey is the key used to store the registry in the context.
const registryKey contextKey = 0
// GetRegistry returns a registry instance for the current request.
func GetRegistry(r *http.Request) *Registry {
var ctx = r.Context()
registry := ctx.Value(registryKey)
if registry != nil {
return registry.(*Registry)
}
newRegistry := &Registry{
request: r,
sessions: make(map[string]sessionInfo),
}
*r = *r.WithContext(context.WithValue(ctx, registryKey, newRegistry))
return newRegistry
}
// Registry stores sessions used during a request.
type Registry struct {
request *http.Request
sessions map[string]sessionInfo
}
// Get registers and returns a session for the given name and session store.
//
// It returns a new session if there are no sessions registered for the name.
func (s *Registry) Get(store Store, name string) (session *Session, err error) {
if !isCookieNameValid(name) {
return nil, fmt.Errorf("sessions: invalid character in cookie name: %s", name)
}
if info, ok := s.sessions[name]; ok {
session, err = info.s, info.e
} else {
session, err = store.New(s.request, name)
session.name = name
s.sessions[name] = sessionInfo{s: session, e: err}
}
session.store = store
return
}
// Save saves all sessions registered for the current request.
func (s *Registry) Save(w http.ResponseWriter) error {
var errMulti MultiError
for name, info := range s.sessions {
session := info.s
if session.store == nil {
errMulti = append(errMulti, fmt.Errorf(
"sessions: missing store for session %q", name))
} else if err := session.store.Save(s.request, w, session); err != nil {
errMulti = append(errMulti, fmt.Errorf(
"sessions: error saving session %q -- %v", name, err))
}
}
if errMulti != nil {
return errMulti
}
return nil
}
// Helpers --------------------------------------------------------------------
func init() {
gob.Register([]any{})
}
// Save saves all sessions used during the current request.
func Save(r *http.Request, w http.ResponseWriter) error {
return GetRegistry(r).Save(w)
}
// NewCookie returns an http.Cookie with the options set. It also sets
// the Expires field calculated based on the MaxAge value, for Internet
// Explorer compatibility.
func NewCookie(name, value string, options *Options) *http.Cookie {
cookie := newCookieFromOptions(name, value, options)
if options.MaxAge > 0 {
d := time.Duration(options.MaxAge) * time.Second
cookie.Expires = time.Now().Add(d)
} else if options.MaxAge < 0 {
// Set it to the past to expire now.
cookie.Expires = time.Unix(1, 0)
}
return cookie
}
// Error ----------------------------------------------------------------------
// MultiError stores multiple errors.
//
// Borrowed from the App Engine SDK.
type MultiError []error
func (m MultiError) Error() string {
s, n := "", 0
for _, e := range m {
if e != nil {
if n == 0 {
s = e.Error()
}
n++
}
}
switch n {
case 0:
return "(0 errors)"
case 1:
return s
case 2:
return s + " (and 1 other error)"
}
return fmt.Sprintf("%s (and %d other errors)", s, n-1)
}
================================================
FILE: gossh/gin/sessions/sessions/store.go
================================================
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sessions
import (
"encoding/base32"
"net/http"
"os"
"path/filepath"
"sync"
"gossh/gin/sessions/securecookie"
)
// Store is an interface for custom session stores.
//
// See CookieStore and FilesystemStore for examples.
type Store interface {
// Get should return a cached session.
Get(r *http.Request, name string) (*Session, error)
// New should create and return a new session.
//
// Note that New should never return a nil session, even in the case of
// an error if using the Registry infrastructure to cache the session.
New(r *http.Request, name string) (*Session, error)
// Save should persist session to the underlying store implementation.
Save(r *http.Request, w http.ResponseWriter, s *Session) error
}
// CookieStore ----------------------------------------------------------------
// NewCookieStore returns a new CookieStore.
//
// Keys are defined in pairs to allow key rotation, but the common case is
// to set a single authentication key and optionally an encryption key.
//
// The first key in a pair is used for authentication and the second for
// encryption. The encryption key can be set to nil or omitted in the last
// pair, but the authentication key is required in all pairs.
//
// It is recommended to use an authentication key with 32 or 64 bytes.
// The encryption key, if set, must be either 16, 24, or 32 bytes to select
// AES-128, AES-192, or AES-256 modes.
func NewCookieStore(keyPairs ...[]byte) *CookieStore {
cs := &CookieStore{
Codecs: securecookie.CodecsFromPairs(keyPairs...),
Options: &Options{
Path: "/",
MaxAge: 86400 * 30,
},
}
cs.MaxAge(cs.Options.MaxAge)
return cs
}
// CookieStore stores sessions using secure cookies.
type CookieStore struct {
Codecs []securecookie.Codec
Options *Options // default configuration
}
// Get returns a session for the given name after adding it to the registry.
//
// It returns a new session if the sessions doesn't exist. Access IsNew on
// the session to check if it is an existing session or a new one.
//
// It returns a new session and an error if the session exists but could
// not be decoded.
func (s *CookieStore) Get(r *http.Request, name string) (*Session, error) {
return GetRegistry(r).Get(s, name)
}
// New returns a session for the given name without adding it to the registry.
//
// The difference between New() and Get() is that calling New() twice will
// decode the session data twice, while Get() registers and reuses the same
// decoded session after the first call.
func (s *CookieStore) New(r *http.Request, name string) (*Session, error) {
session := NewSession(s, name)
opts := *s.Options
session.Options = &opts
session.IsNew = true
var err error
if c, errCookie := r.Cookie(name); errCookie == nil {
err = securecookie.DecodeMulti(name, c.Value, &session.Values,
s.Codecs...)
if err == nil {
session.IsNew = false
}
}
return session, err
}
// Save adds a single session to the response.
func (s *CookieStore) Save(r *http.Request, w http.ResponseWriter,
session *Session) error {
encoded, err := securecookie.EncodeMulti(session.Name(), session.Values,
s.Codecs...)
if err != nil {
return err
}
http.SetCookie(w, NewCookie(session.Name(), encoded, session.Options))
return nil
}
// MaxAge sets the maximum age for the store and the underlying cookie
// implementation. Individual sessions can be deleted by setting Options.MaxAge
// = -1 for that session.
func (s *CookieStore) MaxAge(age int) {
s.Options.MaxAge = age
// Set the maxAge for each securecookie instance.
for _, codec := range s.Codecs {
if sc, ok := codec.(*securecookie.SecureCookie); ok {
sc.MaxAge(age)
}
}
}
// FilesystemStore ------------------------------------------------------------
var fileMutex sync.RWMutex
// NewFilesystemStore returns a new FilesystemStore.
//
// The path argument is the directory where sessions will be saved. If empty
// it will use os.TempDir().
//
// See NewCookieStore() for a description of the other parameters.
func NewFilesystemStore(path string, keyPairs ...[]byte) *FilesystemStore {
if path == "" {
path = os.TempDir()
}
fs := &FilesystemStore{
Codecs: securecookie.CodecsFromPairs(keyPairs...),
Options: &Options{
Path: "/",
MaxAge: 86400 * 30,
},
path: path,
}
fs.MaxAge(fs.Options.MaxAge)
return fs
}
// FilesystemStore stores sessions in the filesystem.
//
// It also serves as a reference for custom stores.
//
// This store is still experimental and not well tested. Feedback is welcome.
type FilesystemStore struct {
Codecs []securecookie.Codec
Options *Options // default configuration
path string
}
// MaxLength restricts the maximum length of new sessions to l.
// If l is 0 there is no limit to the size of a session, use with caution.
// The default for a new FilesystemStore is 4096.
func (s *FilesystemStore) MaxLength(l int) {
for _, c := range s.Codecs {
if codec, ok := c.(*securecookie.SecureCookie); ok {
codec.MaxLength(l)
}
}
}
// Get returns a session for the given name after adding it to the registry.
//
// See CookieStore.Get().
func (s *FilesystemStore) Get(r *http.Request, name string) (*Session, error) {
return GetRegistry(r).Get(s, name)
}
// New returns a session for the given name without adding it to the registry.
//
// See CookieStore.New().
func (s *FilesystemStore) New(r *http.Request, name string) (*Session, error) {
session := NewSession(s, name)
opts := *s.Options
session.Options = &opts
session.IsNew = true
var err error
if c, errCookie := r.Cookie(name); errCookie == nil {
err = securecookie.DecodeMulti(name, c.Value, &session.ID, s.Codecs...)
if err == nil {
err = s.load(session)
if err == nil {
session.IsNew = false
}
}
}
return session, err
}
var base32RawStdEncoding = base32.StdEncoding.WithPadding(base32.NoPadding)
// Save adds a single session to the response.
//
// If the Options.MaxAge of the session is <= 0 then the session file will be
// deleted from the store path. With this process it enforces the properly
// session cookie handling so no need to trust in the cookie management in the
// web browser.
func (s *FilesystemStore) Save(r *http.Request, w http.ResponseWriter,
session *Session) error {
// Delete if max-age is <= 0
if session.Options.MaxAge <= 0 {
if err := s.erase(session); err != nil && !os.IsNotExist(err) {
return err
}
http.SetCookie(w, NewCookie(session.Name(), "", session.Options))
return nil
}
if session.ID == "" {
// Because the ID is used in the filename, encode it to
// use alphanumeric characters only.
session.ID = base32RawStdEncoding.EncodeToString(
securecookie.GenerateRandomKey(32))
}
if err := s.save(session); err != nil {
return err
}
encoded, err := securecookie.EncodeMulti(session.Name(), session.ID,
s.Codecs...)
if err != nil {
return err
}
http.SetCookie(w, NewCookie(session.Name(), encoded, session.Options))
return nil
}
// MaxAge sets the maximum age for the store and the underlying cookie
// implementation. Individual sessions can be deleted by setting Options.MaxAge
// = -1 for that session.
func (s *FilesystemStore) MaxAge(age int) {
s.Options.MaxAge = age
// Set the maxAge for each securecookie instance.
for _, codec := range s.Codecs {
if sc, ok := codec.(*securecookie.SecureCookie); ok {
sc.MaxAge(age)
}
}
}
// save writes encoded session.Values to a file.
func (s *FilesystemStore) save(session *Session) error {
encoded, err := securecookie.EncodeMulti(session.Name(), session.Values,
s.Codecs...)
if err != nil {
return err
}
filename := filepath.Join(s.path, "session_"+session.ID)
fileMutex.Lock()
defer fileMutex.Unlock()
return os.WriteFile(filename, []byte(encoded), 0600)
}
// load reads a file and decodes its content into session.Values.
func (s *FilesystemStore) load(session *Session) error {
filename := filepath.Join(s.path, "session_"+session.ID)
fileMutex.RLock()
defer fileMutex.RUnlock()
fdata, err := os.ReadFile(filepath.Clean(filename))
if err != nil {
return err
}
if err = securecookie.DecodeMulti(session.Name(), string(fdata),
&session.Values, s.Codecs...); err != nil {
return err
}
return nil
}
// delete session file
func (s *FilesystemStore) erase(session *Session) error {
filename := filepath.Join(s.path, "session_"+session.ID)
fileMutex.RLock()
defer fileMutex.RUnlock()
err := os.Remove(filename)
return err
}
================================================
FILE: gossh/gin/sessions/sessions.go
================================================
package sessions
import (
"log"
"net/http"
"gossh/gin"
"gossh/gin/sessions/context"
"gossh/gin/sessions/sessions"
)
const (
DefaultKey = "gin-sessions"
errorFormat = "[sessions] ERROR! %s\n"
)
type Store interface {
sessions.Store
Options(Options)
}
// Wraps thinly gorilla-session methods.
// Session stores the values and optional configuration for a session.
type Session interface {
// ID of the session, generated by stores. It should not be used for user data.
ID() string
// Get returns the session value associated to the given key.
Get(key any) any
// Set sets the session value associated to the given key.
Set(key any, val any)
// Delete removes the session value associated to the given key.
Delete(key any)
// Clear deletes all values in the session.
Clear()
// AddFlash adds a flash message to the session.
// A single variadic argument is accepted, and it is optional: it defines the flash key.
// If not defined "_flash" is used by default.
AddFlash(value any, vars ...string)
// Flashes returns a slice of flash messages from the session.
// A single variadic argument is accepted, and it is optional: it defines the flash key.
// If not defined "_flash" is used by default.
Flashes(vars ...string) []any
// Options sets configuration for a session.
Options(Options)
// Save saves all sessions used during the current request.
Save() error
}
func Sessions(name string, store Store) gin.HandlerFunc {
return func(c *gin.Context) {
s := &session{name, c.Request, store, nil, false, c.Writer}
c.Set(DefaultKey, s)
defer context.Clear(c.Request)
c.Next()
}
}
func SessionsMany(names []string, store Store) gin.HandlerFunc {
return func(c *gin.Context) {
sessions := make(map[string]Session, len(names))
for _, name := range names {
sessions[name] = &session{name, c.Request, store, nil, false, c.Writer}
}
c.Set(DefaultKey, sessions)
defer context.Clear(c.Request)
c.Next()
}
}
type session struct {
name string
request *http.Request
store Store
session *sessions.Session
written bool
writer http.ResponseWriter
}
func (s *session) ID() string {
return s.Session().ID
}
func (s *session) Get(key any) any {
return s.Session().Values[key]
}
func (s *session) Set(key any, val any) {
s.Session().Values[key] = val
s.written = true
}
func (s *session) Delete(key any) {
delete(s.Session().Values, key)
s.written = true
}
func (s *session) Clear() {
for key := range s.Session().Values {
s.Delete(key)
}
}
func (s *session) AddFlash(value any, vars ...string) {
s.Session().AddFlash(value, vars...)
s.written = true
}
func (s *session) Flashes(vars ...string) []any {
s.written = true
return s.Session().Flashes(vars...)
}
func (s *session) Options(options Options) {
s.written = true
s.Session().Options = options.ToGorillaOptions()
}
func (s *session) Save() error {
if s.Written() {
e := s.Session().Save(s.request, s.writer)
if e == nil {
s.written = false
}
return e
}
return nil
}
func (s *session) Session() *sessions.Session {
if s.session == nil {
var err error
s.session, err = s.store.Get(s.request, s.name)
if err != nil {
log.Printf(errorFormat, err)
}
}
return s.session
}
func (s *session) Written() bool {
return s.written
}
// shortcut to get session
func Default(c *gin.Context) Session {
return c.MustGet(DefaultKey).(Session)
}
// shortcut to get session with given name
func DefaultMany(c *gin.Context, name string) Session {
return c.MustGet(DefaultKey).(map[string]Session)[name]
}
================================================
FILE: gossh/gin/sse/sse-decoder.go
================================================
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package sse
import (
"bytes"
"io"
)
type decoder struct {
events []Event
}
func Decode(r io.Reader) ([]Event, error) {
var dec decoder
return dec.decode(r)
}
func (d *decoder) dispatchEvent(event Event, data string) {
dataLength := len(data)
if dataLength > 0 {
//If the data buffer's last character is a U+000A LINE FEED (LF) character, then remove the last character from the data buffer.
data = data[:dataLength-1]
dataLength--
}
if dataLength == 0 && event.Event == "" {
return
}
if event.Event == "" {
event.Event = "message"
}
event.Data = data
d.events = append(d.events, event)
}
func (d *decoder) decode(r io.Reader) ([]Event, error) {
buf, err := io.ReadAll(r)
if err != nil {
return nil, err
}
var currentEvent Event
var dataBuffer *bytes.Buffer = new(bytes.Buffer)
// TODO (and unit tests)
// Lines must be separated by either a U+000D CARRIAGE RETURN U+000A LINE FEED (CRLF) character pair,
// a single U+000A LINE FEED (LF) character,
// or a single U+000D CARRIAGE RETURN (CR) character.
lines := bytes.Split(buf, []byte{'\n'})
for _, line := range lines {
if len(line) == 0 {
// If the line is empty (a blank line). Dispatch the event.
d.dispatchEvent(currentEvent, dataBuffer.String())
// reset current event and data buffer
currentEvent = Event{}
dataBuffer.Reset()
continue
}
if line[0] == byte(':') {
// If the line starts with a U+003A COLON character (:), ignore the line.
continue
}
var field, value []byte
colonIndex := bytes.IndexRune(line, ':')
if colonIndex != -1 {
// If the line contains a U+003A COLON character character (:)
// Collect the characters on the line before the first U+003A COLON character (:),
// and let field be that string.
field = line[:colonIndex]
// Collect the characters on the line after the first U+003A COLON character (:),
// and let value be that string.
value = line[colonIndex+1:]
// If value starts with a single U+0020 SPACE character, remove it from value.
if len(value) > 0 && value[0] == ' ' {
value = value[1:]
}
} else {
// Otherwise, the string is not empty but does not contain a U+003A COLON character character (:)
// Use the whole line as the field name, and the empty string as the field value.
field = line
value = []byte{}
}
// The steps to process the field given a field name and a field value depend on the field name,
// as given in the following list. Field names must be compared literally,
// with no case folding performed.
switch string(field) {
case "event":
// Set the event name buffer to field value.
currentEvent.Event = string(value)
case "id":
// Set the event stream's last event ID to the field value.
currentEvent.Id = string(value)
case "retry":
// If the field value consists of only characters in the range U+0030 DIGIT ZERO (0) to U+0039 DIGIT NINE (9),
// then interpret the field value as an integer in base ten, and set the event stream's reconnection time to that integer.
// Otherwise, ignore the field.
currentEvent.Id = string(value)
case "data":
// Append the field value to the data buffer,
dataBuffer.Write(value)
// then append a single U+000A LINE FEED (LF) character to the data buffer.
dataBuffer.WriteString("\n")
default:
//Otherwise. The field is ignored.
continue
}
}
// Once the end of the file is reached, the user agent must dispatch the event one final time.
d.dispatchEvent(currentEvent, dataBuffer.String())
return d.events, nil
}
================================================
FILE: gossh/gin/sse/sse-encoder.go
================================================
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package sse
import (
"encoding/json"
"fmt"
"io"
"net/http"
"reflect"
"strconv"
"strings"
)
// Server-Sent Events
// W3C Working Draft 29 October 2009
// http://www.w3.org/TR/2009/WD-eventsource-20091029/
const ContentType = "text/event-stream"
var contentType = []string{ContentType}
var noCache = []string{"no-cache"}
var fieldReplacer = strings.NewReplacer(
"\n", "\\n",
"\r", "\\r")
var dataReplacer = strings.NewReplacer(
"\n", "\ndata:",
"\r", "\\r")
type Event struct {
Event string
Id string
Retry uint
Data any
}
func Encode(writer io.Writer, event Event) error {
w := checkWriter(writer)
writeId(w, event.Id)
writeEvent(w, event.Event)
writeRetry(w, event.Retry)
return writeData(w, event.Data)
}
func writeId(w stringWriter, id string) {
if len(id) > 0 {
w.WriteString("id:")
fieldReplacer.WriteString(w, id)
w.WriteString("\n")
}
}
func writeEvent(w stringWriter, event string) {
if len(event) > 0 {
w.WriteString("event:")
fieldReplacer.WriteString(w, event)
w.WriteString("\n")
}
}
func writeRetry(w stringWriter, retry uint) {
if retry > 0 {
w.WriteString("retry:")
w.WriteString(strconv.FormatUint(uint64(retry), 10))
w.WriteString("\n")
}
}
func writeData(w stringWriter, data any) error {
w.WriteString("data:")
switch kindOfData(data) {
case reflect.Struct, reflect.Slice, reflect.Map:
err := json.NewEncoder(w).Encode(data)
if err != nil {
return err
}
w.WriteString("\n")
default:
dataReplacer.WriteString(w, fmt.Sprint(data))
w.WriteString("\n\n")
}
return nil
}
func (r Event) Render(w http.ResponseWriter) error {
r.WriteContentType(w)
return Encode(w, r)
}
func (r Event) WriteContentType(w http.ResponseWriter) {
header := w.Header()
header["Content-Type"] = contentType
if _, exist := header["Cache-Control"]; !exist {
header["Cache-Control"] = noCache
}
}
func kindOfData(data any) reflect.Kind {
value := reflect.ValueOf(data)
valueType := value.Kind()
if valueType == reflect.Ptr {
valueType = value.Elem().Kind()
}
return valueType
}
================================================
FILE: gossh/gin/sse/writer.go
================================================
package sse
import "io"
type stringWriter interface {
io.Writer
WriteString(string) (int, error)
}
type stringWrapper struct {
io.Writer
}
func (w stringWrapper) WriteString(str string) (int, error) {
return w.Writer.Write([]byte(str))
}
func checkWriter(writer io.Writer) stringWriter {
if w, ok := writer.(stringWriter); ok {
return w
} else {
return stringWrapper{writer}
}
}
================================================
FILE: gossh/gin/tree.go
================================================
// Copyright 2013 Julien Schmidt. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be found
// at https://github.com/julienschmidt/httprouter/blob/master/LICENSE
package gin
import (
"bytes"
"net/url"
"strings"
"unicode"
"unicode/utf8"
"gossh/gin/internal/bytesconv"
)
var (
strColon = []byte(":")
strStar = []byte("*")
strSlash = []byte("/")
)
// Param is a single URL parameter, consisting of a key and a value.
type Param struct {
Key string
Value string
}
// Params is a Param-slice, as returned by the router.
// The slice is ordered, the first URL parameter is also the first slice value.
// It is therefore safe to read values by the index.
type Params []Param
// Get returns the value of the first Param which key matches the given name and a boolean true.
// If no matching Param is found, an empty string is returned and a boolean false .
func (ps Params) Get(name string) (string, bool) {
for _, entry := range ps {
if entry.Key == name {
return entry.Value, true
}
}
return "", false
}
// ByName returns the value of the first Param which key matches the given name.
// If no matching Param is found, an empty string is returned.
func (ps Params) ByName(name string) (va string) {
va, _ = ps.Get(name)
return
}
type methodTree struct {
method string
root *node
}
type methodTrees []methodTree
func (trees methodTrees) get(method string) *node {
for _, tree := range trees {
if tree.method == method {
return tree.root
}
}
return nil
}
func min(a, b int) int {
if a <= b {
return a
}
return b
}
func longestCommonPrefix(a, b string) int {
i := 0
max := min(len(a), len(b))
for i < max && a[i] == b[i] {
i++
}
return i
}
// addChild will add a child node, keeping wildcardChild at the end
func (n *node) addChild(child *node) {
if n.wildChild && len(n.children) > 0 {
wildcardChild := n.children[len(n.children)-1]
n.children = append(n.children[:len(n.children)-1], child, wildcardChild)
} else {
n.children = append(n.children, child)
}
}
func countParams(path string) uint16 {
var n uint16
s := bytesconv.StringToBytes(path)
n += uint16(bytes.Count(s, strColon))
n += uint16(bytes.Count(s, strStar))
return n
}
func countSections(path string) uint16 {
s := bytesconv.StringToBytes(path)
return uint16(bytes.Count(s, strSlash))
}
type nodeType uint8
const (
static nodeType = iota
root
param
catchAll
)
type node struct {
path string
indices string
wildChild bool
nType nodeType
priority uint32
children []*node // child nodes, at most 1 :param style node at the end of the array
handlers HandlersChain
fullPath string
}
// Increments priority of the given child and reorders if necessary
func (n *node) incrementChildPrio(pos int) int {
cs := n.children
cs[pos].priority++
prio := cs[pos].priority
// Adjust position (move to front)
newPos := pos
for ; newPos > 0 && cs[newPos-1].priority < prio; newPos-- {
// Swap node positions
cs[newPos-1], cs[newPos] = cs[newPos], cs[newPos-1]
}
// Build new index char string
if newPos != pos {
n.indices = n.indices[:newPos] + // Unchanged prefix, might be empty
n.indices[pos:pos+1] + // The index char we move
n.indices[newPos:pos] + n.indices[pos+1:] // Rest without char at 'pos'
}
return newPos
}
// addRoute adds a node with the given handle to the path.
// Not concurrency-safe!
func (n *node) addRoute(path string, handlers HandlersChain) {
fullPath := path
n.priority++
// Empty tree
if len(n.path) == 0 && len(n.children) == 0 {
n.insertChild(path, fullPath, handlers)
n.nType = root
return
}
parentFullPathIndex := 0
walk:
for {
// Find the longest common prefix.
// This also implies that the common prefix contains no ':' or '*'
// since the existing key can't contain those chars.
i := longestCommonPrefix(path, n.path)
// Split edge
if i < len(n.path) {
child := node{
path: n.path[i:],
wildChild: n.wildChild,
nType: static,
indices: n.indices,
children: n.children,
handlers: n.handlers,
priority: n.priority - 1,
fullPath: n.fullPath,
}
n.children = []*node{&child}
// []byte for proper unicode char conversion, see #65
n.indices = bytesconv.BytesToString([]byte{n.path[i]})
n.path = path[:i]
n.handlers = nil
n.wildChild = false
n.fullPath = fullPath[:parentFullPathIndex+i]
}
// Make new node a child of this node
if i < len(path) {
path = path[i:]
c := path[0]
// '/' after param
if n.nType == param && c == '/' && len(n.children) == 1 {
parentFullPathIndex += len(n.path)
n = n.children[0]
n.priority++
continue walk
}
// Check if a child with the next path byte exists
for i, max := 0, len(n.indices); i < max; i++ {
if c == n.indices[i] {
parentFullPathIndex += len(n.path)
i = n.incrementChildPrio(i)
n = n.children[i]
continue walk
}
}
// Otherwise insert it
if c != ':' && c != '*' && n.nType != catchAll {
// []byte for proper unicode char conversion, see #65
n.indices += bytesconv.BytesToString([]byte{c})
child := &node{
fullPath: fullPath,
}
n.addChild(child)
n.incrementChildPrio(len(n.indices) - 1)
n = child
} else if n.wildChild {
// inserting a wildcard node, need to check if it conflicts with the existing wildcard
n = n.children[len(n.children)-1]
n.priority++
// Check if the wildcard matches
if len(path) >= len(n.path) && n.path == path[:len(n.path)] &&
// Adding a child to a catchAll is not possible
n.nType != catchAll &&
// Check for longer wildcard, e.g. :name and :names
(len(n.path) >= len(path) || path[len(n.path)] == '/') {
continue walk
}
// Wildcard conflict
pathSeg := path
if n.nType != catchAll {
pathSeg = strings.SplitN(pathSeg, "/", 2)[0]
}
prefix := fullPath[:strings.Index(fullPath, pathSeg)] + n.path
panic("'" + pathSeg +
"' in new path '" + fullPath +
"' conflicts with existing wildcard '" + n.path +
"' in existing prefix '" + prefix +
"'")
}
n.insertChild(path, fullPath, handlers)
return
}
// Otherwise add handle to current node
if n.handlers != nil {
panic("handlers are already registered for path '" + fullPath + "'")
}
n.handlers = handlers
n.fullPath = fullPath
return
}
}
// Search for a wildcard segment and check the name for invalid characters.
// Returns -1 as index, if no wildcard was found.
func findWildcard(path string) (wildcard string, i int, valid bool) {
// Find start
for start, c := range []byte(path) {
// A wildcard starts with ':' (param) or '*' (catch-all)
if c != ':' && c != '*' {
continue
}
// Find end and check for invalid characters
valid = true
for end, c := range []byte(path[start+1:]) {
switch c {
case '/':
return path[start : start+1+end], start, valid
case ':', '*':
valid = false
}
}
return path[start:], start, valid
}
return "", -1, false
}
func (n *node) insertChild(path string, fullPath string, handlers HandlersChain) {
for {
// Find prefix until first wildcard
wildcard, i, valid := findWildcard(path)
if i < 0 { // No wildcard found
break
}
// The wildcard name must only contain one ':' or '*' character
if !valid {
panic("only one wildcard per path segment is allowed, has: '" +
wildcard + "' in path '" + fullPath + "'")
}
// check if the wildcard has a name
if len(wildcard) < 2 {
panic("wildcards must be named with a non-empty name in path '" + fullPath + "'")
}
if wildcard[0] == ':' { // param
if i > 0 {
// Insert prefix before the current wildcard
n.path = path[:i]
path = path[i:]
}
child := &node{
nType: param,
path: wildcard,
fullPath: fullPath,
}
n.addChild(child)
n.wildChild = true
n = child
n.priority++
// if the path doesn't end with the wildcard, then there
// will be another subpath starting with '/'
if len(wildcard) < len(path) {
path = path[len(wildcard):]
child := &node{
priority: 1,
fullPath: fullPath,
}
n.addChild(child)
n = child
continue
}
// Otherwise we're done. Insert the handle in the new leaf
n.handlers = handlers
return
}
// catchAll
if i+len(wildcard) != len(path) {
panic("catch-all routes are only allowed at the end of the path in path '" + fullPath + "'")
}
if len(n.path) > 0 && n.path[len(n.path)-1] == '/' {
pathSeg := ""
if len(n.children) != 0 {
pathSeg = strings.SplitN(n.children[0].path, "/", 2)[0]
}
panic("catch-all wildcard '" + path +
"' in new path '" + fullPath +
"' conflicts with existing path segment '" + pathSeg +
"' in existing prefix '" + n.path + pathSeg +
"'")
}
// currently fixed width 1 for '/'
i--
if path[i] != '/' {
panic("no / before catch-all in path '" + fullPath + "'")
}
n.path = path[:i]
// First node: catchAll node with empty path
child := &node{
wildChild: true,
nType: catchAll,
fullPath: fullPath,
}
n.addChild(child)
n.indices = string('/')
n = child
n.priority++
// second node: node holding the variable
child = &node{
path: path[i:],
nType: catchAll,
handlers: handlers,
priority: 1,
fullPath: fullPath,
}
n.children = []*node{child}
return
}
// If no wildcard was found, simply insert the path and handle
n.path = path
n.handlers = handlers
n.fullPath = fullPath
}
// nodeValue holds return values of (*Node).getValue method
type nodeValue struct {
handlers HandlersChain
params *Params
tsr bool
fullPath string
}
type skippedNode struct {
path string
node *node
paramsCount int16
}
// Returns the handle registered with the given path (key). The values of
// wildcards are saved to a map.
// If no handle can be found, a TSR (trailing slash redirect) recommendation is
// made if a handle exists with an extra (without the) trailing slash for the
// given path.
func (n *node) getValue(path string, params *Params, skippedNodes *[]skippedNode, unescape bool) (value nodeValue) {
var globalParamsCount int16
walk: // Outer loop for walking the tree
for {
prefix := n.path
if len(path) > len(prefix) {
if path[:len(prefix)] == prefix {
path = path[len(prefix):]
// Try all the non-wildcard children first by matching the indices
idxc := path[0]
for i, c := range []byte(n.indices) {
if c == idxc {
// strings.HasPrefix(n.children[len(n.children)-1].path, ":") == n.wildChild
if n.wildChild {
index := len(*skippedNodes)
*skippedNodes = (*skippedNodes)[:index+1]
(*skippedNodes)[index] = skippedNode{
path: prefix + path,
node: &node{
path: n.path,
wildChild: n.wildChild,
nType: n.nType,
priority: n.priority,
children: n.children,
handlers: n.handlers,
fullPath: n.fullPath,
},
paramsCount: globalParamsCount,
}
}
n = n.children[i]
continue walk
}
}
if !n.wildChild {
// If the path at the end of the loop is not equal to '/' and the current node has no child nodes
// the current node needs to roll back to last valid skippedNode
if path != "/" {
for length := len(*skippedNodes); length > 0; length-- {
skippedNode := (*skippedNodes)[length-1]
*skippedNodes = (*skippedNodes)[:length-1]
if strings.HasSuffix(skippedNode.path, path) {
path = skippedNode.path
n = skippedNode.node
if value.params != nil {
*value.params = (*value.params)[:skippedNode.paramsCount]
}
globalParamsCount = skippedNode.paramsCount
continue walk
}
}
}
// Nothing found.
// We can recommend to redirect to the same URL without a
// trailing slash if a leaf exists for that path.
value.tsr = path == "/" && n.handlers != nil
return value
}
// Handle wildcard child, which is always at the end of the array
n = n.children[len(n.children)-1]
globalParamsCount++
switch n.nType {
case param:
// fix truncate the parameter
// tree_test.go line: 204
// Find param end (either '/' or path end)
end := 0
for end < len(path) && path[end] != '/' {
end++
}
// Save param value
if params != nil {
// Preallocate capacity if necessary
if cap(*params) < int(globalParamsCount) {
newParams := make(Params, len(*params), globalParamsCount)
copy(newParams, *params)
*params = newParams
}
if value.params == nil {
value.params = params
}
// Expand slice within preallocated capacity
i := len(*value.params)
*value.params = (*value.params)[:i+1]
val := path[:end]
if unescape {
if v, err := url.QueryUnescape(val); err == nil {
val = v
}
}
(*value.params)[i] = Param{
Key: n.path[1:],
Value: val,
}
}
// we need to go deeper!
if end < len(path) {
if len(n.children) > 0 {
path = path[end:]
n = n.children[0]
continue walk
}
// ... but we can't
value.tsr = len(path) == end+1
return value
}
if value.handlers = n.handlers; value.handlers != nil {
value.fullPath = n.fullPath
return value
}
if len(n.children) == 1 {
// No handle found. Check if a handle for this path + a
// trailing slash exists for TSR recommendation
n = n.children[0]
value.tsr = (n.path == "/" && n.handlers != nil) || (n.path == "" && n.indices == "/")
}
return value
case catchAll:
// Save param value
if params != nil {
// Preallocate capacity if necessary
if cap(*params) < int(globalParamsCount) {
newParams := make(Params, len(*params), globalParamsCount)
copy(newParams, *params)
*params = newParams
}
if value.params == nil {
value.params = params
}
// Expand slice within preallocated capacity
i := len(*value.params)
*value.params = (*value.params)[:i+1]
val := path
if unescape {
if v, err := url.QueryUnescape(path); err == nil {
val = v
}
}
(*value.params)[i] = Param{
Key: n.path[2:],
Value: val,
}
}
value.handlers = n.handlers
value.fullPath = n.fullPath
return value
default:
panic("invalid node type")
}
}
}
if path == prefix {
// If the current path does not equal '/' and the node does not have a registered handle and the most recently matched node has a child node
// the current node needs to roll back to last valid skippedNode
if n.handlers == nil && path != "/" {
for length := len(*skippedNodes); length > 0; length-- {
skippedNode := (*skippedNodes)[length-1]
*skippedNodes = (*skippedNodes)[:length-1]
if strings.HasSuffix(skippedNode.path, path) {
path = skippedNode.path
n = skippedNode.node
if value.params != nil {
*value.params = (*value.params)[:skippedNode.paramsCount]
}
globalParamsCount = skippedNode.paramsCount
continue walk
}
}
// n = latestNode.children[len(latestNode.children)-1]
}
// We should have reached the node containing the handle.
// Check if this node has a handle registered.
if value.handlers = n.handlers; value.handlers != nil {
value.fullPath = n.fullPath
return value
}
// If there is no handle for this route, but this route has a
// wildcard child, there must be a handle for this path with an
// additional trailing slash
if path == "/" && n.wildChild && n.nType != root {
value.tsr = true
return value
}
if path == "/" && n.nType == static {
value.tsr = true
return value
}
// No handle found. Check if a handle for this path + a
// trailing slash exists for trailing slash recommendation
for i, c := range []byte(n.indices) {
if c == '/' {
n = n.children[i]
value.tsr = (len(n.path) == 1 && n.handlers != nil) ||
(n.nType == catchAll && n.children[0].handlers != nil)
return value
}
}
return value
}
// Nothing found. We can recommend to redirect to the same URL with an
// extra trailing slash if a leaf exists for that path
value.tsr = path == "/" ||
(len(prefix) == len(path)+1 && prefix[len(path)] == '/' &&
path == prefix[:len(prefix)-1] && n.handlers != nil)
// roll back to last valid skippedNode
if !value.tsr && path != "/" {
for length := len(*skippedNodes); length > 0; length-- {
skippedNode := (*skippedNodes)[length-1]
*skippedNodes = (*skippedNodes)[:length-1]
if strings.HasSuffix(skippedNode.path, path) {
path = skippedNode.path
n = skippedNode.node
if value.params != nil {
*value.params = (*value.params)[:skippedNode.paramsCount]
}
globalParamsCount = skippedNode.paramsCount
continue walk
}
}
}
return value
}
}
// Makes a case-insensitive lookup of the given path and tries to find a handler.
// It can optionally also fix trailing slashes.
// It returns the case-corrected path and a bool indicating whether the lookup
// was successful.
func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) ([]byte, bool) {
const stackBufSize = 128
// Use a static sized buffer on the stack in the common case.
// If the path is too long, allocate a buffer on the heap instead.
buf := make([]byte, 0, stackBufSize)
if length := len(path) + 1; length > stackBufSize {
buf = make([]byte, 0, length)
}
ciPath := n.findCaseInsensitivePathRec(
path,
buf, // Preallocate enough memory for new path
[4]byte{}, // Empty rune buffer
fixTrailingSlash,
)
return ciPath, ciPath != nil
}
// Shift bytes in array by n bytes left
func shiftNRuneBytes(rb [4]byte, n int) [4]byte {
switch n {
case 0:
return rb
case 1:
return [4]byte{rb[1], rb[2], rb[3], 0}
case 2:
return [4]byte{rb[2], rb[3]}
case 3:
return [4]byte{rb[3]}
default:
return [4]byte{}
}
}
// Recursive case-insensitive lookup function used by n.findCaseInsensitivePath
func (n *node) findCaseInsensitivePathRec(path string, ciPath []byte, rb [4]byte, fixTrailingSlash bool) []byte {
npLen := len(n.path)
walk: // Outer loop for walking the tree
for len(path) >= npLen && (npLen == 0 || strings.EqualFold(path[1:npLen], n.path[1:])) {
// Add common prefix to result
oldPath := path
path = path[npLen:]
ciPath = append(ciPath, n.path...)
if len(path) == 0 {
// We should have reached the node containing the handle.
// Check if this node has a handle registered.
if n.handlers != nil {
return ciPath
}
// No handle found.
// Try to fix the path by adding a trailing slash
if fixTrailingSlash {
for i, c := range []byte(n.indices) {
if c == '/' {
n = n.children[i]
if (len(n.path) == 1 && n.handlers != nil) ||
(n.nType == catchAll && n.children[0].handlers != nil) {
return append(ciPath, '/')
}
return nil
}
}
}
return nil
}
// If this node does not have a wildcard (param or catchAll) child,
// we can just look up the next child node and continue to walk down
// the tree
if !n.wildChild {
// Skip rune bytes already processed
rb = shiftNRuneBytes(rb, npLen)
if rb[0] != 0 {
// Old rune not finished
idxc := rb[0]
for i, c := range []byte(n.indices) {
if c == idxc {
// continue with child node
n = n.children[i]
npLen = len(n.path)
continue walk
}
}
} else {
// Process a new rune
var rv rune
// Find rune start.
// Runes are up to 4 byte long,
// -4 would definitely be another rune.
var off int
for max := min(npLen, 3); off < max; off++ {
if i := npLen - off; utf8.RuneStart(oldPath[i]) {
// read rune from cached path
rv, _ = utf8.DecodeRuneInString(oldPath[i:])
break
}
}
// Calculate lowercase bytes of current rune
lo := unicode.ToLower(rv)
utf8.EncodeRune(rb[:], lo)
// Skip already processed bytes
rb = shiftNRuneBytes(rb, off)
idxc := rb[0]
for i, c := range []byte(n.indices) {
// Lowercase matches
if c == idxc {
// must use a recursive approach since both the
// uppercase byte and the lowercase byte might exist
// as an index
if out := n.children[i].findCaseInsensitivePathRec(
path, ciPath, rb, fixTrailingSlash,
); out != nil {
return out
}
break
}
}
// If we found no match, the same for the uppercase rune,
// if it differs
if up := unicode.ToUpper(rv); up != lo {
utf8.EncodeRune(rb[:], up)
rb = shiftNRuneBytes(rb, off)
idxc := rb[0]
for i, c := range []byte(n.indices) {
// Uppercase matches
if c == idxc {
// Continue with child node
n = n.children[i]
npLen = len(n.path)
continue walk
}
}
}
}
// Nothing found. We can recommend to redirect to the same URL
// without a trailing slash if a leaf exists for that path
if fixTrailingSlash && path == "/" && n.handlers != nil {
return ciPath
}
return nil
}
n = n.children[0]
switch n.nType {
case param:
// Find param end (either '/' or path end)
end := 0
for end < len(path) && path[end] != '/' {
end++
}
// Add param value to case insensitive path
ciPath = append(ciPath, path[:end]...)
// We need to go deeper!
if end < len(path) {
if len(n.children) > 0 {
// Continue with child node
n = n.children[0]
npLen = len(n.path)
path = path[end:]
continue
}
// ... but we can't
if fixTrailingSlash && len(path) == end+1 {
return ciPath
}
return nil
}
if n.handlers != nil {
return ciPath
}
if fixTrailingSlash && len(n.children) == 1 {
// No handle found. Check if a handle for this path + a
// trailing slash exists
n = n.children[0]
if n.path == "/" && n.handlers != nil {
return append(ciPath, '/')
}
}
return nil
case catchAll:
return append(ciPath, path...)
default:
panic("invalid node type")
}
}
// Nothing found.
// Try to fix the path by adding / removing a trailing slash
if fixTrailingSlash {
if path == "/" {
return ciPath
}
if len(path)+1 == npLen && n.path[len(path)] == '/' &&
strings.EqualFold(path[1:], n.path[1:len(path)]) && n.handlers != nil {
return append(ciPath, n.path...)
}
}
return nil
}
================================================
FILE: gossh/gin/utils.go
================================================
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package gin
import (
"encoding/xml"
"net/http"
"os"
"path"
"reflect"
"runtime"
"strings"
"unicode"
)
// BindKey indicates a default bind key.
const BindKey = "_gin-gonic/gin/bindkey"
// Bind is a helper function for given interface object and returns a Gin middleware.
func Bind(val any) HandlerFunc {
value := reflect.ValueOf(val)
if value.Kind() == reflect.Ptr {
panic(`Bind struct can not be a pointer. Example:
Use: gin.Bind(Struct{}) instead of gin.Bind(&Struct{})
`)
}
typ := value.Type()
return func(c *Context) {
obj := reflect.New(typ).Interface()
if c.Bind(obj) == nil {
c.Set(BindKey, obj)
}
}
}
// WrapF is a helper function for wrapping http.HandlerFunc and returns a Gin middleware.
func WrapF(f http.HandlerFunc) HandlerFunc {
return func(c *Context) {
f(c.Writer, c.Request)
}
}
// WrapH is a helper function for wrapping http.Handler and returns a Gin middleware.
func WrapH(h http.Handler) HandlerFunc {
return func(c *Context) {
h.ServeHTTP(c.Writer, c.Request)
}
}
// H is a shortcut for map[string]any
type H map[string]any
// MarshalXML allows type H to be used with xml.Marshal.
func (h H) MarshalXML(e *xml.Encoder, start xml.StartElement) error {
start.Name = xml.Name{
Space: "",
Local: "map",
}
if err := e.EncodeToken(start); err != nil {
return err
}
for key, value := range h {
elem := xml.StartElement{
Name: xml.Name{Space: "", Local: key},
Attr: []xml.Attr{},
}
if err := e.EncodeElement(value, elem); err != nil {
return err
}
}
return e.EncodeToken(xml.EndElement{Name: start.Name})
}
func assert1(guard bool, text string) {
if !guard {
panic(text)
}
}
func filterFlags(content string) string {
for i, char := range content {
if char == ' ' || char == ';' {
return content[:i]
}
}
return content
}
func chooseData(custom, wildcard any) any {
if custom != nil {
return custom
}
if wildcard != nil {
return wildcard
}
panic("negotiation config is invalid")
}
func parseAccept(acceptHeader string) []string {
parts := strings.Split(acceptHeader, ",")
out := make([]string, 0, len(parts))
for _, part := range parts {
if i := strings.IndexByte(part, ';'); i > 0 {
part = part[:i]
}
if part = strings.TrimSpace(part); part != "" {
out = append(out, part)
}
}
return out
}
func lastChar(str string) uint8 {
if str == "" {
panic("The length of the string can't be 0")
}
return str[len(str)-1]
}
func nameOfFunction(f any) string {
return runtime.FuncForPC(reflect.ValueOf(f).Pointer()).Name()
}
func joinPaths(absolutePath, relativePath string) string {
if relativePath == "" {
return absolutePath
}
finalPath := path.Join(absolutePath, relativePath)
if lastChar(relativePath) == '/' && lastChar(finalPath) != '/' {
return finalPath + "/"
}
return finalPath
}
func resolveAddress(addr []string) string {
switch len(addr) {
case 0:
if port := os.Getenv("PORT"); port != "" {
debugPrint("Environment variable PORT=\"%s\"", port)
return ":" + port
}
debugPrint("Environment variable PORT is undefined. Using port :8080 by default")
return ":8080"
case 1:
return addr[0]
default:
panic("too many parameters")
}
}
// https://stackoverflow.com/questions/53069040/checking-a-string-contains-only-ascii-characters
func isASCII(s string) bool {
for i := 0; i < len(s); i++ {
if s[i] > unicode.MaxASCII {
return false
}
}
return true
}
================================================
FILE: gossh/gin/validator/baked_in.go
================================================
package validator
import (
"bytes"
"context"
"crypto/sha256"
"encoding/json"
"fmt"
"io/fs"
"net"
"net/url"
"os"
"reflect"
"strconv"
"strings"
"sync"
"syscall"
"time"
"unicode/utf8"
)
// Func accepts a FieldLevel interface for all validation needs. The return
// value should be true when validation succeeds.
type Func func(fl FieldLevel) bool
// FuncCtx accepts a context.Context and FieldLevel interface for all
// validation needs. The return value should be true when validation succeeds.
type FuncCtx func(ctx context.Context, fl FieldLevel) bool
// wrapFunc wraps normal Func makes it compatible with FuncCtx
func wrapFunc(fn Func) FuncCtx {
if fn == nil {
return nil // be sure not to wrap a bad function.
}
return func(ctx context.Context, fl FieldLevel) bool {
return fn(fl)
}
}
var (
restrictedTags = map[string]struct{}{
diveTag: {},
keysTag: {},
endKeysTag: {},
structOnlyTag: {},
omitempty: {},
omitnil: {},
skipValidationTag: {},
utf8HexComma: {},
utf8Pipe: {},
noStructLevelTag: {},
requiredTag: {},
isdefault: {},
}
// bakedInAliases is a default mapping of a single validation tag that
// defines a common or complex set of validation(s) to simplify
// adding validation to structs.
bakedInAliases = map[string]string{
"iscolor": "hexcolor|rgb|rgba|hsl|hsla",
"country_code": "iso3166_1_alpha2|iso3166_1_alpha3|iso3166_1_alpha_numeric",
}
// bakedInValidators is the default map of ValidationFunc
// you can add, remove or even replace items to suite your needs,
// or even disregard and use your own map if so desired.
bakedInValidators = map[string]Func{
"required": hasValue,
"required_if": requiredIf,
"required_unless": requiredUnless,
"skip_unless": skipUnless,
"required_with": requiredWith,
"required_with_all": requiredWithAll,
"required_without": requiredWithout,
"required_without_all": requiredWithoutAll,
"excluded_if": excludedIf,
"excluded_unless": excludedUnless,
"excluded_with": excludedWith,
"excluded_with_all": excludedWithAll,
"excluded_without": excludedWithout,
"excluded_without_all": excludedWithoutAll,
"isdefault": isDefault,
"len": hasLengthOf,
"min": hasMinOf,
"max": hasMaxOf,
"eq": isEq,
"eq_ignore_case": isEqIgnoreCase,
"ne": isNe,
"ne_ignore_case": isNeIgnoreCase,
"lt": isLt,
"lte": isLte,
"gt": isGt,
"gte": isGte,
"eqfield": isEqField,
"eqcsfield": isEqCrossStructField,
"necsfield": isNeCrossStructField,
"gtcsfield": isGtCrossStructField,
"gtecsfield": isGteCrossStructField,
"ltcsfield": isLtCrossStructField,
"ltecsfield": isLteCrossStructField,
"nefield": isNeField,
"gtefield": isGteField,
"gtfield": isGtField,
"ltefield": isLteField,
"ltfield": isLtField,
"fieldcontains": fieldContains,
"fieldexcludes": fieldExcludes,
"alpha": isAlpha,
"alphanum": isAlphanum,
"alphaunicode": isAlphaUnicode,
"alphanumunicode": isAlphanumUnicode,
"boolean": isBoolean,
"numeric": isNumeric,
"number": isNumber,
"hexadecimal": isHexadecimal,
"hexcolor": isHEXColor,
"rgb": isRGB,
"rgba": isRGBA,
"hsl": isHSL,
"hsla": isHSLA,
"e164": isE164,
"email": isEmail,
"url": isURL,
"http_url": isHttpURL,
"uri": isURI,
"file": isFile,
"filepath": isFilePath,
"base64": isBase64,
"base64url": isBase64URL,
"base64rawurl": isBase64RawURL,
"contains": contains,
"containsany": containsAny,
"containsrune": containsRune,
"excludes": excludes,
"excludesall": excludesAll,
"excludesrune": excludesRune,
"startswith": startsWith,
"endswith": endsWith,
"startsnotwith": startsNotWith,
"endsnotwith": endsNotWith,
"isbn": isISBN,
"isbn10": isISBN10,
"isbn13": isISBN13,
"issn": isISSN,
"eth_addr": isEthereumAddress,
"btc_addr": isBitcoinAddress,
"btc_addr_bech32": isBitcoinBech32Address,
"uuid": isUUID,
"uuid3": isUUID3,
"uuid4": isUUID4,
"uuid5": isUUID5,
"uuid_rfc4122": isUUIDRFC4122,
"uuid3_rfc4122": isUUID3RFC4122,
"uuid4_rfc4122": isUUID4RFC4122,
"uuid5_rfc4122": isUUID5RFC4122,
"ulid": isULID,
"md4": isMD4,
"md5": isMD5,
"sha256": isSHA256,
"sha384": isSHA384,
"sha512": isSHA512,
"ripemd128": isRIPEMD128,
"ripemd160": isRIPEMD160,
"tiger128": isTIGER128,
"tiger160": isTIGER160,
"tiger192": isTIGER192,
"ascii": isASCII,
"printascii": isPrintableASCII,
"multibyte": hasMultiByteCharacter,
"datauri": isDataURI,
"latitude": isLatitude,
"longitude": isLongitude,
"ssn": isSSN,
"ipv4": isIPv4,
"ipv6": isIPv6,
"ip": isIP,
"cidrv4": isCIDRv4,
"cidrv6": isCIDRv6,
"cidr": isCIDR,
"tcp4_addr": isTCP4AddrResolvable,
"tcp6_addr": isTCP6AddrResolvable,
"tcp_addr": isTCPAddrResolvable,
"udp4_addr": isUDP4AddrResolvable,
"udp6_addr": isUDP6AddrResolvable,
"udp_addr": isUDPAddrResolvable,
"ip4_addr": isIP4AddrResolvable,
"ip6_addr": isIP6AddrResolvable,
"ip_addr": isIPAddrResolvable,
"unix_addr": isUnixAddrResolvable,
"mac": isMAC,
"hostname": isHostnameRFC952, // RFC 952
"hostname_rfc1123": isHostnameRFC1123, // RFC 1123
"fqdn": isFQDN,
"unique": isUnique,
"oneof": isOneOf,
"html": isHTML,
"html_encoded": isHTMLEncoded,
"url_encoded": isURLEncoded,
"dir": isDir,
"dirpath": isDirPath,
"json": isJSON,
"jwt": isJWT,
"hostname_port": isHostnamePort,
"lowercase": isLowercase,
"uppercase": isUppercase,
"datetime": isDatetime,
"timezone": isTimeZone,
"iso3166_1_alpha2": isIso3166Alpha2,
"iso3166_1_alpha3": isIso3166Alpha3,
"iso3166_1_alpha_numeric": isIso3166AlphaNumeric,
"iso3166_2": isIso31662,
"iso4217": isIso4217,
"iso4217_numeric": isIso4217Numeric,
"postcode_iso3166_alpha2": isPostcodeByIso3166Alpha2,
"postcode_iso3166_alpha2_field": isPostcodeByIso3166Alpha2Field,
"bic": isIsoBicFormat,
"semver": isSemverFormat,
"dns_rfc1035_label": isDnsRFC1035LabelFormat,
"credit_card": isCreditCard,
"cve": isCveFormat,
"luhn_checksum": hasLuhnChecksum,
"mongodb": isMongoDB,
"cron": isCron,
"spicedb": isSpiceDB,
}
)
var (
oneofValsCache = map[string][]string{}
oneofValsCacheRWLock = sync.RWMutex{}
)
func parseOneOfParam2(s string) []string {
oneofValsCacheRWLock.RLock()
vals, ok := oneofValsCache[s]
oneofValsCacheRWLock.RUnlock()
if !ok {
oneofValsCacheRWLock.Lock()
vals = splitParamsRegex.FindAllString(s, -1)
for i := 0; i < len(vals); i++ {
vals[i] = strings.Replace(vals[i], "'", "", -1)
}
oneofValsCache[s] = vals
oneofValsCacheRWLock.Unlock()
}
return vals
}
func isURLEncoded(fl FieldLevel) bool {
return uRLEncodedRegex.MatchString(fl.Field().String())
}
func isHTMLEncoded(fl FieldLevel) bool {
return hTMLEncodedRegex.MatchString(fl.Field().String())
}
func isHTML(fl FieldLevel) bool {
return hTMLRegex.MatchString(fl.Field().String())
}
func isOneOf(fl FieldLevel) bool {
vals := parseOneOfParam2(fl.Param())
field := fl.Field()
var v string
switch field.Kind() {
case reflect.String:
v = field.String()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
v = strconv.FormatInt(field.Int(), 10)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
v = strconv.FormatUint(field.Uint(), 10)
default:
panic(fmt.Sprintf("Bad field type %T", field.Interface()))
}
for i := 0; i < len(vals); i++ {
if vals[i] == v {
return true
}
}
return false
}
// isUnique is the validation function for validating if each array|slice|map value is unique
func isUnique(fl FieldLevel) bool {
field := fl.Field()
param := fl.Param()
v := reflect.ValueOf(struct{}{})
switch field.Kind() {
case reflect.Slice, reflect.Array:
elem := field.Type().Elem()
if elem.Kind() == reflect.Ptr {
elem = elem.Elem()
}
if param == "" {
m := reflect.MakeMap(reflect.MapOf(elem, v.Type()))
for i := 0; i < field.Len(); i++ {
m.SetMapIndex(reflect.Indirect(field.Index(i)), v)
}
return field.Len() == m.Len()
}
sf, ok := elem.FieldByName(param)
if !ok {
panic(fmt.Sprintf("Bad field name %s", param))
}
sfTyp := sf.Type
if sfTyp.Kind() == reflect.Ptr {
sfTyp = sfTyp.Elem()
}
m := reflect.MakeMap(reflect.MapOf(sfTyp, v.Type()))
var fieldlen int
for i := 0; i < field.Len(); i++ {
key := reflect.Indirect(reflect.Indirect(field.Index(i)).FieldByName(param))
if key.IsValid() {
fieldlen++
m.SetMapIndex(key, v)
}
}
return fieldlen == m.Len()
case reflect.Map:
var m reflect.Value
if field.Type().Elem().Kind() == reflect.Ptr {
m = reflect.MakeMap(reflect.MapOf(field.Type().Elem().Elem(), v.Type()))
} else {
m = reflect.MakeMap(reflect.MapOf(field.Type().Elem(), v.Type()))
}
for _, k := range field.MapKeys() {
m.SetMapIndex(reflect.Indirect(field.MapIndex(k)), v)
}
return field.Len() == m.Len()
default:
if parent := fl.Parent(); parent.Kind() == reflect.Struct {
uniqueField := parent.FieldByName(param)
if uniqueField == reflect.ValueOf(nil) {
panic(fmt.Sprintf("Bad field name provided %s", param))
}
if uniqueField.Kind() != field.Kind() {
panic(fmt.Sprintf("Bad field type %T:%T", field.Interface(), uniqueField.Interface()))
}
return field.Interface() != uniqueField.Interface()
}
panic(fmt.Sprintf("Bad field type %T", field.Interface()))
}
}
// isMAC is the validation function for validating if the field's value is a valid MAC address.
func isMAC(fl FieldLevel) bool {
_, err := net.ParseMAC(fl.Field().String())
return err == nil
}
// isCIDRv4 is the validation function for validating if the field's value is a valid v4 CIDR address.
func isCIDRv4(fl FieldLevel) bool {
ip, net, err := net.ParseCIDR(fl.Field().String())
return err == nil && ip.To4() != nil && net.IP.Equal(ip)
}
// isCIDRv6 is the validation function for validating if the field's value is a valid v6 CIDR address.
func isCIDRv6(fl FieldLevel) bool {
ip, _, err := net.ParseCIDR(fl.Field().String())
return err == nil && ip.To4() == nil
}
// isCIDR is the validation function for validating if the field's value is a valid v4 or v6 CIDR address.
func isCIDR(fl FieldLevel) bool {
_, _, err := net.ParseCIDR(fl.Field().String())
return err == nil
}
// isIPv4 is the validation function for validating if a value is a valid v4 IP address.
func isIPv4(fl FieldLevel) bool {
ip := net.ParseIP(fl.Field().String())
return ip != nil && ip.To4() != nil
}
// isIPv6 is the validation function for validating if the field's value is a valid v6 IP address.
func isIPv6(fl FieldLevel) bool {
ip := net.ParseIP(fl.Field().String())
return ip != nil && ip.To4() == nil
}
// isIP is the validation function for validating if the field's value is a valid v4 or v6 IP address.
func isIP(fl FieldLevel) bool {
ip := net.ParseIP(fl.Field().String())
return ip != nil
}
// isSSN is the validation function for validating if the field's value is a valid SSN.
func isSSN(fl FieldLevel) bool {
field := fl.Field()
if field.Len() != 11 {
return false
}
return sSNRegex.MatchString(field.String())
}
// isLongitude is the validation function for validating if the field's value is a valid longitude coordinate.
func isLongitude(fl FieldLevel) bool {
field := fl.Field()
var v string
switch field.Kind() {
case reflect.String:
v = field.String()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
v = strconv.FormatInt(field.Int(), 10)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
v = strconv.FormatUint(field.Uint(), 10)
case reflect.Float32:
v = strconv.FormatFloat(field.Float(), 'f', -1, 32)
case reflect.Float64:
v = strconv.FormatFloat(field.Float(), 'f', -1, 64)
default:
panic(fmt.Sprintf("Bad field type %T", field.Interface()))
}
return longitudeRegex.MatchString(v)
}
// isLatitude is the validation function for validating if the field's value is a valid latitude coordinate.
func isLatitude(fl FieldLevel) bool {
field := fl.Field()
var v string
switch field.Kind() {
case reflect.String:
v = field.String()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
v = strconv.FormatInt(field.Int(), 10)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
v = strconv.FormatUint(field.Uint(), 10)
case reflect.Float32:
v = strconv.FormatFloat(field.Float(), 'f', -1, 32)
case reflect.Float64:
v = strconv.FormatFloat(field.Float(), 'f', -1, 64)
default:
panic(fmt.Sprintf("Bad field type %T", field.Interface()))
}
return latitudeRegex.MatchString(v)
}
// isDataURI is the validation function for validating if the field's value is a valid data URI.
func isDataURI(fl FieldLevel) bool {
uri := strings.SplitN(fl.Field().String(), ",", 2)
if len(uri) != 2 {
return false
}
if !dataURIRegex.MatchString(uri[0]) {
return false
}
return base64Regex.MatchString(uri[1])
}
// hasMultiByteCharacter is the validation function for validating if the field's value has a multi byte character.
func hasMultiByteCharacter(fl FieldLevel) bool {
field := fl.Field()
if field.Len() == 0 {
return true
}
return multibyteRegex.MatchString(field.String())
}
// isPrintableASCII is the validation function for validating if the field's value is a valid printable ASCII character.
func isPrintableASCII(fl FieldLevel) bool {
return printableASCIIRegex.MatchString(fl.Field().String())
}
// isASCII is the validation function for validating if the field's value is a valid ASCII character.
func isASCII(fl FieldLevel) bool {
return aSCIIRegex.MatchString(fl.Field().String())
}
// isUUID5 is the validation function for validating if the field's value is a valid v5 UUID.
func isUUID5(fl FieldLevel) bool {
return fieldMatchesRegexByStringerValOrString(uUID5Regex, fl)
}
// isUUID4 is the validation function for validating if the field's value is a valid v4 UUID.
func isUUID4(fl FieldLevel) bool {
return fieldMatchesRegexByStringerValOrString(uUID4Regex, fl)
}
// isUUID3 is the validation function for validating if the field's value is a valid v3 UUID.
func isUUID3(fl FieldLevel) bool {
return fieldMatchesRegexByStringerValOrString(uUID3Regex, fl)
}
// isUUID is the validation function for validating if the field's value is a valid UUID of any version.
func isUUID(fl FieldLevel) bool {
return fieldMatchesRegexByStringerValOrString(uUIDRegex, fl)
}
// isUUID5RFC4122 is the validation function for validating if the field's value is a valid RFC4122 v5 UUID.
func isUUID5RFC4122(fl FieldLevel) bool {
return fieldMatchesRegexByStringerValOrString(uUID5RFC4122Regex, fl)
}
// isUUID4RFC4122 is the validation function for validating if the field's value is a valid RFC4122 v4 UUID.
func isUUID4RFC4122(fl FieldLevel) bool {
return fieldMatchesRegexByStringerValOrString(uUID4RFC4122Regex, fl)
}
// isUUID3RFC4122 is the validation function for validating if the field's value is a valid RFC4122 v3 UUID.
func isUUID3RFC4122(fl FieldLevel) bool {
return fieldMatchesRegexByStringerValOrString(uUID3RFC4122Regex, fl)
}
// isUUIDRFC4122 is the validation function for validating if the field's value is a valid RFC4122 UUID of any version.
func isUUIDRFC4122(fl FieldLevel) bool {
return fieldMatchesRegexByStringerValOrString(uUIDRFC4122Regex, fl)
}
// isULID is the validation function for validating if the field's value is a valid ULID.
func isULID(fl FieldLevel) bool {
return fieldMatchesRegexByStringerValOrString(uLIDRegex, fl)
}
// isMD4 is the validation function for validating if the field's value is a valid MD4.
func isMD4(fl FieldLevel) bool {
return md4Regex.MatchString(fl.Field().String())
}
// isMD5 is the validation function for validating if the field's value is a valid MD5.
func isMD5(fl FieldLevel) bool {
return md5Regex.MatchString(fl.Field().String())
}
// isSHA256 is the validation function for validating if the field's value is a valid SHA256.
func isSHA256(fl FieldLevel) bool {
return sha256Regex.MatchString(fl.Field().String())
}
// isSHA384 is the validation function for validating if the field's value is a valid SHA384.
func isSHA384(fl FieldLevel) bool {
return sha384Regex.MatchString(fl.Field().String())
}
// isSHA512 is the validation function for validating if the field's value is a valid SHA512.
func isSHA512(fl FieldLevel) bool {
return sha512Regex.MatchString(fl.Field().String())
}
// isRIPEMD128 is the validation function for validating if the field's value is a valid PIPEMD128.
func isRIPEMD128(fl FieldLevel) bool {
return ripemd128Regex.MatchString(fl.Field().String())
}
// isRIPEMD160 is the validation function for validating if the field's value is a valid PIPEMD160.
func isRIPEMD160(fl FieldLevel) bool {
return ripemd160Regex.MatchString(fl.Field().String())
}
// isTIGER128 is the validation function for validating if the field's value is a valid TIGER128.
func isTIGER128(fl FieldLevel) bool {
return tiger128Regex.MatchString(fl.Field().String())
}
// isTIGER160 is the validation function for validating if the field's value is a valid TIGER160.
func isTIGER160(fl FieldLevel) bool {
return tiger160Regex.MatchString(fl.Field().String())
}
// isTIGER192 is the validation function for validating if the field's value is a valid isTIGER192.
func isTIGER192(fl FieldLevel) bool {
return tiger192Regex.MatchString(fl.Field().String())
}
// isISBN is the validation function for validating if the field's value is a valid v10 or v13 ISBN.
func isISBN(fl FieldLevel) bool {
return isISBN10(fl) || isISBN13(fl)
}
// isISBN13 is the validation function for validating if the field's value is a valid v13 ISBN.
func isISBN13(fl FieldLevel) bool {
s := strings.Replace(strings.Replace(fl.Field().String(), "-", "", 4), " ", "", 4)
if !iSBN13Regex.MatchString(s) {
return false
}
var checksum int32
var i int32
factor := []int32{1, 3}
for i = 0; i < 12; i++ {
checksum += factor[i%2] * int32(s[i]-'0')
}
return (int32(s[12]-'0'))-((10-(checksum%10))%10) == 0
}
// isISBN10 is the validation function for validating if the field's value is a valid v10 ISBN.
func isISBN10(fl FieldLevel) bool {
s := strings.Replace(strings.Replace(fl.Field().String(), "-", "", 3), " ", "", 3)
if !iSBN10Regex.MatchString(s) {
return false
}
var checksum int32
var i int32
for i = 0; i < 9; i++ {
checksum += (i + 1) * int32(s[i]-'0')
}
if s[9] == 'X' {
checksum += 10 * 10
} else {
checksum += 10 * int32(s[9]-'0')
}
return checksum%11 == 0
}
// isISSN is the validation function for validating if the field's value is a valid ISSN.
func isISSN(fl FieldLevel) bool {
s := fl.Field().String()
if !iSSNRegex.MatchString(s) {
return false
}
s = strings.ReplaceAll(s, "-", "")
pos := 8
checksum := 0
for i := 0; i < 7; i++ {
checksum += pos * int(s[i]-'0')
pos--
}
if s[7] == 'X' {
checksum += 10
} else {
checksum += int(s[7] - '0')
}
return checksum%11 == 0
}
// isEthereumAddress is the validation function for validating if the field's value is a valid Ethereum address.
func isEthereumAddress(fl FieldLevel) bool {
address := fl.Field().String()
return ethAddressRegex.MatchString(address)
}
// isBitcoinAddress is the validation function for validating if the field's value is a valid btc address
func isBitcoinAddress(fl FieldLevel) bool {
address := fl.Field().String()
if !btcAddressRegex.MatchString(address) {
return false
}
alphabet := []byte("123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz")
decode := [25]byte{}
for _, n := range []byte(address) {
d := bytes.IndexByte(alphabet, n)
for i := 24; i >= 0; i-- {
d += 58 * int(decode[i])
decode[i] = byte(d % 256)
d /= 256
}
}
h := sha256.New()
_, _ = h.Write(decode[:21])
d := h.Sum([]byte{})
h = sha256.New()
_, _ = h.Write(d)
validchecksum := [4]byte{}
computedchecksum := [4]byte{}
copy(computedchecksum[:], h.Sum(d[:0]))
copy(validchecksum[:], decode[21:])
return validchecksum == computedchecksum
}
// isBitcoinBech32Address is the validation function for validating if the field's value is a valid bech32 btc address
func isBitcoinBech32Address(fl FieldLevel) bool {
address := fl.Field().String()
if !btcLowerAddressRegexBech32.MatchString(address) && !btcUpperAddressRegexBech32.MatchString(address) {
return false
}
am := len(address) % 8
if am == 0 || am == 3 || am == 5 {
return false
}
address = strings.ToLower(address)
alphabet := "qpzry9x8gf2tvdw0s3jn54khce6mua7l"
hr := []int{3, 3, 0, 2, 3} // the human readable part will always be bc
addr := address[3:]
dp := make([]int, 0, len(addr))
for _, c := range addr {
dp = append(dp, strings.IndexRune(alphabet, c))
}
ver := dp[0]
if ver < 0 || ver > 16 {
return false
}
if ver == 0 {
if len(address) != 42 && len(address) != 62 {
return false
}
}
values := append(hr, dp...)
GEN := []int{0x3b6a57b2, 0x26508e6d, 0x1ea119fa, 0x3d4233dd, 0x2a1462b3}
p := 1
for _, v := range values {
b := p >> 25
p = (p&0x1ffffff)<<5 ^ v
for i := 0; i < 5; i++ {
if (b>>uint(i))&1 == 1 {
p ^= GEN[i]
}
}
}
if p != 1 {
return false
}
b := uint(0)
acc := 0
mv := (1 << 5) - 1
var sw []int
for _, v := range dp[1 : len(dp)-6] {
acc = (acc << 5) | v
b += 5
for b >= 8 {
b -= 8
sw = append(sw, (acc>>b)&mv)
}
}
if len(sw) < 2 || len(sw) > 40 {
return false
}
return true
}
// excludesRune is the validation function for validating that the field's value does not contain the rune specified within the param.
func excludesRune(fl FieldLevel) bool {
return !containsRune(fl)
}
// excludesAll is the validation function for validating that the field's value does not contain any of the characters specified within the param.
func excludesAll(fl FieldLevel) bool {
return !containsAny(fl)
}
// excludes is the validation function for validating that the field's value does not contain the text specified within the param.
func excludes(fl FieldLevel) bool {
return !contains(fl)
}
// containsRune is the validation function for validating that the field's value contains the rune specified within the param.
func containsRune(fl FieldLevel) bool {
r, _ := utf8.DecodeRuneInString(fl.Param())
return strings.ContainsRune(fl.Field().String(), r)
}
// containsAny is the validation function for validating that the field's value contains any of the characters specified within the param.
func containsAny(fl FieldLevel) bool {
return strings.ContainsAny(fl.Field().String(), fl.Param())
}
// contains is the validation function for validating that the field's value contains the text specified within the param.
func contains(fl FieldLevel) bool {
return strings.Contains(fl.Field().String(), fl.Param())
}
// startsWith is the validation function for validating that the field's value starts with the text specified within the param.
func startsWith(fl FieldLevel) bool {
return strings.HasPrefix(fl.Field().String(), fl.Param())
}
// endsWith is the validation function for validating that the field's value ends with the text specified within the param.
func endsWith(fl FieldLevel) bool {
return strings.HasSuffix(fl.Field().String(), fl.Param())
}
// startsNotWith is the validation function for validating that the field's value does not start with the text specified within the param.
func startsNotWith(fl FieldLevel) bool {
return !startsWith(fl)
}
// endsNotWith is the validation function for validating that the field's value does not end with the text specified within the param.
func endsNotWith(fl FieldLevel) bool {
return !endsWith(fl)
}
// fieldContains is the validation function for validating if the current field's value contains the field specified by the param's value.
func fieldContains(fl FieldLevel) bool {
field := fl.Field()
currentField, _, ok := fl.GetStructFieldOK()
if !ok {
return false
}
return strings.Contains(field.String(), currentField.String())
}
// fieldExcludes is the validation function for validating if the current field's value excludes the field specified by the param's value.
func fieldExcludes(fl FieldLevel) bool {
field := fl.Field()
currentField, _, ok := fl.GetStructFieldOK()
if !ok {
return true
}
return !strings.Contains(field.String(), currentField.String())
}
// isNeField is the validation function for validating if the current field's value is not equal to the field specified by the param's value.
func isNeField(fl FieldLevel) bool {
field := fl.Field()
kind := field.Kind()
currentField, currentKind, ok := fl.GetStructFieldOK()
if !ok || currentKind != kind {
return true
}
switch kind {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return field.Int() != currentField.Int()
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return field.Uint() != currentField.Uint()
case reflect.Float32, reflect.Float64:
return field.Float() != currentField.Float()
case reflect.Slice, reflect.Map, reflect.Array:
return int64(field.Len()) != int64(currentField.Len())
case reflect.Bool:
return field.Bool() != currentField.Bool()
case reflect.Struct:
fieldType := field.Type()
if fieldType.ConvertibleTo(timeType) && currentField.Type().ConvertibleTo(timeType) {
t := currentField.Interface().(time.Time)
fieldTime := field.Interface().(time.Time)
return !fieldTime.Equal(t)
}
// Not Same underlying type i.e. struct and time
if fieldType != currentField.Type() {
return true
}
}
// default reflect.String:
return field.String() != currentField.String()
}
// isNe is the validation function for validating that the field's value does not equal the provided param value.
func isNe(fl FieldLevel) bool {
return !isEq(fl)
}
// isNeIgnoreCase is the validation function for validating that the field's string value does not equal the
// provided param value. The comparison is case-insensitive
func isNeIgnoreCase(fl FieldLevel) bool {
return !isEqIgnoreCase(fl)
}
// isLteCrossStructField is the validation function for validating if the current field's value is less than or equal to the field, within a separate struct, specified by the param's value.
func isLteCrossStructField(fl FieldLevel) bool {
field := fl.Field()
kind := field.Kind()
topField, topKind, ok := fl.GetStructFieldOK()
if !ok || topKind != kind {
return false
}
switch kind {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return field.Int() <= topField.Int()
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return field.Uint() <= topField.Uint()
case reflect.Float32, reflect.Float64:
return field.Float() <= topField.Float()
case reflect.Slice, reflect.Map, reflect.Array:
return int64(field.Len()) <= int64(topField.Len())
case reflect.Struct:
fieldType := field.Type()
if fieldType.ConvertibleTo(timeType) && topField.Type().ConvertibleTo(timeType) {
fieldTime := field.Convert(timeType).Interface().(time.Time)
topTime := topField.Convert(timeType).Interface().(time.Time)
return fieldTime.Before(topTime) || fieldTime.Equal(topTime)
}
// Not Same underlying type i.e. struct and time
if fieldType != topField.Type() {
return false
}
}
// default reflect.String:
return field.String() <= topField.String()
}
// isLtCrossStructField is the validation function for validating if the current field's value is less than the field, within a separate struct, specified by the param's value.
// NOTE: This is exposed for use within your own custom functions and not intended to be called directly.
func isLtCrossStructField(fl FieldLevel) bool {
field := fl.Field()
kind := field.Kind()
topField, topKind, ok := fl.GetStructFieldOK()
if !ok || topKind != kind {
return false
}
switch kind {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return field.Int() < topField.Int()
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return field.Uint() < topField.Uint()
case reflect.Float32, reflect.Float64:
return field.Float() < topField.Float()
case reflect.Slice, reflect.Map, reflect.Array:
return int64(field.Len()) < int64(topField.Len())
case reflect.Struct:
fieldType := field.Type()
if fieldType.ConvertibleTo(timeType) && topField.Type().ConvertibleTo(timeType) {
fieldTime := field.Convert(timeType).Interface().(time.Time)
topTime := topField.Convert(timeType).Interface().(time.Time)
return fieldTime.Before(topTime)
}
// Not Same underlying type i.e. struct and time
if fieldType != topField.Type() {
return false
}
}
// default reflect.String:
return field.String() < topField.String()
}
// isGteCrossStructField is the validation function for validating if the current field's value is greater than or equal to the field, within a separate struct, specified by the param's value.
func isGteCrossStructField(fl FieldLevel) bool {
field := fl.Field()
kind := field.Kind()
topField, topKind, ok := fl.GetStructFieldOK()
if !ok || topKind != kind {
return false
}
switch kind {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return field.Int() >= topField.Int()
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return field.Uint() >= topField.Uint()
case reflect.Float32, reflect.Float64:
return field.Float() >= topField.Float()
case reflect.Slice, reflect.Map, reflect.Array:
return int64(field.Len()) >= int64(topField.Len())
case reflect.Struct:
fieldType := field.Type()
if fieldType.ConvertibleTo(timeType) && topField.Type().ConvertibleTo(timeType) {
fieldTime := field.Convert(timeType).Interface().(time.Time)
topTime := topField.Convert(timeType).Interface().(time.Time)
return fieldTime.After(topTime) || fieldTime.Equal(topTime)
}
// Not Same underlying type i.e. struct and time
if fieldType != topField.Type() {
return false
}
}
// default reflect.String:
return field.String() >= topField.String()
}
// isGtCrossStructField is the validation function for validating if the current field's value is greater than the field, within a separate struct, specified by the param's value.
func isGtCrossStructField(fl FieldLevel) bool {
field := fl.Field()
kind := field.Kind()
topField, topKind, ok := fl.GetStructFieldOK()
if !ok || topKind != kind {
return false
}
switch kind {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return field.Int() > topField.Int()
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return field.Uint() > topField.Uint()
case reflect.Float32, reflect.Float64:
return field.Float() > topField.Float()
case reflect.Slice, reflect.Map, reflect.Array:
return int64(field.Len()) > int64(topField.Len())
case reflect.Struct:
fieldType := field.Type()
if fieldType.ConvertibleTo(timeType) && topField.Type().ConvertibleTo(timeType) {
fieldTime := field.Convert(timeType).Interface().(time.Time)
topTime := topField.Convert(timeType).Interface().(time.Time)
return fieldTime.After(topTime)
}
// Not Same underlying type i.e. struct and time
if fieldType != topField.Type() {
return false
}
}
// default reflect.String:
return field.String() > topField.String()
}
// isNeCrossStructField is the validation function for validating that the current field's value is not equal to the field, within a separate struct, specified by the param's value.
func isNeCrossStructField(fl FieldLevel) bool {
field := fl.Field()
kind := field.Kind()
topField, currentKind, ok := fl.GetStructFieldOK()
if !ok || currentKind != kind {
return true
}
switch kind {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return topField.Int() != field.Int()
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return topField.Uint() != field.Uint()
case reflect.Float32, reflect.Float64:
return topField.Float() != field.Float()
case reflect.Slice, reflect.Map, reflect.Array:
return int64(topField.Len()) != int64(field.Len())
case reflect.Bool:
return topField.Bool() != field.Bool()
case reflect.Struct:
fieldType := field.Type()
if fieldType.ConvertibleTo(timeType) && topField.Type().ConvertibleTo(timeType) {
t := field.Convert(timeType).Interface().(time.Time)
fieldTime := topField.Convert(timeType).Interface().(time.Time)
return !fieldTime.Equal(t)
}
// Not Same underlying type i.e. struct and time
if fieldType != topField.Type() {
return true
}
}
// default reflect.String:
return topField.String() != field.String()
}
// isEqCrossStructField is the validation function for validating that the current field's value is equal to the field, within a separate struct, specified by the param's value.
func isEqCrossStructField(fl FieldLevel) bool {
field := fl.Field()
kind := field.Kind()
topField, topKind, ok := fl.GetStructFieldOK()
if !ok || topKind != kind {
return false
}
switch kind {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return topField.Int() == field.Int()
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return topField.Uint() == field.Uint()
case reflect.Float32, reflect.Float64:
return topField.Float() == field.Float()
case reflect.Slice, reflect.Map, reflect.Array:
return int64(topField.Len()) == int64(field.Len())
case reflect.Bool:
return topField.Bool() == field.Bool()
case reflect.Struct:
fieldType := field.Type()
if fieldType.ConvertibleTo(timeType) && topField.Type().ConvertibleTo(timeType) {
t := field.Convert(timeType).Interface().(time.Time)
fieldTime := topField.Convert(timeType).Interface().(time.Time)
return fieldTime.Equal(t)
}
// Not Same underlying type i.e. struct and time
if fieldType != topField.Type() {
return false
}
}
// default reflect.String:
return topField.String() == field.String()
}
// isEqField is the validation function for validating if the current field's value is equal to the field specified by the param's value.
func isEqField(fl FieldLevel) bool {
field := fl.Field()
kind := field.Kind()
currentField, currentKind, ok := fl.GetStructFieldOK()
if !ok || currentKind != kind {
return false
}
switch kind {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return field.Int() == currentField.Int()
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return field.Uint() == currentField.Uint()
case reflect.Float32, reflect.Float64:
return field.Float() == currentField.Float()
case reflect.Slice, reflect.Map, reflect.Array:
return int64(field.Len()) == int64(currentField.Len())
case reflect.Bool:
return field.Bool() == currentField.Bool()
case reflect.Struct:
fieldType := field.Type()
if fieldType.ConvertibleTo(timeType) && currentField.Type().ConvertibleTo(timeType) {
t := currentField.Convert(timeType).Interface().(time.Time)
fieldTime := field.Convert(timeType).Interface().(time.Time)
return fieldTime.Equal(t)
}
// Not Same underlying type i.e. struct and time
if fieldType != currentField.Type() {
return false
}
}
// default reflect.String:
return field.String() == currentField.String()
}
// isEq is the validation function for validating if the current field's value is equal to the param's value.
func isEq(fl FieldLevel) bool {
field := fl.Field()
param := fl.Param()
switch field.Kind() {
case reflect.String:
return field.String() == param
case reflect.Slice, reflect.Map, reflect.Array:
p := asInt(param)
return int64(field.Len()) == p
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
p := asIntFromType(field.Type(), param)
return field.Int() == p
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
p := asUint(param)
return field.Uint() == p
case reflect.Float32:
p := asFloat32(param)
return field.Float() == p
case reflect.Float64:
p := asFloat64(param)
return field.Float() == p
case reflect.Bool:
p := asBool(param)
return field.Bool() == p
}
panic(fmt.Sprintf("Bad field type %T", field.Interface()))
}
// isEqIgnoreCase is the validation function for validating if the current field's string value is
// equal to the param's value.
// The comparison is case-insensitive.
func isEqIgnoreCase(fl FieldLevel) bool {
field := fl.Field()
param := fl.Param()
switch field.Kind() {
case reflect.String:
return strings.EqualFold(field.String(), param)
}
panic(fmt.Sprintf("Bad field type %T", field.Interface()))
}
// isPostcodeByIso3166Alpha2 validates by value which is country code in iso 3166 alpha 2
// example: `postcode_iso3166_alpha2=US`
func isPostcodeByIso3166Alpha2(fl FieldLevel) bool {
field := fl.Field()
param := fl.Param()
reg, found := postCodeRegexDict[param]
if !found {
return false
}
return reg.MatchString(field.String())
}
// isPostcodeByIso3166Alpha2Field validates by field which represents for a value of country code in iso 3166 alpha 2
// example: `postcode_iso3166_alpha2_field=CountryCode`
func isPostcodeByIso3166Alpha2Field(fl FieldLevel) bool {
field := fl.Field()
params := parseOneOfParam2(fl.Param())
if len(params) != 1 {
return false
}
currentField, kind, _, found := fl.GetStructFieldOKAdvanced2(fl.Parent(), params[0])
if !found {
return false
}
if kind != reflect.String {
panic(fmt.Sprintf("Bad field type %T", currentField.Interface()))
}
reg, found := postCodeRegexDict[currentField.String()]
if !found {
return false
}
return reg.MatchString(field.String())
}
// isBase64 is the validation function for validating if the current field's value is a valid base 64.
func isBase64(fl FieldLevel) bool {
return base64Regex.MatchString(fl.Field().String())
}
// isBase64URL is the validation function for validating if the current field's value is a valid base64 URL safe string.
func isBase64URL(fl FieldLevel) bool {
return base64URLRegex.MatchString(fl.Field().String())
}
// isBase64RawURL is the validation function for validating if the current field's value is a valid base64 URL safe string without '=' padding.
func isBase64RawURL(fl FieldLevel) bool {
return base64RawURLRegex.MatchString(fl.Field().String())
}
// isURI is the validation function for validating if the current field's value is a valid URI.
func isURI(fl FieldLevel) bool {
field := fl.Field()
switch field.Kind() {
case reflect.String:
s := field.String()
// checks needed as of Go 1.6 because of change https://github.com/golang/go/commit/617c93ce740c3c3cc28cdd1a0d712be183d0b328#diff-6c2d018290e298803c0c9419d8739885L195
// emulate browser and strip the '#' suffix prior to validation. see issue-#237
if i := strings.Index(s, "#"); i > -1 {
s = s[:i]
}
if len(s) == 0 {
return false
}
_, err := url.ParseRequestURI(s)
return err == nil
}
panic(fmt.Sprintf("Bad field type %T", field.Interface()))
}
// isFileURL is the helper function for validating if the `path` valid file URL as per RFC8089
func isFileURL(path string) bool {
if !strings.HasPrefix(path, "file:/") {
return false
}
_, err := url.ParseRequestURI(path)
return err == nil
}
// isURL is the validation function for validating if the current field's value is a valid URL.
func isURL(fl FieldLevel) bool {
field := fl.Field()
switch field.Kind() {
case reflect.String:
s := strings.ToLower(field.String())
if len(s) == 0 {
return false
}
if isFileURL(s) {
return true
}
url, err := url.Parse(s)
if err != nil || url.Scheme == "" {
return false
}
if url.Host == "" && url.Fragment == "" && url.Opaque == "" {
return false
}
return true
}
panic(fmt.Sprintf("Bad field type %T", field.Interface()))
}
// isHttpURL is the validation function for validating if the current field's value is a valid HTTP(s) URL.
func isHttpURL(fl FieldLevel) bool {
if !isURL(fl) {
return false
}
field := fl.Field()
switch field.Kind() {
case reflect.String:
s := strings.ToLower(field.String())
url, err := url.Parse(s)
if err != nil || url.Host == "" {
return false
}
return url.Scheme == "http" || url.Scheme == "https"
}
panic(fmt.Sprintf("Bad field type %T", field.Interface()))
}
// isFile is the validation function for validating if the current field's value is a valid existing file path.
func isFile(fl FieldLevel) bool {
field := fl.Field()
switch field.Kind() {
case reflect.String:
fileInfo, err := os.Stat(field.String())
if err != nil {
return false
}
return !fileInfo.IsDir()
}
panic(fmt.Sprintf("Bad field type %T", field.Interface()))
}
// isFilePath is the validation function for validating if the current field's value is a valid file path.
func isFilePath(fl FieldLevel) bool {
var exists bool
var err error
field := fl.Field()
// Not valid if it is a directory.
if isDir(fl) {
return false
}
// If it exists, it obviously is valid.
// This is done first to avoid code duplication and unnecessary additional logic.
if exists = isFile(fl); exists {
return true
}
// It does not exist but may still be a valid filepath.
switch field.Kind() {
case reflect.String:
// Every OS allows for whitespace, but none
// let you use a file with no filename (to my knowledge).
// Unless you're dealing with raw inodes, but I digress.
if strings.TrimSpace(field.String()) == "" {
return false
}
// We make sure it isn't a directory.
if strings.HasSuffix(field.String(), string(os.PathSeparator)) {
return false
}
if _, err = os.Stat(field.String()); err != nil {
switch t := err.(type) {
case *fs.PathError:
if t.Err == syscall.EINVAL {
// It's definitely an invalid character in the filepath.
return false
}
// It could be a permission error, a does-not-exist error, etc.
// Out-of-scope for this validation, though.
return true
default:
// Something went *seriously* wrong.
/*
Per https://pkg.go.dev/os#Stat:
"If there is an error, it will be of type *PathError."
*/
panic(err)
}
}
}
panic(fmt.Sprintf("Bad field type %T", field.Interface()))
}
// isE164 is the validation function for validating if the current field's value is a valid e.164 formatted phone number.
func isE164(fl FieldLevel) bool {
return e164Regex.MatchString(fl.Field().String())
}
// isEmail is the validation function for validating if the current field's value is a valid email address.
func isEmail(fl FieldLevel) bool {
return emailRegex.MatchString(fl.Field().String())
}
// isHSLA is the validation function for validating if the current field's value is a valid HSLA color.
func isHSLA(fl FieldLevel) bool {
return hslaRegex.MatchString(fl.Field().String())
}
// isHSL is the validation function for validating if the current field's value is a valid HSL color.
func isHSL(fl FieldLevel) bool {
return hslRegex.MatchString(fl.Field().String())
}
// isRGBA is the validation function for validating if the current field's value is a valid RGBA color.
func isRGBA(fl FieldLevel) bool {
return rgbaRegex.MatchString(fl.Field().String())
}
// isRGB is the validation function for validating if the current field's value is a valid RGB color.
func isRGB(fl FieldLevel) bool {
return rgbRegex.MatchString(fl.Field().String())
}
// isHEXColor is the validation function for validating if the current field's value is a valid HEX color.
func isHEXColor(fl FieldLevel) bool {
return hexColorRegex.MatchString(fl.Field().String())
}
// isHexadecimal is the validation function for validating if the current field's value is a valid hexadecimal.
func isHexadecimal(fl FieldLevel) bool {
return hexadecimalRegex.MatchString(fl.Field().String())
}
// isNumber is the validation function for validating if the current field's value is a valid number.
func isNumber(fl FieldLevel) bool {
switch fl.Field().Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.Float32, reflect.Float64:
return true
default:
return numberRegex.MatchString(fl.Field().String())
}
}
// isNumeric is the validation function for validating if the current field's value is a valid numeric value.
func isNumeric(fl FieldLevel) bool {
switch fl.Field().Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.Float32, reflect.Float64:
return true
default:
return numericRegex.MatchString(fl.Field().String())
}
}
// isAlphanum is the validation function for validating if the current field's value is a valid alphanumeric value.
func isAlphanum(fl FieldLevel) bool {
return alphaNumericRegex.MatchString(fl.Field().String())
}
// isAlpha is the validation function for validating if the current field's value is a valid alpha value.
func isAlpha(fl FieldLevel) bool {
return alphaRegex.MatchString(fl.Field().String())
}
// isAlphanumUnicode is the validation function for validating if the current field's value is a valid alphanumeric unicode value.
func isAlphanumUnicode(fl FieldLevel) bool {
return alphaUnicodeNumericRegex.MatchString(fl.Field().String())
}
// isAlphaUnicode is the validation function for validating if the current field's value is a valid alpha unicode value.
func isAlphaUnicode(fl FieldLevel) bool {
return alphaUnicodeRegex.MatchString(fl.Field().String())
}
// isBoolean is the validation function for validating if the current field's value is a valid boolean value or can be safely converted to a boolean value.
func isBoolean(fl FieldLevel) bool {
switch fl.Field().Kind() {
case reflect.Bool:
return true
default:
_, err := strconv.ParseBool(fl.Field().String())
return err == nil
}
}
// isDefault is the opposite of required aka hasValue
func isDefault(fl FieldLevel) bool {
return !hasValue(fl)
}
// hasValue is the validation function for validating if the current field's value is not the default static value.
func hasValue(fl FieldLevel) bool {
field := fl.Field()
switch field.Kind() {
case reflect.Slice, reflect.Map, reflect.Ptr, reflect.Interface, reflect.Chan, reflect.Func:
return !field.IsNil()
default:
if fl.(*validate).fldIsPointer && field.Interface() != nil {
return true
}
return field.IsValid() && !field.IsZero()
}
}
// requireCheckFieldKind is a func for check field kind
func requireCheckFieldKind(fl FieldLevel, param string, defaultNotFoundValue bool) bool {
field := fl.Field()
kind := field.Kind()
var nullable, found bool
if len(param) > 0 {
field, kind, nullable, found = fl.GetStructFieldOKAdvanced2(fl.Parent(), param)
if !found {
return defaultNotFoundValue
}
}
switch kind {
case reflect.Invalid:
return defaultNotFoundValue
case reflect.Slice, reflect.Map, reflect.Ptr, reflect.Interface, reflect.Chan, reflect.Func:
return field.IsNil()
default:
if nullable && field.Interface() != nil {
return false
}
return field.IsValid() && field.IsZero()
}
}
// requireCheckFieldValue is a func for check field value
func requireCheckFieldValue(
fl FieldLevel, param string, value string, defaultNotFoundValue bool,
) bool {
field, kind, _, found := fl.GetStructFieldOKAdvanced2(fl.Parent(), param)
if !found {
return defaultNotFoundValue
}
switch kind {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return field.Int() == asInt(value)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return field.Uint() == asUint(value)
case reflect.Float32:
return field.Float() == asFloat32(value)
case reflect.Float64:
return field.Float() == asFloat64(value)
case reflect.Slice, reflect.Map, reflect.Array:
return int64(field.Len()) == asInt(value)
case reflect.Bool:
return field.Bool() == asBool(value)
}
// default reflect.String:
return field.String() == value
}
// requiredIf is the validation function
// The field under validation must be present and not empty only if all the other specified fields are equal to the value following with the specified field.
func requiredIf(fl FieldLevel) bool {
params := parseOneOfParam2(fl.Param())
if len(params)%2 != 0 {
panic(fmt.Sprintf("Bad param number for required_if %s", fl.FieldName()))
}
for i := 0; i < len(params); i += 2 {
if !requireCheckFieldValue(fl, params[i], params[i+1], false) {
return true
}
}
return hasValue(fl)
}
// excludedIf is the validation function
// The field under validation must not be present or is empty only if all the other specified fields are equal to the value following with the specified field.
func excludedIf(fl FieldLevel) bool {
params := parseOneOfParam2(fl.Param())
if len(params)%2 != 0 {
panic(fmt.Sprintf("Bad param number for excluded_if %s", fl.FieldName()))
}
for i := 0; i < len(params); i += 2 {
if !requireCheckFieldValue(fl, params[i], params[i+1], false) {
return true
}
}
return !hasValue(fl)
}
// requiredUnless is the validation function
// The field under validation must be present and not empty only unless all the other specified fields are equal to the value following with the specified field.
func requiredUnless(fl FieldLevel) bool {
params := parseOneOfParam2(fl.Param())
if len(params)%2 != 0 {
panic(fmt.Sprintf("Bad param number for required_unless %s", fl.FieldName()))
}
for i := 0; i < len(params); i += 2 {
if requireCheckFieldValue(fl, params[i], params[i+1], false) {
return true
}
}
return hasValue(fl)
}
// skipUnless is the validation function
// The field under validation must be present and not empty only unless all the other specified fields are equal to the value following with the specified field.
func skipUnless(fl FieldLevel) bool {
params := parseOneOfParam2(fl.Param())
if len(params)%2 != 0 {
panic(fmt.Sprintf("Bad param number for skip_unless %s", fl.FieldName()))
}
for i := 0; i < len(params); i += 2 {
if !requireCheckFieldValue(fl, params[i], params[i+1], false) {
return true
}
}
return hasValue(fl)
}
// excludedUnless is the validation function
// The field under validation must not be present or is empty unless all the other specified fields are equal to the value following with the specified field.
func excludedUnless(fl FieldLevel) bool {
params := parseOneOfParam2(fl.Param())
if len(params)%2 != 0 {
panic(fmt.Sprintf("Bad param number for excluded_unless %s", fl.FieldName()))
}
for i := 0; i < len(params); i += 2 {
if !requireCheckFieldValue(fl, params[i], params[i+1], false) {
return !hasValue(fl)
}
}
return true
}
// excludedWith is the validation function
// The field under validation must not be present or is empty if any of the other specified fields are present.
func excludedWith(fl FieldLevel) bool {
params := parseOneOfParam2(fl.Param())
for _, param := range params {
if !requireCheckFieldKind(fl, param, true) {
return !hasValue(fl)
}
}
return true
}
// requiredWith is the validation function
// The field under validation must be present and not empty only if any of the other specified fields are present.
func requiredWith(fl FieldLevel) bool {
params := parseOneOfParam2(fl.Param())
for _, param := range params {
if !requireCheckFieldKind(fl, param, true) {
return hasValue(fl)
}
}
return true
}
// excludedWithAll is the validation function
// The field under validation must not be present or is empty if all of the other specified fields are present.
func excludedWithAll(fl FieldLevel) bool {
params := parseOneOfParam2(fl.Param())
for _, param := range params {
if requireCheckFieldKind(fl, param, true) {
return true
}
}
return !hasValue(fl)
}
// requiredWithAll is the validation function
// The field under validation must be present and not empty only if all of the other specified fields are present.
func requiredWithAll(fl FieldLevel) bool {
params := parseOneOfParam2(fl.Param())
for _, param := range params {
if requireCheckFieldKind(fl, param, true) {
return true
}
}
return hasValue(fl)
}
// excludedWithout is the validation function
// The field under validation must not be present or is empty when any of the other specified fields are not present.
func excludedWithout(fl FieldLevel) bool {
if requireCheckFieldKind(fl, strings.TrimSpace(fl.Param()), true) {
return !hasValue(fl)
}
return true
}
// requiredWithout is the validation function
// The field under validation must be present and not empty only when any of the other specified fields are not present.
func requiredWithout(fl FieldLevel) bool {
if requireCheckFieldKind(fl, strings.TrimSpace(fl.Param()), true) {
return hasValue(fl)
}
return true
}
// excludedWithoutAll is the validation function
// The field under validation must not be present or is empty when all of the other specified fields are not present.
func excludedWithoutAll(fl FieldLevel) bool {
params := parseOneOfParam2(fl.Param())
for _, param := range params {
if !requireCheckFieldKind(fl, param, true) {
return true
}
}
return !hasValue(fl)
}
// requiredWithoutAll is the validation function
// The field under validation must be present and not empty only when all of the other specified fields are not present.
func requiredWithoutAll(fl FieldLevel) bool {
params := parseOneOfParam2(fl.Param())
for _, param := range params {
if !requireCheckFieldKind(fl, param, true) {
return true
}
}
return hasValue(fl)
}
// isGteField is the validation function for validating if the current field's value is greater than or equal to the field specified by the param's value.
func isGteField(fl FieldLevel) bool {
field := fl.Field()
kind := field.Kind()
currentField, currentKind, ok := fl.GetStructFieldOK()
if !ok || currentKind != kind {
return false
}
switch kind {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return field.Int() >= currentField.Int()
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return field.Uint() >= currentField.Uint()
case reflect.Float32, reflect.Float64:
return field.Float() >= currentField.Float()
case reflect.Struct:
fieldType := field.Type()
if fieldType.ConvertibleTo(timeType) && currentField.Type().ConvertibleTo(timeType) {
t := currentField.Convert(timeType).Interface().(time.Time)
fieldTime := field.Convert(timeType).Interface().(time.Time)
return fieldTime.After(t) || fieldTime.Equal(t)
}
// Not Same underlying type i.e. struct and time
if fieldType != currentField.Type() {
return false
}
}
// default reflect.String
return len(field.String()) >= len(currentField.String())
}
// isGtField is the validation function for validating if the current field's value is greater than the field specified by the param's value.
func isGtField(fl FieldLevel) bool {
field := fl.Field()
kind := field.Kind()
currentField, currentKind, ok := fl.GetStructFieldOK()
if !ok || currentKind != kind {
return false
}
switch kind {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return field.Int() > currentField.Int()
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return field.Uint() > currentField.Uint()
case reflect.Float32, reflect.Float64:
return field.Float() > currentField.Float()
case reflect.Struct:
fieldType := field.Type()
if fieldType.ConvertibleTo(timeType) && currentField.Type().ConvertibleTo(timeType) {
t := currentField.Convert(timeType).Interface().(time.Time)
fieldTime := field.Convert(timeType).Interface().(time.Time)
return fieldTime.After(t)
}
// Not Same underlying type i.e. struct and time
if fieldType != currentField.Type() {
return false
}
}
// default reflect.String
return len(field.String()) > len(currentField.String())
}
// isGte is the validation function for validating if the current field's value is greater than or equal to the param's value.
func isGte(fl FieldLevel) bool {
field := fl.Field()
param := fl.Param()
switch field.Kind() {
case reflect.String:
p := asInt(param)
return int64(utf8.RuneCountInString(field.String())) >= p
case reflect.Slice, reflect.Map, reflect.Array:
p := asInt(param)
return int64(field.Len()) >= p
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
p := asIntFromType(field.Type(), param)
return field.Int() >= p
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
p := asUint(param)
return field.Uint() >= p
case reflect.Float32:
p := asFloat32(param)
return field.Float() >= p
case reflect.Float64:
p := asFloat64(param)
return field.Float() >= p
case reflect.Struct:
if field.Type().ConvertibleTo(timeType) {
now := time.Now().UTC()
t := field.Convert(timeType).Interface().(time.Time)
return t.After(now) || t.Equal(now)
}
}
panic(fmt.Sprintf("Bad field type %T", field.Interface()))
}
// isGt is the validation function for validating if the current field's value is greater than the param's value.
func isGt(fl FieldLevel) bool {
field := fl.Field()
param := fl.Param()
switch field.Kind() {
case reflect.String:
p := asInt(param)
return int64(utf8.RuneCountInString(field.String())) > p
case reflect.Slice, reflect.Map, reflect.Array:
p := asInt(param)
return int64(field.Len()) > p
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
p := asIntFromType(field.Type(), param)
return field.Int() > p
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
p := asUint(param)
return field.Uint() > p
case reflect.Float32:
p := asFloat32(param)
return field.Float() > p
case reflect.Float64:
p := asFloat64(param)
return field.Float() > p
case reflect.Struct:
if field.Type().ConvertibleTo(timeType) {
return field.Convert(timeType).Interface().(time.Time).After(time.Now().UTC())
}
}
panic(fmt.Sprintf("Bad field type %T", field.Interface()))
}
// hasLengthOf is the validation function for validating if the current field's value is equal to the param's value.
func hasLengthOf(fl FieldLevel) bool {
field := fl.Field()
param := fl.Param()
switch field.Kind() {
case reflect.String:
p := asInt(param)
return int64(utf8.RuneCountInString(field.String())) == p
case reflect.Slice, reflect.Map, reflect.Array:
p := asInt(param)
return int64(field.Len()) == p
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
p := asIntFromType(field.Type(), param)
return field.Int() == p
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
p := asUint(param)
return field.Uint() == p
case reflect.Float32:
p := asFloat32(param)
return field.Float() == p
case reflect.Float64:
p := asFloat64(param)
return field.Float() == p
}
panic(fmt.Sprintf("Bad field type %T", field.Interface()))
}
// hasMinOf is the validation function for validating if the current field's value is greater than or equal to the param's value.
func hasMinOf(fl FieldLevel) bool {
return isGte(fl)
}
// isLteField is the validation function for validating if the current field's value is less than or equal to the field specified by the param's value.
func isLteField(fl FieldLevel) bool {
field := fl.Field()
kind := field.Kind()
currentField, currentKind, ok := fl.GetStructFieldOK()
if !ok || currentKind != kind {
return false
}
switch kind {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return field.Int() <= currentField.Int()
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return field.Uint() <= currentField.Uint()
case reflect.Float32, reflect.Float64:
return field.Float() <= currentField.Float()
case reflect.Struct:
fieldType := field.Type()
if fieldType.ConvertibleTo(timeType) && currentField.Type().ConvertibleTo(timeType) {
t := currentField.Convert(timeType).Interface().(time.Time)
fieldTime := field.Convert(timeType).Interface().(time.Time)
return fieldTime.Before(t) || fieldTime.Equal(t)
}
// Not Same underlying type i.e. struct and time
if fieldType != currentField.Type() {
return false
}
}
// default reflect.String
return len(field.String()) <= len(currentField.String())
}
// isLtField is the validation function for validating if the current field's value is less than the field specified by the param's value.
func isLtField(fl FieldLevel) bool {
field := fl.Field()
kind := field.Kind()
currentField, currentKind, ok := fl.GetStructFieldOK()
if !ok || currentKind != kind {
return false
}
switch kind {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return field.Int() < currentField.Int()
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return field.Uint() < currentField.Uint()
case reflect.Float32, reflect.Float64:
return field.Float() < currentField.Float()
case reflect.Struct:
fieldType := field.Type()
if fieldType.ConvertibleTo(timeType) && currentField.Type().ConvertibleTo(timeType) {
t := currentField.Convert(timeType).Interface().(time.Time)
fieldTime := field.Convert(timeType).Interface().(time.Time)
return fieldTime.Before(t)
}
// Not Same underlying type i.e. struct and time
if fieldType != currentField.Type() {
return false
}
}
// default reflect.String
return len(field.String()) < len(currentField.String())
}
// isLte is the validation function for validating if the current field's value is less than or equal to the param's value.
func isLte(fl FieldLevel) bool {
field := fl.Field()
param := fl.Param()
switch field.Kind() {
case reflect.String:
p := asInt(param)
return int64(utf8.RuneCountInString(field.String())) <= p
case reflect.Slice, reflect.Map, reflect.Array:
p := asInt(param)
return int64(field.Len()) <= p
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
p := asIntFromType(field.Type(), param)
return field.Int() <= p
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
p := asUint(param)
return field.Uint() <= p
case reflect.Float32:
p := asFloat32(param)
return field.Float() <= p
case reflect.Float64:
p := asFloat64(param)
return field.Float() <= p
case reflect.Struct:
if field.Type().ConvertibleTo(timeType) {
now := time.Now().UTC()
t := field.Convert(timeType).Interface().(time.Time)
return t.Before(now) || t.Equal(now)
}
}
panic(fmt.Sprintf("Bad field type %T", field.Interface()))
}
// isLt is the validation function for validating if the current field's value is less than the param's value.
func isLt(fl FieldLevel) bool {
field := fl.Field()
param := fl.Param()
switch field.Kind() {
case reflect.String:
p := asInt(param)
return int64(utf8.RuneCountInString(field.String())) < p
case reflect.Slice, reflect.Map, reflect.Array:
p := asInt(param)
return int64(field.Len()) < p
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
p := asIntFromType(field.Type(), param)
return field.Int() < p
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
p := asUint(param)
return field.Uint() < p
case reflect.Float32:
p := asFloat32(param)
return field.Float() < p
case reflect.Float64:
p := asFloat64(param)
return field.Float() < p
case reflect.Struct:
if field.Type().ConvertibleTo(timeType) {
return field.Convert(timeType).Interface().(time.Time).Before(time.Now().UTC())
}
}
panic(fmt.Sprintf("Bad field type %T", field.Interface()))
}
// hasMaxOf is the validation function for validating if the current field's value is less than or equal to the param's value.
func hasMaxOf(fl FieldLevel) bool {
return isLte(fl)
}
// isTCP4AddrResolvable is the validation function for validating if the field's value is a resolvable tcp4 address.
func isTCP4AddrResolvable(fl FieldLevel) bool {
if !isIP4Addr(fl) {
return false
}
_, err := net.ResolveTCPAddr("tcp4", fl.Field().String())
return err == nil
}
// isTCP6AddrResolvable is the validation function for validating if the field's value is a resolvable tcp6 address.
func isTCP6AddrResolvable(fl FieldLevel) bool {
if !isIP6Addr(fl) {
return false
}
_, err := net.ResolveTCPAddr("tcp6", fl.Field().String())
return err == nil
}
// isTCPAddrResolvable is the validation function for validating if the field's value is a resolvable tcp address.
func isTCPAddrResolvable(fl FieldLevel) bool {
if !isIP4Addr(fl) && !isIP6Addr(fl) {
return false
}
_, err := net.ResolveTCPAddr("tcp", fl.Field().String())
return err == nil
}
// isUDP4AddrResolvable is the validation function for validating if the field's value is a resolvable udp4 address.
func isUDP4AddrResolvable(fl FieldLevel) bool {
if !isIP4Addr(fl) {
return false
}
_, err := net.ResolveUDPAddr("udp4", fl.Field().String())
return err == nil
}
// isUDP6AddrResolvable is the validation function for validating if the field's value is a resolvable udp6 address.
func isUDP6AddrResolvable(fl FieldLevel) bool {
if !isIP6Addr(fl) {
return false
}
_, err := net.ResolveUDPAddr("udp6", fl.Field().String())
return err == nil
}
// isUDPAddrResolvable is the validation function for validating if the field's value is a resolvable udp address.
func isUDPAddrResolvable(fl FieldLevel) bool {
if !isIP4Addr(fl) && !isIP6Addr(fl) {
return false
}
_, err := net.ResolveUDPAddr("udp", fl.Field().String())
return err == nil
}
// isIP4AddrResolvable is the validation function for validating if the field's value is a resolvable ip4 address.
func isIP4AddrResolvable(fl FieldLevel) bool {
if !isIPv4(fl) {
return false
}
_, err := net.ResolveIPAddr("ip4", fl.Field().String())
return err == nil
}
// isIP6AddrResolvable is the validation function for validating if the field's value is a resolvable ip6 address.
func isIP6AddrResolvable(fl FieldLevel) bool {
if !isIPv6(fl) {
return false
}
_, err := net.ResolveIPAddr("ip6", fl.Field().String())
return err == nil
}
// isIPAddrResolvable is the validation function for validating if the field's value is a resolvable ip address.
func isIPAddrResolvable(fl FieldLevel) bool {
if !isIP(fl) {
return false
}
_, err := net.ResolveIPAddr("ip", fl.Field().String())
return err == nil
}
// isUnixAddrResolvable is the validation function for validating if the field's value is a resolvable unix address.
func isUnixAddrResolvable(fl FieldLevel) bool {
_, err := net.ResolveUnixAddr("unix", fl.Field().String())
return err == nil
}
func isIP4Addr(fl FieldLevel) bool {
val := fl.Field().String()
if idx := strings.LastIndex(val, ":"); idx != -1 {
val = val[0:idx]
}
ip := net.ParseIP(val)
return ip != nil && ip.To4() != nil
}
func isIP6Addr(fl FieldLevel) bool {
val := fl.Field().String()
if idx := strings.LastIndex(val, ":"); idx != -1 {
if idx != 0 && val[idx-1:idx] == "]" {
val = val[1 : idx-1]
}
}
ip := net.ParseIP(val)
return ip != nil && ip.To4() == nil
}
func isHostnameRFC952(fl FieldLevel) bool {
return hostnameRegexRFC952.MatchString(fl.Field().String())
}
func isHostnameRFC1123(fl FieldLevel) bool {
return hostnameRegexRFC1123.MatchString(fl.Field().String())
}
func isFQDN(fl FieldLevel) bool {
val := fl.Field().String()
if val == "" {
return false
}
return fqdnRegexRFC1123.MatchString(val)
}
// isDir is the validation function for validating if the current field's value is a valid existing directory.
func isDir(fl FieldLevel) bool {
field := fl.Field()
if field.Kind() == reflect.String {
fileInfo, err := os.Stat(field.String())
if err != nil {
return false
}
return fileInfo.IsDir()
}
panic(fmt.Sprintf("Bad field type %T", field.Interface()))
}
// isDirPath is the validation function for validating if the current field's value is a valid directory.
func isDirPath(fl FieldLevel) bool {
var exists bool
var err error
field := fl.Field()
// If it exists, it obviously is valid.
// This is done first to avoid code duplication and unnecessary additional logic.
if exists = isDir(fl); exists {
return true
}
// It does not exist but may still be a valid path.
switch field.Kind() {
case reflect.String:
// Every OS allows for whitespace, but none
// let you use a dir with no name (to my knowledge).
// Unless you're dealing with raw inodes, but I digress.
if strings.TrimSpace(field.String()) == "" {
return false
}
if _, err = os.Stat(field.String()); err != nil {
switch t := err.(type) {
case *fs.PathError:
if t.Err == syscall.EINVAL {
// It's definitely an invalid character in the path.
return false
}
// It could be a permission error, a does-not-exist error, etc.
// Out-of-scope for this validation, though.
// Lastly, we make sure it is a directory.
if strings.HasSuffix(field.String(), string(os.PathSeparator)) {
return true
} else {
return false
}
default:
// Something went *seriously* wrong.
/*
Per https://pkg.go.dev/os#Stat:
"If there is an error, it will be of type *PathError."
*/
panic(err)
}
}
// We repeat the check here to make sure it is an explicit directory in case the above os.Stat didn't trigger an error.
if strings.HasSuffix(field.String(), string(os.PathSeparator)) {
return true
} else {
return false
}
}
panic(fmt.Sprintf("Bad field type %T", field.Interface()))
}
// isJSON is the validation function for validating if the current field's value is a valid json string.
func isJSON(fl FieldLevel) bool {
field := fl.Field()
switch field.Kind() {
case reflect.String:
val := field.String()
return json.Valid([]byte(val))
case reflect.Slice:
fieldType := field.Type()
if fieldType.ConvertibleTo(byteSliceType) {
b := field.Convert(byteSliceType).Interface().([]byte)
return json.Valid(b)
}
}
panic(fmt.Sprintf("Bad field type %T", field.Interface()))
}
// isJWT is the validation function for validating if the current field's value is a valid JWT string.
func isJWT(fl FieldLevel) bool {
return jWTRegex.MatchString(fl.Field().String())
}
// isHostnamePort validates a : combination for fields typically used for socket address.
func isHostnamePort(fl FieldLevel) bool {
val := fl.Field().String()
host, port, err := net.SplitHostPort(val)
if err != nil {
return false
}
// Port must be a iny <= 65535.
if portNum, err := strconv.ParseInt(
port, 10, 32,
); err != nil || portNum > 65535 || portNum < 1 {
return false
}
// If host is specified, it should match a DNS name
if host != "" {
return hostnameRegexRFC1123.MatchString(host)
}
return true
}
// isLowercase is the validation function for validating if the current field's value is a lowercase string.
func isLowercase(fl FieldLevel) bool {
field := fl.Field()
if field.Kind() == reflect.String {
if field.String() == "" {
return false
}
return field.String() == strings.ToLower(field.String())
}
panic(fmt.Sprintf("Bad field type %T", field.Interface()))
}
// isUppercase is the validation function for validating if the current field's value is an uppercase string.
func isUppercase(fl FieldLevel) bool {
field := fl.Field()
if field.Kind() == reflect.String {
if field.String() == "" {
return false
}
return field.String() == strings.ToUpper(field.String())
}
panic(fmt.Sprintf("Bad field type %T", field.Interface()))
}
// isDatetime is the validation function for validating if the current field's value is a valid datetime string.
func isDatetime(fl FieldLevel) bool {
field := fl.Field()
param := fl.Param()
if field.Kind() == reflect.String {
_, err := time.Parse(param, field.String())
return err == nil
}
panic(fmt.Sprintf("Bad field type %T", field.Interface()))
}
// isTimeZone is the validation function for validating if the current field's value is a valid time zone string.
func isTimeZone(fl FieldLevel) bool {
field := fl.Field()
if field.Kind() == reflect.String {
// empty value is converted to UTC by time.LoadLocation but disallow it as it is not a valid time zone name
if field.String() == "" {
return false
}
// Local value is converted to the current system time zone by time.LoadLocation but disallow it as it is not a valid time zone name
if strings.ToLower(field.String()) == "local" {
return false
}
_, err := time.LoadLocation(field.String())
return err == nil
}
panic(fmt.Sprintf("Bad field type %T", field.Interface()))
}
// isIso3166Alpha2 is the validation function for validating if the current field's value is a valid iso3166-1 alpha-2 country code.
func isIso3166Alpha2(fl FieldLevel) bool {
val := fl.Field().String()
return iso3166_1_alpha2[val]
}
// isIso3166Alpha3 is the validation function for validating if the current field's value is a valid iso3166-1 alpha-3 country code.
func isIso3166Alpha3(fl FieldLevel) bool {
val := fl.Field().String()
return iso3166_1_alpha3[val]
}
// isIso3166AlphaNumeric is the validation function for validating if the current field's value is a valid iso3166-1 alpha-numeric country code.
func isIso3166AlphaNumeric(fl FieldLevel) bool {
field := fl.Field()
var code int
switch field.Kind() {
case reflect.String:
i, err := strconv.Atoi(field.String())
if err != nil {
return false
}
code = i % 1000
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
code = int(field.Int() % 1000)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
code = int(field.Uint() % 1000)
default:
panic(fmt.Sprintf("Bad field type %T", field.Interface()))
}
return iso3166_1_alpha_numeric[code]
}
// isIso31662 is the validation function for validating if the current field's value is a valid iso3166-2 code.
func isIso31662(fl FieldLevel) bool {
val := fl.Field().String()
return iso3166_2[val]
}
// isIso4217 is the validation function for validating if the current field's value is a valid iso4217 currency code.
func isIso4217(fl FieldLevel) bool {
val := fl.Field().String()
return iso4217[val]
}
// isIso4217Numeric is the validation function for validating if the current field's value is a valid iso4217 numeric currency code.
func isIso4217Numeric(fl FieldLevel) bool {
field := fl.Field()
var code int
switch field.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
code = int(field.Int())
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
code = int(field.Uint())
default:
panic(fmt.Sprintf("Bad field type %T", field.Interface()))
}
return iso4217_numeric[code]
}
// isIsoBicFormat is the validation function for validating if the current field's value is a valid Business Identifier Code (SWIFT code), defined in ISO 9362
func isIsoBicFormat(fl FieldLevel) bool {
bicString := fl.Field().String()
return bicRegex.MatchString(bicString)
}
// isSemverFormat is the validation function for validating if the current field's value is a valid semver version, defined in Semantic Versioning 2.0.0
func isSemverFormat(fl FieldLevel) bool {
semverString := fl.Field().String()
return semverRegex.MatchString(semverString)
}
// isCveFormat is the validation function for validating if the current field's value is a valid cve id, defined in CVE mitre org
func isCveFormat(fl FieldLevel) bool {
cveString := fl.Field().String()
return cveRegex.MatchString(cveString)
}
// isDnsRFC1035LabelFormat is the validation function
// for validating if the current field's value is
// a valid dns RFC 1035 label, defined in RFC 1035.
func isDnsRFC1035LabelFormat(fl FieldLevel) bool {
val := fl.Field().String()
return dnsRegexRFC1035Label.MatchString(val)
}
// digitsHaveLuhnChecksum returns true if and only if the last element of the given digits slice is the Luhn checksum of the previous elements
func digitsHaveLuhnChecksum(digits []string) bool {
size := len(digits)
sum := 0
for i, digit := range digits {
value, err := strconv.Atoi(digit)
if err != nil {
return false
}
if size%2 == 0 && i%2 == 0 || size%2 == 1 && i%2 == 1 {
v := value * 2
if v >= 10 {
sum += 1 + (v % 10)
} else {
sum += v
}
} else {
sum += value
}
}
return (sum % 10) == 0
}
// isMongoDB is the validation function for validating if the current field's value is valid mongoDB objectID
func isMongoDB(fl FieldLevel) bool {
val := fl.Field().String()
return mongodbRegex.MatchString(val)
}
// isSpiceDB is the validation function for validating if the current field's value is valid for use with Authzed SpiceDB in the indicated way
func isSpiceDB(fl FieldLevel) bool {
val := fl.Field().String()
param := fl.Param()
switch param {
case "permission":
return spicedbPermissionRegex.MatchString(val)
case "type":
return spicedbTypeRegex.MatchString(val)
case "id", "":
return spicedbIDRegex.MatchString(val)
}
panic("Unrecognized parameter: " + param)
}
// isCreditCard is the validation function for validating if the current field's value is a valid credit card number
func isCreditCard(fl FieldLevel) bool {
val := fl.Field().String()
var creditCard bytes.Buffer
segments := strings.Split(val, " ")
for _, segment := range segments {
if len(segment) < 3 {
return false
}
creditCard.WriteString(segment)
}
ccDigits := strings.Split(creditCard.String(), "")
size := len(ccDigits)
if size < 12 || size > 19 {
return false
}
return digitsHaveLuhnChecksum(ccDigits)
}
// hasLuhnChecksum is the validation for validating if the current field's value has a valid Luhn checksum
func hasLuhnChecksum(fl FieldLevel) bool {
field := fl.Field()
var str string // convert to a string which will then be split into single digits; easier and more readable than shifting/extracting single digits from a number
switch field.Kind() {
case reflect.String:
str = field.String()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
str = strconv.FormatInt(field.Int(), 10)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
str = strconv.FormatUint(field.Uint(), 10)
default:
panic(fmt.Sprintf("Bad field type %T", field.Interface()))
}
size := len(str)
if size < 2 { // there has to be at least one digit that carries a meaning + the checksum
return false
}
digits := strings.Split(str, "")
return digitsHaveLuhnChecksum(digits)
}
// isCron is the validation function for validating if the current field's value is a valid cron expression
func isCron(fl FieldLevel) bool {
cronString := fl.Field().String()
return cronRegex.MatchString(cronString)
}
================================================
FILE: gossh/gin/validator/cache.go
================================================
package validator
import (
"fmt"
"reflect"
"strings"
"sync"
"sync/atomic"
)
type tagType uint8
const (
typeDefault tagType = iota
typeOmitEmpty
typeIsDefault
typeNoStructLevel
typeStructOnly
typeDive
typeOr
typeKeys
typeEndKeys
typeOmitNil
)
const (
invalidValidation = "Invalid validation tag on field '%s'"
undefinedValidation = "Undefined validation function '%s' on field '%s'"
keysTagNotDefined = "'" + endKeysTag + "' tag encountered without a corresponding '" + keysTag + "' tag"
)
type structCache struct {
lock sync.Mutex
m atomic.Value // map[reflect.Type]*cStruct
}
func (sc *structCache) Get(key reflect.Type) (c *cStruct, found bool) {
c, found = sc.m.Load().(map[reflect.Type]*cStruct)[key]
return
}
func (sc *structCache) Set(key reflect.Type, value *cStruct) {
m := sc.m.Load().(map[reflect.Type]*cStruct)
nm := make(map[reflect.Type]*cStruct, len(m)+1)
for k, v := range m {
nm[k] = v
}
nm[key] = value
sc.m.Store(nm)
}
type tagCache struct {
lock sync.Mutex
m atomic.Value // map[string]*cTag
}
func (tc *tagCache) Get(key string) (c *cTag, found bool) {
c, found = tc.m.Load().(map[string]*cTag)[key]
return
}
func (tc *tagCache) Set(key string, value *cTag) {
m := tc.m.Load().(map[string]*cTag)
nm := make(map[string]*cTag, len(m)+1)
for k, v := range m {
nm[k] = v
}
nm[key] = value
tc.m.Store(nm)
}
type cStruct struct {
name string
fields []*cField
fn StructLevelFuncCtx
}
type cField struct {
idx int
name string
altName string
namesEqual bool
cTags *cTag
}
type cTag struct {
tag string
aliasTag string
actualAliasTag string
param string
keys *cTag // only populated when using tag's 'keys' and 'endkeys' for map key validation
next *cTag
fn FuncCtx
typeof tagType
hasTag bool
hasAlias bool
hasParam bool // true if parameter used eg. eq= where the equal sign has been set
isBlockEnd bool // indicates the current tag represents the last validation in the block
runValidationWhenNil bool
}
func (v *Validate) extractStructCache(current reflect.Value, sName string) *cStruct {
v.structCache.lock.Lock()
defer v.structCache.lock.Unlock() // leave as defer! because if inner panics, it will never get unlocked otherwise!
typ := current.Type()
// could have been multiple trying to access, but once first is done this ensures struct
// isn't parsed again.
cs, ok := v.structCache.Get(typ)
if ok {
return cs
}
cs = &cStruct{name: sName, fields: make([]*cField, 0), fn: v.structLevelFuncs[typ]}
numFields := current.NumField()
rules := v.rules[typ]
var ctag *cTag
var fld reflect.StructField
var tag string
var customName string
for i := 0; i < numFields; i++ {
fld = typ.Field(i)
if !fld.Anonymous && len(fld.PkgPath) > 0 {
continue
}
if rtag, ok := rules[fld.Name]; ok {
tag = rtag
} else {
tag = fld.Tag.Get(v.tagName)
}
if tag == skipValidationTag {
continue
}
customName = fld.Name
if v.hasTagNameFunc {
name := v.tagNameFunc(fld)
if len(name) > 0 {
customName = name
}
}
// NOTE: cannot use shared tag cache, because tags may be equal, but things like alias may be different
// and so only struct level caching can be used instead of combined with Field tag caching
if len(tag) > 0 {
ctag, _ = v.parseFieldTagsRecursive(tag, fld.Name, "", false)
} else {
// even if field doesn't have validations need cTag for traversing to potential inner/nested
// elements of the field.
ctag = new(cTag)
}
cs.fields = append(cs.fields, &cField{
idx: i,
name: fld.Name,
altName: customName,
cTags: ctag,
namesEqual: fld.Name == customName,
})
}
v.structCache.Set(typ, cs)
return cs
}
func (v *Validate) parseFieldTagsRecursive(tag string, fieldName string, alias string, hasAlias bool) (firstCtag *cTag, current *cTag) {
var t string
noAlias := len(alias) == 0
tags := strings.Split(tag, tagSeparator)
for i := 0; i < len(tags); i++ {
t = tags[i]
if noAlias {
alias = t
}
// check map for alias and process new tags, otherwise process as usual
if tagsVal, found := v.aliases[t]; found {
if i == 0 {
firstCtag, current = v.parseFieldTagsRecursive(tagsVal, fieldName, t, true)
} else {
next, curr := v.parseFieldTagsRecursive(tagsVal, fieldName, t, true)
current.next, current = next, curr
}
continue
}
var prevTag tagType
if i == 0 {
current = &cTag{aliasTag: alias, hasAlias: hasAlias, hasTag: true, typeof: typeDefault}
firstCtag = current
} else {
prevTag = current.typeof
current.next = &cTag{aliasTag: alias, hasAlias: hasAlias, hasTag: true}
current = current.next
}
switch t {
case diveTag:
current.typeof = typeDive
continue
case keysTag:
current.typeof = typeKeys
if i == 0 || prevTag != typeDive {
panic(fmt.Sprintf("'%s' tag must be immediately preceded by the '%s' tag", keysTag, diveTag))
}
current.typeof = typeKeys
// need to pass along only keys tag
// need to increment i to skip over the keys tags
b := make([]byte, 0, 64)
i++
for ; i < len(tags); i++ {
b = append(b, tags[i]...)
b = append(b, ',')
if tags[i] == endKeysTag {
break
}
}
current.keys, _ = v.parseFieldTagsRecursive(string(b[:len(b)-1]), fieldName, "", false)
continue
case endKeysTag:
current.typeof = typeEndKeys
// if there are more in tags then there was no keysTag defined
// and an error should be thrown
if i != len(tags)-1 {
panic(keysTagNotDefined)
}
return
case omitempty:
current.typeof = typeOmitEmpty
continue
case omitnil:
current.typeof = typeOmitNil
continue
case structOnlyTag:
current.typeof = typeStructOnly
continue
case noStructLevelTag:
current.typeof = typeNoStructLevel
continue
default:
if t == isdefault {
current.typeof = typeIsDefault
}
// if a pipe character is needed within the param you must use the utf8Pipe representation "0x7C"
orVals := strings.Split(t, orSeparator)
for j := 0; j < len(orVals); j++ {
vals := strings.SplitN(orVals[j], tagKeySeparator, 2)
if noAlias {
alias = vals[0]
current.aliasTag = alias
} else {
current.actualAliasTag = t
}
if j > 0 {
current.next = &cTag{aliasTag: alias, actualAliasTag: current.actualAliasTag, hasAlias: hasAlias, hasTag: true}
current = current.next
}
current.hasParam = len(vals) > 1
current.tag = vals[0]
if len(current.tag) == 0 {
panic(strings.TrimSpace(fmt.Sprintf(invalidValidation, fieldName)))
}
if wrapper, ok := v.validations[current.tag]; ok {
current.fn = wrapper.fn
current.runValidationWhenNil = wrapper.runValidatinOnNil
} else {
panic(strings.TrimSpace(fmt.Sprintf(undefinedValidation, current.tag, fieldName)))
}
if len(orVals) > 1 {
current.typeof = typeOr
}
if len(vals) > 1 {
current.param = strings.Replace(strings.Replace(vals[1], utf8HexComma, ",", -1), utf8Pipe, "|", -1)
}
}
current.isBlockEnd = true
}
}
return
}
func (v *Validate) fetchCacheTag(tag string) *cTag {
// find cached tag
ctag, found := v.tagCache.Get(tag)
if !found {
v.tagCache.lock.Lock()
defer v.tagCache.lock.Unlock()
// could have been multiple trying to access, but once first is done this ensures tag
// isn't parsed again.
ctag, found = v.tagCache.Get(tag)
if !found {
ctag, _ = v.parseFieldTagsRecursive(tag, "", "", false)
v.tagCache.Set(tag, ctag)
}
}
return ctag
}
================================================
FILE: gossh/gin/validator/country_codes.go
================================================
package validator
var iso3166_1_alpha2 = map[string]bool{
// see: https://www.iso.org/iso-3166-country-codes.html
"AF": true, "AX": true, "AL": true, "DZ": true, "AS": true,
"AD": true, "AO": true, "AI": true, "AQ": true, "AG": true,
"AR": true, "AM": true, "AW": true, "AU": true, "AT": true,
"AZ": true, "BS": true, "BH": true, "BD": true, "BB": true,
"BY": true, "BE": true, "BZ": true, "BJ": true, "BM": true,
"BT": true, "BO": true, "BQ": true, "BA": true, "BW": true,
"BV": true, "BR": true, "IO": true, "BN": true, "BG": true,
"BF": true, "BI": true, "KH": true, "CM": true, "CA": true,
"CV": true, "KY": true, "CF": true, "TD": true, "CL": true,
"CN": true, "CX": true, "CC": true, "CO": true, "KM": true,
"CG": true, "CD": true, "CK": true, "CR": true, "CI": true,
"HR": true, "CU": true, "CW": true, "CY": true, "CZ": true,
"DK": true, "DJ": true, "DM": true, "DO": true, "EC": true,
"EG": true, "SV": true, "GQ": true, "ER": true, "EE": true,
"ET": true, "FK": true, "FO": true, "FJ": true, "FI": true,
"FR": true, "GF": true, "PF": true, "TF": true, "GA": true,
"GM": true, "GE": true, "DE": true, "GH": true, "GI": true,
"GR": true, "GL": true, "GD": true, "GP": true, "GU": true,
"GT": true, "GG": true, "GN": true, "GW": true, "GY": true,
"HT": true, "HM": true, "VA": true, "HN": true, "HK": true,
"HU": true, "IS": true, "IN": true, "ID": true, "IR": true,
"IQ": true, "IE": true, "IM": true, "IL": true, "IT": true,
"JM": true, "JP": true, "JE": true, "JO": true, "KZ": true,
"KE": true, "KI": true, "KP": true, "KR": true, "KW": true,
"KG": true, "LA": true, "LV": true, "LB": true, "LS": true,
"LR": true, "LY": true, "LI": true, "LT": true, "LU": true,
"MO": true, "MK": true, "MG": true, "MW": true, "MY": true,
"MV": true, "ML": true, "MT": true, "MH": true, "MQ": true,
"MR": true, "MU": true, "YT": true, "MX": true, "FM": true,
"MD": true, "MC": true, "MN": true, "ME": true, "MS": true,
"MA": true, "MZ": true, "MM": true, "NA": true, "NR": true,
"NP": true, "NL": true, "NC": true, "NZ": true, "NI": true,
"NE": true, "NG": true, "NU": true, "NF": true, "MP": true,
"NO": true, "OM": true, "PK": true, "PW": true, "PS": true,
"PA": true, "PG": true, "PY": true, "PE": true, "PH": true,
"PN": true, "PL": true, "PT": true, "PR": true, "QA": true,
"RE": true, "RO": true, "RU": true, "RW": true, "BL": true,
"SH": true, "KN": true, "LC": true, "MF": true, "PM": true,
"VC": true, "WS": true, "SM": true, "ST": true, "SA": true,
"SN": true, "RS": true, "SC": true, "SL": true, "SG": true,
"SX": true, "SK": true, "SI": true, "SB": true, "SO": true,
"ZA": true, "GS": true, "SS": true, "ES": true, "LK": true,
"SD": true, "SR": true, "SJ": true, "SZ": true, "SE": true,
"CH": true, "SY": true, "TW": true, "TJ": true, "TZ": true,
"TH": true, "TL": true, "TG": true, "TK": true, "TO": true,
"TT": true, "TN": true, "TR": true, "TM": true, "TC": true,
"TV": true, "UG": true, "UA": true, "AE": true, "GB": true,
"US": true, "UM": true, "UY": true, "UZ": true, "VU": true,
"VE": true, "VN": true, "VG": true, "VI": true, "WF": true,
"EH": true, "YE": true, "ZM": true, "ZW": true, "XK": true,
}
var iso3166_1_alpha3 = map[string]bool{
// see: https://www.iso.org/iso-3166-country-codes.html
"AFG": true, "ALB": true, "DZA": true, "ASM": true, "AND": true,
"AGO": true, "AIA": true, "ATA": true, "ATG": true, "ARG": true,
"ARM": true, "ABW": true, "AUS": true, "AUT": true, "AZE": true,
"BHS": true, "BHR": true, "BGD": true, "BRB": true, "BLR": true,
"BEL": true, "BLZ": true, "BEN": true, "BMU": true, "BTN": true,
"BOL": true, "BES": true, "BIH": true, "BWA": true, "BVT": true,
"BRA": true, "IOT": true, "BRN": true, "BGR": true, "BFA": true,
"BDI": true, "CPV": true, "KHM": true, "CMR": true, "CAN": true,
"CYM": true, "CAF": true, "TCD": true, "CHL": true, "CHN": true,
"CXR": true, "CCK": true, "COL": true, "COM": true, "COD": true,
"COG": true, "COK": true, "CRI": true, "HRV": true, "CUB": true,
"CUW": true, "CYP": true, "CZE": true, "CIV": true, "DNK": true,
"DJI": true, "DMA": true, "DOM": true, "ECU": true, "EGY": true,
"SLV": true, "GNQ": true, "ERI": true, "EST": true, "SWZ": true,
"ETH": true, "FLK": true, "FRO": true, "FJI": true, "FIN": true,
"FRA": true, "GUF": true, "PYF": true, "ATF": true, "GAB": true,
"GMB": true, "GEO": true, "DEU": true, "GHA": true, "GIB": true,
"GRC": true, "GRL": true, "GRD": true, "GLP": true, "GUM": true,
"GTM": true, "GGY": true, "GIN": true, "GNB": true, "GUY": true,
"HTI": true, "HMD": true, "VAT": true, "HND": true, "HKG": true,
"HUN": true, "ISL": true, "IND": true, "IDN": true, "IRN": true,
"IRQ": true, "IRL": true, "IMN": true, "ISR": true, "ITA": true,
"JAM": true, "JPN": true, "JEY": true, "JOR": true, "KAZ": true,
"KEN": true, "KIR": true, "PRK": true, "KOR": true, "KWT": true,
"KGZ": true, "LAO": true, "LVA": true, "LBN": true, "LSO": true,
"LBR": true, "LBY": true, "LIE": true, "LTU": true, "LUX": true,
"MAC": true, "MDG": true, "MWI": true, "MYS": true, "MDV": true,
"MLI": true, "MLT": true, "MHL": true, "MTQ": true, "MRT": true,
"MUS": true, "MYT": true, "MEX": true, "FSM": true, "MDA": true,
"MCO": true, "MNG": true, "MNE": true, "MSR": true, "MAR": true,
"MOZ": true, "MMR": true, "NAM": true, "NRU": true, "NPL": true,
"NLD": true, "NCL": true, "NZL": true, "NIC": true, "NER": true,
"NGA": true, "NIU": true, "NFK": true, "MKD": true, "MNP": true,
"NOR": true, "OMN": true, "PAK": true, "PLW": true, "PSE": true,
"PAN": true, "PNG": true, "PRY": true, "PER": true, "PHL": true,
"PCN": true, "POL": true, "PRT": true, "PRI": true, "QAT": true,
"ROU": true, "RUS": true, "RWA": true, "REU": true, "BLM": true,
"SHN": true, "KNA": true, "LCA": true, "MAF": true, "SPM": true,
"VCT": true, "WSM": true, "SMR": true, "STP": true, "SAU": true,
"SEN": true, "SRB": true, "SYC": true, "SLE": true, "SGP": true,
"SXM": true, "SVK": true, "SVN": true, "SLB": true, "SOM": true,
"ZAF": true, "SGS": true, "SSD": true, "ESP": true, "LKA": true,
"SDN": true, "SUR": true, "SJM": true, "SWE": true, "CHE": true,
"SYR": true, "TWN": true, "TJK": true, "TZA": true, "THA": true,
"TLS": true, "TGO": true, "TKL": true, "TON": true, "TTO": true,
"TUN": true, "TUR": true, "TKM": true, "TCA": true, "TUV": true,
"UGA": true, "UKR": true, "ARE": true, "GBR": true, "UMI": true,
"USA": true, "URY": true, "UZB": true, "VUT": true, "VEN": true,
"VNM": true, "VGB": true, "VIR": true, "WLF": true, "ESH": true,
"YEM": true, "ZMB": true, "ZWE": true, "ALA": true, "UNK": true,
}
var iso3166_1_alpha_numeric = map[int]bool{
// see: https://www.iso.org/iso-3166-country-codes.html
4: true, 8: true, 12: true, 16: true, 20: true,
24: true, 660: true, 10: true, 28: true, 32: true,
51: true, 533: true, 36: true, 40: true, 31: true,
44: true, 48: true, 50: true, 52: true, 112: true,
56: true, 84: true, 204: true, 60: true, 64: true,
68: true, 535: true, 70: true, 72: true, 74: true,
76: true, 86: true, 96: true, 100: true, 854: true,
108: true, 132: true, 116: true, 120: true, 124: true,
136: true, 140: true, 148: true, 152: true, 156: true,
162: true, 166: true, 170: true, 174: true, 180: true,
178: true, 184: true, 188: true, 191: true, 192: true,
531: true, 196: true, 203: true, 384: true, 208: true,
262: true, 212: true, 214: true, 218: true, 818: true,
222: true, 226: true, 232: true, 233: true, 748: true,
231: true, 238: true, 234: true, 242: true, 246: true,
250: true, 254: true, 258: true, 260: true, 266: true,
270: true, 268: true, 276: true, 288: true, 292: true,
300: true, 304: true, 308: true, 312: true, 316: true,
320: true, 831: true, 324: true, 624: true, 328: true,
332: true, 334: true, 336: true, 340: true, 344: true,
348: true, 352: true, 356: true, 360: true, 364: true,
368: true, 372: true, 833: true, 376: true, 380: true,
388: true, 392: true, 832: true, 400: true, 398: true,
404: true, 296: true, 408: true, 410: true, 414: true,
417: true, 418: true, 428: true, 422: true, 426: true,
430: true, 434: true, 438: true, 440: true, 442: true,
446: true, 450: true, 454: true, 458: true, 462: true,
466: true, 470: true, 584: true, 474: true, 478: true,
480: true, 175: true, 484: true, 583: true, 498: true,
492: true, 496: true, 499: true, 500: true, 504: true,
508: true, 104: true, 516: true, 520: true, 524: true,
528: true, 540: true, 554: true, 558: true, 562: true,
566: true, 570: true, 574: true, 807: true, 580: true,
578: true, 512: true, 586: true, 585: true, 275: true,
591: true, 598: true, 600: true, 604: true, 608: true,
612: true, 616: true, 620: true, 630: true, 634: true,
642: true, 643: true, 646: true, 638: true, 652: true,
654: true, 659: true, 662: true, 663: true, 666: true,
670: true, 882: true, 674: true, 678: true, 682: true,
686: true, 688: true, 690: true, 694: true, 702: true,
534: true, 703: true, 705: true, 90: true, 706: true,
710: true, 239: true, 728: true, 724: true, 144: true,
729: true, 740: true, 744: true, 752: true, 756: true,
760: true, 158: true, 762: true, 834: true, 764: true,
626: true, 768: true, 772: true, 776: true, 780: true,
788: true, 792: true, 795: true, 796: true, 798: true,
800: true, 804: true, 784: true, 826: true, 581: true,
840: true, 858: true, 860: true, 548: true, 862: true,
704: true, 92: true, 850: true, 876: true, 732: true,
887: true, 894: true, 716: true, 248: true, 153: true,
}
var iso3166_2 = map[string]bool{
"AD-02": true, "AD-03": true, "AD-04": true, "AD-05": true, "AD-06": true,
"AD-07": true, "AD-08": true, "AE-AJ": true, "AE-AZ": true, "AE-DU": true,
"AE-FU": true, "AE-RK": true, "AE-SH": true, "AE-UQ": true, "AF-BAL": true,
"AF-BAM": true, "AF-BDG": true, "AF-BDS": true, "AF-BGL": true, "AF-DAY": true,
"AF-FRA": true, "AF-FYB": true, "AF-GHA": true, "AF-GHO": true, "AF-HEL": true,
"AF-HER": true, "AF-JOW": true, "AF-KAB": true, "AF-KAN": true, "AF-KAP": true,
"AF-KDZ": true, "AF-KHO": true, "AF-KNR": true, "AF-LAG": true, "AF-LOG": true,
"AF-NAN": true, "AF-NIM": true, "AF-NUR": true, "AF-PAN": true, "AF-PAR": true,
"AF-PIA": true, "AF-PKA": true, "AF-SAM": true, "AF-SAR": true, "AF-TAK": true,
"AF-URU": true, "AF-WAR": true, "AF-ZAB": true, "AG-03": true, "AG-04": true,
"AG-05": true, "AG-06": true, "AG-07": true, "AG-08": true, "AG-10": true,
"AG-11": true, "AL-01": true, "AL-02": true, "AL-03": true, "AL-04": true,
"AL-05": true, "AL-06": true, "AL-07": true, "AL-08": true, "AL-09": true,
"AL-10": true, "AL-11": true, "AL-12": true, "AL-BR": true, "AL-BU": true,
"AL-DI": true, "AL-DL": true, "AL-DR": true, "AL-DV": true, "AL-EL": true,
"AL-ER": true, "AL-FR": true, "AL-GJ": true, "AL-GR": true, "AL-HA": true,
"AL-KA": true, "AL-KB": true, "AL-KC": true, "AL-KO": true, "AL-KR": true,
"AL-KU": true, "AL-LB": true, "AL-LE": true, "AL-LU": true, "AL-MK": true,
"AL-MM": true, "AL-MR": true, "AL-MT": true, "AL-PG": true, "AL-PQ": true,
"AL-PR": true, "AL-PU": true, "AL-SH": true, "AL-SK": true, "AL-SR": true,
"AL-TE": true, "AL-TP": true, "AL-TR": true, "AL-VL": true, "AM-AG": true,
"AM-AR": true, "AM-AV": true, "AM-ER": true, "AM-GR": true, "AM-KT": true,
"AM-LO": true, "AM-SH": true, "AM-SU": true, "AM-TV": true, "AM-VD": true,
"AO-BGO": true, "AO-BGU": true, "AO-BIE": true, "AO-CAB": true, "AO-CCU": true,
"AO-CNN": true, "AO-CNO": true, "AO-CUS": true, "AO-HUA": true, "AO-HUI": true,
"AO-LNO": true, "AO-LSU": true, "AO-LUA": true, "AO-MAL": true, "AO-MOX": true,
"AO-NAM": true, "AO-UIG": true, "AO-ZAI": true, "AR-A": true, "AR-B": true,
"AR-C": true, "AR-D": true, "AR-E": true, "AR-F": true, "AR-G": true, "AR-H": true,
"AR-J": true, "AR-K": true, "AR-L": true, "AR-M": true, "AR-N": true,
"AR-P": true, "AR-Q": true, "AR-R": true, "AR-S": true, "AR-T": true,
"AR-U": true, "AR-V": true, "AR-W": true, "AR-X": true, "AR-Y": true,
"AR-Z": true, "AT-1": true, "AT-2": true, "AT-3": true, "AT-4": true,
"AT-5": true, "AT-6": true, "AT-7": true, "AT-8": true, "AT-9": true,
"AU-ACT": true, "AU-NSW": true, "AU-NT": true, "AU-QLD": true, "AU-SA": true,
"AU-TAS": true, "AU-VIC": true, "AU-WA": true, "AZ-ABS": true, "AZ-AGA": true,
"AZ-AGC": true, "AZ-AGM": true, "AZ-AGS": true, "AZ-AGU": true, "AZ-AST": true,
"AZ-BA": true, "AZ-BAB": true, "AZ-BAL": true, "AZ-BAR": true, "AZ-BEY": true,
"AZ-BIL": true, "AZ-CAB": true, "AZ-CAL": true, "AZ-CUL": true, "AZ-DAS": true,
"AZ-FUZ": true, "AZ-GA": true, "AZ-GAD": true, "AZ-GOR": true, "AZ-GOY": true,
"AZ-GYG": true, "AZ-HAC": true, "AZ-IMI": true, "AZ-ISM": true, "AZ-KAL": true,
"AZ-KAN": true, "AZ-KUR": true, "AZ-LA": true, "AZ-LAC": true, "AZ-LAN": true,
"AZ-LER": true, "AZ-MAS": true, "AZ-MI": true, "AZ-NA": true, "AZ-NEF": true,
"AZ-NV": true, "AZ-NX": true, "AZ-OGU": true, "AZ-ORD": true, "AZ-QAB": true,
"AZ-QAX": true, "AZ-QAZ": true, "AZ-QBA": true, "AZ-QBI": true, "AZ-QOB": true,
"AZ-QUS": true, "AZ-SA": true, "AZ-SAB": true, "AZ-SAD": true, "AZ-SAH": true,
"AZ-SAK": true, "AZ-SAL": true, "AZ-SAR": true, "AZ-SAT": true, "AZ-SBN": true,
"AZ-SIY": true, "AZ-SKR": true, "AZ-SM": true, "AZ-SMI": true, "AZ-SMX": true,
"AZ-SR": true, "AZ-SUS": true, "AZ-TAR": true, "AZ-TOV": true, "AZ-UCA": true,
"AZ-XA": true, "AZ-XAC": true, "AZ-XCI": true, "AZ-XIZ": true, "AZ-XVD": true,
"AZ-YAR": true, "AZ-YE": true, "AZ-YEV": true, "AZ-ZAN": true, "AZ-ZAQ": true,
"AZ-ZAR": true, "BA-01": true, "BA-02": true, "BA-03": true, "BA-04": true,
"BA-05": true, "BA-06": true, "BA-07": true, "BA-08": true, "BA-09": true,
"BA-10": true, "BA-BIH": true, "BA-BRC": true, "BA-SRP": true, "BB-01": true,
"BB-02": true, "BB-03": true, "BB-04": true, "BB-05": true, "BB-06": true,
"BB-07": true, "BB-08": true, "BB-09": true, "BB-10": true, "BB-11": true,
"BD-01": true, "BD-02": true, "BD-03": true, "BD-04": true, "BD-05": true,
"BD-06": true, "BD-07": true, "BD-08": true, "BD-09": true, "BD-10": true,
"BD-11": true, "BD-12": true, "BD-13": true, "BD-14": true, "BD-15": true,
"BD-16": true, "BD-17": true, "BD-18": true, "BD-19": true, "BD-20": true,
"BD-21": true, "BD-22": true, "BD-23": true, "BD-24": true, "BD-25": true,
"BD-26": true, "BD-27": true, "BD-28": true, "BD-29": true, "BD-30": true,
"BD-31": true, "BD-32": true, "BD-33": true, "BD-34": true, "BD-35": true,
"BD-36": true, "BD-37": true, "BD-38": true, "BD-39": true, "BD-40": true,
"BD-41": true, "BD-42": true, "BD-43": true, "BD-44": true, "BD-45": true,
"BD-46": true, "BD-47": true, "BD-48": true, "BD-49": true, "BD-50": true,
"BD-51": true, "BD-52": true, "BD-53": true, "BD-54": true, "BD-55": true,
"BD-56": true, "BD-57": true, "BD-58": true, "BD-59": true, "BD-60": true,
"BD-61": true, "BD-62": true, "BD-63": true, "BD-64": true, "BD-A": true,
"BD-B": true, "BD-C": true, "BD-D": true, "BD-E": true, "BD-F": true,
"BD-G": true, "BE-BRU": true, "BE-VAN": true, "BE-VBR": true, "BE-VLG": true,
"BE-VLI": true, "BE-VOV": true, "BE-VWV": true, "BE-WAL": true, "BE-WBR": true,
"BE-WHT": true, "BE-WLG": true, "BE-WLX": true, "BE-WNA": true, "BF-01": true,
"BF-02": true, "BF-03": true, "BF-04": true, "BF-05": true, "BF-06": true,
"BF-07": true, "BF-08": true, "BF-09": true, "BF-10": true, "BF-11": true,
"BF-12": true, "BF-13": true, "BF-BAL": true, "BF-BAM": true, "BF-BAN": true,
"BF-BAZ": true, "BF-BGR": true, "BF-BLG": true, "BF-BLK": true, "BF-COM": true,
"BF-GAN": true, "BF-GNA": true, "BF-GOU": true, "BF-HOU": true, "BF-IOB": true,
"BF-KAD": true, "BF-KEN": true, "BF-KMD": true, "BF-KMP": true, "BF-KOP": true,
"BF-KOS": true, "BF-KOT": true, "BF-KOW": true, "BF-LER": true, "BF-LOR": true,
"BF-MOU": true, "BF-NAM": true, "BF-NAO": true, "BF-NAY": true, "BF-NOU": true,
"BF-OUB": true, "BF-OUD": true, "BF-PAS": true, "BF-PON": true, "BF-SEN": true,
"BF-SIS": true, "BF-SMT": true, "BF-SNG": true, "BF-SOM": true, "BF-SOR": true,
"BF-TAP": true, "BF-TUI": true, "BF-YAG": true, "BF-YAT": true, "BF-ZIR": true,
"BF-ZON": true, "BF-ZOU": true, "BG-01": true, "BG-02": true, "BG-03": true,
"BG-04": true, "BG-05": true, "BG-06": true, "BG-07": true, "BG-08": true,
"BG-09": true, "BG-10": true, "BG-11": true, "BG-12": true, "BG-13": true,
"BG-14": true, "BG-15": true, "BG-16": true, "BG-17": true, "BG-18": true,
"BG-19": true, "BG-20": true, "BG-21": true, "BG-22": true, "BG-23": true,
"BG-24": true, "BG-25": true, "BG-26": true, "BG-27": true, "BG-28": true,
"BH-13": true, "BH-14": true, "BH-15": true, "BH-16": true, "BH-17": true,
"BI-BB": true, "BI-BL": true, "BI-BM": true, "BI-BR": true, "BI-CA": true,
"BI-CI": true, "BI-GI": true, "BI-KI": true, "BI-KR": true, "BI-KY": true,
"BI-MA": true, "BI-MU": true, "BI-MW": true, "BI-NG": true, "BI-RM": true, "BI-RT": true,
"BI-RY": true, "BJ-AK": true, "BJ-AL": true, "BJ-AQ": true, "BJ-BO": true,
"BJ-CO": true, "BJ-DO": true, "BJ-KO": true, "BJ-LI": true, "BJ-MO": true,
"BJ-OU": true, "BJ-PL": true, "BJ-ZO": true, "BN-BE": true, "BN-BM": true,
"BN-TE": true, "BN-TU": true, "BO-B": true, "BO-C": true, "BO-H": true,
"BO-L": true, "BO-N": true, "BO-O": true, "BO-P": true, "BO-S": true,
"BO-T": true, "BQ-BO": true, "BQ-SA": true, "BQ-SE": true, "BR-AC": true,
"BR-AL": true, "BR-AM": true, "BR-AP": true, "BR-BA": true, "BR-CE": true,
"BR-DF": true, "BR-ES": true, "BR-FN": true, "BR-GO": true, "BR-MA": true,
"BR-MG": true, "BR-MS": true, "BR-MT": true, "BR-PA": true, "BR-PB": true,
"BR-PE": true, "BR-PI": true, "BR-PR": true, "BR-RJ": true, "BR-RN": true,
"BR-RO": true, "BR-RR": true, "BR-RS": true, "BR-SC": true, "BR-SE": true,
"BR-SP": true, "BR-TO": true, "BS-AK": true, "BS-BI": true, "BS-BP": true,
"BS-BY": true, "BS-CE": true, "BS-CI": true, "BS-CK": true, "BS-CO": true,
"BS-CS": true, "BS-EG": true, "BS-EX": true, "BS-FP": true, "BS-GC": true,
"BS-HI": true, "BS-HT": true, "BS-IN": true, "BS-LI": true, "BS-MC": true,
"BS-MG": true, "BS-MI": true, "BS-NE": true, "BS-NO": true, "BS-NP": true, "BS-NS": true,
"BS-RC": true, "BS-RI": true, "BS-SA": true, "BS-SE": true, "BS-SO": true,
"BS-SS": true, "BS-SW": true, "BS-WG": true, "BT-11": true, "BT-12": true,
"BT-13": true, "BT-14": true, "BT-15": true, "BT-21": true, "BT-22": true,
"BT-23": true, "BT-24": true, "BT-31": true, "BT-32": true, "BT-33": true,
"BT-34": true, "BT-41": true, "BT-42": true, "BT-43": true, "BT-44": true,
"BT-45": true, "BT-GA": true, "BT-TY": true, "BW-CE": true, "BW-CH": true, "BW-GH": true,
"BW-KG": true, "BW-KL": true, "BW-KW": true, "BW-NE": true, "BW-NW": true,
"BW-SE": true, "BW-SO": true, "BY-BR": true, "BY-HM": true, "BY-HO": true,
"BY-HR": true, "BY-MA": true, "BY-MI": true, "BY-VI": true, "BZ-BZ": true,
"BZ-CY": true, "BZ-CZL": true, "BZ-OW": true, "BZ-SC": true, "BZ-TOL": true,
"CA-AB": true, "CA-BC": true, "CA-MB": true, "CA-NB": true, "CA-NL": true,
"CA-NS": true, "CA-NT": true, "CA-NU": true, "CA-ON": true, "CA-PE": true,
"CA-QC": true, "CA-SK": true, "CA-YT": true, "CD-BC": true, "CD-BN": true,
"CD-EQ": true, "CD-HK": true, "CD-IT": true, "CD-KA": true, "CD-KC": true, "CD-KE": true, "CD-KG": true, "CD-KN": true,
"CD-KW": true, "CD-KS": true, "CD-LU": true, "CD-MA": true, "CD-NK": true, "CD-OR": true, "CD-SA": true, "CD-SK": true,
"CD-TA": true, "CD-TO": true, "CF-AC": true, "CF-BB": true, "CF-BGF": true, "CF-BK": true, "CF-HK": true, "CF-HM": true,
"CF-HS": true, "CF-KB": true, "CF-KG": true, "CF-LB": true, "CF-MB": true,
"CF-MP": true, "CF-NM": true, "CF-OP": true, "CF-SE": true, "CF-UK": true,
"CF-VK": true, "CG-11": true, "CG-12": true, "CG-13": true, "CG-14": true,
"CG-15": true, "CG-16": true, "CG-2": true, "CG-5": true, "CG-7": true, "CG-8": true,
"CG-9": true, "CG-BZV": true, "CH-AG": true, "CH-AI": true, "CH-AR": true,
"CH-BE": true, "CH-BL": true, "CH-BS": true, "CH-FR": true, "CH-GE": true,
"CH-GL": true, "CH-GR": true, "CH-JU": true, "CH-LU": true, "CH-NE": true,
"CH-NW": true, "CH-OW": true, "CH-SG": true, "CH-SH": true, "CH-SO": true,
"CH-SZ": true, "CH-TG": true, "CH-TI": true, "CH-UR": true, "CH-VD": true,
"CH-VS": true, "CH-ZG": true, "CH-ZH": true, "CI-AB": true, "CI-BS": true,
"CI-CM": true, "CI-DN": true, "CI-GD": true, "CI-LC": true, "CI-LG": true,
"CI-MG": true, "CI-SM": true, "CI-SV": true, "CI-VB": true, "CI-WR": true,
"CI-YM": true, "CI-ZZ": true, "CL-AI": true, "CL-AN": true, "CL-AP": true,
"CL-AR": true, "CL-AT": true, "CL-BI": true, "CL-CO": true, "CL-LI": true,
"CL-LL": true, "CL-LR": true, "CL-MA": true, "CL-ML": true, "CL-NB": true, "CL-RM": true,
"CL-TA": true, "CL-VS": true, "CM-AD": true, "CM-CE": true, "CM-EN": true,
"CM-ES": true, "CM-LT": true, "CM-NO": true, "CM-NW": true, "CM-OU": true,
"CM-SU": true, "CM-SW": true, "CN-AH": true, "CN-BJ": true, "CN-CQ": true,
"CN-FJ": true, "CN-GS": true, "CN-GD": true, "CN-GX": true, "CN-GZ": true,
"CN-HI": true, "CN-HE": true, "CN-HL": true, "CN-HA": true, "CN-HB": true,
"CN-HN": true, "CN-JS": true, "CN-JX": true, "CN-JL": true, "CN-LN": true,
"CN-NM": true, "CN-NX": true, "CN-QH": true, "CN-SN": true, "CN-SD": true, "CN-SH": true,
"CN-SX": true, "CN-SC": true, "CN-TJ": true, "CN-XJ": true, "CN-XZ": true, "CN-YN": true,
"CN-ZJ": true, "CO-AMA": true, "CO-ANT": true, "CO-ARA": true, "CO-ATL": true,
"CO-BOL": true, "CO-BOY": true, "CO-CAL": true, "CO-CAQ": true, "CO-CAS": true,
"CO-CAU": true, "CO-CES": true, "CO-CHO": true, "CO-COR": true, "CO-CUN": true,
"CO-DC": true, "CO-GUA": true, "CO-GUV": true, "CO-HUI": true, "CO-LAG": true,
"CO-MAG": true, "CO-MET": true, "CO-NAR": true, "CO-NSA": true, "CO-PUT": true,
"CO-QUI": true, "CO-RIS": true, "CO-SAN": true, "CO-SAP": true, "CO-SUC": true,
"CO-TOL": true, "CO-VAC": true, "CO-VAU": true, "CO-VID": true, "CR-A": true,
"CR-C": true, "CR-G": true, "CR-H": true, "CR-L": true, "CR-P": true,
"CR-SJ": true, "CU-01": true, "CU-02": true, "CU-03": true, "CU-04": true,
"CU-05": true, "CU-06": true, "CU-07": true, "CU-08": true, "CU-09": true,
"CU-10": true, "CU-11": true, "CU-12": true, "CU-13": true, "CU-14": true, "CU-15": true,
"CU-16": true, "CU-99": true, "CV-B": true, "CV-BR": true, "CV-BV": true, "CV-CA": true,
"CV-CF": true, "CV-CR": true, "CV-MA": true, "CV-MO": true, "CV-PA": true,
"CV-PN": true, "CV-PR": true, "CV-RB": true, "CV-RG": true, "CV-RS": true,
"CV-S": true, "CV-SD": true, "CV-SF": true, "CV-SL": true, "CV-SM": true,
"CV-SO": true, "CV-SS": true, "CV-SV": true, "CV-TA": true, "CV-TS": true,
"CY-01": true, "CY-02": true, "CY-03": true, "CY-04": true, "CY-05": true,
"CY-06": true, "CZ-10": true, "CZ-101": true, "CZ-102": true, "CZ-103": true,
"CZ-104": true, "CZ-105": true, "CZ-106": true, "CZ-107": true, "CZ-108": true,
"CZ-109": true, "CZ-110": true, "CZ-111": true, "CZ-112": true, "CZ-113": true,
"CZ-114": true, "CZ-115": true, "CZ-116": true, "CZ-117": true, "CZ-118": true,
"CZ-119": true, "CZ-120": true, "CZ-121": true, "CZ-122": true, "CZ-20": true,
"CZ-201": true, "CZ-202": true, "CZ-203": true, "CZ-204": true, "CZ-205": true,
"CZ-206": true, "CZ-207": true, "CZ-208": true, "CZ-209": true, "CZ-20A": true,
"CZ-20B": true, "CZ-20C": true, "CZ-31": true, "CZ-311": true, "CZ-312": true,
"CZ-313": true, "CZ-314": true, "CZ-315": true, "CZ-316": true, "CZ-317": true,
"CZ-32": true, "CZ-321": true, "CZ-322": true, "CZ-323": true, "CZ-324": true,
"CZ-325": true, "CZ-326": true, "CZ-327": true, "CZ-41": true, "CZ-411": true,
"CZ-412": true, "CZ-413": true, "CZ-42": true, "CZ-421": true, "CZ-422": true,
"CZ-423": true, "CZ-424": true, "CZ-425": true, "CZ-426": true, "CZ-427": true,
"CZ-51": true, "CZ-511": true, "CZ-512": true, "CZ-513": true, "CZ-514": true,
"CZ-52": true, "CZ-521": true, "CZ-522": true, "CZ-523": true, "CZ-524": true,
"CZ-525": true, "CZ-53": true, "CZ-531": true, "CZ-532": true, "CZ-533": true,
"CZ-534": true, "CZ-63": true, "CZ-631": true, "CZ-632": true, "CZ-633": true,
"CZ-634": true, "CZ-635": true, "CZ-64": true, "CZ-641": true, "CZ-642": true,
"CZ-643": true, "CZ-644": true, "CZ-645": true, "CZ-646": true, "CZ-647": true,
"CZ-71": true, "CZ-711": true, "CZ-712": true, "CZ-713": true, "CZ-714": true,
"CZ-715": true, "CZ-72": true, "CZ-721": true, "CZ-722": true, "CZ-723": true,
"CZ-724": true, "CZ-80": true, "CZ-801": true, "CZ-802": true, "CZ-803": true,
"CZ-804": true, "CZ-805": true, "CZ-806": true, "DE-BB": true, "DE-BE": true,
"DE-BW": true, "DE-BY": true, "DE-HB": true, "DE-HE": true, "DE-HH": true,
"DE-MV": true, "DE-NI": true, "DE-NW": true, "DE-RP": true, "DE-SH": true,
"DE-SL": true, "DE-SN": true, "DE-ST": true, "DE-TH": true, "DJ-AR": true,
"DJ-AS": true, "DJ-DI": true, "DJ-DJ": true, "DJ-OB": true, "DJ-TA": true,
"DK-81": true, "DK-82": true, "DK-83": true, "DK-84": true, "DK-85": true,
"DM-01": true, "DM-02": true, "DM-03": true, "DM-04": true, "DM-05": true,
"DM-06": true, "DM-07": true, "DM-08": true, "DM-09": true, "DM-10": true,
"DO-01": true, "DO-02": true, "DO-03": true, "DO-04": true, "DO-05": true,
"DO-06": true, "DO-07": true, "DO-08": true, "DO-09": true, "DO-10": true,
"DO-11": true, "DO-12": true, "DO-13": true, "DO-14": true, "DO-15": true,
"DO-16": true, "DO-17": true, "DO-18": true, "DO-19": true, "DO-20": true,
"DO-21": true, "DO-22": true, "DO-23": true, "DO-24": true, "DO-25": true,
"DO-26": true, "DO-27": true, "DO-28": true, "DO-29": true, "DO-30": true, "DO-31": true,
"DZ-01": true, "DZ-02": true, "DZ-03": true, "DZ-04": true, "DZ-05": true,
"DZ-06": true, "DZ-07": true, "DZ-08": true, "DZ-09": true, "DZ-10": true,
"DZ-11": true, "DZ-12": true, "DZ-13": true, "DZ-14": true, "DZ-15": true,
"DZ-16": true, "DZ-17": true, "DZ-18": true, "DZ-19": true, "DZ-20": true,
"DZ-21": true, "DZ-22": true, "DZ-23": true, "DZ-24": true, "DZ-25": true,
"DZ-26": true, "DZ-27": true, "DZ-28": true, "DZ-29": true, "DZ-30": true,
"DZ-31": true, "DZ-32": true, "DZ-33": true, "DZ-34": true, "DZ-35": true,
"DZ-36": true, "DZ-37": true, "DZ-38": true, "DZ-39": true, "DZ-40": true,
"DZ-41": true, "DZ-42": true, "DZ-43": true, "DZ-44": true, "DZ-45": true,
"DZ-46": true, "DZ-47": true, "DZ-48": true, "DZ-49": true, "DZ-51": true,
"DZ-53": true, "DZ-55": true, "DZ-56": true, "DZ-57": true, "EC-A": true, "EC-B": true,
"EC-C": true, "EC-D": true, "EC-E": true, "EC-F": true, "EC-G": true,
"EC-H": true, "EC-I": true, "EC-L": true, "EC-M": true, "EC-N": true,
"EC-O": true, "EC-P": true, "EC-R": true, "EC-S": true, "EC-SD": true,
"EC-SE": true, "EC-T": true, "EC-U": true, "EC-W": true, "EC-X": true,
"EC-Y": true, "EC-Z": true, "EE-37": true, "EE-39": true, "EE-44": true, "EE-45": true,
"EE-49": true, "EE-50": true, "EE-51": true, "EE-52": true, "EE-56": true, "EE-57": true,
"EE-59": true, "EE-60": true, "EE-64": true, "EE-65": true, "EE-67": true, "EE-68": true,
"EE-70": true, "EE-71": true, "EE-74": true, "EE-78": true, "EE-79": true, "EE-81": true, "EE-82": true,
"EE-84": true, "EE-86": true, "EE-87": true, "EG-ALX": true, "EG-ASN": true, "EG-AST": true,
"EG-BA": true, "EG-BH": true, "EG-BNS": true, "EG-C": true, "EG-DK": true,
"EG-DT": true, "EG-FYM": true, "EG-GH": true, "EG-GZ": true, "EG-HU": true,
"EG-IS": true, "EG-JS": true, "EG-KB": true, "EG-KFS": true, "EG-KN": true,
"EG-LX": true, "EG-MN": true, "EG-MNF": true, "EG-MT": true, "EG-PTS": true, "EG-SHG": true,
"EG-SHR": true, "EG-SIN": true, "EG-SU": true, "EG-SUZ": true, "EG-WAD": true,
"ER-AN": true, "ER-DK": true, "ER-DU": true, "ER-GB": true, "ER-MA": true,
"ER-SK": true, "ES-A": true, "ES-AB": true, "ES-AL": true, "ES-AN": true,
"ES-AR": true, "ES-AS": true, "ES-AV": true, "ES-B": true, "ES-BA": true,
"ES-BI": true, "ES-BU": true, "ES-C": true, "ES-CA": true, "ES-CB": true,
"ES-CC": true, "ES-CE": true, "ES-CL": true, "ES-CM": true, "ES-CN": true,
"ES-CO": true, "ES-CR": true, "ES-CS": true, "ES-CT": true, "ES-CU": true,
"ES-EX": true, "ES-GA": true, "ES-GC": true, "ES-GI": true, "ES-GR": true,
"ES-GU": true, "ES-H": true, "ES-HU": true, "ES-IB": true, "ES-J": true,
"ES-L": true, "ES-LE": true, "ES-LO": true, "ES-LU": true, "ES-M": true,
"ES-MA": true, "ES-MC": true, "ES-MD": true, "ES-ML": true, "ES-MU": true,
"ES-NA": true, "ES-NC": true, "ES-O": true, "ES-OR": true, "ES-P": true,
"ES-PM": true, "ES-PO": true, "ES-PV": true, "ES-RI": true, "ES-S": true,
"ES-SA": true, "ES-SE": true, "ES-SG": true, "ES-SO": true, "ES-SS": true,
"ES-T": true, "ES-TE": true, "ES-TF": true, "ES-TO": true, "ES-V": true,
"ES-VA": true, "ES-VC": true, "ES-VI": true, "ES-Z": true, "ES-ZA": true,
"ET-AA": true, "ET-AF": true, "ET-AM": true, "ET-BE": true, "ET-DD": true,
"ET-GA": true, "ET-HA": true, "ET-OR": true, "ET-SN": true, "ET-SO": true,
"ET-TI": true, "FI-01": true, "FI-02": true, "FI-03": true, "FI-04": true,
"FI-05": true, "FI-06": true, "FI-07": true, "FI-08": true, "FI-09": true,
"FI-10": true, "FI-11": true, "FI-12": true, "FI-13": true, "FI-14": true,
"FI-15": true, "FI-16": true, "FI-17": true, "FI-18": true, "FI-19": true,
"FJ-C": true, "FJ-E": true, "FJ-N": true, "FJ-R": true, "FJ-W": true,
"FM-KSA": true, "FM-PNI": true, "FM-TRK": true, "FM-YAP": true, "FR-01": true,
"FR-02": true, "FR-03": true, "FR-04": true, "FR-05": true, "FR-06": true,
"FR-07": true, "FR-08": true, "FR-09": true, "FR-10": true, "FR-11": true,
"FR-12": true, "FR-13": true, "FR-14": true, "FR-15": true, "FR-16": true,
"FR-17": true, "FR-18": true, "FR-19": true, "FR-20R": true, "FR-21": true, "FR-22": true,
"FR-23": true, "FR-24": true, "FR-25": true, "FR-26": true, "FR-27": true,
"FR-28": true, "FR-29": true, "FR-2A": true, "FR-2B": true, "FR-30": true,
"FR-31": true, "FR-32": true, "FR-33": true, "FR-34": true, "FR-35": true,
"FR-36": true, "FR-37": true, "FR-38": true, "FR-39": true, "FR-40": true,
"FR-41": true, "FR-42": true, "FR-43": true, "FR-44": true, "FR-45": true,
"FR-46": true, "FR-47": true, "FR-48": true, "FR-49": true, "FR-50": true,
"FR-51": true, "FR-52": true, "FR-53": true, "FR-54": true, "FR-55": true,
"FR-56": true, "FR-57": true, "FR-58": true, "FR-59": true, "FR-60": true,
"FR-61": true, "FR-62": true, "FR-63": true, "FR-64": true, "FR-65": true,
"FR-66": true, "FR-67": true, "FR-68": true, "FR-69": true, "FR-70": true,
"FR-71": true, "FR-72": true, "FR-73": true, "FR-74": true, "FR-75": true,
"FR-76": true, "FR-77": true, "FR-78": true, "FR-79": true, "FR-80": true,
"FR-81": true, "FR-82": true, "FR-83": true, "FR-84": true, "FR-85": true,
"FR-86": true, "FR-87": true, "FR-88": true, "FR-89": true, "FR-90": true,
"FR-91": true, "FR-92": true, "FR-93": true, "FR-94": true, "FR-95": true,
"FR-ARA": true, "FR-BFC": true, "FR-BL": true, "FR-BRE": true, "FR-COR": true,
"FR-CP": true, "FR-CVL": true, "FR-GES": true, "FR-GF": true, "FR-GP": true,
"FR-GUA": true, "FR-HDF": true, "FR-IDF": true, "FR-LRE": true, "FR-MAY": true,
"FR-MF": true, "FR-MQ": true, "FR-NAQ": true, "FR-NC": true, "FR-NOR": true,
"FR-OCC": true, "FR-PAC": true, "FR-PDL": true, "FR-PF": true, "FR-PM": true,
"FR-RE": true, "FR-TF": true, "FR-WF": true, "FR-YT": true, "GA-1": true,
"GA-2": true, "GA-3": true, "GA-4": true, "GA-5": true, "GA-6": true,
"GA-7": true, "GA-8": true, "GA-9": true, "GB-ABC": true, "GB-ABD": true,
"GB-ABE": true, "GB-AGB": true, "GB-AGY": true, "GB-AND": true, "GB-ANN": true,
"GB-ANS": true, "GB-BAS": true, "GB-BBD": true, "GB-BDF": true, "GB-BDG": true,
"GB-BEN": true, "GB-BEX": true, "GB-BFS": true, "GB-BGE": true, "GB-BGW": true,
"GB-BIR": true, "GB-BKM": true, "GB-BMH": true, "GB-BNE": true, "GB-BNH": true,
"GB-BNS": true, "GB-BOL": true, "GB-BPL": true, "GB-BRC": true, "GB-BRD": true,
"GB-BRY": true, "GB-BST": true, "GB-BUR": true, "GB-CAM": true, "GB-CAY": true,
"GB-CBF": true, "GB-CCG": true, "GB-CGN": true, "GB-CHE": true, "GB-CHW": true,
"GB-CLD": true, "GB-CLK": true, "GB-CMA": true, "GB-CMD": true, "GB-CMN": true,
"GB-CON": true, "GB-COV": true, "GB-CRF": true, "GB-CRY": true, "GB-CWY": true,
"GB-DAL": true, "GB-DBY": true, "GB-DEN": true, "GB-DER": true, "GB-DEV": true,
"GB-DGY": true, "GB-DNC": true, "GB-DND": true, "GB-DOR": true, "GB-DRS": true,
"GB-DUD": true, "GB-DUR": true, "GB-EAL": true, "GB-EAW": true, "GB-EAY": true,
"GB-EDH": true, "GB-EDU": true, "GB-ELN": true, "GB-ELS": true, "GB-ENF": true,
"GB-ENG": true, "GB-ERW": true, "GB-ERY": true, "GB-ESS": true, "GB-ESX": true,
"GB-FAL": true, "GB-FIF": true, "GB-FLN": true, "GB-FMO": true, "GB-GAT": true,
"GB-GBN": true, "GB-GLG": true, "GB-GLS": true, "GB-GRE": true, "GB-GWN": true,
"GB-HAL": true, "GB-HAM": true, "GB-HAV": true, "GB-HCK": true, "GB-HEF": true,
"GB-HIL": true, "GB-HLD": true, "GB-HMF": true, "GB-HNS": true, "GB-HPL": true,
"GB-HRT": true, "GB-HRW": true, "GB-HRY": true, "GB-IOS": true, "GB-IOW": true,
"GB-ISL": true, "GB-IVC": true, "GB-KEC": true, "GB-KEN": true, "GB-KHL": true,
"GB-KIR": true, "GB-KTT": true, "GB-KWL": true, "GB-LAN": true, "GB-LBC": true,
"GB-LBH": true, "GB-LCE": true, "GB-LDS": true, "GB-LEC": true, "GB-LEW": true,
"GB-LIN": true, "GB-LIV": true, "GB-LND": true, "GB-LUT": true, "GB-MAN": true,
"GB-MDB": true, "GB-MDW": true, "GB-MEA": true, "GB-MIK": true, "GD-01": true,
"GB-MLN": true, "GB-MON": true, "GB-MRT": true, "GB-MRY": true, "GB-MTY": true,
"GB-MUL": true, "GB-NAY": true, "GB-NBL": true, "GB-NEL": true, "GB-NET": true,
"GB-NFK": true, "GB-NGM": true, "GB-NIR": true, "GB-NLK": true, "GB-NLN": true,
"GB-NMD": true, "GB-NSM": true, "GB-NTH": true, "GB-NTL": true, "GB-NTT": true,
"GB-NTY": true, "GB-NWM": true, "GB-NWP": true, "GB-NYK": true, "GB-OLD": true,
"GB-ORK": true, "GB-OXF": true, "GB-PEM": true, "GB-PKN": true, "GB-PLY": true,
"GB-POL": true, "GB-POR": true, "GB-POW": true, "GB-PTE": true, "GB-RCC": true,
"GB-RCH": true, "GB-RCT": true, "GB-RDB": true, "GB-RDG": true, "GB-RFW": true,
"GB-RIC": true, "GB-ROT": true, "GB-RUT": true, "GB-SAW": true, "GB-SAY": true,
"GB-SCB": true, "GB-SCT": true, "GB-SFK": true, "GB-SFT": true, "GB-SGC": true,
"GB-SHF": true, "GB-SHN": true, "GB-SHR": true, "GB-SKP": true, "GB-SLF": true,
"GB-SLG": true, "GB-SLK": true, "GB-SND": true, "GB-SOL": true, "GB-SOM": true,
"GB-SOS": true, "GB-SRY": true, "GB-STE": true, "GB-STG": true, "GB-STH": true,
"GB-STN": true, "GB-STS": true, "GB-STT": true, "GB-STY": true, "GB-SWA": true,
"GB-SWD": true, "GB-SWK": true, "GB-TAM": true, "GB-TFW": true, "GB-THR": true,
"GB-TOB": true, "GB-TOF": true, "GB-TRF": true, "GB-TWH": true, "GB-UKM": true,
"GB-VGL": true, "GB-WAR": true, "GB-WBK": true, "GB-WDU": true, "GB-WFT": true,
"GB-WGN": true, "GB-WIL": true, "GB-WKF": true, "GB-WLL": true, "GB-WLN": true,
"GB-WLS": true, "GB-WLV": true, "GB-WND": true, "GB-WNM": true, "GB-WOK": true,
"GB-WOR": true, "GB-WRL": true, "GB-WRT": true, "GB-WRX": true, "GB-WSM": true,
"GB-WSX": true, "GB-YOR": true, "GB-ZET": true, "GD-02": true, "GD-03": true,
"GD-04": true, "GD-05": true, "GD-06": true, "GD-10": true, "GE-AB": true,
"GE-AJ": true, "GE-GU": true, "GE-IM": true, "GE-KA": true, "GE-KK": true,
"GE-MM": true, "GE-RL": true, "GE-SJ": true, "GE-SK": true, "GE-SZ": true,
"GE-TB": true, "GH-AA": true, "GH-AH": true, "GH-AF": true, "GH-BA": true, "GH-BO": true, "GH-BE": true, "GH-CP": true,
"GH-EP": true, "GH-NP": true, "GH-TV": true, "GH-UE": true, "GH-UW": true,
"GH-WP": true, "GL-AV": true, "GL-KU": true, "GL-QA": true, "GL-QT": true, "GL-QE": true, "GL-SM": true,
"GM-B": true, "GM-L": true, "GM-M": true, "GM-N": true, "GM-U": true,
"GM-W": true, "GN-B": true, "GN-BE": true, "GN-BF": true, "GN-BK": true,
"GN-C": true, "GN-CO": true, "GN-D": true, "GN-DB": true, "GN-DI": true,
"GN-DL": true, "GN-DU": true, "GN-F": true, "GN-FA": true, "GN-FO": true,
"GN-FR": true, "GN-GA": true, "GN-GU": true, "GN-K": true, "GN-KA": true,
"GN-KB": true, "GN-KD": true, "GN-KE": true, "GN-KN": true, "GN-KO": true,
"GN-KS": true, "GN-L": true, "GN-LA": true, "GN-LE": true, "GN-LO": true,
"GN-M": true, "GN-MC": true, "GN-MD": true, "GN-ML": true, "GN-MM": true,
"GN-N": true, "GN-NZ": true, "GN-PI": true, "GN-SI": true, "GN-TE": true,
"GN-TO": true, "GN-YO": true, "GQ-AN": true, "GQ-BN": true, "GQ-BS": true,
"GQ-C": true, "GQ-CS": true, "GQ-I": true, "GQ-KN": true, "GQ-LI": true,
"GQ-WN": true, "GR-01": true, "GR-03": true, "GR-04": true, "GR-05": true,
"GR-06": true, "GR-07": true, "GR-11": true, "GR-12": true, "GR-13": true,
"GR-14": true, "GR-15": true, "GR-16": true, "GR-17": true, "GR-21": true,
"GR-22": true, "GR-23": true, "GR-24": true, "GR-31": true, "GR-32": true,
"GR-33": true, "GR-34": true, "GR-41": true, "GR-42": true, "GR-43": true,
"GR-44": true, "GR-51": true, "GR-52": true, "GR-53": true, "GR-54": true,
"GR-55": true, "GR-56": true, "GR-57": true, "GR-58": true, "GR-59": true,
"GR-61": true, "GR-62": true, "GR-63": true, "GR-64": true, "GR-69": true,
"GR-71": true, "GR-72": true, "GR-73": true, "GR-81": true, "GR-82": true,
"GR-83": true, "GR-84": true, "GR-85": true, "GR-91": true, "GR-92": true,
"GR-93": true, "GR-94": true, "GR-A": true, "GR-A1": true, "GR-B": true,
"GR-C": true, "GR-D": true, "GR-E": true, "GR-F": true, "GR-G": true,
"GR-H": true, "GR-I": true, "GR-J": true, "GR-K": true, "GR-L": true,
"GR-M": true, "GT-01": true, "GT-02": true, "GT-03": true, "GT-04": true,
"GT-05": true, "GT-06": true, "GT-07": true, "GT-08": true, "GT-09": true,
"GT-10": true, "GT-11": true, "GT-12": true, "GT-13": true, "GT-14": true,
"GT-15": true, "GT-16": true, "GT-17": true, "GT-18": true, "GT-19": true,
"GT-20": true, "GT-21": true, "GT-22": true, "GW-BA": true, "GW-BL": true,
"GW-BM": true, "GW-BS": true, "GW-CA": true, "GW-GA": true, "GW-L": true,
"GW-N": true, "GW-OI": true, "GW-QU": true, "GW-S": true, "GW-TO": true,
"GY-BA": true, "GY-CU": true, "GY-DE": true, "GY-EB": true, "GY-ES": true,
"GY-MA": true, "GY-PM": true, "GY-PT": true, "GY-UD": true, "GY-UT": true,
"HN-AT": true, "HN-CH": true, "HN-CL": true, "HN-CM": true, "HN-CP": true,
"HN-CR": true, "HN-EP": true, "HN-FM": true, "HN-GD": true, "HN-IB": true,
"HN-IN": true, "HN-LE": true, "HN-LP": true, "HN-OC": true, "HN-OL": true,
"HN-SB": true, "HN-VA": true, "HN-YO": true, "HR-01": true, "HR-02": true,
"HR-03": true, "HR-04": true, "HR-05": true, "HR-06": true, "HR-07": true,
"HR-08": true, "HR-09": true, "HR-10": true, "HR-11": true, "HR-12": true,
"HR-13": true, "HR-14": true, "HR-15": true, "HR-16": true, "HR-17": true,
"HR-18": true, "HR-19": true, "HR-20": true, "HR-21": true, "HT-AR": true,
"HT-CE": true, "HT-GA": true, "HT-ND": true, "HT-NE": true, "HT-NO": true, "HT-NI": true,
"HT-OU": true, "HT-SD": true, "HT-SE": true, "HU-BA": true, "HU-BC": true,
"HU-BE": true, "HU-BK": true, "HU-BU": true, "HU-BZ": true, "HU-CS": true,
"HU-DE": true, "HU-DU": true, "HU-EG": true, "HU-ER": true, "HU-FE": true,
"HU-GS": true, "HU-GY": true, "HU-HB": true, "HU-HE": true, "HU-HV": true,
"HU-JN": true, "HU-KE": true, "HU-KM": true, "HU-KV": true, "HU-MI": true,
"HU-NK": true, "HU-NO": true, "HU-NY": true, "HU-PE": true, "HU-PS": true,
"HU-SD": true, "HU-SF": true, "HU-SH": true, "HU-SK": true, "HU-SN": true,
"HU-SO": true, "HU-SS": true, "HU-ST": true, "HU-SZ": true, "HU-TB": true,
"HU-TO": true, "HU-VA": true, "HU-VE": true, "HU-VM": true, "HU-ZA": true,
"HU-ZE": true, "ID-AC": true, "ID-BA": true, "ID-BB": true, "ID-BE": true,
"ID-BT": true, "ID-GO": true, "ID-IJ": true, "ID-JA": true, "ID-JB": true,
"ID-JI": true, "ID-JK": true, "ID-JT": true, "ID-JW": true, "ID-KA": true,
"ID-KB": true, "ID-KI": true, "ID-KU": true, "ID-KR": true, "ID-KS": true,
"ID-KT": true, "ID-LA": true, "ID-MA": true, "ID-ML": true, "ID-MU": true,
"ID-NB": true, "ID-NT": true, "ID-NU": true, "ID-PA": true, "ID-PB": true,
"ID-PE": true, "ID-PP": true, "ID-PS": true, "ID-PT": true, "ID-RI": true,
"ID-SA": true, "ID-SB": true, "ID-SG": true, "ID-SL": true, "ID-SM": true,
"ID-SN": true, "ID-SR": true, "ID-SS": true, "ID-ST": true, "ID-SU": true,
"ID-YO": true, "IE-C": true, "IE-CE": true, "IE-CN": true, "IE-CO": true,
"IE-CW": true, "IE-D": true, "IE-DL": true, "IE-G": true, "IE-KE": true,
"IE-KK": true, "IE-KY": true, "IE-L": true, "IE-LD": true, "IE-LH": true,
"IE-LK": true, "IE-LM": true, "IE-LS": true, "IE-M": true, "IE-MH": true,
"IE-MN": true, "IE-MO": true, "IE-OY": true, "IE-RN": true, "IE-SO": true,
"IE-TA": true, "IE-U": true, "IE-WD": true, "IE-WH": true, "IE-WW": true,
"IE-WX": true, "IL-D": true, "IL-HA": true, "IL-JM": true, "IL-M": true,
"IL-TA": true, "IL-Z": true, "IN-AN": true, "IN-AP": true, "IN-AR": true,
"IN-AS": true, "IN-BR": true, "IN-CH": true, "IN-CT": true, "IN-DH": true,
"IN-DL": true, "IN-DN": true, "IN-GA": true, "IN-GJ": true, "IN-HP": true,
"IN-HR": true, "IN-JH": true, "IN-JK": true, "IN-KA": true, "IN-KL": true,
"IN-LD": true, "IN-MH": true, "IN-ML": true, "IN-MN": true, "IN-MP": true,
"IN-MZ": true, "IN-NL": true, "IN-TG": true, "IN-OR": true, "IN-PB": true, "IN-PY": true,
"IN-RJ": true, "IN-SK": true, "IN-TN": true, "IN-TR": true, "IN-UP": true,
"IN-UT": true, "IN-WB": true, "IQ-AN": true, "IQ-AR": true, "IQ-BA": true,
"IQ-BB": true, "IQ-BG": true, "IQ-DA": true, "IQ-DI": true, "IQ-DQ": true,
"IQ-KA": true, "IQ-KI": true, "IQ-MA": true, "IQ-MU": true, "IQ-NA": true, "IQ-NI": true,
"IQ-QA": true, "IQ-SD": true, "IQ-SW": true, "IQ-SU": true, "IQ-TS": true, "IQ-WA": true,
"IR-00": true, "IR-01": true, "IR-02": true, "IR-03": true, "IR-04": true, "IR-05": true,
"IR-06": true, "IR-07": true, "IR-08": true, "IR-09": true, "IR-10": true, "IR-11": true,
"IR-12": true, "IR-13": true, "IR-14": true, "IR-15": true, "IR-16": true,
"IR-17": true, "IR-18": true, "IR-19": true, "IR-20": true, "IR-21": true,
"IR-22": true, "IR-23": true, "IR-24": true, "IR-25": true, "IR-26": true,
"IR-27": true, "IR-28": true, "IR-29": true, "IR-30": true, "IR-31": true,
"IS-0": true, "IS-1": true, "IS-2": true, "IS-3": true, "IS-4": true,
"IS-5": true, "IS-6": true, "IS-7": true, "IS-8": true, "IT-21": true,
"IT-23": true, "IT-25": true, "IT-32": true, "IT-34": true, "IT-36": true,
"IT-42": true, "IT-45": true, "IT-52": true, "IT-55": true, "IT-57": true,
"IT-62": true, "IT-65": true, "IT-67": true, "IT-72": true, "IT-75": true,
"IT-77": true, "IT-78": true, "IT-82": true, "IT-88": true, "IT-AG": true,
"IT-AL": true, "IT-AN": true, "IT-AO": true, "IT-AP": true, "IT-AQ": true,
"IT-AR": true, "IT-AT": true, "IT-AV": true, "IT-BA": true, "IT-BG": true,
"IT-BI": true, "IT-BL": true, "IT-BN": true, "IT-BO": true, "IT-BR": true,
"IT-BS": true, "IT-BT": true, "IT-BZ": true, "IT-CA": true, "IT-CB": true,
"IT-CE": true, "IT-CH": true, "IT-CI": true, "IT-CL": true, "IT-CN": true,
"IT-CO": true, "IT-CR": true, "IT-CS": true, "IT-CT": true, "IT-CZ": true,
"IT-EN": true, "IT-FC": true, "IT-FE": true, "IT-FG": true, "IT-FI": true,
"IT-FM": true, "IT-FR": true, "IT-GE": true, "IT-GO": true, "IT-GR": true,
"IT-IM": true, "IT-IS": true, "IT-KR": true, "IT-LC": true, "IT-LE": true,
"IT-LI": true, "IT-LO": true, "IT-LT": true, "IT-LU": true, "IT-MB": true,
"IT-MC": true, "IT-ME": true, "IT-MI": true, "IT-MN": true, "IT-MO": true,
"IT-MS": true, "IT-MT": true, "IT-NA": true, "IT-NO": true, "IT-NU": true,
"IT-OG": true, "IT-OR": true, "IT-OT": true, "IT-PA": true, "IT-PC": true,
"IT-PD": true, "IT-PE": true, "IT-PG": true, "IT-PI": true, "IT-PN": true,
"IT-PO": true, "IT-PR": true, "IT-PT": true, "IT-PU": true, "IT-PV": true,
"IT-PZ": true, "IT-RA": true, "IT-RC": true, "IT-RE": true, "IT-RG": true,
"IT-RI": true, "IT-RM": true, "IT-RN": true, "IT-RO": true, "IT-SA": true,
"IT-SI": true, "IT-SO": true, "IT-SP": true, "IT-SR": true, "IT-SS": true,
"IT-SV": true, "IT-TA": true, "IT-TE": true, "IT-TN": true, "IT-TO": true,
"IT-TP": true, "IT-TR": true, "IT-TS": true, "IT-TV": true, "IT-UD": true,
"IT-VA": true, "IT-VB": true, "IT-VC": true, "IT-VE": true, "IT-VI": true,
"IT-VR": true, "IT-VS": true, "IT-VT": true, "IT-VV": true, "JM-01": true,
"JM-02": true, "JM-03": true, "JM-04": true, "JM-05": true, "JM-06": true,
"JM-07": true, "JM-08": true, "JM-09": true, "JM-10": true, "JM-11": true,
"JM-12": true, "JM-13": true, "JM-14": true, "JO-AJ": true, "JO-AM": true,
"JO-AQ": true, "JO-AT": true, "JO-AZ": true, "JO-BA": true, "JO-IR": true,
"JO-JA": true, "JO-KA": true, "JO-MA": true, "JO-MD": true, "JO-MN": true,
"JP-01": true, "JP-02": true, "JP-03": true, "JP-04": true, "JP-05": true,
"JP-06": true, "JP-07": true, "JP-08": true, "JP-09": true, "JP-10": true,
"JP-11": true, "JP-12": true, "JP-13": true, "JP-14": true, "JP-15": true,
"JP-16": true, "JP-17": true, "JP-18": true, "JP-19": true, "JP-20": true,
"JP-21": true, "JP-22": true, "JP-23": true, "JP-24": true, "JP-25": true,
"JP-26": true, "JP-27": true, "JP-28": true, "JP-29": true, "JP-30": true,
"JP-31": true, "JP-32": true, "JP-33": true, "JP-34": true, "JP-35": true,
"JP-36": true, "JP-37": true, "JP-38": true, "JP-39": true, "JP-40": true,
"JP-41": true, "JP-42": true, "JP-43": true, "JP-44": true, "JP-45": true,
"JP-46": true, "JP-47": true, "KE-01": true, "KE-02": true, "KE-03": true,
"KE-04": true, "KE-05": true, "KE-06": true, "KE-07": true, "KE-08": true,
"KE-09": true, "KE-10": true, "KE-11": true, "KE-12": true, "KE-13": true,
"KE-14": true, "KE-15": true, "KE-16": true, "KE-17": true, "KE-18": true,
"KE-19": true, "KE-20": true, "KE-21": true, "KE-22": true, "KE-23": true,
"KE-24": true, "KE-25": true, "KE-26": true, "KE-27": true, "KE-28": true,
"KE-29": true, "KE-30": true, "KE-31": true, "KE-32": true, "KE-33": true,
"KE-34": true, "KE-35": true, "KE-36": true, "KE-37": true, "KE-38": true,
"KE-39": true, "KE-40": true, "KE-41": true, "KE-42": true, "KE-43": true,
"KE-44": true, "KE-45": true, "KE-46": true, "KE-47": true, "KG-B": true,
"KG-C": true, "KG-GB": true, "KG-GO": true, "KG-J": true, "KG-N": true, "KG-O": true,
"KG-T": true, "KG-Y": true, "KH-1": true, "KH-10": true, "KH-11": true,
"KH-12": true, "KH-13": true, "KH-14": true, "KH-15": true, "KH-16": true,
"KH-17": true, "KH-18": true, "KH-19": true, "KH-2": true, "KH-20": true,
"KH-21": true, "KH-22": true, "KH-23": true, "KH-24": true, "KH-3": true,
"KH-4": true, "KH-5": true, "KH-6": true, "KH-7": true, "KH-8": true,
"KH-9": true, "KI-G": true, "KI-L": true, "KI-P": true, "KM-A": true,
"KM-G": true, "KM-M": true, "KN-01": true, "KN-02": true, "KN-03": true,
"KN-04": true, "KN-05": true, "KN-06": true, "KN-07": true, "KN-08": true,
"KN-09": true, "KN-10": true, "KN-11": true, "KN-12": true, "KN-13": true,
"KN-15": true, "KN-K": true, "KN-N": true, "KP-01": true, "KP-02": true,
"KP-03": true, "KP-04": true, "KP-05": true, "KP-06": true, "KP-07": true,
"KP-08": true, "KP-09": true, "KP-10": true, "KP-13": true, "KR-11": true,
"KR-26": true, "KR-27": true, "KR-28": true, "KR-29": true, "KR-30": true,
"KR-31": true, "KR-41": true, "KR-42": true, "KR-43": true, "KR-44": true,
"KR-45": true, "KR-46": true, "KR-47": true, "KR-48": true, "KR-49": true,
"KW-AH": true, "KW-FA": true, "KW-HA": true, "KW-JA": true, "KW-KU": true,
"KW-MU": true, "KZ-10": true, "KZ-75": true, "KZ-19": true, "KZ-11": true,
"KZ-15": true, "KZ-71": true, "KZ-23": true, "KZ-27": true, "KZ-47": true,
"KZ-55": true, "KZ-35": true, "KZ-39": true, "KZ-43": true, "KZ-63": true,
"KZ-79": true, "KZ-59": true, "KZ-61": true, "KZ-62": true, "KZ-31": true,
"KZ-33": true, "LA-AT": true, "LA-BK": true, "LA-BL": true,
"LA-CH": true, "LA-HO": true, "LA-KH": true, "LA-LM": true, "LA-LP": true,
"LA-OU": true, "LA-PH": true, "LA-SL": true, "LA-SV": true, "LA-VI": true,
"LA-VT": true, "LA-XA": true, "LA-XE": true, "LA-XI": true, "LA-XS": true,
"LB-AK": true, "LB-AS": true, "LB-BA": true, "LB-BH": true, "LB-BI": true,
"LB-JA": true, "LB-JL": true, "LB-NA": true, "LC-01": true, "LC-02": true,
"LC-03": true, "LC-05": true, "LC-06": true, "LC-07": true, "LC-08": true,
"LC-10": true, "LC-11": true, "LI-01": true, "LI-02": true,
"LI-03": true, "LI-04": true, "LI-05": true, "LI-06": true, "LI-07": true,
"LI-08": true, "LI-09": true, "LI-10": true, "LI-11": true, "LK-1": true,
"LK-11": true, "LK-12": true, "LK-13": true, "LK-2": true, "LK-21": true,
"LK-22": true, "LK-23": true, "LK-3": true, "LK-31": true, "LK-32": true,
"LK-33": true, "LK-4": true, "LK-41": true, "LK-42": true, "LK-43": true,
"LK-44": true, "LK-45": true, "LK-5": true, "LK-51": true, "LK-52": true,
"LK-53": true, "LK-6": true, "LK-61": true, "LK-62": true, "LK-7": true,
"LK-71": true, "LK-72": true, "LK-8": true, "LK-81": true, "LK-82": true,
"LK-9": true, "LK-91": true, "LK-92": true, "LR-BG": true, "LR-BM": true,
"LR-CM": true, "LR-GB": true, "LR-GG": true, "LR-GK": true, "LR-LO": true,
"LR-MG": true, "LR-MO": true, "LR-MY": true, "LR-NI": true, "LR-RI": true,
"LR-SI": true, "LS-A": true, "LS-B": true, "LS-C": true, "LS-D": true,
"LS-E": true, "LS-F": true, "LS-G": true, "LS-H": true, "LS-J": true,
"LS-K": true, "LT-AL": true, "LT-KL": true, "LT-KU": true, "LT-MR": true,
"LT-PN": true, "LT-SA": true, "LT-TA": true, "LT-TE": true, "LT-UT": true,
"LT-VL": true, "LU-CA": true, "LU-CL": true, "LU-DI": true, "LU-EC": true,
"LU-ES": true, "LU-GR": true, "LU-LU": true, "LU-ME": true, "LU-RD": true,
"LU-RM": true, "LU-VD": true, "LU-WI": true, "LU-D": true, "LU-G": true, "LU-L": true,
"LV-001": true, "LV-111": true, "LV-112": true, "LV-113": true,
"LV-002": true, "LV-003": true, "LV-004": true, "LV-005": true, "LV-006": true,
"LV-007": true, "LV-008": true, "LV-009": true, "LV-010": true, "LV-011": true,
"LV-012": true, "LV-013": true, "LV-014": true, "LV-015": true, "LV-016": true,
"LV-017": true, "LV-018": true, "LV-019": true, "LV-020": true, "LV-021": true,
"LV-022": true, "LV-023": true, "LV-024": true, "LV-025": true, "LV-026": true,
"LV-027": true, "LV-028": true, "LV-029": true, "LV-030": true, "LV-031": true,
"LV-032": true, "LV-033": true, "LV-034": true, "LV-035": true, "LV-036": true,
"LV-037": true, "LV-038": true, "LV-039": true, "LV-040": true, "LV-041": true,
"LV-042": true, "LV-043": true, "LV-044": true, "LV-045": true, "LV-046": true,
"LV-047": true, "LV-048": true, "LV-049": true, "LV-050": true, "LV-051": true,
"LV-052": true, "LV-053": true, "LV-054": true, "LV-055": true, "LV-056": true,
"LV-057": true, "LV-058": true, "LV-059": true, "LV-060": true, "LV-061": true,
"LV-062": true, "LV-063": true, "LV-064": true, "LV-065": true, "LV-066": true,
"LV-067": true, "LV-068": true, "LV-069": true, "LV-070": true, "LV-071": true,
"LV-072": true, "LV-073": true, "LV-074": true, "LV-075": true, "LV-076": true,
"LV-077": true, "LV-078": true, "LV-079": true, "LV-080": true, "LV-081": true,
"LV-082": true, "LV-083": true, "LV-084": true, "LV-085": true, "LV-086": true,
"LV-087": true, "LV-088": true, "LV-089": true, "LV-090": true, "LV-091": true,
"LV-092": true, "LV-093": true, "LV-094": true, "LV-095": true, "LV-096": true,
"LV-097": true, "LV-098": true, "LV-099": true, "LV-100": true, "LV-101": true,
"LV-102": true, "LV-103": true, "LV-104": true, "LV-105": true, "LV-106": true,
"LV-107": true, "LV-108": true, "LV-109": true, "LV-110": true, "LV-DGV": true,
"LV-JEL": true, "LV-JKB": true, "LV-JUR": true, "LV-LPX": true, "LV-REZ": true,
"LV-RIX": true, "LV-VEN": true, "LV-VMR": true, "LY-BA": true, "LY-BU": true,
"LY-DR": true, "LY-GT": true, "LY-JA": true, "LY-JB": true, "LY-JG": true,
"LY-JI": true, "LY-JU": true, "LY-KF": true, "LY-MB": true, "LY-MI": true,
"LY-MJ": true, "LY-MQ": true, "LY-NL": true, "LY-NQ": true, "LY-SB": true,
"LY-SR": true, "LY-TB": true, "LY-WA": true, "LY-WD": true, "LY-WS": true,
"LY-ZA": true, "MA-01": true, "MA-02": true, "MA-03": true, "MA-04": true,
"MA-05": true, "MA-06": true, "MA-07": true, "MA-08": true, "MA-09": true,
"MA-10": true, "MA-11": true, "MA-12": true, "MA-13": true, "MA-14": true,
"MA-15": true, "MA-16": true, "MA-AGD": true, "MA-AOU": true, "MA-ASZ": true,
"MA-AZI": true, "MA-BEM": true, "MA-BER": true, "MA-BES": true, "MA-BOD": true,
"MA-BOM": true, "MA-CAS": true, "MA-CHE": true, "MA-CHI": true, "MA-CHT": true,
"MA-ERR": true, "MA-ESI": true, "MA-ESM": true, "MA-FAH": true, "MA-FES": true,
"MA-FIG": true, "MA-GUE": true, "MA-HAJ": true, "MA-HAO": true, "MA-HOC": true,
"MA-IFR": true, "MA-INE": true, "MA-JDI": true, "MA-JRA": true, "MA-KEN": true,
"MA-KES": true, "MA-KHE": true, "MA-KHN": true, "MA-KHO": true, "MA-LAA": true,
"MA-LAR": true, "MA-MED": true, "MA-MEK": true, "MA-MMD": true, "MA-MMN": true,
"MA-MOH": true, "MA-MOU": true, "MA-NAD": true, "MA-NOU": true, "MA-OUA": true,
"MA-OUD": true, "MA-OUJ": true, "MA-RAB": true, "MA-SAF": true, "MA-SAL": true,
"MA-SEF": true, "MA-SET": true, "MA-SIK": true, "MA-SKH": true, "MA-SYB": true,
"MA-TAI": true, "MA-TAO": true, "MA-TAR": true, "MA-TAT": true, "MA-TAZ": true,
"MA-TET": true, "MA-TIZ": true, "MA-TNG": true, "MA-TNT": true, "MA-ZAG": true,
"MC-CL": true, "MC-CO": true, "MC-FO": true, "MC-GA": true, "MC-JE": true,
"MC-LA": true, "MC-MA": true, "MC-MC": true, "MC-MG": true, "MC-MO": true,
"MC-MU": true, "MC-PH": true, "MC-SD": true, "MC-SO": true, "MC-SP": true,
"MC-SR": true, "MC-VR": true, "MD-AN": true, "MD-BA": true, "MD-BD": true,
"MD-BR": true, "MD-BS": true, "MD-CA": true, "MD-CL": true, "MD-CM": true,
"MD-CR": true, "MD-CS": true, "MD-CT": true, "MD-CU": true, "MD-DO": true,
"MD-DR": true, "MD-DU": true, "MD-ED": true, "MD-FA": true, "MD-FL": true,
"MD-GA": true, "MD-GL": true, "MD-HI": true, "MD-IA": true, "MD-LE": true,
"MD-NI": true, "MD-OC": true, "MD-OR": true, "MD-RE": true, "MD-RI": true,
"MD-SD": true, "MD-SI": true, "MD-SN": true, "MD-SO": true, "MD-ST": true,
"MD-SV": true, "MD-TA": true, "MD-TE": true, "MD-UN": true, "ME-01": true,
"ME-02": true, "ME-03": true, "ME-04": true, "ME-05": true, "ME-06": true,
"ME-07": true, "ME-08": true, "ME-09": true, "ME-10": true, "ME-11": true,
"ME-12": true, "ME-13": true, "ME-14": true, "ME-15": true, "ME-16": true,
"ME-17": true, "ME-18": true, "ME-19": true, "ME-20": true, "ME-21": true, "ME-24": true,
"MG-A": true, "MG-D": true, "MG-F": true, "MG-M": true, "MG-T": true,
"MG-U": true, "MH-ALK": true, "MH-ALL": true, "MH-ARN": true, "MH-AUR": true,
"MH-EBO": true, "MH-ENI": true, "MH-JAB": true, "MH-JAL": true, "MH-KIL": true,
"MH-KWA": true, "MH-L": true, "MH-LAE": true, "MH-LIB": true, "MH-LIK": true,
"MH-MAJ": true, "MH-MAL": true, "MH-MEJ": true, "MH-MIL": true, "MH-NMK": true,
"MH-NMU": true, "MH-RON": true, "MH-T": true, "MH-UJA": true, "MH-UTI": true,
"MH-WTJ": true, "MH-WTN": true, "MK-101": true, "MK-102": true, "MK-103": true,
"MK-104": true, "MK-105": true,
"MK-106": true, "MK-107": true, "MK-108": true, "MK-109": true, "MK-201": true,
"MK-202": true, "MK-205": true, "MK-206": true, "MK-207": true, "MK-208": true,
"MK-209": true, "MK-210": true, "MK-211": true, "MK-301": true, "MK-303": true,
"MK-307": true, "MK-308": true, "MK-310": true, "MK-311": true, "MK-312": true,
"MK-401": true, "MK-402": true, "MK-403": true, "MK-404": true, "MK-405": true,
"MK-406": true, "MK-408": true, "MK-409": true, "MK-410": true, "MK-501": true,
"MK-502": true, "MK-503": true, "MK-505": true, "MK-506": true, "MK-507": true,
"MK-508": true, "MK-509": true, "MK-601": true, "MK-602": true, "MK-604": true,
"MK-605": true, "MK-606": true, "MK-607": true, "MK-608": true, "MK-609": true,
"MK-701": true, "MK-702": true, "MK-703": true, "MK-704": true, "MK-705": true,
"MK-803": true, "MK-804": true, "MK-806": true, "MK-807": true, "MK-809": true,
"MK-810": true, "MK-811": true, "MK-812": true, "MK-813": true, "MK-814": true,
"MK-816": true, "ML-1": true, "ML-2": true, "ML-3": true, "ML-4": true,
"ML-5": true, "ML-6": true, "ML-7": true, "ML-8": true, "ML-BKO": true,
"MM-01": true, "MM-02": true, "MM-03": true, "MM-04": true, "MM-05": true,
"MM-06": true, "MM-07": true, "MM-11": true, "MM-12": true, "MM-13": true,
"MM-14": true, "MM-15": true, "MM-16": true, "MM-17": true, "MM-18": true, "MN-035": true,
"MN-037": true, "MN-039": true, "MN-041": true, "MN-043": true, "MN-046": true,
"MN-047": true, "MN-049": true, "MN-051": true, "MN-053": true, "MN-055": true,
"MN-057": true, "MN-059": true, "MN-061": true, "MN-063": true, "MN-064": true,
"MN-065": true, "MN-067": true, "MN-069": true, "MN-071": true, "MN-073": true,
"MN-1": true, "MR-01": true, "MR-02": true, "MR-03": true, "MR-04": true,
"MR-05": true, "MR-06": true, "MR-07": true, "MR-08": true, "MR-09": true,
"MR-10": true, "MR-11": true, "MR-12": true, "MR-13": true, "MR-NKC": true, "MT-01": true,
"MT-02": true, "MT-03": true, "MT-04": true, "MT-05": true, "MT-06": true,
"MT-07": true, "MT-08": true, "MT-09": true, "MT-10": true, "MT-11": true,
"MT-12": true, "MT-13": true, "MT-14": true, "MT-15": true, "MT-16": true,
"MT-17": true, "MT-18": true, "MT-19": true, "MT-20": true, "MT-21": true,
"MT-22": true, "MT-23": true, "MT-24": true, "MT-25": true, "MT-26": true,
"MT-27": true, "MT-28": true, "MT-29": true, "MT-30": true, "MT-31": true,
"MT-32": true, "MT-33": true, "MT-34": true, "MT-35": true, "MT-36": true,
"MT-37": true, "MT-38": true, "MT-39": true, "MT-40": true, "MT-41": true,
"MT-42": true, "MT-43": true, "MT-44": true, "MT-45": true, "MT-46": true,
"MT-47": true, "MT-48": true, "MT-49": true, "MT-50": true, "MT-51": true,
"MT-52": true, "MT-53": true, "MT-54": true, "MT-55": true, "MT-56": true,
"MT-57": true, "MT-58": true, "MT-59": true, "MT-60": true, "MT-61": true,
"MT-62": true, "MT-63": true, "MT-64": true, "MT-65": true, "MT-66": true,
"MT-67": true, "MT-68": true, "MU-AG": true, "MU-BL": true, "MU-BR": true,
"MU-CC": true, "MU-CU": true, "MU-FL": true, "MU-GP": true, "MU-MO": true,
"MU-PA": true, "MU-PL": true, "MU-PU": true, "MU-PW": true, "MU-QB": true,
"MU-RO": true, "MU-RP": true, "MU-RR": true, "MU-SA": true, "MU-VP": true, "MV-00": true,
"MV-01": true, "MV-02": true, "MV-03": true, "MV-04": true, "MV-05": true,
"MV-07": true, "MV-08": true, "MV-12": true, "MV-13": true, "MV-14": true,
"MV-17": true, "MV-20": true, "MV-23": true, "MV-24": true, "MV-25": true,
"MV-26": true, "MV-27": true, "MV-28": true, "MV-29": true, "MV-CE": true,
"MV-MLE": true, "MV-NC": true, "MV-NO": true, "MV-SC": true, "MV-SU": true,
"MV-UN": true, "MV-US": true, "MW-BA": true, "MW-BL": true, "MW-C": true,
"MW-CK": true, "MW-CR": true, "MW-CT": true, "MW-DE": true, "MW-DO": true,
"MW-KR": true, "MW-KS": true, "MW-LI": true, "MW-LK": true, "MW-MC": true,
"MW-MG": true, "MW-MH": true, "MW-MU": true, "MW-MW": true, "MW-MZ": true,
"MW-N": true, "MW-NB": true, "MW-NE": true, "MW-NI": true, "MW-NK": true,
"MW-NS": true, "MW-NU": true, "MW-PH": true, "MW-RU": true, "MW-S": true,
"MW-SA": true, "MW-TH": true, "MW-ZO": true, "MX-AGU": true, "MX-BCN": true,
"MX-BCS": true, "MX-CAM": true, "MX-CHH": true, "MX-CHP": true, "MX-COA": true,
"MX-COL": true, "MX-CMX": true, "MX-DIF": true, "MX-DUR": true, "MX-GRO": true, "MX-GUA": true,
"MX-HID": true, "MX-JAL": true, "MX-MEX": true, "MX-MIC": true, "MX-MOR": true,
"MX-NAY": true, "MX-NLE": true, "MX-OAX": true, "MX-PUE": true, "MX-QUE": true,
"MX-ROO": true, "MX-SIN": true, "MX-SLP": true, "MX-SON": true, "MX-TAB": true,
"MX-TAM": true, "MX-TLA": true, "MX-VER": true, "MX-YUC": true, "MX-ZAC": true,
"MY-01": true, "MY-02": true, "MY-03": true, "MY-04": true, "MY-05": true,
"MY-06": true, "MY-07": true, "MY-08": true, "MY-09": true, "MY-10": true,
"MY-11": true, "MY-12": true, "MY-13": true, "MY-14": true, "MY-15": true,
"MY-16": true, "MZ-A": true, "MZ-B": true, "MZ-G": true, "MZ-I": true,
"MZ-L": true, "MZ-MPM": true, "MZ-N": true, "MZ-P": true, "MZ-Q": true,
"MZ-S": true, "MZ-T": true, "NA-CA": true, "NA-ER": true, "NA-HA": true,
"NA-KA": true, "NA-KE": true, "NA-KH": true, "NA-KU": true, "NA-KW": true, "NA-OD": true, "NA-OH": true,
"NA-OK": true, "NA-ON": true, "NA-OS": true, "NA-OT": true, "NA-OW": true,
"NE-1": true, "NE-2": true, "NE-3": true, "NE-4": true, "NE-5": true,
"NE-6": true, "NE-7": true, "NE-8": true, "NG-AB": true, "NG-AD": true,
"NG-AK": true, "NG-AN": true, "NG-BA": true, "NG-BE": true, "NG-BO": true,
"NG-BY": true, "NG-CR": true, "NG-DE": true, "NG-EB": true, "NG-ED": true,
"NG-EK": true, "NG-EN": true, "NG-FC": true, "NG-GO": true, "NG-IM": true,
"NG-JI": true, "NG-KD": true, "NG-KE": true, "NG-KN": true, "NG-KO": true,
"NG-KT": true, "NG-KW": true, "NG-LA": true, "NG-NA": true, "NG-NI": true,
"NG-OG": true, "NG-ON": true, "NG-OS": true, "NG-OY": true, "NG-PL": true,
"NG-RI": true, "NG-SO": true, "NG-TA": true, "NG-YO": true, "NG-ZA": true,
"NI-AN": true, "NI-AS": true, "NI-BO": true, "NI-CA": true, "NI-CI": true,
"NI-CO": true, "NI-ES": true, "NI-GR": true, "NI-JI": true, "NI-LE": true,
"NI-MD": true, "NI-MN": true, "NI-MS": true, "NI-MT": true, "NI-NS": true,
"NI-RI": true, "NI-SJ": true, "NL-AW": true, "NL-BQ1": true, "NL-BQ2": true,
"NL-BQ3": true, "NL-CW": true, "NL-DR": true, "NL-FL": true, "NL-FR": true,
"NL-GE": true, "NL-GR": true, "NL-LI": true, "NL-NB": true, "NL-NH": true,
"NL-OV": true, "NL-SX": true, "NL-UT": true, "NL-ZE": true, "NL-ZH": true,
"NO-03": true, "NO-11": true, "NO-15": true, "NO-16": true, "NO-17": true,
"NO-18": true, "NO-21": true, "NO-30": true, "NO-34": true, "NO-38": true,
"NO-42": true, "NO-46": true, "NO-50": true, "NO-54": true,
"NO-22": true, "NP-1": true, "NP-2": true, "NP-3": true, "NP-4": true,
"NP-5": true, "NP-BA": true, "NP-BH": true, "NP-DH": true, "NP-GA": true,
"NP-JA": true, "NP-KA": true, "NP-KO": true, "NP-LU": true, "NP-MA": true,
"NP-ME": true, "NP-NA": true, "NP-RA": true, "NP-SA": true, "NP-SE": true,
"NR-01": true, "NR-02": true, "NR-03": true, "NR-04": true, "NR-05": true,
"NR-06": true, "NR-07": true, "NR-08": true, "NR-09": true, "NR-10": true,
"NR-11": true, "NR-12": true, "NR-13": true, "NR-14": true, "NZ-AUK": true,
"NZ-BOP": true, "NZ-CAN": true, "NZ-CIT": true, "NZ-GIS": true, "NZ-HKB": true,
"NZ-MBH": true, "NZ-MWT": true, "NZ-N": true, "NZ-NSN": true, "NZ-NTL": true,
"NZ-OTA": true, "NZ-S": true, "NZ-STL": true, "NZ-TAS": true, "NZ-TKI": true,
"NZ-WGN": true, "NZ-WKO": true, "NZ-WTC": true, "OM-BA": true, "OM-BS": true, "OM-BU": true, "OM-BJ": true,
"OM-DA": true, "OM-MA": true, "OM-MU": true, "OM-SH": true, "OM-SJ": true, "OM-SS": true, "OM-WU": true,
"OM-ZA": true, "OM-ZU": true, "PA-1": true, "PA-2": true, "PA-3": true,
"PA-4": true, "PA-5": true, "PA-6": true, "PA-7": true, "PA-8": true,
"PA-9": true, "PA-EM": true, "PA-KY": true, "PA-NB": true, "PE-AMA": true,
"PE-ANC": true, "PE-APU": true, "PE-ARE": true, "PE-AYA": true, "PE-CAJ": true,
"PE-CAL": true, "PE-CUS": true, "PE-HUC": true, "PE-HUV": true, "PE-ICA": true,
"PE-JUN": true, "PE-LAL": true, "PE-LAM": true, "PE-LIM": true, "PE-LMA": true,
"PE-LOR": true, "PE-MDD": true, "PE-MOQ": true, "PE-PAS": true, "PE-PIU": true,
"PE-PUN": true, "PE-SAM": true, "PE-TAC": true, "PE-TUM": true, "PE-UCA": true,
"PG-CPK": true, "PG-CPM": true, "PG-EBR": true, "PG-EHG": true, "PG-EPW": true,
"PG-ESW": true, "PG-GPK": true, "PG-MBA": true, "PG-MPL": true, "PG-MPM": true,
"PG-MRL": true, "PG-NCD": true, "PG-NIK": true, "PG-NPP": true, "PG-NSB": true,
"PG-SAN": true, "PG-SHM": true, "PG-WBK": true, "PG-WHM": true, "PG-WPD": true,
"PH-00": true, "PH-01": true, "PH-02": true, "PH-03": true, "PH-05": true,
"PH-06": true, "PH-07": true, "PH-08": true, "PH-09": true, "PH-10": true,
"PH-11": true, "PH-12": true, "PH-13": true, "PH-14": true, "PH-15": true,
"PH-40": true, "PH-41": true, "PH-ABR": true, "PH-AGN": true, "PH-AGS": true,
"PH-AKL": true, "PH-ALB": true, "PH-ANT": true, "PH-APA": true, "PH-AUR": true,
"PH-BAN": true, "PH-BAS": true, "PH-BEN": true, "PH-BIL": true, "PH-BOH": true,
"PH-BTG": true, "PH-BTN": true, "PH-BUK": true, "PH-BUL": true, "PH-CAG": true,
"PH-CAM": true, "PH-CAN": true, "PH-CAP": true, "PH-CAS": true, "PH-CAT": true,
"PH-CAV": true, "PH-CEB": true, "PH-COM": true, "PH-DAO": true, "PH-DAS": true,
"PH-DAV": true, "PH-DIN": true, "PH-EAS": true, "PH-GUI": true, "PH-IFU": true,
"PH-ILI": true, "PH-ILN": true, "PH-ILS": true, "PH-ISA": true, "PH-KAL": true,
"PH-LAG": true, "PH-LAN": true, "PH-LAS": true, "PH-LEY": true, "PH-LUN": true,
"PH-MAD": true, "PH-MAG": true, "PH-MAS": true, "PH-MDC": true, "PH-MDR": true,
"PH-MOU": true, "PH-MSC": true, "PH-MSR": true, "PH-NCO": true, "PH-NEC": true,
"PH-NER": true, "PH-NSA": true, "PH-NUE": true, "PH-NUV": true, "PH-PAM": true,
"PH-PAN": true, "PH-PLW": true, "PH-QUE": true, "PH-QUI": true, "PH-RIZ": true,
"PH-ROM": true, "PH-SAR": true, "PH-SCO": true, "PH-SIG": true, "PH-SLE": true,
"PH-SLU": true, "PH-SOR": true, "PH-SUK": true, "PH-SUN": true, "PH-SUR": true,
"PH-TAR": true, "PH-TAW": true, "PH-WSA": true, "PH-ZAN": true, "PH-ZAS": true,
"PH-ZMB": true, "PH-ZSI": true, "PK-BA": true, "PK-GB": true, "PK-IS": true,
"PK-JK": true, "PK-KP": true, "PK-PB": true, "PK-SD": true, "PK-TA": true,
"PL-02": true, "PL-04": true, "PL-06": true, "PL-08": true, "PL-10": true,
"PL-12": true, "PL-14": true, "PL-16": true, "PL-18": true, "PL-20": true,
"PL-22": true, "PL-24": true, "PL-26": true, "PL-28": true, "PL-30": true, "PL-32": true,
"PS-BTH": true, "PS-DEB": true, "PS-GZA": true, "PS-HBN": true,
"PS-JEM": true, "PS-JEN": true, "PS-JRH": true, "PS-KYS": true, "PS-NBS": true,
"PS-NGZ": true, "PS-QQA": true, "PS-RBH": true, "PS-RFH": true, "PS-SLT": true,
"PS-TBS": true, "PS-TKM": true, "PT-01": true, "PT-02": true, "PT-03": true,
"PT-04": true, "PT-05": true, "PT-06": true, "PT-07": true, "PT-08": true,
"PT-09": true, "PT-10": true, "PT-11": true, "PT-12": true, "PT-13": true,
"PT-14": true, "PT-15": true, "PT-16": true, "PT-17": true, "PT-18": true,
"PT-20": true, "PT-30": true, "PW-002": true, "PW-004": true, "PW-010": true,
"PW-050": true, "PW-100": true, "PW-150": true, "PW-212": true, "PW-214": true,
"PW-218": true, "PW-222": true, "PW-224": true, "PW-226": true, "PW-227": true,
"PW-228": true, "PW-350": true, "PW-370": true, "PY-1": true, "PY-10": true,
"PY-11": true, "PY-12": true, "PY-13": true, "PY-14": true, "PY-15": true,
"PY-16": true, "PY-19": true, "PY-2": true, "PY-3": true, "PY-4": true,
"PY-5": true, "PY-6": true, "PY-7": true, "PY-8": true, "PY-9": true,
"PY-ASU": true, "QA-DA": true, "QA-KH": true, "QA-MS": true, "QA-RA": true,
"QA-US": true, "QA-WA": true, "QA-ZA": true, "RO-AB": true, "RO-AG": true,
"RO-AR": true, "RO-B": true, "RO-BC": true, "RO-BH": true, "RO-BN": true,
"RO-BR": true, "RO-BT": true, "RO-BV": true, "RO-BZ": true, "RO-CJ": true,
"RO-CL": true, "RO-CS": true, "RO-CT": true, "RO-CV": true, "RO-DB": true,
"RO-DJ": true, "RO-GJ": true, "RO-GL": true, "RO-GR": true, "RO-HD": true,
"RO-HR": true, "RO-IF": true, "RO-IL": true, "RO-IS": true, "RO-MH": true,
"RO-MM": true, "RO-MS": true, "RO-NT": true, "RO-OT": true, "RO-PH": true,
"RO-SB": true, "RO-SJ": true, "RO-SM": true, "RO-SV": true, "RO-TL": true,
"RO-TM": true, "RO-TR": true, "RO-VL": true, "RO-VN": true, "RO-VS": true,
"RS-00": true, "RS-01": true, "RS-02": true, "RS-03": true, "RS-04": true,
"RS-05": true, "RS-06": true, "RS-07": true, "RS-08": true, "RS-09": true,
"RS-10": true, "RS-11": true, "RS-12": true, "RS-13": true, "RS-14": true,
"RS-15": true, "RS-16": true, "RS-17": true, "RS-18": true, "RS-19": true,
"RS-20": true, "RS-21": true, "RS-22": true, "RS-23": true, "RS-24": true,
"RS-25": true, "RS-26": true, "RS-27": true, "RS-28": true, "RS-29": true,
"RS-KM": true, "RS-VO": true, "RU-AD": true, "RU-AL": true, "RU-ALT": true,
"RU-AMU": true, "RU-ARK": true, "RU-AST": true, "RU-BA": true, "RU-BEL": true,
"RU-BRY": true, "RU-BU": true, "RU-CE": true, "RU-CHE": true, "RU-CHU": true,
"RU-CU": true, "RU-DA": true, "RU-IN": true, "RU-IRK": true, "RU-IVA": true,
"RU-KAM": true, "RU-KB": true, "RU-KC": true, "RU-KDA": true, "RU-KEM": true,
"RU-KGD": true, "RU-KGN": true, "RU-KHA": true, "RU-KHM": true, "RU-KIR": true,
"RU-KK": true, "RU-KL": true, "RU-KLU": true, "RU-KO": true, "RU-KOS": true,
"RU-KR": true, "RU-KRS": true, "RU-KYA": true, "RU-LEN": true, "RU-LIP": true,
"RU-MAG": true, "RU-ME": true, "RU-MO": true, "RU-MOS": true, "RU-MOW": true,
"RU-MUR": true, "RU-NEN": true, "RU-NGR": true, "RU-NIZ": true, "RU-NVS": true,
"RU-OMS": true, "RU-ORE": true, "RU-ORL": true, "RU-PER": true, "RU-PNZ": true,
"RU-PRI": true, "RU-PSK": true, "RU-ROS": true, "RU-RYA": true, "RU-SA": true,
"RU-SAK": true, "RU-SAM": true, "RU-SAR": true, "RU-SE": true, "RU-SMO": true,
"RU-SPE": true, "RU-STA": true, "RU-SVE": true, "RU-TA": true, "RU-TAM": true,
"RU-TOM": true, "RU-TUL": true, "RU-TVE": true, "RU-TY": true, "RU-TYU": true,
"RU-UD": true, "RU-ULY": true, "RU-VGG": true, "RU-VLA": true, "RU-VLG": true,
"RU-VOR": true, "RU-YAN": true, "RU-YAR": true, "RU-YEV": true, "RU-ZAB": true,
"RW-01": true, "RW-02": true, "RW-03": true, "RW-04": true, "RW-05": true,
"SA-01": true, "SA-02": true, "SA-03": true, "SA-04": true, "SA-05": true,
"SA-06": true, "SA-07": true, "SA-08": true, "SA-09": true, "SA-10": true,
"SA-11": true, "SA-12": true, "SA-14": true, "SB-CE": true, "SB-CH": true,
"SB-CT": true, "SB-GU": true, "SB-IS": true, "SB-MK": true, "SB-ML": true,
"SB-RB": true, "SB-TE": true, "SB-WE": true, "SC-01": true, "SC-02": true,
"SC-03": true, "SC-04": true, "SC-05": true, "SC-06": true, "SC-07": true,
"SC-08": true, "SC-09": true, "SC-10": true, "SC-11": true, "SC-12": true,
"SC-13": true, "SC-14": true, "SC-15": true, "SC-16": true, "SC-17": true,
"SC-18": true, "SC-19": true, "SC-20": true, "SC-21": true, "SC-22": true,
"SC-23": true, "SC-24": true, "SC-25": true, "SD-DC": true, "SD-DE": true,
"SD-DN": true, "SD-DS": true, "SD-DW": true, "SD-GD": true, "SD-GK": true, "SD-GZ": true,
"SD-KA": true, "SD-KH": true, "SD-KN": true, "SD-KS": true, "SD-NB": true,
"SD-NO": true, "SD-NR": true, "SD-NW": true, "SD-RS": true, "SD-SI": true,
"SE-AB": true, "SE-AC": true, "SE-BD": true, "SE-C": true, "SE-D": true,
"SE-E": true, "SE-F": true, "SE-G": true, "SE-H": true, "SE-I": true,
"SE-K": true, "SE-M": true, "SE-N": true, "SE-O": true, "SE-S": true,
"SE-T": true, "SE-U": true, "SE-W": true, "SE-X": true, "SE-Y": true,
"SE-Z": true, "SG-01": true, "SG-02": true, "SG-03": true, "SG-04": true,
"SG-05": true, "SH-AC": true, "SH-HL": true, "SH-TA": true, "SI-001": true,
"SI-002": true, "SI-003": true, "SI-004": true, "SI-005": true, "SI-006": true,
"SI-007": true, "SI-008": true, "SI-009": true, "SI-010": true, "SI-011": true,
"SI-012": true, "SI-013": true, "SI-014": true, "SI-015": true, "SI-016": true,
"SI-017": true, "SI-018": true, "SI-019": true, "SI-020": true, "SI-021": true,
"SI-022": true, "SI-023": true, "SI-024": true, "SI-025": true, "SI-026": true,
"SI-027": true, "SI-028": true, "SI-029": true, "SI-030": true, "SI-031": true,
"SI-032": true, "SI-033": true, "SI-034": true, "SI-035": true, "SI-036": true,
"SI-037": true, "SI-038": true, "SI-039": true, "SI-040": true, "SI-041": true,
"SI-042": true, "SI-043": true, "SI-044": true, "SI-045": true, "SI-046": true,
"SI-047": true, "SI-048": true, "SI-049": true, "SI-050": true, "SI-051": true,
"SI-052": true, "SI-053": true, "SI-054": true, "SI-055": true, "SI-056": true,
"SI-057": true, "SI-058": true, "SI-059": true, "SI-060": true, "SI-061": true,
"SI-062": true, "SI-063": true, "SI-064": true, "SI-065": true, "SI-066": true,
"SI-067": true, "SI-068": true, "SI-069": true, "SI-070": true, "SI-071": true,
"SI-072": true, "SI-073": true, "SI-074": true, "SI-075": true, "SI-076": true,
"SI-077": true, "SI-078": true, "SI-079": true, "SI-080": true, "SI-081": true,
"SI-082": true, "SI-083": true, "SI-084": true, "SI-085": true, "SI-086": true,
"SI-087": true, "SI-088": true, "SI-089": true, "SI-090": true, "SI-091": true,
"SI-092": true, "SI-093": true, "SI-094": true, "SI-095": true, "SI-096": true,
"SI-097": true, "SI-098": true, "SI-099": true, "SI-100": true, "SI-101": true,
"SI-102": true, "SI-103": true, "SI-104": true, "SI-105": true, "SI-106": true,
"SI-107": true, "SI-108": true, "SI-109": true, "SI-110": true, "SI-111": true,
"SI-112": true, "SI-113": true, "SI-114": true, "SI-115": true, "SI-116": true,
"SI-117": true, "SI-118": true, "SI-119": true, "SI-120": true, "SI-121": true,
"SI-122": true, "SI-123": true, "SI-124": true, "SI-125": true, "SI-126": true,
"SI-127": true, "SI-128": true, "SI-129": true, "SI-130": true, "SI-131": true,
"SI-132": true, "SI-133": true, "SI-134": true, "SI-135": true, "SI-136": true,
"SI-137": true, "SI-138": true, "SI-139": true, "SI-140": true, "SI-141": true,
"SI-142": true, "SI-143": true, "SI-144": true, "SI-146": true, "SI-147": true,
"SI-148": true, "SI-149": true, "SI-150": true, "SI-151": true, "SI-152": true,
"SI-153": true, "SI-154": true, "SI-155": true, "SI-156": true, "SI-157": true,
"SI-158": true, "SI-159": true, "SI-160": true, "SI-161": true, "SI-162": true,
"SI-163": true, "SI-164": true, "SI-165": true, "SI-166": true, "SI-167": true,
"SI-168": true, "SI-169": true, "SI-170": true, "SI-171": true, "SI-172": true,
"SI-173": true, "SI-174": true, "SI-175": true, "SI-176": true, "SI-177": true,
"SI-178": true, "SI-179": true, "SI-180": true, "SI-181": true, "SI-182": true,
"SI-183": true, "SI-184": true, "SI-185": true, "SI-186": true, "SI-187": true,
"SI-188": true, "SI-189": true, "SI-190": true, "SI-191": true, "SI-192": true,
"SI-193": true, "SI-194": true, "SI-195": true, "SI-196": true, "SI-197": true,
"SI-198": true, "SI-199": true, "SI-200": true, "SI-201": true, "SI-202": true,
"SI-203": true, "SI-204": true, "SI-205": true, "SI-206": true, "SI-207": true,
"SI-208": true, "SI-209": true, "SI-210": true, "SI-211": true, "SI-212": true, "SI-213": true, "SK-BC": true,
"SK-BL": true, "SK-KI": true, "SK-NI": true, "SK-PV": true, "SK-TA": true,
"SK-TC": true, "SK-ZI": true, "SL-E": true, "SL-N": true, "SL-S": true,
"SL-W": true, "SM-01": true, "SM-02": true, "SM-03": true, "SM-04": true,
"SM-05": true, "SM-06": true, "SM-07": true, "SM-08": true, "SM-09": true,
"SN-DB": true, "SN-DK": true, "SN-FK": true, "SN-KA": true, "SN-KD": true,
"SN-KE": true, "SN-KL": true, "SN-LG": true, "SN-MT": true, "SN-SE": true,
"SN-SL": true, "SN-TC": true, "SN-TH": true, "SN-ZG": true, "SO-AW": true,
"SO-BK": true, "SO-BN": true, "SO-BR": true, "SO-BY": true, "SO-GA": true,
"SO-GE": true, "SO-HI": true, "SO-JD": true, "SO-JH": true, "SO-MU": true,
"SO-NU": true, "SO-SA": true, "SO-SD": true, "SO-SH": true, "SO-SO": true,
"SO-TO": true, "SO-WO": true, "SR-BR": true, "SR-CM": true, "SR-CR": true,
"SR-MA": true, "SR-NI": true, "SR-PM": true, "SR-PR": true, "SR-SA": true,
"SR-SI": true, "SR-WA": true, "SS-BN": true, "SS-BW": true, "SS-EC": true,
"SS-EE8": true, "SS-EE": true, "SS-EW": true, "SS-JG": true, "SS-LK": true, "SS-NU": true,
"SS-UY": true, "SS-WR": true, "ST-01": true, "ST-P": true, "ST-S": true, "SV-AH": true,
"SV-CA": true, "SV-CH": true, "SV-CU": true, "SV-LI": true, "SV-MO": true,
"SV-PA": true, "SV-SA": true, "SV-SM": true, "SV-SO": true, "SV-SS": true,
"SV-SV": true, "SV-UN": true, "SV-US": true, "SY-DI": true, "SY-DR": true,
"SY-DY": true, "SY-HA": true, "SY-HI": true, "SY-HL": true, "SY-HM": true,
"SY-ID": true, "SY-LA": true, "SY-QU": true, "SY-RA": true, "SY-RD": true,
"SY-SU": true, "SY-TA": true, "SZ-HH": true, "SZ-LU": true, "SZ-MA": true,
"SZ-SH": true, "TD-BA": true, "TD-BG": true, "TD-BO": true, "TD-CB": true,
"TD-EN": true, "TD-GR": true, "TD-HL": true, "TD-KA": true, "TD-LC": true,
"TD-LO": true, "TD-LR": true, "TD-MA": true, "TD-MC": true, "TD-ME": true,
"TD-MO": true, "TD-ND": true, "TD-OD": true, "TD-SA": true, "TD-SI": true,
"TD-TA": true, "TD-TI": true, "TD-WF": true, "TG-C": true, "TG-K": true,
"TG-M": true, "TG-P": true, "TG-S": true, "TH-10": true, "TH-11": true,
"TH-12": true, "TH-13": true, "TH-14": true, "TH-15": true, "TH-16": true,
"TH-17": true, "TH-18": true, "TH-19": true, "TH-20": true, "TH-21": true,
"TH-22": true, "TH-23": true, "TH-24": true, "TH-25": true, "TH-26": true,
"TH-27": true, "TH-30": true, "TH-31": true, "TH-32": true, "TH-33": true,
"TH-34": true, "TH-35": true, "TH-36": true, "TH-37": true, "TH-38": true, "TH-39": true,
"TH-40": true, "TH-41": true, "TH-42": true, "TH-43": true, "TH-44": true,
"TH-45": true, "TH-46": true, "TH-47": true, "TH-48": true, "TH-49": true,
"TH-50": true, "TH-51": true, "TH-52": true, "TH-53": true, "TH-54": true,
"TH-55": true, "TH-56": true, "TH-57": true, "TH-58": true, "TH-60": true,
"TH-61": true, "TH-62": true, "TH-63": true, "TH-64": true, "TH-65": true,
"TH-66": true, "TH-67": true, "TH-70": true, "TH-71": true, "TH-72": true,
"TH-73": true, "TH-74": true, "TH-75": true, "TH-76": true, "TH-77": true,
"TH-80": true, "TH-81": true, "TH-82": true, "TH-83": true, "TH-84": true,
"TH-85": true, "TH-86": true, "TH-90": true, "TH-91": true, "TH-92": true,
"TH-93": true, "TH-94": true, "TH-95": true, "TH-96": true, "TH-S": true,
"TJ-GB": true, "TJ-KT": true, "TJ-SU": true, "TJ-DU": true, "TJ-RA": true, "TL-AL": true, "TL-AN": true,
"TL-BA": true, "TL-BO": true, "TL-CO": true, "TL-DI": true, "TL-ER": true,
"TL-LA": true, "TL-LI": true, "TL-MF": true, "TL-MT": true, "TL-OE": true,
"TL-VI": true, "TM-A": true, "TM-B": true, "TM-D": true, "TM-L": true,
"TM-M": true, "TM-S": true, "TN-11": true, "TN-12": true, "TN-13": true,
"TN-14": true, "TN-21": true, "TN-22": true, "TN-23": true, "TN-31": true,
"TN-32": true, "TN-33": true, "TN-34": true, "TN-41": true, "TN-42": true,
"TN-43": true, "TN-51": true, "TN-52": true, "TN-53": true, "TN-61": true,
"TN-71": true, "TN-72": true, "TN-73": true, "TN-81": true, "TN-82": true,
"TN-83": true, "TO-01": true, "TO-02": true, "TO-03": true, "TO-04": true,
"TO-05": true, "TR-01": true, "TR-02": true, "TR-03": true, "TR-04": true,
"TR-05": true, "TR-06": true, "TR-07": true, "TR-08": true, "TR-09": true,
"TR-10": true, "TR-11": true, "TR-12": true, "TR-13": true, "TR-14": true,
"TR-15": true, "TR-16": true, "TR-17": true, "TR-18": true, "TR-19": true,
"TR-20": true, "TR-21": true, "TR-22": true, "TR-23": true, "TR-24": true,
"TR-25": true, "TR-26": true, "TR-27": true, "TR-28": true, "TR-29": true,
"TR-30": true, "TR-31": true, "TR-32": true, "TR-33": true, "TR-34": true,
"TR-35": true, "TR-36": true, "TR-37": true, "TR-38": true, "TR-39": true,
"TR-40": true, "TR-41": true, "TR-42": true, "TR-43": true, "TR-44": true,
"TR-45": true, "TR-46": true, "TR-47": true, "TR-48": true, "TR-49": true,
"TR-50": true, "TR-51": true, "TR-52": true, "TR-53": true, "TR-54": true,
"TR-55": true, "TR-56": true, "TR-57": true, "TR-58": true, "TR-59": true,
"TR-60": true, "TR-61": true, "TR-62": true, "TR-63": true, "TR-64": true,
"TR-65": true, "TR-66": true, "TR-67": true, "TR-68": true, "TR-69": true,
"TR-70": true, "TR-71": true, "TR-72": true, "TR-73": true, "TR-74": true,
"TR-75": true, "TR-76": true, "TR-77": true, "TR-78": true, "TR-79": true,
"TR-80": true, "TR-81": true, "TT-ARI": true, "TT-CHA": true, "TT-CTT": true,
"TT-DMN": true, "TT-ETO": true, "TT-MRC": true, "TT-TOB": true, "TT-PED": true, "TT-POS": true, "TT-PRT": true,
"TT-PTF": true, "TT-RCM": true, "TT-SFO": true, "TT-SGE": true, "TT-SIP": true,
"TT-SJL": true, "TT-TUP": true, "TT-WTO": true, "TV-FUN": true, "TV-NIT": true,
"TV-NKF": true, "TV-NKL": true, "TV-NMA": true, "TV-NMG": true, "TV-NUI": true,
"TV-VAI": true, "TW-CHA": true, "TW-CYI": true, "TW-CYQ": true, "TW-KIN": true, "TW-HSQ": true,
"TW-HSZ": true, "TW-HUA": true, "TW-LIE": true, "TW-ILA": true, "TW-KEE": true, "TW-KHH": true,
"TW-KHQ": true, "TW-MIA": true, "TW-NAN": true, "TW-NWT": true, "TW-PEN": true, "TW-PIF": true,
"TW-TAO": true, "TW-TNN": true, "TW-TNQ": true, "TW-TPE": true, "TW-TPQ": true,
"TW-TTT": true, "TW-TXG": true, "TW-TXQ": true, "TW-YUN": true, "TZ-01": true,
"TZ-02": true, "TZ-03": true, "TZ-04": true, "TZ-05": true, "TZ-06": true,
"TZ-07": true, "TZ-08": true, "TZ-09": true, "TZ-10": true, "TZ-11": true,
"TZ-12": true, "TZ-13": true, "TZ-14": true, "TZ-15": true, "TZ-16": true,
"TZ-17": true, "TZ-18": true, "TZ-19": true, "TZ-20": true, "TZ-21": true,
"TZ-22": true, "TZ-23": true, "TZ-24": true, "TZ-25": true, "TZ-26": true, "TZ-27": true, "TZ-28": true, "TZ-29": true, "TZ-30": true, "TZ-31": true,
"UA-05": true, "UA-07": true, "UA-09": true, "UA-12": true, "UA-14": true,
"UA-18": true, "UA-21": true, "UA-23": true, "UA-26": true, "UA-30": true,
"UA-32": true, "UA-35": true, "UA-40": true, "UA-43": true, "UA-46": true,
"UA-48": true, "UA-51": true, "UA-53": true, "UA-56": true, "UA-59": true,
"UA-61": true, "UA-63": true, "UA-65": true, "UA-68": true, "UA-71": true,
"UA-74": true, "UA-77": true, "UG-101": true, "UG-102": true, "UG-103": true,
"UG-104": true, "UG-105": true, "UG-106": true, "UG-107": true, "UG-108": true,
"UG-109": true, "UG-110": true, "UG-111": true, "UG-112": true, "UG-113": true,
"UG-114": true, "UG-115": true, "UG-116": true, "UG-201": true, "UG-202": true,
"UG-203": true, "UG-204": true, "UG-205": true, "UG-206": true, "UG-207": true,
"UG-208": true, "UG-209": true, "UG-210": true, "UG-211": true, "UG-212": true,
"UG-213": true, "UG-214": true, "UG-215": true, "UG-216": true, "UG-217": true,
"UG-218": true, "UG-219": true, "UG-220": true, "UG-221": true, "UG-222": true,
"UG-223": true, "UG-224": true, "UG-301": true, "UG-302": true, "UG-303": true,
"UG-304": true, "UG-305": true, "UG-306": true, "UG-307": true, "UG-308": true,
"UG-309": true, "UG-310": true, "UG-311": true, "UG-312": true, "UG-313": true,
"UG-314": true, "UG-315": true, "UG-316": true, "UG-317": true, "UG-318": true,
"UG-319": true, "UG-320": true, "UG-321": true, "UG-401": true, "UG-402": true,
"UG-403": true, "UG-404": true, "UG-405": true, "UG-406": true, "UG-407": true,
"UG-408": true, "UG-409": true, "UG-410": true, "UG-411": true, "UG-412": true,
"UG-413": true, "UG-414": true, "UG-415": true, "UG-416": true, "UG-417": true,
"UG-418": true, "UG-419": true, "UG-C": true, "UG-E": true, "UG-N": true,
"UG-W": true, "UG-322": true, "UG-323": true, "UG-420": true, "UG-117": true,
"UG-118": true, "UG-225": true, "UG-120": true, "UG-226": true,
"UG-121": true, "UG-122": true, "UG-227": true, "UG-421": true,
"UG-325": true, "UG-228": true, "UG-123": true, "UG-422": true,
"UG-326": true, "UG-229": true, "UG-124": true, "UG-423": true,
"UG-230": true, "UG-327": true, "UG-424": true, "UG-328": true,
"UG-425": true, "UG-426": true, "UG-330": true,
"UM-67": true, "UM-71": true, "UM-76": true, "UM-79": true,
"UM-81": true, "UM-84": true, "UM-86": true, "UM-89": true, "UM-95": true,
"US-AK": true, "US-AL": true, "US-AR": true, "US-AS": true, "US-AZ": true,
"US-CA": true, "US-CO": true, "US-CT": true, "US-DC": true, "US-DE": true,
"US-FL": true, "US-GA": true, "US-GU": true, "US-HI": true, "US-IA": true,
"US-ID": true, "US-IL": true, "US-IN": true, "US-KS": true, "US-KY": true,
"US-LA": true, "US-MA": true, "US-MD": true, "US-ME": true, "US-MI": true,
"US-MN": true, "US-MO": true, "US-MP": true, "US-MS": true, "US-MT": true,
"US-NC": true, "US-ND": true, "US-NE": true, "US-NH": true, "US-NJ": true,
"US-NM": true, "US-NV": true, "US-NY": true, "US-OH": true, "US-OK": true,
"US-OR": true, "US-PA": true, "US-PR": true, "US-RI": true, "US-SC": true,
"US-SD": true, "US-TN": true, "US-TX": true, "US-UM": true, "US-UT": true,
"US-VA": true, "US-VI": true, "US-VT": true, "US-WA": true, "US-WI": true,
"US-WV": true, "US-WY": true, "UY-AR": true, "UY-CA": true, "UY-CL": true,
"UY-CO": true, "UY-DU": true, "UY-FD": true, "UY-FS": true, "UY-LA": true,
"UY-MA": true, "UY-MO": true, "UY-PA": true, "UY-RN": true, "UY-RO": true,
"UY-RV": true, "UY-SA": true, "UY-SJ": true, "UY-SO": true, "UY-TA": true,
"UY-TT": true, "UZ-AN": true, "UZ-BU": true, "UZ-FA": true, "UZ-JI": true,
"UZ-NG": true, "UZ-NW": true, "UZ-QA": true, "UZ-QR": true, "UZ-SA": true,
"UZ-SI": true, "UZ-SU": true, "UZ-TK": true, "UZ-TO": true, "UZ-XO": true,
"VC-01": true, "VC-02": true, "VC-03": true, "VC-04": true, "VC-05": true,
"VC-06": true, "VE-A": true, "VE-B": true, "VE-C": true, "VE-D": true,
"VE-E": true, "VE-F": true, "VE-G": true, "VE-H": true, "VE-I": true,
"VE-J": true, "VE-K": true, "VE-L": true, "VE-M": true, "VE-N": true,
"VE-O": true, "VE-P": true, "VE-R": true, "VE-S": true, "VE-T": true,
"VE-U": true, "VE-V": true, "VE-W": true, "VE-X": true, "VE-Y": true,
"VE-Z": true, "VN-01": true, "VN-02": true, "VN-03": true, "VN-04": true,
"VN-05": true, "VN-06": true, "VN-07": true, "VN-09": true, "VN-13": true,
"VN-14": true, "VN-15": true, "VN-18": true, "VN-20": true, "VN-21": true,
"VN-22": true, "VN-23": true, "VN-24": true, "VN-25": true, "VN-26": true,
"VN-27": true, "VN-28": true, "VN-29": true, "VN-30": true, "VN-31": true,
"VN-32": true, "VN-33": true, "VN-34": true, "VN-35": true, "VN-36": true,
"VN-37": true, "VN-39": true, "VN-40": true, "VN-41": true, "VN-43": true,
"VN-44": true, "VN-45": true, "VN-46": true, "VN-47": true, "VN-49": true,
"VN-50": true, "VN-51": true, "VN-52": true, "VN-53": true, "VN-54": true,
"VN-55": true, "VN-56": true, "VN-57": true, "VN-58": true, "VN-59": true,
"VN-61": true, "VN-63": true, "VN-66": true, "VN-67": true, "VN-68": true,
"VN-69": true, "VN-70": true, "VN-71": true, "VN-72": true, "VN-73": true,
"VN-CT": true, "VN-DN": true, "VN-HN": true, "VN-HP": true, "VN-SG": true,
"VU-MAP": true, "VU-PAM": true, "VU-SAM": true, "VU-SEE": true, "VU-TAE": true,
"VU-TOB": true, "WF-SG": true, "WF-UV": true, "WS-AA": true, "WS-AL": true, "WS-AT": true, "WS-FA": true,
"WS-GE": true, "WS-GI": true, "WS-PA": true, "WS-SA": true, "WS-TU": true,
"WS-VF": true, "WS-VS": true, "YE-AB": true, "YE-AD": true, "YE-AM": true,
"YE-BA": true, "YE-DA": true, "YE-DH": true, "YE-HD": true, "YE-HJ": true, "YE-HU": true,
"YE-IB": true, "YE-JA": true, "YE-LA": true, "YE-MA": true, "YE-MR": true,
"YE-MU": true, "YE-MW": true, "YE-RA": true, "YE-SA": true, "YE-SD": true, "YE-SH": true,
"YE-SN": true, "YE-TA": true, "ZA-EC": true, "ZA-FS": true, "ZA-GP": true,
"ZA-LP": true, "ZA-MP": true, "ZA-NC": true, "ZA-NW": true, "ZA-WC": true,
"ZA-ZN": true, "ZA-KZN": true, "ZM-01": true, "ZM-02": true, "ZM-03": true, "ZM-04": true,
"ZM-05": true, "ZM-06": true, "ZM-07": true, "ZM-08": true, "ZM-09": true, "ZM-10": true,
"ZW-BU": true, "ZW-HA": true, "ZW-MA": true, "ZW-MC": true, "ZW-ME": true,
"ZW-MI": true, "ZW-MN": true, "ZW-MS": true, "ZW-MV": true, "ZW-MW": true,
}
================================================
FILE: gossh/gin/validator/currency_codes.go
================================================
package validator
var iso4217 = map[string]bool{
"AFN": true, "EUR": true, "ALL": true, "DZD": true, "USD": true,
"AOA": true, "XCD": true, "ARS": true, "AMD": true, "AWG": true,
"AUD": true, "AZN": true, "BSD": true, "BHD": true, "BDT": true,
"BBD": true, "BYN": true, "BZD": true, "XOF": true, "BMD": true,
"INR": true, "BTN": true, "BOB": true, "BOV": true, "BAM": true,
"BWP": true, "NOK": true, "BRL": true, "BND": true, "BGN": true,
"BIF": true, "CVE": true, "KHR": true, "XAF": true, "CAD": true,
"KYD": true, "CLP": true, "CLF": true, "CNY": true, "COP": true,
"COU": true, "KMF": true, "CDF": true, "NZD": true, "CRC": true,
"HRK": true, "CUP": true, "CUC": true, "ANG": true, "CZK": true,
"DKK": true, "DJF": true, "DOP": true, "EGP": true, "SVC": true,
"ERN": true, "SZL": true, "ETB": true, "FKP": true, "FJD": true,
"XPF": true, "GMD": true, "GEL": true, "GHS": true, "GIP": true,
"GTQ": true, "GBP": true, "GNF": true, "GYD": true, "HTG": true,
"HNL": true, "HKD": true, "HUF": true, "ISK": true, "IDR": true,
"XDR": true, "IRR": true, "IQD": true, "ILS": true, "JMD": true,
"JPY": true, "JOD": true, "KZT": true, "KES": true, "KPW": true,
"KRW": true, "KWD": true, "KGS": true, "LAK": true, "LBP": true,
"LSL": true, "ZAR": true, "LRD": true, "LYD": true, "CHF": true,
"MOP": true, "MKD": true, "MGA": true, "MWK": true, "MYR": true,
"MVR": true, "MRU": true, "MUR": true, "XUA": true, "MXN": true,
"MXV": true, "MDL": true, "MNT": true, "MAD": true, "MZN": true,
"MMK": true, "NAD": true, "NPR": true, "NIO": true, "NGN": true,
"OMR": true, "PKR": true, "PAB": true, "PGK": true, "PYG": true,
"PEN": true, "PHP": true, "PLN": true, "QAR": true, "RON": true,
"RUB": true, "RWF": true, "SHP": true, "WST": true, "STN": true,
"SAR": true, "RSD": true, "SCR": true, "SLL": true, "SGD": true,
"XSU": true, "SBD": true, "SOS": true, "SSP": true, "LKR": true,
"SDG": true, "SRD": true, "SEK": true, "CHE": true, "CHW": true,
"SYP": true, "TWD": true, "TJS": true, "TZS": true, "THB": true,
"TOP": true, "TTD": true, "TND": true, "TRY": true, "TMT": true,
"UGX": true, "UAH": true, "AED": true, "USN": true, "UYU": true,
"UYI": true, "UYW": true, "UZS": true, "VUV": true, "VES": true,
"VND": true, "YER": true, "ZMW": true, "ZWL": true, "XBA": true,
"XBB": true, "XBC": true, "XBD": true, "XTS": true, "XXX": true,
"XAU": true, "XPD": true, "XPT": true, "XAG": true,
}
var iso4217_numeric = map[int]bool{
8: true, 12: true, 32: true, 36: true, 44: true,
48: true, 50: true, 51: true, 52: true, 60: true,
64: true, 68: true, 72: true, 84: true, 90: true,
96: true, 104: true, 108: true, 116: true, 124: true,
132: true, 136: true, 144: true, 152: true, 156: true,
170: true, 174: true, 188: true, 191: true, 192: true,
203: true, 208: true, 214: true, 222: true, 230: true,
232: true, 238: true, 242: true, 262: true, 270: true,
292: true, 320: true, 324: true, 328: true, 332: true,
340: true, 344: true, 348: true, 352: true, 356: true,
360: true, 364: true, 368: true, 376: true, 388: true,
392: true, 398: true, 400: true, 404: true, 408: true,
410: true, 414: true, 417: true, 418: true, 422: true,
426: true, 430: true, 434: true, 446: true, 454: true,
458: true, 462: true, 480: true, 484: true, 496: true,
498: true, 504: true, 512: true, 516: true, 524: true,
532: true, 533: true, 548: true, 554: true, 558: true,
566: true, 578: true, 586: true, 590: true, 598: true,
600: true, 604: true, 608: true, 634: true, 643: true,
646: true, 654: true, 682: true, 690: true, 694: true,
702: true, 704: true, 706: true, 710: true, 728: true,
748: true, 752: true, 756: true, 760: true, 764: true,
776: true, 780: true, 784: true, 788: true, 800: true,
807: true, 818: true, 826: true, 834: true, 840: true,
858: true, 860: true, 882: true, 886: true, 901: true,
927: true, 928: true, 929: true, 930: true, 931: true,
932: true, 933: true, 934: true, 936: true, 938: true,
940: true, 941: true, 943: true, 944: true, 946: true,
947: true, 948: true, 949: true, 950: true, 951: true,
952: true, 953: true, 955: true, 956: true, 957: true,
958: true, 959: true, 960: true, 961: true, 962: true,
963: true, 964: true, 965: true, 967: true, 968: true,
969: true, 970: true, 971: true, 972: true, 973: true,
975: true, 976: true, 977: true, 978: true, 979: true,
980: true, 981: true, 984: true, 985: true, 986: true,
990: true, 994: true, 997: true, 999: true,
}
================================================
FILE: gossh/gin/validator/errors.go
================================================
package validator
import (
"bytes"
"fmt"
"reflect"
"strings"
)
const (
fieldErrMsg = "Key: '%s' Error:Field validation for '%s' failed on the '%s' tag"
)
// ValidationErrorsTranslations is the translation return type
type ValidationErrorsTranslations map[string]string
// InvalidValidationError describes an invalid argument passed to
// `Struct`, `StructExcept`, StructPartial` or `Field`
type InvalidValidationError struct {
Type reflect.Type
}
// Error returns InvalidValidationError message
func (e *InvalidValidationError) Error() string {
if e.Type == nil {
return "validator: (nil)"
}
return "validator: (nil " + e.Type.String() + ")"
}
// ValidationErrors is an array of FieldError's
// for use in custom error messages post validation.
type ValidationErrors []FieldError
// Error is intended for use in development + debugging and not intended to be a production error message.
// It allows ValidationErrors to subscribe to the Error interface.
// All information to create an error message specific to your application is contained within
// the FieldError found within the ValidationErrors array
func (ve ValidationErrors) Error() string {
buff := bytes.NewBufferString("")
for i := 0; i < len(ve); i++ {
buff.WriteString(ve[i].Error())
buff.WriteString("\n")
}
return strings.TrimSpace(buff.String())
}
// FieldError contains all functions to get error details
type FieldError interface {
// Tag returns the validation tag that failed. if the
// validation was an alias, this will return the
// alias name and not the underlying tag that failed.
//
// eg. alias "iscolor": "hexcolor|rgb|rgba|hsl|hsla"
// will return "iscolor"
Tag() string
// ActualTag returns the validation tag that failed, even if an
// alias the actual tag within the alias will be returned.
// If an 'or' validation fails the entire or will be returned.
//
// eg. alias "iscolor": "hexcolor|rgb|rgba|hsl|hsla"
// will return "hexcolor|rgb|rgba|hsl|hsla"
ActualTag() string
// Namespace returns the namespace for the field error, with the tag
// name taking precedence over the field's actual name.
//
// eg. JSON name "User.fname"
//
// See StructNamespace() for a version that returns actual names.
//
// NOTE: this field can be blank when validating a single primitive field
// using validate.Field(...) as there is no way to extract it's name
Namespace() string
// StructNamespace returns the namespace for the field error, with the field's
// actual name.
//
// eq. "User.FirstName" see Namespace for comparison
//
// NOTE: this field can be blank when validating a single primitive field
// using validate.Field(...) as there is no way to extract its name
StructNamespace() string
// Field returns the fields name with the tag name taking precedence over the
// field's actual name.
//
// eq. JSON name "fname"
// see StructField for comparison
Field() string
// StructField returns the field's actual name from the struct, when able to determine.
//
// eq. "FirstName"
// see Field for comparison
StructField() string
// Value returns the actual field's value in case needed for creating the error
// message
Value() any
// Param returns the param value, in string form for comparison; this will also
// help with generating an error message
Param() string
// Kind returns the Field's reflect Kind
//
// eg. time.Time's kind is a struct
Kind() reflect.Kind
// Type returns the Field's reflect Type
//
// eg. time.Time's type is time.Time
Type() reflect.Type
// Translate returns the FieldError's translated error
// from the provided 'ut.Translator' and registered 'TranslationFunc'
//
// NOTE: if no registered translator can be found it returns the same as
// calling fe.Error()
Translate() string
// Error returns the FieldError's message
Error() string
}
// compile time interface checks
var _ error = new(fieldError)
// fieldError contains a single field's validation error along
// with other properties that may be needed for error message creation
// it complies with the FieldError interface
type fieldError struct {
v *Validate
tag string
actualTag string
ns string
structNs string
fieldLen uint8
structfieldLen uint8
value any
param string
kind reflect.Kind
typ reflect.Type
}
func (fe *fieldError) Translate() string {
return ""
}
// Tag returns the validation tag that failed.
func (fe *fieldError) Tag() string {
return fe.tag
}
// ActualTag returns the validation tag that failed, even if an
// alias the actual tag within the alias will be returned.
func (fe *fieldError) ActualTag() string {
return fe.actualTag
}
// Namespace returns the namespace for the field error, with the tag
// name taking precedence over the field's actual name.
func (fe *fieldError) Namespace() string {
return fe.ns
}
// StructNamespace returns the namespace for the field error, with the field's
// actual name.
func (fe *fieldError) StructNamespace() string {
return fe.structNs
}
// Field returns the field's name with the tag name taking precedence over the
// field's actual name.
func (fe *fieldError) Field() string {
return fe.ns[len(fe.ns)-int(fe.fieldLen):]
// // return fe.field
// fld := fe.ns[len(fe.ns)-int(fe.fieldLen):]
// log.Println("FLD:", fld)
// if len(fld) > 0 && fld[:1] == "." {
// return fld[1:]
// }
// return fld
}
// StructField returns the field's actual name from the struct, when able to determine.
func (fe *fieldError) StructField() string {
// return fe.structField
return fe.structNs[len(fe.structNs)-int(fe.structfieldLen):]
}
// Value returns the actual field's value in case needed for creating the error
// message
func (fe *fieldError) Value() any {
return fe.value
}
// Param returns the param value, in string form for comparison; this will
// also help with generating an error message
func (fe *fieldError) Param() string {
return fe.param
}
// Kind returns the Field's reflect Kind
func (fe *fieldError) Kind() reflect.Kind {
return fe.kind
}
// Type returns the Field's reflect Type
func (fe *fieldError) Type() reflect.Type {
return fe.typ
}
// Error returns the fieldError's error message
func (fe *fieldError) Error() string {
return fmt.Sprintf(fieldErrMsg, fe.ns, fe.Field(), fe.tag)
}
================================================
FILE: gossh/gin/validator/field_level.go
================================================
package validator
import "reflect"
// FieldLevel contains all the information and helper functions
// to validate a field
type FieldLevel interface {
// Top returns the top level struct, if any
Top() reflect.Value
// Parent returns the current fields parent struct, if any or
// the comparison value if called 'VarWithValue'
Parent() reflect.Value
// Field returns current field for validation
Field() reflect.Value
// FieldName returns the field's name with the tag
// name taking precedence over the fields actual name.
FieldName() string
// StructFieldName returns the struct field's name
StructFieldName() string
// Param returns param for validation against current field
Param() string
// GetTag returns the current validations tag name
GetTag() string
// ExtractType gets the actual underlying type of field value.
// It will dive into pointers, customTypes and return you the
// underlying value and it's kind.
ExtractType(field reflect.Value) (value reflect.Value, kind reflect.Kind, nullable bool)
// GetStructFieldOK traverses the parent struct to retrieve a specific field denoted by the provided namespace
// in the param and returns the field, field kind and whether is was successful in retrieving
// the field at all.
//
// NOTE: when not successful ok will be false, this can happen when a nested struct is nil and so the field
// could not be retrieved because it didn't exist.
//
// Deprecated: Use GetStructFieldOK2() instead which also return if the value is nullable.
GetStructFieldOK() (reflect.Value, reflect.Kind, bool)
// GetStructFieldOKAdvanced is the same as GetStructFieldOK except that it accepts the parent struct to start looking for
// the field and namespace allowing more extensibility for validators.
//
// Deprecated: Use GetStructFieldOKAdvanced2() instead which also return if the value is nullable.
GetStructFieldOKAdvanced(val reflect.Value, namespace string) (reflect.Value, reflect.Kind, bool)
// GetStructFieldOK2 traverses the parent struct to retrieve a specific field denoted by the provided namespace
// in the param and returns the field, field kind, if it's a nullable type and whether is was successful in retrieving
// the field at all.
//
// NOTE: when not successful ok will be false, this can happen when a nested struct is nil and so the field
// could not be retrieved because it didn't exist.
GetStructFieldOK2() (reflect.Value, reflect.Kind, bool, bool)
// GetStructFieldOKAdvanced2 is the same as GetStructFieldOK except that it accepts the parent struct to start looking for
// the field and namespace allowing more extensibility for validators.
GetStructFieldOKAdvanced2(val reflect.Value, namespace string) (reflect.Value, reflect.Kind, bool, bool)
}
var _ FieldLevel = new(validate)
// Field returns current field for validation
func (v *validate) Field() reflect.Value {
return v.flField
}
// FieldName returns the field's name with the tag
// name taking precedence over the fields actual name.
func (v *validate) FieldName() string {
return v.cf.altName
}
// GetTag returns the current validations tag name
func (v *validate) GetTag() string {
return v.ct.tag
}
// StructFieldName returns the struct field's name
func (v *validate) StructFieldName() string {
return v.cf.name
}
// Param returns param for validation against current field
func (v *validate) Param() string {
return v.ct.param
}
// GetStructFieldOK returns Param returns param for validation against current field
//
// Deprecated: Use GetStructFieldOK2() instead which also return if the value is nullable.
func (v *validate) GetStructFieldOK() (reflect.Value, reflect.Kind, bool) {
current, kind, _, found := v.getStructFieldOKInternal(v.slflParent, v.ct.param)
return current, kind, found
}
// GetStructFieldOKAdvanced is the same as GetStructFieldOK except that it accepts the parent struct to start looking for
// the field and namespace allowing more extensibility for validators.
//
// Deprecated: Use GetStructFieldOKAdvanced2() instead which also return if the value is nullable.
func (v *validate) GetStructFieldOKAdvanced(val reflect.Value, namespace string) (reflect.Value, reflect.Kind, bool) {
current, kind, _, found := v.GetStructFieldOKAdvanced2(val, namespace)
return current, kind, found
}
// GetStructFieldOK2 returns Param returns param for validation against current field
func (v *validate) GetStructFieldOK2() (reflect.Value, reflect.Kind, bool, bool) {
return v.getStructFieldOKInternal(v.slflParent, v.ct.param)
}
// GetStructFieldOKAdvanced2 is the same as GetStructFieldOK except that it accepts the parent struct to start looking for
// the field and namespace allowing more extensibility for validators.
func (v *validate) GetStructFieldOKAdvanced2(val reflect.Value, namespace string) (reflect.Value, reflect.Kind, bool, bool) {
return v.getStructFieldOKInternal(val, namespace)
}
================================================
FILE: gossh/gin/validator/options.go
================================================
package validator
// Option represents a configurations option to be applied to validator during initialization.
type Option func(*Validate)
// WithRequiredStructEnabled enables required tag on non-pointer structs to be applied instead of ignored.
//
// This was made opt-in behaviour in order to maintain backward compatibility with the behaviour previous
// to being able to apply struct level validations on struct fields directly.
//
// It is recommended you enabled this as it will be the default behaviour in v11+
func WithRequiredStructEnabled() Option {
return func(v *Validate) {
v.requiredStructEnabled = true
}
}
================================================
FILE: gossh/gin/validator/postcode_regexes.go
================================================
package validator
import "regexp"
var postCodePatternDict = map[string]string{
"GB": `^GIR[ ]?0AA|((AB|AL|B|BA|BB|BD|BH|BL|BN|BR|BS|BT|CA|CB|CF|CH|CM|CO|CR|CT|CV|CW|DA|DD|DE|DG|DH|DL|DN|DT|DY|E|EC|EH|EN|EX|FK|FY|G|GL|GY|GU|HA|HD|HG|HP|HR|HS|HU|HX|IG|IM|IP|IV|JE|KA|KT|KW|KY|L|LA|LD|LE|LL|LN|LS|LU|M|ME|MK|ML|N|NE|NG|NN|NP|NR|NW|OL|OX|PA|PE|PH|PL|PO|PR|RG|RH|RM|S|SA|SE|SG|SK|SL|SM|SN|SO|SP|SR|SS|ST|SW|SY|TA|TD|TF|TN|TQ|TR|TS|TW|UB|W|WA|WC|WD|WF|WN|WR|WS|WV|YO|ZE)(\d[\dA-Z]?[ ]?\d[ABD-HJLN-UW-Z]{2}))|BFPO[ ]?\d{1,4}$`,
"JE": `^JE\d[\dA-Z]?[ ]?\d[ABD-HJLN-UW-Z]{2}$`,
"GG": `^GY\d[\dA-Z]?[ ]?\d[ABD-HJLN-UW-Z]{2}$`,
"IM": `^IM\d[\dA-Z]?[ ]?\d[ABD-HJLN-UW-Z]{2}$`,
"US": `^\d{5}([ \-]\d{4})?$`,
"CA": `^[ABCEGHJKLMNPRSTVXY]\d[ABCEGHJ-NPRSTV-Z][ ]?\d[ABCEGHJ-NPRSTV-Z]\d$`,
"DE": `^\d{5}$`,
"JP": `^\d{3}-\d{4}$`,
"FR": `^\d{2}[ ]?\d{3}$`,
"AU": `^\d{4}$`,
"IT": `^\d{5}$`,
"CH": `^\d{4}$`,
"AT": `^\d{4}$`,
"ES": `^\d{5}$`,
"NL": `^\d{4}[ ]?[A-Z]{2}$`,
"BE": `^\d{4}$`,
"DK": `^\d{4}$`,
"SE": `^\d{3}[ ]?\d{2}$`,
"NO": `^\d{4}$`,
"BR": `^\d{5}[\-]?\d{3}$`,
"PT": `^\d{4}([\-]\d{3})?$`,
"FI": `^\d{5}$`,
"AX": `^22\d{3}$`,
"KR": `^\d{3}[\-]\d{3}$`,
"CN": `^\d{6}$`,
"TW": `^\d{3}(\d{2})?$`,
"SG": `^\d{6}$`,
"DZ": `^\d{5}$`,
"AD": `^AD\d{3}$`,
"AR": `^([A-HJ-NP-Z])?\d{4}([A-Z]{3})?$`,
"AM": `^(37)?\d{4}$`,
"AZ": `^\d{4}$`,
"BH": `^((1[0-2]|[2-9])\d{2})?$`,
"BD": `^\d{4}$`,
"BB": `^(BB\d{5})?$`,
"BY": `^\d{6}$`,
"BM": `^[A-Z]{2}[ ]?[A-Z0-9]{2}$`,
"BA": `^\d{5}$`,
"IO": `^BBND 1ZZ$`,
"BN": `^[A-Z]{2}[ ]?\d{4}$`,
"BG": `^\d{4}$`,
"KH": `^\d{5}$`,
"CV": `^\d{4}$`,
"CL": `^\d{7}$`,
"CR": `^\d{4,5}|\d{3}-\d{4}$`,
"HR": `^\d{5}$`,
"CY": `^\d{4}$`,
"CZ": `^\d{3}[ ]?\d{2}$`,
"DO": `^\d{5}$`,
"EC": `^([A-Z]\d{4}[A-Z]|(?:[A-Z]{2})?\d{6})?$`,
"EG": `^\d{5}$`,
"EE": `^\d{5}$`,
"FO": `^\d{3}$`,
"GE": `^\d{4}$`,
"GR": `^\d{3}[ ]?\d{2}$`,
"GL": `^39\d{2}$`,
"GT": `^\d{5}$`,
"HT": `^\d{4}$`,
"HN": `^(?:\d{5})?$`,
"HU": `^\d{4}$`,
"IS": `^\d{3}$`,
"IN": `^\d{6}$`,
"ID": `^\d{5}$`,
"IL": `^\d{5}$`,
"JO": `^\d{5}$`,
"KZ": `^\d{6}$`,
"KE": `^\d{5}$`,
"KW": `^\d{5}$`,
"LA": `^\d{5}$`,
"LV": `^\d{4}$`,
"LB": `^(\d{4}([ ]?\d{4})?)?$`,
"LI": `^(948[5-9])|(949[0-7])$`,
"LT": `^\d{5}$`,
"LU": `^\d{4}$`,
"MK": `^\d{4}$`,
"MY": `^\d{5}$`,
"MV": `^\d{5}$`,
"MT": `^[A-Z]{3}[ ]?\d{2,4}$`,
"MU": `^(\d{3}[A-Z]{2}\d{3})?$`,
"MX": `^\d{5}$`,
"MD": `^\d{4}$`,
"MC": `^980\d{2}$`,
"MA": `^\d{5}$`,
"NP": `^\d{5}$`,
"NZ": `^\d{4}$`,
"NI": `^((\d{4}-)?\d{3}-\d{3}(-\d{1})?)?$`,
"NG": `^(\d{6})?$`,
"OM": `^(PC )?\d{3}$`,
"PK": `^\d{5}$`,
"PY": `^\d{4}$`,
"PH": `^\d{4}$`,
"PL": `^\d{2}-\d{3}$`,
"PR": `^00[679]\d{2}([ \-]\d{4})?$`,
"RO": `^\d{6}$`,
"RU": `^\d{6}$`,
"SM": `^4789\d$`,
"SA": `^\d{5}$`,
"SN": `^\d{5}$`,
"SK": `^\d{3}[ ]?\d{2}$`,
"SI": `^\d{4}$`,
"ZA": `^\d{4}$`,
"LK": `^\d{5}$`,
"TJ": `^\d{6}$`,
"TH": `^\d{5}$`,
"TN": `^\d{4}$`,
"TR": `^\d{5}$`,
"TM": `^\d{6}$`,
"UA": `^\d{5}$`,
"UY": `^\d{5}$`,
"UZ": `^\d{6}$`,
"VA": `^00120$`,
"VE": `^\d{4}$`,
"ZM": `^\d{5}$`,
"AS": `^96799$`,
"CC": `^6799$`,
"CK": `^\d{4}$`,
"RS": `^\d{6}$`,
"ME": `^8\d{4}$`,
"CS": `^\d{5}$`,
"YU": `^\d{5}$`,
"CX": `^6798$`,
"ET": `^\d{4}$`,
"FK": `^FIQQ 1ZZ$`,
"NF": `^2899$`,
"FM": `^(9694[1-4])([ \-]\d{4})?$`,
"GF": `^9[78]3\d{2}$`,
"GN": `^\d{3}$`,
"GP": `^9[78][01]\d{2}$`,
"GS": `^SIQQ 1ZZ$`,
"GU": `^969[123]\d([ \-]\d{4})?$`,
"GW": `^\d{4}$`,
"HM": `^\d{4}$`,
"IQ": `^\d{5}$`,
"KG": `^\d{6}$`,
"LR": `^\d{4}$`,
"LS": `^\d{3}$`,
"MG": `^\d{3}$`,
"MH": `^969[67]\d([ \-]\d{4})?$`,
"MN": `^\d{6}$`,
"MP": `^9695[012]([ \-]\d{4})?$`,
"MQ": `^9[78]2\d{2}$`,
"NC": `^988\d{2}$`,
"NE": `^\d{4}$`,
"VI": `^008(([0-4]\d)|(5[01]))([ \-]\d{4})?$`,
"VN": `^[0-9]{1,6}$`,
"PF": `^987\d{2}$`,
"PG": `^\d{3}$`,
"PM": `^9[78]5\d{2}$`,
"PN": `^PCRN 1ZZ$`,
"PW": `^96940$`,
"RE": `^9[78]4\d{2}$`,
"SH": `^(ASCN|STHL) 1ZZ$`,
"SJ": `^\d{4}$`,
"SO": `^\d{5}$`,
"SZ": `^[HLMS]\d{3}$`,
"TC": `^TKCA 1ZZ$`,
"WF": `^986\d{2}$`,
"XK": `^\d{5}$`,
"YT": `^976\d{2}$`,
}
var postCodeRegexDict = map[string]*regexp.Regexp{}
func init() {
for countryCode, pattern := range postCodePatternDict {
postCodeRegexDict[countryCode] = regexp.MustCompile(pattern)
}
}
================================================
FILE: gossh/gin/validator/regexes.go
================================================
package validator
import "regexp"
const (
alphaRegexString = "^[a-zA-Z]+$"
alphaNumericRegexString = "^[a-zA-Z0-9]+$"
alphaUnicodeRegexString = "^[\\p{L}]+$"
alphaUnicodeNumericRegexString = "^[\\p{L}\\p{N}]+$"
numericRegexString = "^[-+]?[0-9]+(?:\\.[0-9]+)?$"
numberRegexString = "^[0-9]+$"
hexadecimalRegexString = "^(0[xX])?[0-9a-fA-F]+$"
hexColorRegexString = "^#(?:[0-9a-fA-F]{3}|[0-9a-fA-F]{4}|[0-9a-fA-F]{6}|[0-9a-fA-F]{8})$"
rgbRegexString = "^rgb\\(\\s*(?:(?:0|[1-9]\\d?|1\\d\\d?|2[0-4]\\d|25[0-5])\\s*,\\s*(?:0|[1-9]\\d?|1\\d\\d?|2[0-4]\\d|25[0-5])\\s*,\\s*(?:0|[1-9]\\d?|1\\d\\d?|2[0-4]\\d|25[0-5])|(?:0|[1-9]\\d?|1\\d\\d?|2[0-4]\\d|25[0-5])%\\s*,\\s*(?:0|[1-9]\\d?|1\\d\\d?|2[0-4]\\d|25[0-5])%\\s*,\\s*(?:0|[1-9]\\d?|1\\d\\d?|2[0-4]\\d|25[0-5])%)\\s*\\)$"
rgbaRegexString = "^rgba\\(\\s*(?:(?:0|[1-9]\\d?|1\\d\\d?|2[0-4]\\d|25[0-5])\\s*,\\s*(?:0|[1-9]\\d?|1\\d\\d?|2[0-4]\\d|25[0-5])\\s*,\\s*(?:0|[1-9]\\d?|1\\d\\d?|2[0-4]\\d|25[0-5])|(?:0|[1-9]\\d?|1\\d\\d?|2[0-4]\\d|25[0-5])%\\s*,\\s*(?:0|[1-9]\\d?|1\\d\\d?|2[0-4]\\d|25[0-5])%\\s*,\\s*(?:0|[1-9]\\d?|1\\d\\d?|2[0-4]\\d|25[0-5])%)\\s*,\\s*(?:(?:0.[1-9]*)|[01])\\s*\\)$"
hslRegexString = "^hsl\\(\\s*(?:0|[1-9]\\d?|[12]\\d\\d|3[0-5]\\d|360)\\s*,\\s*(?:(?:0|[1-9]\\d?|100)%)\\s*,\\s*(?:(?:0|[1-9]\\d?|100)%)\\s*\\)$"
hslaRegexString = "^hsla\\(\\s*(?:0|[1-9]\\d?|[12]\\d\\d|3[0-5]\\d|360)\\s*,\\s*(?:(?:0|[1-9]\\d?|100)%)\\s*,\\s*(?:(?:0|[1-9]\\d?|100)%)\\s*,\\s*(?:(?:0.[1-9]*)|[01])\\s*\\)$"
emailRegexString = "^(?:(?:(?:(?:[a-zA-Z]|\\d|[!#\\$%&'\\*\\+\\-\\/=\\?\\^_`{\\|}~]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])+(?:\\.([a-zA-Z]|\\d|[!#\\$%&'\\*\\+\\-\\/=\\?\\^_`{\\|}~]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])+)*)|(?:(?:\\x22)(?:(?:(?:(?:\\x20|\\x09)*(?:\\x0d\\x0a))?(?:\\x20|\\x09)+)?(?:(?:[\\x01-\\x08\\x0b\\x0c\\x0e-\\x1f\\x7f]|\\x21|[\\x23-\\x5b]|[\\x5d-\\x7e]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])|(?:(?:[\\x01-\\x09\\x0b\\x0c\\x0d-\\x7f]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}]))))*(?:(?:(?:\\x20|\\x09)*(?:\\x0d\\x0a))?(\\x20|\\x09)+)?(?:\\x22))))@(?:(?:(?:[a-zA-Z]|\\d|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])|(?:(?:[a-zA-Z]|\\d|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])(?:[a-zA-Z]|\\d|-|\\.|~|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])*(?:[a-zA-Z]|\\d|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])))\\.)+(?:(?:[a-zA-Z]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])|(?:(?:[a-zA-Z]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])(?:[a-zA-Z]|\\d|-|\\.|~|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])*(?:[a-zA-Z]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])))\\.?$"
e164RegexString = "^\\+[1-9]?[0-9]{7,14}$"
base64RegexString = "^(?:[A-Za-z0-9+\\/]{4})*(?:[A-Za-z0-9+\\/]{2}==|[A-Za-z0-9+\\/]{3}=|[A-Za-z0-9+\\/]{4})$"
base64URLRegexString = "^(?:[A-Za-z0-9-_]{4})*(?:[A-Za-z0-9-_]{2}==|[A-Za-z0-9-_]{3}=|[A-Za-z0-9-_]{4})$"
base64RawURLRegexString = "^(?:[A-Za-z0-9-_]{4})*(?:[A-Za-z0-9-_]{2,4})$"
iSBN10RegexString = "^(?:[0-9]{9}X|[0-9]{10})$"
iSBN13RegexString = "^(?:(?:97(?:8|9))[0-9]{10})$"
iSSNRegexString = "^(?:[0-9]{4}-[0-9]{3}[0-9X])$"
uUID3RegexString = "^[0-9a-f]{8}-[0-9a-f]{4}-3[0-9a-f]{3}-[0-9a-f]{4}-[0-9a-f]{12}$"
uUID4RegexString = "^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$"
uUID5RegexString = "^[0-9a-f]{8}-[0-9a-f]{4}-5[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$"
uUIDRegexString = "^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"
uUID3RFC4122RegexString = "^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-3[0-9a-fA-F]{3}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$"
uUID4RFC4122RegexString = "^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-4[0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
uUID5RFC4122RegexString = "^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-5[0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
uUIDRFC4122RegexString = "^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$"
uLIDRegexString = "^[A-HJKMNP-TV-Z0-9]{26}$"
md4RegexString = "^[0-9a-f]{32}$"
md5RegexString = "^[0-9a-f]{32}$"
sha256RegexString = "^[0-9a-f]{64}$"
sha384RegexString = "^[0-9a-f]{96}$"
sha512RegexString = "^[0-9a-f]{128}$"
ripemd128RegexString = "^[0-9a-f]{32}$"
ripemd160RegexString = "^[0-9a-f]{40}$"
tiger128RegexString = "^[0-9a-f]{32}$"
tiger160RegexString = "^[0-9a-f]{40}$"
tiger192RegexString = "^[0-9a-f]{48}$"
aSCIIRegexString = "^[\x00-\x7F]*$"
printableASCIIRegexString = "^[\x20-\x7E]*$"
multibyteRegexString = "[^\x00-\x7F]"
dataURIRegexString = `^data:((?:\w+\/(?:([^;]|;[^;]).)+)?)`
latitudeRegexString = "^[-+]?([1-8]?\\d(\\.\\d+)?|90(\\.0+)?)$"
longitudeRegexString = "^[-+]?(180(\\.0+)?|((1[0-7]\\d)|([1-9]?\\d))(\\.\\d+)?)$"
sSNRegexString = `^[0-9]{3}[ -]?(0[1-9]|[1-9][0-9])[ -]?([1-9][0-9]{3}|[0-9][1-9][0-9]{2}|[0-9]{2}[1-9][0-9]|[0-9]{3}[1-9])$`
hostnameRegexStringRFC952 = `^[a-zA-Z]([a-zA-Z0-9\-]+[\.]?)*[a-zA-Z0-9]$` // https://tools.ietf.org/html/rfc952
hostnameRegexStringRFC1123 = `^([a-zA-Z0-9]{1}[a-zA-Z0-9-]{0,62}){1}(\.[a-zA-Z0-9]{1}[a-zA-Z0-9-]{0,62})*?$` // accepts hostname starting with a digit https://tools.ietf.org/html/rfc1123
fqdnRegexStringRFC1123 = `^([a-zA-Z0-9]{1}[a-zA-Z0-9-]{0,62})(\.[a-zA-Z0-9]{1}[a-zA-Z0-9-]{0,62})*?(\.[a-zA-Z]{1}[a-zA-Z0-9]{0,62})\.?$` // same as hostnameRegexStringRFC1123 but must contain a non numerical TLD (possibly ending with '.')
btcAddressRegexString = `^[13][a-km-zA-HJ-NP-Z1-9]{25,34}$` // bitcoin address
btcAddressUpperRegexStringBech32 = `^BC1[02-9AC-HJ-NP-Z]{7,76}$` // bitcoin bech32 address https://en.bitcoin.it/wiki/Bech32
btcAddressLowerRegexStringBech32 = `^bc1[02-9ac-hj-np-z]{7,76}$` // bitcoin bech32 address https://en.bitcoin.it/wiki/Bech32
ethAddressRegexString = `^0x[0-9a-fA-F]{40}$`
ethAddressUpperRegexString = `^0x[0-9A-F]{40}$`
ethAddressLowerRegexString = `^0x[0-9a-f]{40}$`
uRLEncodedRegexString = `^(?:[^%]|%[0-9A-Fa-f]{2})*$`
hTMLEncodedRegexString = `[x]?([0-9a-fA-F]{2})|(>)|(<)|(")|(&)+[;]?`
hTMLRegexString = `<[/]?([a-zA-Z]+).*?>`
jWTRegexString = "^[A-Za-z0-9-_]+\\.[A-Za-z0-9-_]+\\.[A-Za-z0-9-_]*$"
splitParamsRegexString = `'[^']*'|\S+`
bicRegexString = `^[A-Za-z]{6}[A-Za-z0-9]{2}([A-Za-z0-9]{3})?$`
semverRegexString = `^(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)(?:-((?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\+([0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*))?$` // numbered capture groups https://semver.org/
dnsRegexStringRFC1035Label = "^[a-z]([-a-z0-9]*[a-z0-9]){0,62}$"
cveRegexString = `^CVE-(1999|2\d{3})-(0[^0]\d{2}|0\d[^0]\d{1}|0\d{2}[^0]|[1-9]{1}\d{3,})$` // CVE Format Id https://cve.mitre.org/cve/identifiers/syntaxchange.html
mongodbRegexString = "^[a-f\\d]{24}$"
cronRegexString = `(@(annually|yearly|monthly|weekly|daily|hourly|reboot))|(@every (\d+(ns|us|µs|ms|s|m|h))+)|((((\d+,)+\d+|(\d+(\/|-)\d+)|\d+|\*) ?){5,7})`
spicedbIDRegexString = `^(([a-zA-Z0-9/_|\-=+]{1,})|\*)$`
spicedbPermissionRegexString = "^([a-z][a-z0-9_]{1,62}[a-z0-9])?$"
spicedbTypeRegexString = "^([a-z][a-z0-9_]{1,61}[a-z0-9]/)?[a-z][a-z0-9_]{1,62}[a-z0-9]$"
)
var (
alphaRegex = regexp.MustCompile(alphaRegexString)
alphaNumericRegex = regexp.MustCompile(alphaNumericRegexString)
alphaUnicodeRegex = regexp.MustCompile(alphaUnicodeRegexString)
alphaUnicodeNumericRegex = regexp.MustCompile(alphaUnicodeNumericRegexString)
numericRegex = regexp.MustCompile(numericRegexString)
numberRegex = regexp.MustCompile(numberRegexString)
hexadecimalRegex = regexp.MustCompile(hexadecimalRegexString)
hexColorRegex = regexp.MustCompile(hexColorRegexString)
rgbRegex = regexp.MustCompile(rgbRegexString)
rgbaRegex = regexp.MustCompile(rgbaRegexString)
hslRegex = regexp.MustCompile(hslRegexString)
hslaRegex = regexp.MustCompile(hslaRegexString)
e164Regex = regexp.MustCompile(e164RegexString)
emailRegex = regexp.MustCompile(emailRegexString)
base64Regex = regexp.MustCompile(base64RegexString)
base64URLRegex = regexp.MustCompile(base64URLRegexString)
base64RawURLRegex = regexp.MustCompile(base64RawURLRegexString)
iSBN10Regex = regexp.MustCompile(iSBN10RegexString)
iSBN13Regex = regexp.MustCompile(iSBN13RegexString)
iSSNRegex = regexp.MustCompile(iSSNRegexString)
uUID3Regex = regexp.MustCompile(uUID3RegexString)
uUID4Regex = regexp.MustCompile(uUID4RegexString)
uUID5Regex = regexp.MustCompile(uUID5RegexString)
uUIDRegex = regexp.MustCompile(uUIDRegexString)
uUID3RFC4122Regex = regexp.MustCompile(uUID3RFC4122RegexString)
uUID4RFC4122Regex = regexp.MustCompile(uUID4RFC4122RegexString)
uUID5RFC4122Regex = regexp.MustCompile(uUID5RFC4122RegexString)
uUIDRFC4122Regex = regexp.MustCompile(uUIDRFC4122RegexString)
uLIDRegex = regexp.MustCompile(uLIDRegexString)
md4Regex = regexp.MustCompile(md4RegexString)
md5Regex = regexp.MustCompile(md5RegexString)
sha256Regex = regexp.MustCompile(sha256RegexString)
sha384Regex = regexp.MustCompile(sha384RegexString)
sha512Regex = regexp.MustCompile(sha512RegexString)
ripemd128Regex = regexp.MustCompile(ripemd128RegexString)
ripemd160Regex = regexp.MustCompile(ripemd160RegexString)
tiger128Regex = regexp.MustCompile(tiger128RegexString)
tiger160Regex = regexp.MustCompile(tiger160RegexString)
tiger192Regex = regexp.MustCompile(tiger192RegexString)
aSCIIRegex = regexp.MustCompile(aSCIIRegexString)
printableASCIIRegex = regexp.MustCompile(printableASCIIRegexString)
multibyteRegex = regexp.MustCompile(multibyteRegexString)
dataURIRegex = regexp.MustCompile(dataURIRegexString)
latitudeRegex = regexp.MustCompile(latitudeRegexString)
longitudeRegex = regexp.MustCompile(longitudeRegexString)
sSNRegex = regexp.MustCompile(sSNRegexString)
hostnameRegexRFC952 = regexp.MustCompile(hostnameRegexStringRFC952)
hostnameRegexRFC1123 = regexp.MustCompile(hostnameRegexStringRFC1123)
fqdnRegexRFC1123 = regexp.MustCompile(fqdnRegexStringRFC1123)
btcAddressRegex = regexp.MustCompile(btcAddressRegexString)
btcUpperAddressRegexBech32 = regexp.MustCompile(btcAddressUpperRegexStringBech32)
btcLowerAddressRegexBech32 = regexp.MustCompile(btcAddressLowerRegexStringBech32)
ethAddressRegex = regexp.MustCompile(ethAddressRegexString)
uRLEncodedRegex = regexp.MustCompile(uRLEncodedRegexString)
hTMLEncodedRegex = regexp.MustCompile(hTMLEncodedRegexString)
hTMLRegex = regexp.MustCompile(hTMLRegexString)
jWTRegex = regexp.MustCompile(jWTRegexString)
splitParamsRegex = regexp.MustCompile(splitParamsRegexString)
bicRegex = regexp.MustCompile(bicRegexString)
semverRegex = regexp.MustCompile(semverRegexString)
dnsRegexRFC1035Label = regexp.MustCompile(dnsRegexStringRFC1035Label)
cveRegex = regexp.MustCompile(cveRegexString)
mongodbRegex = regexp.MustCompile(mongodbRegexString)
cronRegex = regexp.MustCompile(cronRegexString)
spicedbIDRegex = regexp.MustCompile(spicedbIDRegexString)
spicedbPermissionRegex = regexp.MustCompile(spicedbPermissionRegexString)
spicedbTypeRegex = regexp.MustCompile(spicedbTypeRegexString)
)
================================================
FILE: gossh/gin/validator/struct_level.go
================================================
package validator
import (
"context"
"reflect"
)
// StructLevelFunc accepts all values needed for struct level validation
type StructLevelFunc func(sl StructLevel)
// StructLevelFuncCtx accepts all values needed for struct level validation
// but also allows passing of contextual validation information via context.Context.
type StructLevelFuncCtx func(ctx context.Context, sl StructLevel)
// wrapStructLevelFunc wraps normal StructLevelFunc makes it compatible with StructLevelFuncCtx
func wrapStructLevelFunc(fn StructLevelFunc) StructLevelFuncCtx {
return func(ctx context.Context, sl StructLevel) {
fn(sl)
}
}
// StructLevel contains all the information and helper functions
// to validate a struct
type StructLevel interface {
// Validator returns the main validation object, in case one wants to call validations internally.
// this is so you don't have to use anonymous functions to get access to the validate
// instance.
Validator() *Validate
// Top returns the top level struct, if any
Top() reflect.Value
// Parent returns the current fields parent struct, if any
Parent() reflect.Value
// Current returns the current struct.
Current() reflect.Value
// ExtractType gets the actual underlying type of field value.
// It will dive into pointers, customTypes and return you the
// underlying value and its kind.
ExtractType(field reflect.Value) (value reflect.Value, kind reflect.Kind, nullable bool)
// ReportError reports an error just by passing the field and tag information
//
// NOTES:
//
// fieldName and altName get appended to the existing namespace that
// validator is on. e.g. pass 'FirstName' or 'Names[0]' depending
// on the nesting
//
// tag can be an existing validation tag or just something you make up
// and process on the flip side it's up to you.
ReportError(field any, fieldName, structFieldName string, tag, param string)
// ReportValidationErrors reports an error just by passing ValidationErrors
//
// NOTES:
//
// relativeNamespace and relativeActualNamespace get appended to the
// existing namespace that validator is on.
// e.g. pass 'User.FirstName' or 'Users[0].FirstName' depending
// on the nesting. most of the time they will be blank, unless you validate
// at a level lower the current field depth
ReportValidationErrors(relativeNamespace, relativeActualNamespace string, errs ValidationErrors)
}
var _ StructLevel = new(validate)
// Top returns the top level struct
//
// NOTE: this can be the same as the current struct being validated
// if not is a nested struct.
//
// this is only called when within Struct and Field Level validation and
// should not be relied upon for an accurate value otherwise.
func (v *validate) Top() reflect.Value {
return v.top
}
// Parent returns the current structs parent
//
// NOTE: this can be the same as the current struct being validated
// if not is a nested struct.
//
// this is only called when within Struct and Field Level validation and
// should not be relied upon for an accurate value otherwise.
func (v *validate) Parent() reflect.Value {
return v.slflParent
}
// Current returns the current struct.
func (v *validate) Current() reflect.Value {
return v.slCurrent
}
// Validator returns the main validation object, in case one want to call validations internally.
func (v *validate) Validator() *Validate {
return v.v
}
// ExtractType gets the actual underlying type of field value.
func (v *validate) ExtractType(field reflect.Value) (reflect.Value, reflect.Kind, bool) {
return v.extractTypeInternal(field, false)
}
// ReportError reports an error just by passing the field and tag information
func (v *validate) ReportError(field any, fieldName, structFieldName, tag, param string) {
fv, kind, _ := v.extractTypeInternal(reflect.ValueOf(field), false)
if len(structFieldName) == 0 {
structFieldName = fieldName
}
v.str1 = string(append(v.ns, fieldName...))
if v.v.hasTagNameFunc || fieldName != structFieldName {
v.str2 = string(append(v.actualNs, structFieldName...))
} else {
v.str2 = v.str1
}
if kind == reflect.Invalid {
v.errs = append(v.errs,
&fieldError{
v: v.v,
tag: tag,
actualTag: tag,
ns: v.str1,
structNs: v.str2,
fieldLen: uint8(len(fieldName)),
structfieldLen: uint8(len(structFieldName)),
param: param,
kind: kind,
},
)
return
}
v.errs = append(v.errs,
&fieldError{
v: v.v,
tag: tag,
actualTag: tag,
ns: v.str1,
structNs: v.str2,
fieldLen: uint8(len(fieldName)),
structfieldLen: uint8(len(structFieldName)),
value: fv.Interface(),
param: param,
kind: kind,
typ: fv.Type(),
},
)
}
// ReportValidationErrors reports ValidationErrors obtained from running validations within the Struct Level validation.
//
// NOTE: this function prepends the current namespace to the relative ones.
func (v *validate) ReportValidationErrors(relativeNamespace, relativeStructNamespace string, errs ValidationErrors) {
var err *fieldError
for i := 0; i < len(errs); i++ {
err = errs[i].(*fieldError)
err.ns = string(append(append(v.ns, relativeNamespace...), err.ns...))
err.structNs = string(append(append(v.actualNs, relativeStructNamespace...), err.structNs...))
v.errs = append(v.errs, err)
}
}
================================================
FILE: gossh/gin/validator/util.go
================================================
package validator
import (
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
"time"
)
// extractTypeInternal gets the actual underlying type of field value.
// It will dive into pointers, customTypes and return you the
// underlying value and it's kind.
func (v *validate) extractTypeInternal(current reflect.Value, nullable bool) (reflect.Value, reflect.Kind, bool) {
BEGIN:
switch current.Kind() {
case reflect.Ptr:
nullable = true
if current.IsNil() {
return current, reflect.Ptr, nullable
}
current = current.Elem()
goto BEGIN
case reflect.Interface:
nullable = true
if current.IsNil() {
return current, reflect.Interface, nullable
}
current = current.Elem()
goto BEGIN
case reflect.Invalid:
return current, reflect.Invalid, nullable
default:
if v.v.hasCustomFuncs {
if fn, ok := v.v.customFuncs[current.Type()]; ok {
current = reflect.ValueOf(fn(current))
goto BEGIN
}
}
return current, current.Kind(), nullable
}
}
// getStructFieldOKInternal traverses a struct to retrieve a specific field denoted by the provided namespace and
// returns the field, field kind and whether is was successful in retrieving the field at all.
//
// NOTE: when not successful ok will be false, this can happen when a nested struct is nil and so the field
// could not be retrieved because it didn't exist.
func (v *validate) getStructFieldOKInternal(val reflect.Value, namespace string) (current reflect.Value, kind reflect.Kind, nullable bool, found bool) {
BEGIN:
current, kind, nullable = v.ExtractType(val)
if kind == reflect.Invalid {
return
}
if namespace == "" {
found = true
return
}
switch kind {
case reflect.Ptr, reflect.Interface:
return
case reflect.Struct:
typ := current.Type()
fld := namespace
var ns string
if !typ.ConvertibleTo(timeType) {
idx := strings.Index(namespace, namespaceSeparator)
if idx != -1 {
fld = namespace[:idx]
ns = namespace[idx+1:]
} else {
ns = ""
}
bracketIdx := strings.Index(fld, leftBracket)
if bracketIdx != -1 {
fld = fld[:bracketIdx]
ns = namespace[bracketIdx:]
}
val = current.FieldByName(fld)
namespace = ns
goto BEGIN
}
case reflect.Array, reflect.Slice:
idx := strings.Index(namespace, leftBracket)
idx2 := strings.Index(namespace, rightBracket)
arrIdx, _ := strconv.Atoi(namespace[idx+1 : idx2])
if arrIdx >= current.Len() {
return
}
startIdx := idx2 + 1
if startIdx < len(namespace) {
if namespace[startIdx:startIdx+1] == namespaceSeparator {
startIdx++
}
}
val = current.Index(arrIdx)
namespace = namespace[startIdx:]
goto BEGIN
case reflect.Map:
idx := strings.Index(namespace, leftBracket) + 1
idx2 := strings.Index(namespace, rightBracket)
endIdx := idx2
if endIdx+1 < len(namespace) {
if namespace[endIdx+1:endIdx+2] == namespaceSeparator {
endIdx++
}
}
key := namespace[idx:idx2]
switch current.Type().Key().Kind() {
case reflect.Int:
i, _ := strconv.Atoi(key)
val = current.MapIndex(reflect.ValueOf(i))
namespace = namespace[endIdx+1:]
case reflect.Int8:
i, _ := strconv.ParseInt(key, 10, 8)
val = current.MapIndex(reflect.ValueOf(int8(i)))
namespace = namespace[endIdx+1:]
case reflect.Int16:
i, _ := strconv.ParseInt(key, 10, 16)
val = current.MapIndex(reflect.ValueOf(int16(i)))
namespace = namespace[endIdx+1:]
case reflect.Int32:
i, _ := strconv.ParseInt(key, 10, 32)
val = current.MapIndex(reflect.ValueOf(int32(i)))
namespace = namespace[endIdx+1:]
case reflect.Int64:
i, _ := strconv.ParseInt(key, 10, 64)
val = current.MapIndex(reflect.ValueOf(i))
namespace = namespace[endIdx+1:]
case reflect.Uint:
i, _ := strconv.ParseUint(key, 10, 0)
val = current.MapIndex(reflect.ValueOf(uint(i)))
namespace = namespace[endIdx+1:]
case reflect.Uint8:
i, _ := strconv.ParseUint(key, 10, 8)
val = current.MapIndex(reflect.ValueOf(uint8(i)))
namespace = namespace[endIdx+1:]
case reflect.Uint16:
i, _ := strconv.ParseUint(key, 10, 16)
val = current.MapIndex(reflect.ValueOf(uint16(i)))
namespace = namespace[endIdx+1:]
case reflect.Uint32:
i, _ := strconv.ParseUint(key, 10, 32)
val = current.MapIndex(reflect.ValueOf(uint32(i)))
namespace = namespace[endIdx+1:]
case reflect.Uint64:
i, _ := strconv.ParseUint(key, 10, 64)
val = current.MapIndex(reflect.ValueOf(i))
namespace = namespace[endIdx+1:]
case reflect.Float32:
f, _ := strconv.ParseFloat(key, 32)
val = current.MapIndex(reflect.ValueOf(float32(f)))
namespace = namespace[endIdx+1:]
case reflect.Float64:
f, _ := strconv.ParseFloat(key, 64)
val = current.MapIndex(reflect.ValueOf(f))
namespace = namespace[endIdx+1:]
case reflect.Bool:
b, _ := strconv.ParseBool(key)
val = current.MapIndex(reflect.ValueOf(b))
namespace = namespace[endIdx+1:]
// reflect.Type = string
default:
val = current.MapIndex(reflect.ValueOf(key))
namespace = namespace[endIdx+1:]
}
goto BEGIN
}
// if got here there was more namespace, cannot go any deeper
panic("Invalid field namespace")
}
// asInt returns the parameter as a int64
// or panics if it can't convert
func asInt(param string) int64 {
i, err := strconv.ParseInt(param, 0, 64)
panicIf(err)
return i
}
// asIntFromTimeDuration parses param as time.Duration and returns it as int64
// or panics on error.
func asIntFromTimeDuration(param string) int64 {
d, err := time.ParseDuration(param)
if err != nil {
// attempt parsing as an integer assuming nanosecond precision
return asInt(param)
}
return int64(d)
}
// asIntFromType calls the proper function to parse param as int64,
// given a field's Type t.
func asIntFromType(t reflect.Type, param string) int64 {
switch t {
case timeDurationType:
return asIntFromTimeDuration(param)
default:
return asInt(param)
}
}
// asUint returns the parameter as a uint64
// or panics if it can't convert
func asUint(param string) uint64 {
i, err := strconv.ParseUint(param, 0, 64)
panicIf(err)
return i
}
// asFloat64 returns the parameter as a float64
// or panics if it can't convert
func asFloat64(param string) float64 {
i, err := strconv.ParseFloat(param, 64)
panicIf(err)
return i
}
// asFloat64 returns the parameter as a float64
// or panics if it can't convert
func asFloat32(param string) float64 {
i, err := strconv.ParseFloat(param, 32)
panicIf(err)
return i
}
// asBool returns the parameter as a bool
// or panics if it can't convert
func asBool(param string) bool {
i, err := strconv.ParseBool(param)
panicIf(err)
return i
}
func panicIf(err error) {
if err != nil {
panic(err.Error())
}
}
// Checks if field value matches regex. If fl.Field can be cast to Stringer, it uses the Stringer interfaces
// String() return value. Otherwise, it uses fl.Field's String() value.
func fieldMatchesRegexByStringerValOrString(regex *regexp.Regexp, fl FieldLevel) bool {
switch fl.Field().Kind() {
case reflect.String:
return regex.MatchString(fl.Field().String())
default:
if stringer, ok := fl.Field().Interface().(fmt.Stringer); ok {
return regex.MatchString(stringer.String())
} else {
return regex.MatchString(fl.Field().String())
}
}
}
================================================
FILE: gossh/gin/validator/validator.go
================================================
package validator
import (
"context"
"fmt"
"reflect"
"strconv"
)
// per validate construct
type validate struct {
v *Validate
top reflect.Value
ns []byte
actualNs []byte
errs ValidationErrors
includeExclude map[string]struct{} // reset only if StructPartial or StructExcept are called, no need otherwise
ffn FilterFunc
slflParent reflect.Value // StructLevel & FieldLevel
slCurrent reflect.Value // StructLevel & FieldLevel
flField reflect.Value // StructLevel & FieldLevel
cf *cField // StructLevel & FieldLevel
ct *cTag // StructLevel & FieldLevel
misc []byte // misc reusable
str1 string // misc reusable
str2 string // misc reusable
fldIsPointer bool // StructLevel & FieldLevel
isPartial bool
hasExcludes bool
}
// parent and current will be the same the first run of validateStruct
func (v *validate) validateStruct(ctx context.Context, parent reflect.Value, current reflect.Value, typ reflect.Type, ns []byte, structNs []byte, ct *cTag) {
cs, ok := v.v.structCache.Get(typ)
if !ok {
cs = v.v.extractStructCache(current, typ.Name())
}
if len(ns) == 0 && len(cs.name) != 0 {
ns = append(ns, cs.name...)
ns = append(ns, '.')
structNs = append(structNs, cs.name...)
structNs = append(structNs, '.')
}
// ct is nil on top level struct, and structs as fields that have no tag info
// so if nil or if not nil and the structonly tag isn't present
if ct == nil || ct.typeof != typeStructOnly {
var f *cField
for i := 0; i < len(cs.fields); i++ {
f = cs.fields[i]
if v.isPartial {
if v.ffn != nil {
// used with StructFiltered
if v.ffn(append(structNs, f.name...)) {
continue
}
} else {
// used with StructPartial & StructExcept
_, ok = v.includeExclude[string(append(structNs, f.name...))]
if (ok && v.hasExcludes) || (!ok && !v.hasExcludes) {
continue
}
}
}
v.traverseField(ctx, current, current.Field(f.idx), ns, structNs, f, f.cTags)
}
}
// check if any struct level validations, after all field validations already checked.
// first iteration will have no info about nostructlevel tag, and is checked prior to
// calling the next iteration of validateStruct called from traverseField.
if cs.fn != nil {
v.slflParent = parent
v.slCurrent = current
v.ns = ns
v.actualNs = structNs
cs.fn(ctx, v)
}
}
// traverseField validates any field, be it a struct or single field, ensures it's validity and passes it along to be validated via it's tag options
func (v *validate) traverseField(ctx context.Context, parent reflect.Value, current reflect.Value, ns []byte, structNs []byte, cf *cField, ct *cTag) {
var typ reflect.Type
var kind reflect.Kind
current, kind, v.fldIsPointer = v.extractTypeInternal(current, false)
var isNestedStruct bool
switch kind {
case reflect.Ptr, reflect.Interface, reflect.Invalid:
if ct == nil {
return
}
if ct.typeof == typeOmitEmpty || ct.typeof == typeIsDefault {
return
}
if ct.typeof == typeOmitNil && (kind != reflect.Invalid && current.IsNil()) {
return
}
if ct.hasTag {
if kind == reflect.Invalid {
v.str1 = string(append(ns, cf.altName...))
if v.v.hasTagNameFunc {
v.str2 = string(append(structNs, cf.name...))
} else {
v.str2 = v.str1
}
v.errs = append(v.errs,
&fieldError{
v: v.v,
tag: ct.aliasTag,
actualTag: ct.tag,
ns: v.str1,
structNs: v.str2,
fieldLen: uint8(len(cf.altName)),
structfieldLen: uint8(len(cf.name)),
param: ct.param,
kind: kind,
},
)
return
}
v.str1 = string(append(ns, cf.altName...))
if v.v.hasTagNameFunc {
v.str2 = string(append(structNs, cf.name...))
} else {
v.str2 = v.str1
}
if !ct.runValidationWhenNil {
v.errs = append(v.errs,
&fieldError{
v: v.v,
tag: ct.aliasTag,
actualTag: ct.tag,
ns: v.str1,
structNs: v.str2,
fieldLen: uint8(len(cf.altName)),
structfieldLen: uint8(len(cf.name)),
value: current.Interface(),
param: ct.param,
kind: kind,
typ: current.Type(),
},
)
return
}
}
if kind == reflect.Invalid {
return
}
case reflect.Struct:
isNestedStruct = !current.Type().ConvertibleTo(timeType)
// For backward compatibility before struct level validation tags were supported
// as there were a number of projects relying on `required` not failing on non-pointer
// structs. Since it's basically nonsensical to use `required` with a non-pointer struct
// are explicitly skipping the required validation for it. This WILL be removed in the
// next major version.
if isNestedStruct && !v.v.requiredStructEnabled && ct != nil && ct.tag == requiredTag {
ct = ct.next
}
}
typ = current.Type()
OUTER:
for {
if ct == nil || !ct.hasTag || (isNestedStruct && len(cf.name) == 0) {
// isNestedStruct check here
if isNestedStruct {
// if len == 0 then validating using 'Var' or 'VarWithValue'
// Var - doesn't make much sense to do it that way, should call 'Struct', but no harm...
// VarWithField - this allows for validating against each field within the struct against a specific value
// pretty handy in certain situations
if len(cf.name) > 0 {
ns = append(append(ns, cf.altName...), '.')
structNs = append(append(structNs, cf.name...), '.')
}
v.validateStruct(ctx, parent, current, typ, ns, structNs, ct)
}
return
}
switch ct.typeof {
case typeNoStructLevel:
return
case typeStructOnly:
if isNestedStruct {
// if len == 0 then validating using 'Var' or 'VarWithValue'
// Var - doesn't make much sense to do it that way, should call 'Struct', but no harm...
// VarWithField - this allows for validating against each field within the struct against a specific value
// pretty handy in certain situations
if len(cf.name) > 0 {
ns = append(append(ns, cf.altName...), '.')
structNs = append(append(structNs, cf.name...), '.')
}
v.validateStruct(ctx, parent, current, typ, ns, structNs, ct)
}
return
case typeOmitEmpty:
// set Field Level fields
v.slflParent = parent
v.flField = current
v.cf = cf
v.ct = ct
if !hasValue(v) {
return
}
ct = ct.next
continue
case typeOmitNil:
v.slflParent = parent
v.flField = current
v.cf = cf
v.ct = ct
switch field := v.Field(); field.Kind() {
case reflect.Slice, reflect.Map, reflect.Ptr, reflect.Interface, reflect.Chan, reflect.Func:
if field.IsNil() {
return
}
default:
if v.fldIsPointer && field.Interface() == nil {
return
}
}
ct = ct.next
continue
case typeEndKeys:
return
case typeDive:
ct = ct.next
// traverse slice or map here
// or panic ;)
switch kind {
case reflect.Slice, reflect.Array:
var i64 int64
reusableCF := &cField{}
for i := 0; i < current.Len(); i++ {
i64 = int64(i)
v.misc = append(v.misc[0:0], cf.name...)
v.misc = append(v.misc, '[')
v.misc = strconv.AppendInt(v.misc, i64, 10)
v.misc = append(v.misc, ']')
reusableCF.name = string(v.misc)
if cf.namesEqual {
reusableCF.altName = reusableCF.name
} else {
v.misc = append(v.misc[0:0], cf.altName...)
v.misc = append(v.misc, '[')
v.misc = strconv.AppendInt(v.misc, i64, 10)
v.misc = append(v.misc, ']')
reusableCF.altName = string(v.misc)
}
v.traverseField(ctx, parent, current.Index(i), ns, structNs, reusableCF, ct)
}
case reflect.Map:
var pv string
reusableCF := &cField{}
for _, key := range current.MapKeys() {
pv = fmt.Sprintf("%v", key.Interface())
v.misc = append(v.misc[0:0], cf.name...)
v.misc = append(v.misc, '[')
v.misc = append(v.misc, pv...)
v.misc = append(v.misc, ']')
reusableCF.name = string(v.misc)
if cf.namesEqual {
reusableCF.altName = reusableCF.name
} else {
v.misc = append(v.misc[0:0], cf.altName...)
v.misc = append(v.misc, '[')
v.misc = append(v.misc, pv...)
v.misc = append(v.misc, ']')
reusableCF.altName = string(v.misc)
}
if ct != nil && ct.typeof == typeKeys && ct.keys != nil {
v.traverseField(ctx, parent, key, ns, structNs, reusableCF, ct.keys)
// can be nil when just keys being validated
if ct.next != nil {
v.traverseField(ctx, parent, current.MapIndex(key), ns, structNs, reusableCF, ct.next)
}
} else {
v.traverseField(ctx, parent, current.MapIndex(key), ns, structNs, reusableCF, ct)
}
}
default:
// throw error, if not a slice or map then should not have gotten here
// bad dive tag
panic("dive error! can't dive on a non slice or map")
}
return
case typeOr:
v.misc = v.misc[0:0]
for {
// set Field Level fields
v.slflParent = parent
v.flField = current
v.cf = cf
v.ct = ct
if ct.fn(ctx, v) {
if ct.isBlockEnd {
ct = ct.next
continue OUTER
}
// drain rest of the 'or' values, then continue or leave
for {
ct = ct.next
if ct == nil {
continue OUTER
}
if ct.typeof != typeOr {
continue OUTER
}
if ct.isBlockEnd {
ct = ct.next
continue OUTER
}
}
}
v.misc = append(v.misc, '|')
v.misc = append(v.misc, ct.tag...)
if ct.hasParam {
v.misc = append(v.misc, '=')
v.misc = append(v.misc, ct.param...)
}
if ct.isBlockEnd || ct.next == nil {
// if we get here, no valid 'or' value and no more tags
v.str1 = string(append(ns, cf.altName...))
if v.v.hasTagNameFunc {
v.str2 = string(append(structNs, cf.name...))
} else {
v.str2 = v.str1
}
if ct.hasAlias {
v.errs = append(v.errs,
&fieldError{
v: v.v,
tag: ct.aliasTag,
actualTag: ct.actualAliasTag,
ns: v.str1,
structNs: v.str2,
fieldLen: uint8(len(cf.altName)),
structfieldLen: uint8(len(cf.name)),
value: current.Interface(),
param: ct.param,
kind: kind,
typ: typ,
},
)
} else {
tVal := string(v.misc)[1:]
v.errs = append(v.errs,
&fieldError{
v: v.v,
tag: tVal,
actualTag: tVal,
ns: v.str1,
structNs: v.str2,
fieldLen: uint8(len(cf.altName)),
structfieldLen: uint8(len(cf.name)),
value: current.Interface(),
param: ct.param,
kind: kind,
typ: typ,
},
)
}
return
}
ct = ct.next
}
default:
// set Field Level fields
v.slflParent = parent
v.flField = current
v.cf = cf
v.ct = ct
if !ct.fn(ctx, v) {
v.str1 = string(append(ns, cf.altName...))
if v.v.hasTagNameFunc {
v.str2 = string(append(structNs, cf.name...))
} else {
v.str2 = v.str1
}
v.errs = append(v.errs,
&fieldError{
v: v.v,
tag: ct.aliasTag,
actualTag: ct.tag,
ns: v.str1,
structNs: v.str2,
fieldLen: uint8(len(cf.altName)),
structfieldLen: uint8(len(cf.name)),
value: current.Interface(),
param: ct.param,
kind: kind,
typ: typ,
},
)
return
}
ct = ct.next
}
}
}
================================================
FILE: gossh/gin/validator/validator_instance.go
================================================
package validator
import (
"context"
"errors"
"fmt"
"reflect"
"strings"
"sync"
"time"
)
const (
defaultTagName = "validate"
utf8HexComma = "0x2C"
utf8Pipe = "0x7C"
tagSeparator = ","
orSeparator = "|"
tagKeySeparator = "="
structOnlyTag = "structonly"
noStructLevelTag = "nostructlevel"
omitempty = "omitempty"
omitnil = "omitnil"
isdefault = "isdefault"
requiredWithoutAllTag = "required_without_all"
requiredWithoutTag = "required_without"
requiredWithTag = "required_with"
requiredWithAllTag = "required_with_all"
requiredIfTag = "required_if"
requiredUnlessTag = "required_unless"
skipUnlessTag = "skip_unless"
excludedWithoutAllTag = "excluded_without_all"
excludedWithoutTag = "excluded_without"
excludedWithTag = "excluded_with"
excludedWithAllTag = "excluded_with_all"
excludedIfTag = "excluded_if"
excludedUnlessTag = "excluded_unless"
skipValidationTag = "-"
diveTag = "dive"
keysTag = "keys"
endKeysTag = "endkeys"
requiredTag = "required"
namespaceSeparator = "."
leftBracket = "["
rightBracket = "]"
restrictedTagChars = ".[],|=+()`~!@#$%^&*\\\"/?<>{}"
restrictedAliasErr = "Alias '%s' either contains restricted characters or is the same as a restricted tag needed for normal operation"
restrictedTagErr = "Tag '%s' either contains restricted characters or is the same as a restricted tag needed for normal operation"
)
var (
timeDurationType = reflect.TypeOf(time.Duration(0))
timeType = reflect.TypeOf(time.Time{})
byteSliceType = reflect.TypeOf([]byte{})
defaultCField = &cField{namesEqual: true}
)
// FilterFunc is the type used to filter fields using
// StructFiltered(...) function.
// returning true results in the field being filtered/skipped from
// validation
type FilterFunc func(ns []byte) bool
// CustomTypeFunc allows for overriding or adding custom field type handler functions
// field = field value of the type to return a value to be validated
// example Valuer from sql drive see https://golang.org/src/database/sql/driver/types.go?s=1210:1293#L29
type CustomTypeFunc func(field reflect.Value) any
// TagNameFunc allows for adding of a custom tag name parser
type TagNameFunc func(field reflect.StructField) string
type internalValidationFuncWrapper struct {
fn FuncCtx
runValidatinOnNil bool
}
// Validate contains the validator settings and cache
type Validate struct {
tagName string
pool *sync.Pool
tagNameFunc TagNameFunc
structLevelFuncs map[reflect.Type]StructLevelFuncCtx
customFuncs map[reflect.Type]CustomTypeFunc
aliases map[string]string
validations map[string]internalValidationFuncWrapper
rules map[reflect.Type]map[string]string
tagCache *tagCache
structCache *structCache
hasCustomFuncs bool
hasTagNameFunc bool
requiredStructEnabled bool
}
// New returns a new instance of 'validate' with sane defaults.
// Validate is designed to be thread-safe and used as a singleton instance.
// It caches information about your struct and validations,
// in essence only parsing your validation tags once per struct type.
// Using multiple instances neglects the benefit of caching.
func New(options ...Option) *Validate {
tc := new(tagCache)
tc.m.Store(make(map[string]*cTag))
sc := new(structCache)
sc.m.Store(make(map[reflect.Type]*cStruct))
v := &Validate{
tagName: defaultTagName,
aliases: make(map[string]string, len(bakedInAliases)),
validations: make(map[string]internalValidationFuncWrapper, len(bakedInValidators)),
tagCache: tc,
structCache: sc,
}
// must copy alias validators for separate validations to be used in each validator instance
for k, val := range bakedInAliases {
v.RegisterAlias(k, val)
}
// must copy validators for separate validations to be used in each instance
for k, val := range bakedInValidators {
switch k {
// these require that even if the value is nil that the validation should run, omitempty still overrides this behaviour
case requiredIfTag, requiredUnlessTag, requiredWithTag, requiredWithAllTag, requiredWithoutTag, requiredWithoutAllTag,
excludedIfTag, excludedUnlessTag, excludedWithTag, excludedWithAllTag, excludedWithoutTag, excludedWithoutAllTag,
skipUnlessTag:
_ = v.registerValidation(k, wrapFunc(val), true, true)
default:
// no need to error check here, baked in will always be valid
_ = v.registerValidation(k, wrapFunc(val), true, false)
}
}
v.pool = &sync.Pool{
New: func() any {
return &validate{
v: v,
ns: make([]byte, 0, 64),
actualNs: make([]byte, 0, 64),
misc: make([]byte, 32),
}
},
}
for _, o := range options {
o(v)
}
return v
}
// SetTagName allows for changing of the default tag name of 'validate'
func (v *Validate) SetTagName(name string) {
v.tagName = name
}
// ValidateMapCtx validates a map using a map of validation rules and allows passing of contextual
// validation information via context.Context.
func (v Validate) ValidateMapCtx(ctx context.Context, data map[string]any, rules map[string]any) map[string]any {
errs := make(map[string]any)
for field, rule := range rules {
if ruleObj, ok := rule.(map[string]any); ok {
if dataObj, ok := data[field].(map[string]any); ok {
err := v.ValidateMapCtx(ctx, dataObj, ruleObj)
if len(err) > 0 {
errs[field] = err
}
} else if dataObjs, ok := data[field].([]map[string]any); ok {
for _, obj := range dataObjs {
err := v.ValidateMapCtx(ctx, obj, ruleObj)
if len(err) > 0 {
errs[field] = err
}
}
} else {
errs[field] = errors.New("The field: '" + field + "' is not a map to dive")
}
} else if ruleStr, ok := rule.(string); ok {
err := v.VarCtx(ctx, data[field], ruleStr)
if err != nil {
errs[field] = err
}
}
}
return errs
}
// ValidateMap validates map data from a map of tags
func (v *Validate) ValidateMap(data map[string]any, rules map[string]any) map[string]any {
return v.ValidateMapCtx(context.Background(), data, rules)
}
// RegisterTagNameFunc registers a function to get alternate names for StructFields.
//
// eg. to use the names which have been specified for JSON representations of structs, rather than normal Go field names:
//
// validate.RegisterTagNameFunc(func(fld reflect.StructField) string {
// name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0]
// // skip if tag key says it should be ignored
// if name == "-" {
// return ""
// }
// return name
// })
func (v *Validate) RegisterTagNameFunc(fn TagNameFunc) {
v.tagNameFunc = fn
v.hasTagNameFunc = true
}
// RegisterValidation adds a validation with the given tag
//
// NOTES:
// - if the key already exists, the previous validation function will be replaced.
// - this method is not thread-safe it is intended that these all be registered prior to any validation
func (v *Validate) RegisterValidation(tag string, fn Func, callValidationEvenIfNull ...bool) error {
return v.RegisterValidationCtx(tag, wrapFunc(fn), callValidationEvenIfNull...)
}
// RegisterValidationCtx does the same as RegisterValidation on accepts a FuncCtx validation
// allowing context.Context validation support.
func (v *Validate) RegisterValidationCtx(tag string, fn FuncCtx, callValidationEvenIfNull ...bool) error {
var nilCheckable bool
if len(callValidationEvenIfNull) > 0 {
nilCheckable = callValidationEvenIfNull[0]
}
return v.registerValidation(tag, fn, false, nilCheckable)
}
func (v *Validate) registerValidation(tag string, fn FuncCtx, bakedIn bool, nilCheckable bool) error {
if len(tag) == 0 {
return errors.New("function Key cannot be empty")
}
if fn == nil {
return errors.New("function cannot be empty")
}
_, ok := restrictedTags[tag]
if !bakedIn && (ok || strings.ContainsAny(tag, restrictedTagChars)) {
panic(fmt.Sprintf(restrictedTagErr, tag))
}
v.validations[tag] = internalValidationFuncWrapper{fn: fn, runValidatinOnNil: nilCheckable}
return nil
}
// RegisterAlias registers a mapping of a single validation tag that
// defines a common or complex set of validation(s) to simplify adding validation
// to structs.
//
// NOTE: this function is not thread-safe it is intended that these all be registered prior to any validation
func (v *Validate) RegisterAlias(alias, tags string) {
_, ok := restrictedTags[alias]
if ok || strings.ContainsAny(alias, restrictedTagChars) {
panic(fmt.Sprintf(restrictedAliasErr, alias))
}
v.aliases[alias] = tags
}
// RegisterStructValidation registers a StructLevelFunc against a number of types.
//
// NOTE:
// - this method is not thread-safe it is intended that these all be registered prior to any validation
func (v *Validate) RegisterStructValidation(fn StructLevelFunc, types ...any) {
v.RegisterStructValidationCtx(wrapStructLevelFunc(fn), types...)
}
// RegisterStructValidationCtx registers a StructLevelFuncCtx against a number of types and allows passing
// of contextual validation information via context.Context.
//
// NOTE:
// - this method is not thread-safe it is intended that these all be registered prior to any validation
func (v *Validate) RegisterStructValidationCtx(fn StructLevelFuncCtx, types ...any) {
if v.structLevelFuncs == nil {
v.structLevelFuncs = make(map[reflect.Type]StructLevelFuncCtx)
}
for _, t := range types {
tv := reflect.ValueOf(t)
if tv.Kind() == reflect.Ptr {
t = reflect.Indirect(tv).Interface()
}
v.structLevelFuncs[reflect.TypeOf(t)] = fn
}
}
// RegisterStructValidationMapRules registers validate map rules.
// Be aware that map validation rules supersede those defined on a/the struct if present.
//
// NOTE: this method is not thread-safe it is intended that these all be registered prior to any validation
func (v *Validate) RegisterStructValidationMapRules(rules map[string]string, types ...any) {
if v.rules == nil {
v.rules = make(map[reflect.Type]map[string]string)
}
deepCopyRules := make(map[string]string)
for i, rule := range rules {
deepCopyRules[i] = rule
}
for _, t := range types {
typ := reflect.TypeOf(t)
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}
if typ.Kind() != reflect.Struct {
continue
}
v.rules[typ] = deepCopyRules
}
}
// RegisterCustomTypeFunc registers a CustomTypeFunc against a number of types
//
// NOTE: this method is not thread-safe it is intended that these all be registered prior to any validation
func (v *Validate) RegisterCustomTypeFunc(fn CustomTypeFunc, types ...any) {
if v.customFuncs == nil {
v.customFuncs = make(map[reflect.Type]CustomTypeFunc)
}
for _, t := range types {
v.customFuncs[reflect.TypeOf(t)] = fn
}
v.hasCustomFuncs = true
}
// Struct validates a structs exposed fields, and automatically validates nested structs, unless otherwise specified.
//
// It returns InvalidValidationError for bad values passed in and nil or ValidationErrors as error otherwise.
// You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors.
func (v *Validate) Struct(s any) error {
return v.StructCtx(context.Background(), s)
}
// StructCtx validates a structs exposed fields, and automatically validates nested structs, unless otherwise specified
// and also allows passing of context.Context for contextual validation information.
//
// It returns InvalidValidationError for bad values passed in and nil or ValidationErrors as error otherwise.
// You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors.
func (v *Validate) StructCtx(ctx context.Context, s any) (err error) {
val := reflect.ValueOf(s)
top := val
if val.Kind() == reflect.Ptr && !val.IsNil() {
val = val.Elem()
}
if val.Kind() != reflect.Struct || val.Type().ConvertibleTo(timeType) {
return &InvalidValidationError{Type: reflect.TypeOf(s)}
}
// good to validate
vd := v.pool.Get().(*validate)
vd.top = top
vd.isPartial = false
// vd.hasExcludes = false // only need to reset in StructPartial and StructExcept
vd.validateStruct(ctx, top, val, val.Type(), vd.ns[0:0], vd.actualNs[0:0], nil)
if len(vd.errs) > 0 {
err = vd.errs
vd.errs = nil
}
v.pool.Put(vd)
return
}
// StructFiltered validates a structs exposed fields, that pass the FilterFunc check and automatically validates
// nested structs, unless otherwise specified.
//
// It returns InvalidValidationError for bad values passed in and nil or ValidationErrors as error otherwise.
// You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors.
func (v *Validate) StructFiltered(s any, fn FilterFunc) error {
return v.StructFilteredCtx(context.Background(), s, fn)
}
// StructFilteredCtx validates a structs exposed fields, that pass the FilterFunc check and automatically validates
// nested structs, unless otherwise specified and also allows passing of contextual validation information via
// context.Context
//
// It returns InvalidValidationError for bad values passed in and nil or ValidationErrors as error otherwise.
// You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors.
func (v *Validate) StructFilteredCtx(ctx context.Context, s any, fn FilterFunc) (err error) {
val := reflect.ValueOf(s)
top := val
if val.Kind() == reflect.Ptr && !val.IsNil() {
val = val.Elem()
}
if val.Kind() != reflect.Struct || val.Type().ConvertibleTo(timeType) {
return &InvalidValidationError{Type: reflect.TypeOf(s)}
}
// good to validate
vd := v.pool.Get().(*validate)
vd.top = top
vd.isPartial = true
vd.ffn = fn
// vd.hasExcludes = false // only need to reset in StructPartial and StructExcept
vd.validateStruct(ctx, top, val, val.Type(), vd.ns[0:0], vd.actualNs[0:0], nil)
if len(vd.errs) > 0 {
err = vd.errs
vd.errs = nil
}
v.pool.Put(vd)
return
}
// StructPartial validates the fields passed in only, ignoring all others.
// Fields may be provided in a namespaced fashion relative to the struct provided
// eg. NestedStruct.Field or NestedArrayField[0].Struct.Name
//
// It returns InvalidValidationError for bad values passed in and nil or ValidationErrors as error otherwise.
// You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors.
func (v *Validate) StructPartial(s any, fields ...string) error {
return v.StructPartialCtx(context.Background(), s, fields...)
}
// StructPartialCtx validates the fields passed in only, ignoring all others and allows passing of contextual
// validation information via context.Context
// Fields may be provided in a namespaced fashion relative to the struct provided
// eg. NestedStruct.Field or NestedArrayField[0].Struct.Name
//
// It returns InvalidValidationError for bad values passed in and nil or ValidationErrors as error otherwise.
// You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors.
func (v *Validate) StructPartialCtx(ctx context.Context, s any, fields ...string) (err error) {
val := reflect.ValueOf(s)
top := val
if val.Kind() == reflect.Ptr && !val.IsNil() {
val = val.Elem()
}
if val.Kind() != reflect.Struct || val.Type().ConvertibleTo(timeType) {
return &InvalidValidationError{Type: reflect.TypeOf(s)}
}
// good to validate
vd := v.pool.Get().(*validate)
vd.top = top
vd.isPartial = true
vd.ffn = nil
vd.hasExcludes = false
vd.includeExclude = make(map[string]struct{})
typ := val.Type()
name := typ.Name()
for _, k := range fields {
flds := strings.Split(k, namespaceSeparator)
if len(flds) > 0 {
vd.misc = append(vd.misc[0:0], name...)
// Don't append empty name for unnamed structs
if len(vd.misc) != 0 {
vd.misc = append(vd.misc, '.')
}
for _, s := range flds {
idx := strings.Index(s, leftBracket)
if idx != -1 {
for idx != -1 {
vd.misc = append(vd.misc, s[:idx]...)
vd.includeExclude[string(vd.misc)] = struct{}{}
idx2 := strings.Index(s, rightBracket)
idx2++
vd.misc = append(vd.misc, s[idx:idx2]...)
vd.includeExclude[string(vd.misc)] = struct{}{}
s = s[idx2:]
idx = strings.Index(s, leftBracket)
}
} else {
vd.misc = append(vd.misc, s...)
vd.includeExclude[string(vd.misc)] = struct{}{}
}
vd.misc = append(vd.misc, '.')
}
}
}
vd.validateStruct(ctx, top, val, typ, vd.ns[0:0], vd.actualNs[0:0], nil)
if len(vd.errs) > 0 {
err = vd.errs
vd.errs = nil
}
v.pool.Put(vd)
return
}
// StructExcept validates all fields except the ones passed in.
// Fields may be provided in a namespaced fashion relative to the struct provided
// i.e. NestedStruct.Field or NestedArrayField[0].Struct.Name
//
// It returns InvalidValidationError for bad values passed in and nil or ValidationErrors as error otherwise.
// You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors.
func (v *Validate) StructExcept(s any, fields ...string) error {
return v.StructExceptCtx(context.Background(), s, fields...)
}
// StructExceptCtx validates all fields except the ones passed in and allows passing of contextual
// validation information via context.Context
// Fields may be provided in a namespaced fashion relative to the struct provided
// i.e. NestedStruct.Field or NestedArrayField[0].Struct.Name
//
// It returns InvalidValidationError for bad values passed in and nil or ValidationErrors as error otherwise.
// You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors.
func (v *Validate) StructExceptCtx(ctx context.Context, s any, fields ...string) (err error) {
val := reflect.ValueOf(s)
top := val
if val.Kind() == reflect.Ptr && !val.IsNil() {
val = val.Elem()
}
if val.Kind() != reflect.Struct || val.Type().ConvertibleTo(timeType) {
return &InvalidValidationError{Type: reflect.TypeOf(s)}
}
// good to validate
vd := v.pool.Get().(*validate)
vd.top = top
vd.isPartial = true
vd.ffn = nil
vd.hasExcludes = true
vd.includeExclude = make(map[string]struct{})
typ := val.Type()
name := typ.Name()
for _, key := range fields {
vd.misc = vd.misc[0:0]
if len(name) > 0 {
vd.misc = append(vd.misc, name...)
vd.misc = append(vd.misc, '.')
}
vd.misc = append(vd.misc, key...)
vd.includeExclude[string(vd.misc)] = struct{}{}
}
vd.validateStruct(ctx, top, val, typ, vd.ns[0:0], vd.actualNs[0:0], nil)
if len(vd.errs) > 0 {
err = vd.errs
vd.errs = nil
}
v.pool.Put(vd)
return
}
// Var validates a single variable using tag style validation.
// eg.
// var i int
// validate.Var(i, "gt=1,lt=10")
//
// WARNING: a struct can be passed for validation eg. time.Time is a struct or
// if you have a custom type and have registered a custom type handler, so must
// allow it; however unforeseen validations will occur if trying to validate a
// struct that is meant to be passed to 'validate.Struct'
//
// It returns InvalidValidationError for bad values passed in and nil or ValidationErrors as error otherwise.
// You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors.
// validate Array, Slice and maps fields which may contain more than one error
func (v *Validate) Var(field any, tag string) error {
return v.VarCtx(context.Background(), field, tag)
}
// VarCtx validates a single variable using tag style validation and allows passing of contextual
// validation information via context.Context.
// eg.
// var i int
// validate.Var(i, "gt=1,lt=10")
//
// WARNING: a struct can be passed for validation eg. time.Time is a struct or
// if you have a custom type and have registered a custom type handler, so must
// allow it; however unforeseen validations will occur if trying to validate a
// struct that is meant to be passed to 'validate.Struct'
//
// It returns InvalidValidationError for bad values passed in and nil or ValidationErrors as error otherwise.
// You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors.
// validate Array, Slice and maps fields which may contain more than one error
func (v *Validate) VarCtx(ctx context.Context, field any, tag string) (err error) {
if len(tag) == 0 || tag == skipValidationTag {
return nil
}
ctag := v.fetchCacheTag(tag)
val := reflect.ValueOf(field)
vd := v.pool.Get().(*validate)
vd.top = val
vd.isPartial = false
vd.traverseField(ctx, val, val, vd.ns[0:0], vd.actualNs[0:0], defaultCField, ctag)
if len(vd.errs) > 0 {
err = vd.errs
vd.errs = nil
}
v.pool.Put(vd)
return
}
// VarWithValue validates a single variable, against another variable/field's value using tag style validation
// eg.
// s1 := "abcd"
// s2 := "abcd"
// validate.VarWithValue(s1, s2, "eqcsfield") // returns true
//
// WARNING: a struct can be passed for validation eg. time.Time is a struct or
// if you have a custom type and have registered a custom type handler, so must
// allow it; however unforeseen validations will occur if trying to validate a
// struct that is meant to be passed to 'validate.Struct'
//
// It returns InvalidValidationError for bad values passed in and nil or ValidationErrors as error otherwise.
// You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors.
// validate Array, Slice and maps fields which may contain more than one error
func (v *Validate) VarWithValue(field any, other any, tag string) error {
return v.VarWithValueCtx(context.Background(), field, other, tag)
}
// VarWithValueCtx validates a single variable, against another variable/field's value using tag style validation and
// allows passing of contextual validation validation information via context.Context.
// eg.
// s1 := "abcd"
// s2 := "abcd"
// validate.VarWithValue(s1, s2, "eqcsfield") // returns true
//
// WARNING: a struct can be passed for validation eg. time.Time is a struct or
// if you have a custom type and have registered a custom type handler, so must
// allow it; however unforeseen validations will occur if trying to validate a
// struct that is meant to be passed to 'validate.Struct'
//
// It returns InvalidValidationError for bad values passed in and nil or ValidationErrors as error otherwise.
// You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors.
// validate Array, Slice and maps fields which may contain more than one error
func (v *Validate) VarWithValueCtx(ctx context.Context, field any, other any, tag string) (err error) {
if len(tag) == 0 || tag == skipValidationTag {
return nil
}
ctag := v.fetchCacheTag(tag)
otherVal := reflect.ValueOf(other)
vd := v.pool.Get().(*validate)
vd.top = otherVal
vd.isPartial = false
vd.traverseField(ctx, otherVal, reflect.ValueOf(field), vd.ns[0:0], vd.actualNs[0:0], defaultCField, ctag)
if len(vd.errs) > 0 {
err = vd.errs
vd.errs = nil
}
v.pool.Put(vd)
return
}
================================================
FILE: gossh/gin/version.go
================================================
// Copyright 2018 Gin Core Team. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package gin
// Version is the current gin framework's version.
const Version = "v1.9.1"
================================================
FILE: gossh/go.mod
================================================
module gossh
go 1.23
require (
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/glebarez/go-sqlite v1.21.2 // indirect
github.com/glebarez/sqlite v1.11.0 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/mattn/go-isatty v0.0.17 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
golang.org/x/sys v0.7.0 // indirect
gorm.io/gorm v1.25.7 // indirect
modernc.org/libc v1.22.5 // indirect
modernc.org/mathutil v1.5.0 // indirect
modernc.org/memory v1.5.0 // indirect
modernc.org/sqlite v1.23.1 // indirect
)
================================================
FILE: gossh/gorm/association.go
================================================
package gorm
import (
"fmt"
"reflect"
"strings"
"gossh/gorm/clause"
"gossh/gorm/schema"
"gossh/gorm/utils"
)
// Association Mode contains some helper methods to handle relationship things easily.
type Association struct {
DB *DB
Relationship *schema.Relationship
Unscope bool
Error error
}
func (db *DB) Association(column string) *Association {
association := &Association{DB: db}
table := db.Statement.Table
if err := db.Statement.Parse(db.Statement.Model); err == nil {
db.Statement.Table = table
association.Relationship = db.Statement.Schema.Relationships.Relations[column]
if association.Relationship == nil {
association.Error = fmt.Errorf("%w: %s", ErrUnsupportedRelation, column)
}
db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model)
for db.Statement.ReflectValue.Kind() == reflect.Ptr {
db.Statement.ReflectValue = db.Statement.ReflectValue.Elem()
}
} else {
association.Error = err
}
return association
}
func (association *Association) Unscoped() *Association {
return &Association{
DB: association.DB,
Relationship: association.Relationship,
Error: association.Error,
Unscope: true,
}
}
func (association *Association) Find(out any, conds ...any) error {
if association.Error == nil {
association.Error = association.buildCondition().Find(out, conds...).Error
}
return association.Error
}
func (association *Association) Append(values ...any) error {
if association.Error == nil {
switch association.Relationship.Type {
case schema.HasOne, schema.BelongsTo:
if len(values) > 0 {
association.Error = association.Replace(values...)
}
default:
association.saveAssociation( /*clear*/ false, values...)
}
}
return association.Error
}
func (association *Association) Replace(values ...any) error {
if association.Error == nil {
reflectValue := association.DB.Statement.ReflectValue
rel := association.Relationship
var oldBelongsToExpr clause.Expression
// we have to record the old BelongsTo value
if association.Unscope && rel.Type == schema.BelongsTo {
var foreignFields []*schema.Field
for _, ref := range rel.References {
if !ref.OwnPrimaryKey {
foreignFields = append(foreignFields, ref.ForeignKey)
}
}
if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 {
column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs)
oldBelongsToExpr = clause.IN{Column: column, Values: values}
}
}
// save associations
if association.saveAssociation( /*clear*/ true, values...); association.Error != nil {
return association.Error
}
// set old associations's foreign key to null
switch rel.Type {
case schema.BelongsTo:
if len(values) == 0 {
updateMap := map[string]any{}
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface())
}
case reflect.Struct:
association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(rel.Field.FieldType).Interface())
}
for _, ref := range rel.References {
updateMap[ref.ForeignKey.DBName] = nil
}
association.Error = association.DB.UpdateColumns(updateMap).Error
}
if association.Unscope && oldBelongsToExpr != nil {
association.Error = association.DB.Model(nil).Where(oldBelongsToExpr).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error
}
case schema.HasOne, schema.HasMany:
var (
primaryFields []*schema.Field
foreignKeys []string
updateMap = map[string]any{}
relValues = schema.GetRelationsValues(association.DB.Statement.Context, reflectValue, []*schema.Relationship{rel})
modelValue = reflect.New(rel.FieldSchema.ModelType).Interface()
tx = association.DB.Model(modelValue)
)
if _, rvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 {
if column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 {
tx.Not(clause.IN{Column: column, Values: values})
}
}
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
primaryFields = append(primaryFields, ref.PrimaryKey)
foreignKeys = append(foreignKeys, ref.ForeignKey.DBName)
updateMap[ref.ForeignKey.DBName] = nil
} else if ref.PrimaryValue != "" {
tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
}
}
if _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields); len(pvs) > 0 {
column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
if association.Unscope {
association.Error = tx.Where(clause.IN{Column: column, Values: values}).Delete(modelValue).Error
} else {
association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error
}
}
case schema.Many2Many:
var (
primaryFields, relPrimaryFields []*schema.Field
joinPrimaryKeys, joinRelPrimaryKeys []string
modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
tx = association.DB.Model(modelValue)
)
for _, ref := range rel.References {
if ref.PrimaryValue == "" {
if ref.OwnPrimaryKey {
primaryFields = append(primaryFields, ref.PrimaryKey)
joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName)
} else {
relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey)
joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName)
}
} else {
tx.Clauses(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
}
}
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 {
tx.Where(clause.IN{Column: column, Values: values})
} else {
return ErrPrimaryKeyRequired
}
_, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields)
if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 {
tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues}))
}
association.Error = tx.Delete(modelValue).Error
}
}
return association.Error
}
func (association *Association) Delete(values ...any) error {
if association.Error == nil {
var (
reflectValue = association.DB.Statement.ReflectValue
rel = association.Relationship
primaryFields []*schema.Field
foreignKeys []string
updateAttrs = map[string]any{}
conds []clause.Expression
)
for _, ref := range rel.References {
if ref.PrimaryValue == "" {
primaryFields = append(primaryFields, ref.PrimaryKey)
foreignKeys = append(foreignKeys, ref.ForeignKey.DBName)
updateAttrs[ref.ForeignKey.DBName] = nil
} else {
conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
}
}
switch rel.Type {
case schema.BelongsTo:
associationDB := association.DB.Session(&Session{})
tx := associationDB.Model(reflect.New(rel.Schema.ModelType).Interface())
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, rel.Schema.PrimaryFields)
if pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs); len(pvalues) > 0 {
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
} else {
return ErrPrimaryKeyRequired
}
_, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, primaryFields)
relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
if association.Unscope {
var foreignFields []*schema.Field
for _, ref := range rel.References {
if !ref.OwnPrimaryKey {
foreignFields = append(foreignFields, ref.ForeignKey)
}
}
if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 {
column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs)
association.Error = associationDB.Model(nil).Where(clause.IN{Column: column, Values: values}).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error
}
}
case schema.HasOne, schema.HasMany:
model := reflect.New(rel.FieldSchema.ModelType).Interface()
tx := association.DB.Model(model)
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
if pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs); len(pvalues) > 0 {
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
} else {
return ErrPrimaryKeyRequired
}
_, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields)
relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
if association.Unscope {
association.Error = tx.Clauses(conds...).Delete(model).Error
} else {
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
}
case schema.Many2Many:
var (
primaryFields, relPrimaryFields []*schema.Field
joinPrimaryKeys, joinRelPrimaryKeys []string
joinValue = reflect.New(rel.JoinTable.ModelType).Interface()
)
for _, ref := range rel.References {
if ref.PrimaryValue == "" {
if ref.OwnPrimaryKey {
primaryFields = append(primaryFields, ref.PrimaryKey)
joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName)
} else {
relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey)
joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName)
}
} else {
conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
}
}
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
if pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(pvalues) > 0 {
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
} else {
return ErrPrimaryKeyRequired
}
_, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields)
relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(joinValue).Error
}
if association.Error == nil {
// clean up deleted values's foreign key
relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields)
cleanUpDeletedRelations := func(data reflect.Value) {
if _, zero := rel.Field.ValueOf(association.DB.Statement.Context, data); !zero {
fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(association.DB.Statement.Context, data))
primaryValues := make([]any, len(rel.FieldSchema.PrimaryFields))
switch fieldValue.Kind() {
case reflect.Slice, reflect.Array:
validFieldValues := reflect.Zero(rel.Field.IndirectFieldType)
for i := 0; i < fieldValue.Len(); i++ {
for idx, field := range rel.FieldSchema.PrimaryFields {
primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue.Index(i))
}
if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; !ok {
validFieldValues = reflect.Append(validFieldValues, fieldValue.Index(i))
}
}
association.Error = rel.Field.Set(association.DB.Statement.Context, data, validFieldValues.Interface())
case reflect.Struct:
for idx, field := range rel.FieldSchema.PrimaryFields {
primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue)
}
if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; ok {
if association.Error = rel.Field.Set(association.DB.Statement.Context, data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil {
break
}
if rel.JoinTable == nil {
for _, ref := range rel.References {
if ref.OwnPrimaryKey || ref.PrimaryValue != "" {
association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
} else {
association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, data, reflect.Zero(ref.ForeignKey.FieldType).Interface())
}
}
}
}
}
}
}
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
cleanUpDeletedRelations(reflect.Indirect(reflectValue.Index(i)))
}
case reflect.Struct:
cleanUpDeletedRelations(reflectValue)
}
}
}
return association.Error
}
func (association *Association) Clear() error {
return association.Replace()
}
func (association *Association) Count() (count int64) {
if association.Error == nil {
association.Error = association.buildCondition().Count(&count).Error
}
return
}
type assignBack struct {
Source reflect.Value
Index int
Dest reflect.Value
}
func (association *Association) saveAssociation(clear bool, values ...any) {
var (
reflectValue = association.DB.Statement.ReflectValue
assignBacks []assignBack // assign association values back to arguments after save
)
appendToRelations := func(source, rv reflect.Value, clear bool) {
switch association.Relationship.Type {
case schema.HasOne, schema.BelongsTo:
switch rv.Kind() {
case reflect.Slice, reflect.Array:
if rv.Len() > 0 {
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Index(0).Addr().Interface())
if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv.Index(0)})
}
}
case reflect.Struct:
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Addr().Interface())
if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv})
}
}
case schema.HasMany, schema.Many2Many:
elemType := association.Relationship.Field.IndirectFieldType.Elem()
oldFieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, source))
var fieldValue reflect.Value
if clear {
fieldValue = reflect.MakeSlice(oldFieldValue.Type(), 0, oldFieldValue.Cap())
} else {
fieldValue = reflect.MakeSlice(oldFieldValue.Type(), oldFieldValue.Len(), oldFieldValue.Cap())
reflect.Copy(fieldValue, oldFieldValue)
}
appendToFieldValues := func(ev reflect.Value) {
if ev.Type().AssignableTo(elemType) {
fieldValue = reflect.Append(fieldValue, ev)
} else if ev.Type().Elem().AssignableTo(elemType) {
fieldValue = reflect.Append(fieldValue, ev.Elem())
} else {
association.Error = fmt.Errorf("unsupported data type: %v for relation %s", ev.Type(), association.Relationship.Name)
}
if elemType.Kind() == reflect.Struct {
assignBacks = append(assignBacks, assignBack{Source: source, Dest: ev, Index: fieldValue.Len()})
}
}
switch rv.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < rv.Len(); i++ {
appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr())
}
case reflect.Struct:
appendToFieldValues(rv.Addr())
}
if association.Error == nil {
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, fieldValue.Interface())
}
}
}
selectedSaveColumns := []string{association.Relationship.Name}
omitColumns := []string{}
selectColumns, _ := association.DB.Statement.SelectAndOmitColumns(true, false)
for name, ok := range selectColumns {
columnName := ""
if strings.HasPrefix(name, association.Relationship.Name) {
if columnName = strings.TrimPrefix(name, association.Relationship.Name); columnName == ".*" {
columnName = name
}
} else if strings.HasPrefix(name, clause.Associations) {
columnName = name
}
if columnName != "" {
if ok {
selectedSaveColumns = append(selectedSaveColumns, columnName)
} else {
omitColumns = append(omitColumns, columnName)
}
}
}
for _, ref := range association.Relationship.References {
if !ref.OwnPrimaryKey {
selectedSaveColumns = append(selectedSaveColumns, ref.ForeignKey.Name)
}
}
associationDB := association.DB.Session(&Session{}).Model(nil)
if !association.DB.FullSaveAssociations {
associationDB.Select(selectedSaveColumns)
}
if len(omitColumns) > 0 {
associationDB.Omit(omitColumns...)
}
associationDB = associationDB.Session(&Session{})
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
if len(values) != reflectValue.Len() {
// clear old data
if clear && len(values) == 0 {
for i := 0; i < reflectValue.Len(); i++ {
if err := association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil {
association.Error = err
break
}
if association.Relationship.JoinTable == nil {
for _, ref := range association.Relationship.References {
if !ref.OwnPrimaryKey && ref.PrimaryValue == "" {
if err := ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil {
association.Error = err
break
}
}
}
}
}
break
}
association.Error = ErrInvalidValueOfLength
return
}
for i := 0; i < reflectValue.Len(); i++ {
appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear)
// TODO support save slice data, sql with case?
association.Error = associationDB.Updates(reflectValue.Index(i).Addr().Interface()).Error
}
case reflect.Struct:
// clear old data
if clear && len(values) == 0 {
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
if association.Relationship.JoinTable == nil && association.Error == nil {
for _, ref := range association.Relationship.References {
if !ref.OwnPrimaryKey && ref.PrimaryValue == "" {
association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
}
}
}
}
for idx, value := range values {
rv := reflect.Indirect(reflect.ValueOf(value))
appendToRelations(reflectValue, rv, clear && idx == 0)
}
if len(values) > 0 {
association.Error = associationDB.Updates(reflectValue.Addr().Interface()).Error
}
}
for _, assignBack := range assignBacks {
fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, assignBack.Source))
if assignBack.Index > 0 {
reflect.Indirect(assignBack.Dest).Set(fieldValue.Index(assignBack.Index - 1))
} else {
reflect.Indirect(assignBack.Dest).Set(fieldValue)
}
}
}
func (association *Association) buildCondition() *DB {
var (
queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.Context, association.DB.Statement.ReflectValue)
modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface()
tx = association.DB.Model(modelValue)
)
if association.Relationship.JoinTable != nil {
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
joinStmt := Statement{DB: tx, Context: tx.Statement.Context, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
joinStmt.AddClause(queryClause)
}
joinStmt.Build("WHERE")
if len(joinStmt.SQL.String()) > 0 {
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
}
}
tx = tx.Session(&Session{QueryFields: true}).Clauses(clause.From{Joins: []clause.Join{{
Table: clause.Table{Name: association.Relationship.JoinTable.Table},
ON: clause.Where{Exprs: queryConds},
}}})
} else {
tx.Clauses(clause.Where{Exprs: queryConds})
}
return tx
}
================================================
FILE: gossh/gorm/callbacks/associations.go
================================================
package callbacks
import (
"reflect"
"strings"
"gossh/gorm"
"gossh/gorm/clause"
"gossh/gorm/schema"
"gossh/gorm/utils"
)
func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
return func(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil {
selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create)
// Save Belongs To associations
for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
continue
}
setupReferences := func(obj reflect.Value, elem reflect.Value) {
for _, ref := range rel.References {
if !ref.OwnPrimaryKey {
pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, obj, pv))
if dest, ok := db.Statement.Dest.(map[string]any); ok {
dest[ref.ForeignKey.DBName] = pv
if _, ok := dest[rel.Name]; ok {
dest[rel.Name] = elem.Interface()
}
}
}
}
}
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
var (
rValLen = db.Statement.ReflectValue.Len()
objs = make([]reflect.Value, 0, rValLen)
fieldType = rel.Field.FieldType
isPtr = fieldType.Kind() == reflect.Ptr
)
if !isPtr {
fieldType = reflect.PtrTo(fieldType)
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
identityMap := map[string]bool{}
for i := 0; i < rValLen; i++ {
obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() != reflect.Struct {
break
}
if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { // check belongs to relation value
rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) // relation reflect value
if !isPtr {
rv = rv.Addr()
}
objs = append(objs, obj)
elems = reflect.Append(elems, rv)
relPrimaryValues := make([]any, 0, len(rel.FieldSchema.PrimaryFields))
for _, pf := range rel.FieldSchema.PrimaryFields {
if pfv, ok := pf.ValueOf(db.Statement.Context, rv); !ok {
relPrimaryValues = append(relPrimaryValues, pfv)
}
}
cacheKey := utils.ToStringKey(relPrimaryValues...)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
if cacheKey != "" { // has primary fields
identityMap[cacheKey] = true
}
distinctElems = reflect.Append(distinctElems, rv)
}
}
}
if elems.Len() > 0 {
if saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil) == nil {
for i := 0; i < elems.Len(); i++ {
setupReferences(objs[i], elems.Index(i))
}
}
}
case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero {
rv := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) // relation reflect value
if rv.Kind() != reflect.Ptr {
rv = rv.Addr()
}
if saveAssociations(db, rel, rv, selectColumns, restricted, nil) == nil {
setupReferences(db.Statement.ReflectValue, rv)
}
}
}
}
}
}
}
func SaveAfterAssociations(create bool) func(db *gorm.DB) {
return func(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil {
selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create)
// Save Has One associations
for _, rel := range db.Statement.Schema.Relationships.HasOne {
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
continue
}
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
var (
fieldType = rel.Field.FieldType
isPtr = fieldType.Kind() == reflect.Ptr
)
if !isPtr {
fieldType = reflect.PtrTo(fieldType)
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() == reflect.Struct {
if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero {
rv := rel.Field.ReflectValueOf(db.Statement.Context, obj)
if rv.Kind() != reflect.Ptr {
rv = rv.Addr()
}
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, fv))
} else if ref.PrimaryValue != "" {
db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, ref.PrimaryValue))
}
}
elems = reflect.Append(elems, rv)
}
}
}
if elems.Len() > 0 {
assignmentColumns := make([]string, 0, len(rel.References))
for _, ref := range rel.References {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
}
saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns)
}
case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero {
f := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue)
if f.Kind() != reflect.Ptr {
f = f.Addr()
}
assignmentColumns := make([]string, 0, len(rel.References))
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, fv))
} else if ref.PrimaryValue != "" {
db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, ref.PrimaryValue))
}
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
}
saveAssociations(db, rel, f, selectColumns, restricted, assignmentColumns)
}
}
}
// Save Has Many associations
for _, rel := range db.Statement.Schema.Relationships.HasMany {
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
continue
}
fieldType := rel.Field.IndirectFieldType.Elem()
isPtr := fieldType.Kind() == reflect.Ptr
if !isPtr {
fieldType = reflect.PtrTo(fieldType)
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
identityMap := map[string]bool{}
appendToElems := func(v reflect.Value) {
if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero {
f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v))
for i := 0; i < f.Len(); i++ {
elem := f.Index(i)
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, v)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, pv))
} else if ref.PrimaryValue != "" {
db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, ref.PrimaryValue))
}
}
relPrimaryValues := make([]any, 0, len(rel.FieldSchema.PrimaryFields))
for _, pf := range rel.FieldSchema.PrimaryFields {
if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok {
relPrimaryValues = append(relPrimaryValues, pfv)
}
}
cacheKey := utils.ToStringKey(relPrimaryValues...)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
if cacheKey != "" { // has primary fields
identityMap[cacheKey] = true
}
if isPtr {
elems = reflect.Append(elems, elem)
} else {
elems = reflect.Append(elems, elem.Addr())
}
}
}
}
}
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() == reflect.Struct {
appendToElems(obj)
}
}
case reflect.Struct:
appendToElems(db.Statement.ReflectValue)
}
if elems.Len() > 0 {
assignmentColumns := make([]string, 0, len(rel.References))
for _, ref := range rel.References {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
}
saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns)
}
}
// Save Many2Many associations
for _, rel := range db.Statement.Schema.Relationships.Many2Many {
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
continue
}
fieldType := rel.Field.IndirectFieldType.Elem()
isPtr := fieldType.Kind() == reflect.Ptr
if !isPtr {
fieldType = reflect.PtrTo(fieldType)
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10)
objs := []reflect.Value{}
appendToJoins := func(obj reflect.Value, elem reflect.Value) {
joinValue := reflect.New(rel.JoinTable.ModelType)
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv))
} else if ref.PrimaryValue != "" {
db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, ref.PrimaryValue))
} else {
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv))
}
}
joins = reflect.Append(joins, joinValue)
}
identityMap := map[string]bool{}
appendToElems := func(v reflect.Value) {
if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero {
f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v))
for i := 0; i < f.Len(); i++ {
elem := f.Index(i)
if !isPtr {
elem = elem.Addr()
}
objs = append(objs, v)
elems = reflect.Append(elems, elem)
relPrimaryValues := make([]any, 0, len(rel.FieldSchema.PrimaryFields))
for _, pf := range rel.FieldSchema.PrimaryFields {
if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok {
relPrimaryValues = append(relPrimaryValues, pfv)
}
}
cacheKey := utils.ToStringKey(relPrimaryValues...)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
if cacheKey != "" { // has primary fields
identityMap[cacheKey] = true
}
distinctElems = reflect.Append(distinctElems, elem)
}
}
}
}
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() == reflect.Struct {
appendToElems(obj)
}
}
case reflect.Struct:
appendToElems(db.Statement.ReflectValue)
}
// optimize elems of reflect value length
if elemLen := elems.Len(); elemLen > 0 {
if v, ok := selectColumns[rel.Name+".*"]; !ok || v {
saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil)
}
for i := 0; i < elemLen; i++ {
appendToJoins(objs[i], elems.Index(i))
}
}
if joins.Len() > 0 {
db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(clause.OnConflict{DoNothing: true}).Session(&gorm.Session{
SkipHooks: db.Statement.SkipHooks,
DisableNestedTransaction: true,
}).Create(joins.Interface()).Error)
}
}
}
}
}
func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingColumns []string) (onConflict clause.OnConflict) {
if len(defaultUpdatingColumns) > 0 || stmt.DB.FullSaveAssociations {
onConflict.Columns = make([]clause.Column, 0, len(s.PrimaryFieldDBNames))
for _, dbName := range s.PrimaryFieldDBNames {
onConflict.Columns = append(onConflict.Columns, clause.Column{Name: dbName})
}
onConflict.UpdateAll = stmt.DB.FullSaveAssociations
if !onConflict.UpdateAll {
onConflict.DoUpdates = clause.AssignmentColumns(defaultUpdatingColumns)
}
} else {
onConflict.DoNothing = true
}
return
}
func saveAssociations(db *gorm.DB, rel *schema.Relationship, rValues reflect.Value, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error {
// stop save association loop
if checkAssociationsSaved(db, rValues) {
return nil
}
var (
selects, omits []string
onConflict = onConflictOption(db.Statement, rel.FieldSchema, defaultUpdatingColumns)
refName = rel.Name + "."
values = rValues.Interface()
)
for name, ok := range selectColumns {
columnName := ""
if strings.HasPrefix(name, refName) {
columnName = strings.TrimPrefix(name, refName)
}
if columnName != "" {
if ok {
selects = append(selects, columnName)
} else {
omits = append(omits, columnName)
}
}
}
tx := db.Session(&gorm.Session{NewDB: true}).Clauses(onConflict).Session(&gorm.Session{
FullSaveAssociations: db.FullSaveAssociations,
SkipHooks: db.Statement.SkipHooks,
DisableNestedTransaction: true,
})
db.Statement.Settings.Range(func(k, v any) bool {
tx.Statement.Settings.Store(k, v)
return true
})
if tx.Statement.FullSaveAssociations {
tx = tx.Set("gorm:update_track_time", true)
}
if len(selects) > 0 {
tx = tx.Select(selects)
} else if restricted && len(omits) == 0 {
tx = tx.Omit(clause.Associations)
}
if len(omits) > 0 {
tx = tx.Omit(omits...)
}
return db.AddError(tx.Create(values).Error)
}
// check association values has been saved
// if values kind is Struct, check it has been saved
// if values kind is Slice/Array, check all items have been saved
var visitMapStoreKey = "gorm:saved_association_map"
func checkAssociationsSaved(db *gorm.DB, values reflect.Value) bool {
if visit, ok := db.Get(visitMapStoreKey); ok {
if v, ok := visit.(*visitMap); ok {
if loadOrStoreVisitMap(v, values) {
return true
}
}
} else {
vistMap := make(visitMap)
loadOrStoreVisitMap(&vistMap, values)
db.Set(visitMapStoreKey, &vistMap)
}
return false
}
================================================
FILE: gossh/gorm/callbacks/callbacks.go
================================================
package callbacks
import (
"gossh/gorm"
)
var (
createClauses = []string{"INSERT", "VALUES", "ON CONFLICT"}
queryClauses = []string{"SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR"}
updateClauses = []string{"UPDATE", "SET", "WHERE"}
deleteClauses = []string{"DELETE", "FROM", "WHERE"}
)
type Config struct {
LastInsertIDReversed bool
CreateClauses []string
QueryClauses []string
UpdateClauses []string
DeleteClauses []string
}
func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
enableTransaction := func(db *gorm.DB) bool {
return !db.SkipDefaultTransaction
}
if len(config.CreateClauses) == 0 {
config.CreateClauses = createClauses
}
if len(config.QueryClauses) == 0 {
config.QueryClauses = queryClauses
}
if len(config.DeleteClauses) == 0 {
config.DeleteClauses = deleteClauses
}
if len(config.UpdateClauses) == 0 {
config.UpdateClauses = updateClauses
}
createCallback := db.Callback().Create()
createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
createCallback.Register("gorm:before_create", BeforeCreate)
createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(true))
createCallback.Register("gorm:create", Create(config))
createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true))
createCallback.Register("gorm:after_create", AfterCreate)
createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
createCallback.Clauses = config.CreateClauses
queryCallback := db.Callback().Query()
queryCallback.Register("gorm:query", Query)
queryCallback.Register("gorm:preload", Preload)
queryCallback.Register("gorm:after_query", AfterQuery)
queryCallback.Clauses = config.QueryClauses
deleteCallback := db.Callback().Delete()
deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
deleteCallback.Register("gorm:before_delete", BeforeDelete)
deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations)
deleteCallback.Register("gorm:delete", Delete(config))
deleteCallback.Register("gorm:after_delete", AfterDelete)
deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
deleteCallback.Clauses = config.DeleteClauses
updateCallback := db.Callback().Update()
updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue)
updateCallback.Register("gorm:before_update", BeforeUpdate)
updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false))
updateCallback.Register("gorm:update", Update(config))
updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false))
updateCallback.Register("gorm:after_update", AfterUpdate)
updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
updateCallback.Clauses = config.UpdateClauses
rowCallback := db.Callback().Row()
rowCallback.Register("gorm:row", RowQuery)
rowCallback.Clauses = config.QueryClauses
rawCallback := db.Callback().Raw()
rawCallback.Register("gorm:raw", RawExec)
rawCallback.Clauses = config.QueryClauses
}
================================================
FILE: gossh/gorm/callbacks/callmethod.go
================================================
package callbacks
import (
"reflect"
"gossh/gorm"
)
func callMethod(db *gorm.DB, fc func(value any, tx *gorm.DB) bool) {
tx := db.Session(&gorm.Session{NewDB: true})
if called := fc(db.Statement.ReflectValue.Interface(), tx); !called {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
db.Statement.CurDestIndex = 0
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
if value := reflect.Indirect(db.Statement.ReflectValue.Index(i)); value.CanAddr() {
fc(value.Addr().Interface(), tx)
} else {
db.AddError(gorm.ErrInvalidValue)
return
}
db.Statement.CurDestIndex++
}
case reflect.Struct:
if db.Statement.ReflectValue.CanAddr() {
fc(db.Statement.ReflectValue.Addr().Interface(), tx)
} else {
db.AddError(gorm.ErrInvalidValue)
}
}
}
}
================================================
FILE: gossh/gorm/callbacks/create.go
================================================
package callbacks
import (
"fmt"
"reflect"
"strings"
"gossh/gorm"
"gossh/gorm/clause"
"gossh/gorm/schema"
"gossh/gorm/utils"
)
// BeforeCreate before create hooks
func BeforeCreate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
callMethod(db, func(value any, tx *gorm.DB) (called bool) {
if db.Statement.Schema.BeforeSave {
if i, ok := value.(BeforeSaveInterface); ok {
called = true
db.AddError(i.BeforeSave(tx))
}
}
if db.Statement.Schema.BeforeCreate {
if i, ok := value.(BeforeCreateInterface); ok {
called = true
db.AddError(i.BeforeCreate(tx))
}
}
return called
})
}
}
// Create create hook
func Create(config *Config) func(db *gorm.DB) {
supportReturning := utils.Contains(config.CreateClauses, "RETURNING")
return func(db *gorm.DB) {
if db.Error != nil {
return
}
if db.Statement.Schema != nil {
if !db.Statement.Unscoped {
for _, c := range db.Statement.Schema.CreateClauses {
db.Statement.AddClause(c)
}
}
if supportReturning && len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 {
if _, ok := db.Statement.Clauses["RETURNING"]; !ok {
fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue))
for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue {
fromColumns = append(fromColumns, clause.Column{Name: field.DBName})
}
db.Statement.AddClause(clause.Returning{Columns: fromColumns})
}
}
}
if db.Statement.SQL.Len() == 0 {
db.Statement.SQL.Grow(180)
db.Statement.AddClauseIfNotExists(clause.Insert{})
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
db.Statement.Build(db.Statement.BuildClauses...)
}
isDryRun := !db.DryRun && db.Error == nil
if !isDryRun {
return
}
ok, mode := hasReturning(db, supportReturning)
if ok {
if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok {
if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing {
mode |= gorm.ScanOnConflictDoNothing
}
}
rows, err := db.Statement.ConnPool.QueryContext(
db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...,
)
if db.AddError(err) == nil {
defer func() {
db.AddError(rows.Close())
}()
gorm.Scan(rows, db, mode)
}
return
}
result, err := db.Statement.ConnPool.ExecContext(
db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...,
)
if err != nil {
db.AddError(err)
return
}
db.RowsAffected, _ = result.RowsAffected()
if db.RowsAffected == 0 {
return
}
var (
pkField *schema.Field
pkFieldName = "@id"
)
insertID, err := result.LastInsertId()
insertOk := err == nil && insertID > 0
if !insertOk {
if !supportReturning {
db.AddError(err)
}
return
}
if db.Statement.Schema != nil {
if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
return
}
pkField = db.Statement.Schema.PrioritizedPrimaryField
pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName
}
// append @id column with value for auto-increment primary key
// the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1
switch values := db.Statement.Dest.(type) {
case map[string]any:
values[pkFieldName] = insertID
case *map[string]any:
(*values)[pkFieldName] = insertID
case []map[string]any, *[]map[string]any:
mapValues, ok := values.([]map[string]any)
if !ok {
if v, ok := values.(*[]map[string]any); ok {
if *v != nil {
mapValues = *v
}
}
}
if config.LastInsertIDReversed {
insertID -= int64(len(mapValues)-1) * schema.DefaultAutoIncrementIncrement
}
for _, mapValue := range mapValues {
if mapValue != nil {
mapValue[pkFieldName] = insertID
}
insertID += schema.DefaultAutoIncrementIncrement
}
default:
if pkField == nil {
return
}
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
if config.LastInsertIDReversed {
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
rv := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(rv).Kind() != reflect.Struct {
break
}
_, isZero := pkField.ValueOf(db.Statement.Context, rv)
if isZero {
db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
insertID -= pkField.AutoIncrementIncrement
}
}
} else {
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
rv := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(rv).Kind() != reflect.Struct {
break
}
if _, isZero := pkField.ValueOf(db.Statement.Context, rv); isZero {
db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
insertID += pkField.AutoIncrementIncrement
}
}
}
case reflect.Struct:
_, isZero := pkField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
if isZero {
db.AddError(pkField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
}
}
}
}
}
// AfterCreate after create hooks
func AfterCreate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
callMethod(db, func(value any, tx *gorm.DB) (called bool) {
if db.Statement.Schema.AfterCreate {
if i, ok := value.(AfterCreateInterface); ok {
called = true
db.AddError(i.AfterCreate(tx))
}
}
if db.Statement.Schema.AfterSave {
if i, ok := value.(AfterSaveInterface); ok {
called = true
db.AddError(i.AfterSave(tx))
}
}
return called
})
}
}
// ConvertToCreateValues convert to create values
func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
curTime := stmt.DB.NowFunc()
switch value := stmt.Dest.(type) {
case map[string]any:
values = ConvertMapToValuesForCreate(stmt, value)
case *map[string]any:
values = ConvertMapToValuesForCreate(stmt, *value)
case []map[string]any:
values = ConvertSliceOfMapToValuesForCreate(stmt, value)
case *[]map[string]any:
values = ConvertSliceOfMapToValuesForCreate(stmt, *value)
default:
var (
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
_, updateTrackTime = stmt.Get("gorm:update_track_time")
isZero bool
)
stmt.Settings.Delete("gorm:update_track_time")
values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))}
for _, db := range stmt.Schema.DBNames {
if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil {
if v, ok := selectColumns[db]; (ok && v) || (!ok && (!restricted || field.AutoCreateTime > 0 || field.AutoUpdateTime > 0)) {
values.Columns = append(values.Columns, clause.Column{Name: db})
}
}
}
switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
rValLen := stmt.ReflectValue.Len()
if rValLen == 0 {
stmt.AddError(gorm.ErrEmptySlice)
return
}
stmt.SQL.Grow(rValLen * 18)
stmt.Vars = make([]any, 0, rValLen*len(values.Columns))
values.Values = make([][]any, rValLen)
defaultValueFieldsHavingValue := map[*schema.Field][]any{}
for i := 0; i < rValLen; i++ {
rv := reflect.Indirect(stmt.ReflectValue.Index(i))
if !rv.IsValid() {
stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData))
return
}
values.Values[i] = make([]any, len(values.Columns))
for idx, column := range values.Columns {
field := stmt.Schema.FieldsByDBName[column.Name]
if values.Values[i][idx], isZero = field.ValueOf(stmt.Context, rv); isZero {
if field.DefaultValueInterface != nil {
values.Values[i][idx] = field.DefaultValueInterface
stmt.AddError(field.Set(stmt.Context, rv, field.DefaultValueInterface))
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
stmt.AddError(field.Set(stmt.Context, rv, curTime))
values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv)
}
} else if field.AutoUpdateTime > 0 && updateTrackTime {
stmt.AddError(field.Set(stmt.Context, rv, curTime))
values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv)
}
}
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if rvOfvalue, isZero := field.ValueOf(stmt.Context, rv); !isZero {
if len(defaultValueFieldsHavingValue[field]) == 0 {
defaultValueFieldsHavingValue[field] = make([]any, rValLen)
}
defaultValueFieldsHavingValue[field][i] = rvOfvalue
}
}
}
}
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if vs, ok := defaultValueFieldsHavingValue[field]; ok {
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
for idx := range values.Values {
if vs[idx] == nil {
values.Values[idx] = append(values.Values[idx], stmt.Dialector.DefaultValueOf(field))
} else {
values.Values[idx] = append(values.Values[idx], vs[idx])
}
}
}
}
case reflect.Struct:
values.Values = [][]any{make([]any, len(values.Columns))}
for idx, column := range values.Columns {
field := stmt.Schema.FieldsByDBName[column.Name]
if values.Values[0][idx], isZero = field.ValueOf(stmt.Context, stmt.ReflectValue); isZero {
if field.DefaultValueInterface != nil {
values.Values[0][idx] = field.DefaultValueInterface
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface))
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime))
values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
}
} else if field.AutoUpdateTime > 0 && updateTrackTime {
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime))
values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
}
}
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if rvOfvalue, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero {
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
values.Values[0] = append(values.Values[0], rvOfvalue)
}
}
}
default:
stmt.AddError(gorm.ErrInvalidData)
}
}
if c, ok := stmt.Clauses["ON CONFLICT"]; ok {
if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.UpdateAll {
if stmt.Schema != nil && len(values.Columns) >= 1 {
selectColumns, restricted := stmt.SelectAndOmitColumns(true, true)
columns := make([]string, 0, len(values.Columns)-1)
for _, column := range values.Columns {
if field := stmt.Schema.LookUpField(column.Name); field != nil {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil ||
strings.EqualFold(field.DefaultValue, "NULL")) && field.AutoCreateTime == 0 {
if field.AutoUpdateTime > 0 {
assignment := clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: curTime}
switch field.AutoUpdateTime {
case schema.UnixNanosecond:
assignment.Value = curTime.UnixNano()
case schema.UnixMillisecond:
assignment.Value = curTime.UnixMilli()
case schema.UnixSecond:
assignment.Value = curTime.Unix()
}
onConflict.DoUpdates = append(onConflict.DoUpdates, assignment)
} else {
columns = append(columns, column.Name)
}
}
}
}
}
onConflict.DoUpdates = append(onConflict.DoUpdates, clause.AssignmentColumns(columns)...)
if len(onConflict.DoUpdates) == 0 {
onConflict.DoNothing = true
}
// use primary fields as default OnConflict columns
if len(onConflict.Columns) == 0 {
for _, field := range stmt.Schema.PrimaryFields {
onConflict.Columns = append(onConflict.Columns, clause.Column{Name: field.DBName})
}
}
stmt.AddClause(onConflict)
}
}
}
return values
}
================================================
FILE: gossh/gorm/callbacks/delete.go
================================================
package callbacks
import (
"reflect"
"strings"
"gossh/gorm"
"gossh/gorm/clause"
"gossh/gorm/schema"
"gossh/gorm/utils"
)
func BeforeDelete(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.BeforeDelete {
callMethod(db, func(value any, tx *gorm.DB) bool {
if i, ok := value.(BeforeDeleteInterface); ok {
db.AddError(i.BeforeDelete(tx))
return true
}
return false
})
}
}
func DeleteBeforeAssociations(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil {
selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false)
if !restricted {
return
}
for column, v := range selectColumns {
if !v {
continue
}
rel, ok := db.Statement.Schema.Relationships.Relations[column]
if !ok {
continue
}
switch rel.Type {
case schema.HasOne, schema.HasMany:
queryConds := rel.ToQueryConditions(db.Statement.Context, db.Statement.ReflectValue)
modelValue := reflect.New(rel.FieldSchema.ModelType).Interface()
tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue)
withoutConditions := false
if db.Statement.Unscoped {
tx = tx.Unscoped()
}
if len(db.Statement.Selects) > 0 {
selects := make([]string, 0, len(db.Statement.Selects))
for _, s := range db.Statement.Selects {
if s == clause.Associations {
selects = append(selects, s)
} else if columnPrefix := column + "."; strings.HasPrefix(s, columnPrefix) {
selects = append(selects, strings.TrimPrefix(s, columnPrefix))
}
}
if len(selects) > 0 {
tx = tx.Select(selects)
}
}
for _, cond := range queryConds {
if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 {
withoutConditions = true
break
}
}
if !withoutConditions && db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil {
return
}
case schema.Many2Many:
var (
queryConds = make([]clause.Expression, 0, len(rel.References))
foreignFields = make([]*schema.Field, 0, len(rel.References))
relForeignKeys = make([]string, 0, len(rel.References))
modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
table = rel.JoinTable.Table
tx = db.Session(&gorm.Session{NewDB: true}).Model(modelValue).Table(table)
)
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
foreignFields = append(foreignFields, ref.PrimaryKey)
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
} else if ref.PrimaryValue != "" {
queryConds = append(queryConds, clause.Eq{
Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName},
Value: ref.PrimaryValue,
})
}
}
_, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, foreignFields)
column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues)
queryConds = append(queryConds, clause.IN{Column: column, Values: values})
if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil {
return
}
}
}
}
}
func Delete(config *Config) func(db *gorm.DB) {
supportReturning := utils.Contains(config.DeleteClauses, "RETURNING")
return func(db *gorm.DB) {
if db.Error != nil {
return
}
if db.Statement.Schema != nil {
for _, c := range db.Statement.Schema.DeleteClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.Len() == 0 {
db.Statement.SQL.Grow(100)
db.Statement.AddClauseIfNotExists(clause.Delete{})
if db.Statement.Schema != nil {
_, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields)
column, values := schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 {
db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
}
if db.Statement.ReflectValue.CanAddr() && db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil {
_, queryValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields)
column, values = schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 {
db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
}
}
}
db.Statement.AddClauseIfNotExists(clause.From{})
db.Statement.Build(db.Statement.BuildClauses...)
}
checkMissingWhereConditions(db)
if !db.DryRun && db.Error == nil {
ok, mode := hasReturning(db, supportReturning)
if !ok {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if db.AddError(err) == nil {
db.RowsAffected, _ = result.RowsAffected()
}
return
}
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
gorm.Scan(rows, db, mode)
db.AddError(rows.Close())
}
}
}
}
func AfterDelete(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterDelete {
callMethod(db, func(value any, tx *gorm.DB) bool {
if i, ok := value.(AfterDeleteInterface); ok {
db.AddError(i.AfterDelete(tx))
return true
}
return false
})
}
}
================================================
FILE: gossh/gorm/callbacks/helper.go
================================================
package callbacks
import (
"reflect"
"sort"
"gossh/gorm"
"gossh/gorm/clause"
)
// ConvertMapToValuesForCreate convert map to values
func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]any) (values clause.Values) {
values.Columns = make([]clause.Column, 0, len(mapValue))
selectColumns, restricted := stmt.SelectAndOmitColumns(true, false)
keys := make([]string, 0, len(mapValue))
for k := range mapValue {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
value := mapValue[k]
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(k); field != nil {
k = field.DBName
}
}
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
values.Columns = append(values.Columns, clause.Column{Name: k})
if len(values.Values) == 0 {
values.Values = [][]any{{}}
}
values.Values[0] = append(values.Values[0], value)
}
}
return
}
// ConvertSliceOfMapToValuesForCreate convert slice of map to values
func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[string]any) (values clause.Values) {
columns := make([]string, 0, len(mapValues))
// when the length of mapValues is zero,return directly here
// no need to call stmt.SelectAndOmitColumns method
if len(mapValues) == 0 {
stmt.AddError(gorm.ErrEmptySlice)
return
}
var (
result = make(map[string][]any, len(mapValues))
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
)
for idx, mapValue := range mapValues {
for k, v := range mapValue {
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(k); field != nil {
k = field.DBName
}
}
if _, ok := result[k]; !ok {
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
result[k] = make([]any, len(mapValues))
columns = append(columns, k)
} else {
continue
}
}
result[k][idx] = v
}
}
sort.Strings(columns)
values.Values = make([][]any, len(mapValues))
values.Columns = make([]clause.Column, len(columns))
for idx, column := range columns {
values.Columns[idx] = clause.Column{Name: column}
for i, v := range result[column] {
if len(values.Values[i]) == 0 {
values.Values[i] = make([]any, len(columns))
}
values.Values[i][idx] = v
}
}
return
}
func hasReturning(tx *gorm.DB, supportReturning bool) (bool, gorm.ScanMode) {
if supportReturning {
if c, ok := tx.Statement.Clauses["RETURNING"]; ok {
returning, _ := c.Expression.(clause.Returning)
if len(returning.Columns) == 0 || (len(returning.Columns) == 1 && returning.Columns[0].Name == "*") {
return true, 0
}
return true, gorm.ScanUpdate
}
}
return false, 0
}
func checkMissingWhereConditions(db *gorm.DB) {
if !db.AllowGlobalUpdate && db.Error == nil {
where, withCondition := db.Statement.Clauses["WHERE"]
if withCondition {
if _, withSoftDelete := db.Statement.Clauses["soft_delete_enabled"]; withSoftDelete {
whereClause, _ := where.Expression.(clause.Where)
withCondition = len(whereClause.Exprs) > 1
}
}
if !withCondition {
db.AddError(gorm.ErrMissingWhereClause)
}
return
}
}
type visitMap = map[reflect.Value]bool
// Check if circular values, return true if loaded
func loadOrStoreVisitMap(visitMap *visitMap, v reflect.Value) (loaded bool) {
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
switch v.Kind() {
case reflect.Slice, reflect.Array:
loaded = true
for i := 0; i < v.Len(); i++ {
if !loadOrStoreVisitMap(visitMap, v.Index(i)) {
loaded = false
}
}
case reflect.Struct, reflect.Interface:
if v.CanAddr() {
p := v.Addr()
if _, ok := (*visitMap)[p]; ok {
return true
}
(*visitMap)[p] = true
}
}
return
}
================================================
FILE: gossh/gorm/callbacks/interfaces.go
================================================
package callbacks
import "gossh/gorm"
type BeforeCreateInterface interface {
BeforeCreate(*gorm.DB) error
}
type AfterCreateInterface interface {
AfterCreate(*gorm.DB) error
}
type BeforeUpdateInterface interface {
BeforeUpdate(*gorm.DB) error
}
type AfterUpdateInterface interface {
AfterUpdate(*gorm.DB) error
}
type BeforeSaveInterface interface {
BeforeSave(*gorm.DB) error
}
type AfterSaveInterface interface {
AfterSave(*gorm.DB) error
}
type BeforeDeleteInterface interface {
BeforeDelete(*gorm.DB) error
}
type AfterDeleteInterface interface {
AfterDelete(*gorm.DB) error
}
type AfterFindInterface interface {
AfterFind(*gorm.DB) error
}
================================================
FILE: gossh/gorm/callbacks/preload.go
================================================
package callbacks
import (
"fmt"
"reflect"
"sort"
"strings"
"gossh/gorm"
"gossh/gorm/clause"
"gossh/gorm/schema"
"gossh/gorm/utils"
)
// parsePreloadMap extracts nested preloads. e.g.
//
// // schema has a "k0" relation and a "k7.k8" embedded relation
// parsePreloadMap(schema, map[string][]interface{}{
// clause.Associations: {"arg1"},
// "k1": {"arg2"},
// "k2.k3": {"arg3"},
// "k4.k5.k6": {"arg4"},
// })
// // preloadMap is
// map[string]map[string][]interface{}{
// "k0": {},
// "k7": {
// "k8": {},
// },
// "k1": {},
// "k2": {
// "k3": {"arg3"},
// },
// "k4": {
// "k5.k6": {"arg4"},
// },
// }
func parsePreloadMap(s *schema.Schema, preloads map[string][]any) map[string]map[string][]any {
preloadMap := map[string]map[string][]any{}
setPreloadMap := func(name, value string, args []any) {
if _, ok := preloadMap[name]; !ok {
preloadMap[name] = map[string][]any{}
}
if value != "" {
preloadMap[name][value] = args
}
}
for name, args := range preloads {
preloadFields := strings.Split(name, ".")
value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), ".")
if preloadFields[0] == clause.Associations {
for _, relation := range s.Relationships.Relations {
if relation.Schema == s {
setPreloadMap(relation.Name, value, args)
}
}
for embedded, embeddedRelations := range s.Relationships.EmbeddedRelations {
for _, value := range embeddedValues(embeddedRelations) {
setPreloadMap(embedded, value, args)
}
}
} else {
setPreloadMap(preloadFields[0], value, args)
}
}
return preloadMap
}
func embeddedValues(embeddedRelations *schema.Relationships) []string {
if embeddedRelations == nil {
return nil
}
names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations))
for _, relation := range embeddedRelations.Relations {
// skip first struct name
names = append(names, strings.Join(relation.Field.BindNames[1:], "."))
}
for _, relations := range embeddedRelations.EmbeddedRelations {
names = append(names, embeddedValues(relations)...)
}
return names
}
// preloadEntryPoint enters layer by layer. It will call real preload if it finds the right entry point.
// If the current relationship is embedded or joined, current query will be ignored.
//
//nolint:cyclop
func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]any, associationsConds []any) error {
preloadMap := parsePreloadMap(db.Statement.Schema, preloads)
// avoid random traversal of the map
preloadNames := make([]string, 0, len(preloadMap))
for key := range preloadMap {
preloadNames = append(preloadNames, key)
}
sort.Strings(preloadNames)
isJoined := func(name string) (joined bool, nestedJoins []string) {
for _, join := range joins {
if _, ok := relationships.Relations[join]; ok && name == join {
joined = true
continue
}
joinNames := strings.SplitN(join, ".", 2)
if len(joinNames) == 2 {
if _, ok := relationships.Relations[joinNames[0]]; ok && name == joinNames[0] {
joined = true
nestedJoins = append(nestedJoins, joinNames[1])
}
}
}
return joined, nestedJoins
}
for _, name := range preloadNames {
if relations := relationships.EmbeddedRelations[name]; relations != nil {
if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds); err != nil {
return err
}
} else if rel := relationships.Relations[name]; rel != nil {
if joined, nestedJoins := isJoined(name); joined {
switch rv := db.Statement.ReflectValue; rv.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < rv.Len(); i++ {
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i))
tx := preloadDB(db, reflectValue, reflectValue.Interface())
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
return err
}
}
case reflect.Struct:
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv)
tx := preloadDB(db, reflectValue, reflectValue.Interface())
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
return err
}
default:
return gorm.ErrInvalidData
}
} else {
tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks})
tx.Statement.ReflectValue = db.Statement.ReflectValue
tx.Statement.Unscoped = db.Statement.Unscoped
if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil {
return err
}
}
} else {
return fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)
}
}
return nil
}
func preloadDB(db *gorm.DB, reflectValue reflect.Value, dest any) *gorm.DB {
tx := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true})
db.Statement.Settings.Range(func(k, v any) bool {
tx.Statement.Settings.Store(k, v)
return true
})
if err := tx.Statement.Parse(dest); err != nil {
tx.AddError(err)
return tx
}
tx.Statement.ReflectValue = reflectValue
tx.Statement.Unscoped = db.Statement.Unscoped
return tx
}
func preload(tx *gorm.DB, rel *schema.Relationship, conds []any, preloads map[string][]any) error {
var (
reflectValue = tx.Statement.ReflectValue
relForeignKeys []string
relForeignFields []*schema.Field
foreignFields []*schema.Field
foreignValues [][]any
identityMap = map[string][]reflect.Value{}
inlineConds []any
)
if rel.JoinTable != nil {
var (
joinForeignFields = make([]*schema.Field, 0, len(rel.References))
joinRelForeignFields = make([]*schema.Field, 0, len(rel.References))
joinForeignKeys = make([]string, 0, len(rel.References))
)
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
joinForeignKeys = append(joinForeignKeys, ref.ForeignKey.DBName)
joinForeignFields = append(joinForeignFields, ref.ForeignKey)
foreignFields = append(foreignFields, ref.PrimaryKey)
} else if ref.PrimaryValue != "" {
tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
} else {
joinRelForeignFields = append(joinRelForeignFields, ref.ForeignKey)
relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName)
relForeignFields = append(relForeignFields, ref.PrimaryKey)
}
}
joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields)
if len(joinForeignValues) == 0 {
return nil
}
joinResults := rel.JoinTable.MakeSlice().Elem()
column, values := schema.ToQueryValues(clause.CurrentTable, joinForeignKeys, joinForeignValues)
if err := tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error; err != nil {
return err
}
// convert join identity map to relation identity map
fieldValues := make([]any, len(joinForeignFields))
joinFieldValues := make([]any, len(joinRelForeignFields))
for i := 0; i < joinResults.Len(); i++ {
joinIndexValue := joinResults.Index(i)
for idx, field := range joinForeignFields {
fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue)
}
for idx, field := range joinRelForeignFields {
joinFieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue)
}
if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok {
joinKey := utils.ToStringKey(joinFieldValues...)
identityMap[joinKey] = append(identityMap[joinKey], results...)
}
}
_, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, joinResults, joinRelForeignFields)
} else {
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
relForeignFields = append(relForeignFields, ref.ForeignKey)
foreignFields = append(foreignFields, ref.PrimaryKey)
} else if ref.PrimaryValue != "" {
tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
} else {
relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName)
relForeignFields = append(relForeignFields, ref.PrimaryKey)
foreignFields = append(foreignFields, ref.ForeignKey)
}
}
identityMap, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields)
if len(foreignValues) == 0 {
return nil
}
}
// nested preload
for p, pvs := range preloads {
tx = tx.Preload(p, pvs...)
}
reflectResults := rel.FieldSchema.MakeSlice().Elem()
column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)
if len(values) != 0 {
for _, cond := range conds {
if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
tx = fc(tx)
} else {
inlineConds = append(inlineConds, cond)
}
}
if err := tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error; err != nil {
return err
}
}
fieldValues := make([]any, len(relForeignFields))
// clean up old values before preloading
switch reflectValue.Kind() {
case reflect.Struct:
switch rel.Type {
case schema.HasMany, schema.Many2Many:
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()))
default:
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface()))
}
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
switch rel.Type {
case schema.HasMany, schema.Many2Many:
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()))
default:
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()))
}
}
}
for i := 0; i < reflectResults.Len(); i++ {
elem := reflectResults.Index(i)
for idx, field := range relForeignFields {
fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, elem)
}
datas, ok := identityMap[utils.ToStringKey(fieldValues...)]
if !ok {
return fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", elem.Interface())
}
for _, data := range datas {
reflectFieldValue := rel.Field.ReflectValueOf(tx.Statement.Context, data)
if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() {
reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem()))
}
reflectFieldValue = reflect.Indirect(reflectFieldValue)
switch reflectFieldValue.Kind() {
case reflect.Struct:
tx.AddError(rel.Field.Set(tx.Statement.Context, data, elem.Interface()))
case reflect.Slice, reflect.Array:
if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr {
tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface()))
} else {
tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()))
}
}
}
}
return tx.Error
}
================================================
FILE: gossh/gorm/callbacks/query.go
================================================
package callbacks
import (
"fmt"
"reflect"
"strings"
"gossh/gorm"
"gossh/gorm/clause"
"gossh/gorm/schema"
"gossh/gorm/utils"
)
func Query(db *gorm.DB) {
if db.Error == nil {
BuildQuerySQL(db)
if !db.DryRun && db.Error == nil {
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err != nil {
db.AddError(err)
return
}
defer func() {
db.AddError(rows.Close())
}()
gorm.Scan(rows, db, 0)
}
}
}
func BuildQuerySQL(db *gorm.DB) {
if db.Statement.Schema != nil {
for _, c := range db.Statement.Schema.QueryClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.Len() == 0 {
db.Statement.SQL.Grow(100)
clauseSelect := clause.Select{Distinct: db.Statement.Distinct}
if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType {
var conds []clause.Expression
for _, primaryField := range db.Statement.Schema.PrimaryFields {
if v, isZero := primaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !isZero {
conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v})
}
}
if len(conds) > 0 {
db.Statement.AddClause(clause.Where{Exprs: conds})
}
}
if len(db.Statement.Selects) > 0 {
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Selects))
for idx, name := range db.Statement.Selects {
if db.Statement.Schema == nil {
clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true}
} else if f := db.Statement.Schema.LookUpField(name); f != nil {
clauseSelect.Columns[idx] = clause.Column{Name: f.DBName}
} else {
clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true}
}
}
} else if db.Statement.Schema != nil && len(db.Statement.Omits) > 0 {
selectColumns, _ := db.Statement.SelectAndOmitColumns(false, false)
clauseSelect.Columns = make([]clause.Column, 0, len(db.Statement.Schema.DBNames))
for _, dbName := range db.Statement.Schema.DBNames {
if v, ok := selectColumns[dbName]; (ok && v) || !ok {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{Table: db.Statement.Table, Name: dbName})
}
}
} else if db.Statement.Schema != nil && db.Statement.ReflectValue.IsValid() {
queryFields := db.QueryFields
if !queryFields {
switch db.Statement.ReflectValue.Kind() {
case reflect.Struct:
queryFields = db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType
case reflect.Slice:
queryFields = db.Statement.ReflectValue.Type().Elem() != db.Statement.Schema.ModelType
}
}
if queryFields {
stmt := gorm.Statement{DB: db}
// smaller struct
if err := stmt.Parse(db.Statement.Dest); err == nil && (db.QueryFields || stmt.Schema.ModelType != db.Statement.Schema.ModelType) {
clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames))
for idx, dbName := range stmt.Schema.DBNames {
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
}
}
}
}
// inline joins
fromClause := clause.From{}
if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok {
fromClause = v
}
if len(db.Statement.Joins) != 0 || len(fromClause.Joins) != 0 {
if len(db.Statement.Selects) == 0 && len(db.Statement.Omits) == 0 && db.Statement.Schema != nil {
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames))
for idx, dbName := range db.Statement.Schema.DBNames {
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
}
}
specifiedRelationsName := make(map[string]any)
for _, join := range db.Statement.Joins {
if db.Statement.Schema != nil {
var isRelations bool // is relations or raw sql
var relations []*schema.Relationship
relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]
if ok {
isRelations = true
relations = append(relations, relation)
} else {
// handle nested join like "Manager.Company"
nestedJoinNames := strings.Split(join.Name, ".")
if len(nestedJoinNames) > 1 {
isNestedJoin := true
gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
currentRelations := db.Statement.Schema.Relationships.Relations
for _, relname := range nestedJoinNames {
// incomplete match, only treated as raw sql
if relation, ok = currentRelations[relname]; ok {
gussNestedRelations = append(gussNestedRelations, relation)
currentRelations = relation.FieldSchema.Relationships.Relations
} else {
isNestedJoin = false
break
}
}
if isNestedJoin {
isRelations = true
relations = gussNestedRelations
}
}
}
if isRelations {
genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join {
tableAliasName := relation.Name
if parentTableName != clause.CurrentTable {
tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName)
}
columnStmt := gorm.Statement{
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
Selects: join.Selects, Omits: join.Omits,
}
selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false)
for _, s := range relation.FieldSchema.DBNames {
if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
Table: tableAliasName,
Name: s,
Alias: utils.NestedRelationName(tableAliasName, s),
})
}
}
exprs := make([]clause.Expression, len(relation.References))
for idx, ref := range relation.References {
if ref.OwnPrimaryKey {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: parentTableName, Name: ref.PrimaryKey.DBName},
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
}
} else {
if ref.PrimaryValue == "" {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: parentTableName, Name: ref.ForeignKey.DBName},
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
}
} else {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
Value: ref.PrimaryValue,
}
}
}
}
{
onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}}
for _, c := range relation.FieldSchema.QueryClauses {
onStmt.AddClause(c)
}
if join.On != nil {
onStmt.AddClause(join.On)
}
if cs, ok := onStmt.Clauses["WHERE"]; ok {
if where, ok := cs.Expression.(clause.Where); ok {
where.Build(&onStmt)
if onSQL := onStmt.SQL.String(); onSQL != "" {
vars := onStmt.Vars
for idx, v := range vars {
bindvar := strings.Builder{}
onStmt.Vars = vars[0 : idx+1]
db.Dialector.BindVarTo(&bindvar, &onStmt, v)
onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1)
}
exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars})
}
}
}
}
return clause.Join{
Type: joinType,
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
ON: clause.Where{Exprs: exprs},
}
}
parentTableName := clause.CurrentTable
for _, rel := range relations {
// joins table alias like "Manager, Company, Manager__Company"
nestedAlias := utils.NestedRelationName(parentTableName, rel.Name)
if _, ok := specifiedRelationsName[nestedAlias]; !ok {
fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel))
specifiedRelationsName[nestedAlias] = nil
}
if parentTableName != clause.CurrentTable {
parentTableName = utils.NestedRelationName(parentTableName, rel.Name)
} else {
parentTableName = rel.Name
}
}
} else {
fromClause.Joins = append(fromClause.Joins, clause.Join{
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
})
}
} else {
fromClause.Joins = append(fromClause.Joins, clause.Join{
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
})
}
}
db.Statement.AddClause(fromClause)
} else {
db.Statement.AddClauseIfNotExists(clause.From{})
}
db.Statement.AddClauseIfNotExists(clauseSelect)
db.Statement.Build(db.Statement.BuildClauses...)
}
}
func Preload(db *gorm.DB) {
if db.Error == nil && len(db.Statement.Preloads) > 0 {
if db.Statement.Schema == nil {
db.AddError(fmt.Errorf("%w when using preload", gorm.ErrModelValueRequired))
return
}
joins := make([]string, 0, len(db.Statement.Joins))
for _, join := range db.Statement.Joins {
joins = append(joins, join.Name)
}
tx := preloadDB(db, db.Statement.ReflectValue, db.Statement.Dest)
if tx.Error != nil {
return
}
db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations]))
}
}
func AfterQuery(db *gorm.DB) {
// clear the joins after query because preload need it
db.Statement.Joins = nil
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 {
callMethod(db, func(value any, tx *gorm.DB) bool {
if i, ok := value.(AfterFindInterface); ok {
db.AddError(i.AfterFind(tx))
return true
}
return false
})
}
}
================================================
FILE: gossh/gorm/callbacks/raw.go
================================================
package callbacks
import (
"gossh/gorm"
)
func RawExec(db *gorm.DB) {
if db.Error == nil && !db.DryRun {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err != nil {
db.AddError(err)
return
}
db.RowsAffected, _ = result.RowsAffected()
}
}
================================================
FILE: gossh/gorm/callbacks/row.go
================================================
package callbacks
import (
"gossh/gorm"
)
func RowQuery(db *gorm.DB) {
if db.Error == nil {
BuildQuerySQL(db)
if db.DryRun || db.Error != nil {
return
}
if isRows, ok := db.Get("rows"); ok && isRows.(bool) {
db.Statement.Settings.Delete("rows")
db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
} else {
db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
}
db.RowsAffected = -1
}
}
================================================
FILE: gossh/gorm/callbacks/transaction.go
================================================
package callbacks
import (
"gossh/gorm"
)
func BeginTransaction(db *gorm.DB) {
if !db.Config.SkipDefaultTransaction && db.Error == nil {
if tx := db.Begin(); tx.Error == nil {
db.Statement.ConnPool = tx.Statement.ConnPool
db.InstanceSet("gorm:started_transaction", true)
} else if tx.Error == gorm.ErrInvalidTransaction {
tx.Error = nil
} else {
db.Error = tx.Error
}
}
}
func CommitOrRollbackTransaction(db *gorm.DB) {
if !db.Config.SkipDefaultTransaction {
if _, ok := db.InstanceGet("gorm:started_transaction"); ok {
if db.Error != nil {
db.Rollback()
} else {
db.Commit()
}
db.Statement.ConnPool = db.ConnPool
}
}
}
================================================
FILE: gossh/gorm/callbacks/update.go
================================================
package callbacks
import (
"reflect"
"sort"
"gossh/gorm"
"gossh/gorm/clause"
"gossh/gorm/schema"
"gossh/gorm/utils"
)
func SetupUpdateReflectValue(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil {
if !db.Statement.ReflectValue.CanAddr() || db.Statement.Model != db.Statement.Dest {
db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model)
for db.Statement.ReflectValue.Kind() == reflect.Ptr {
db.Statement.ReflectValue = db.Statement.ReflectValue.Elem()
}
if dest, ok := db.Statement.Dest.(map[string]any); ok {
for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
if _, ok := dest[rel.Name]; ok {
db.AddError(rel.Field.Set(db.Statement.Context, db.Statement.ReflectValue, dest[rel.Name]))
}
}
}
}
}
}
// BeforeUpdate before update hooks
func BeforeUpdate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
callMethod(db, func(value any, tx *gorm.DB) (called bool) {
if db.Statement.Schema.BeforeSave {
if i, ok := value.(BeforeSaveInterface); ok {
called = true
db.AddError(i.BeforeSave(tx))
}
}
if db.Statement.Schema.BeforeUpdate {
if i, ok := value.(BeforeUpdateInterface); ok {
called = true
db.AddError(i.BeforeUpdate(tx))
}
}
return called
})
}
}
// Update update hook
func Update(config *Config) func(db *gorm.DB) {
supportReturning := utils.Contains(config.UpdateClauses, "RETURNING")
return func(db *gorm.DB) {
if db.Error != nil {
return
}
if db.Statement.Schema != nil {
for _, c := range db.Statement.Schema.UpdateClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.Len() == 0 {
db.Statement.SQL.Grow(180)
db.Statement.AddClauseIfNotExists(clause.Update{})
if _, ok := db.Statement.Clauses["SET"]; !ok {
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
defer delete(db.Statement.Clauses, "SET")
db.Statement.AddClause(set)
} else {
return
}
}
db.Statement.Build(db.Statement.BuildClauses...)
}
checkMissingWhereConditions(db)
if !db.DryRun && db.Error == nil {
if ok, mode := hasReturning(db, supportReturning); ok {
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
dest := db.Statement.Dest
db.Statement.Dest = db.Statement.ReflectValue.Addr().Interface()
gorm.Scan(rows, db, mode)
db.Statement.Dest = dest
db.AddError(rows.Close())
}
} else {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if db.AddError(err) == nil {
db.RowsAffected, _ = result.RowsAffected()
}
}
}
}
}
// AfterUpdate after update hooks
func AfterUpdate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
callMethod(db, func(value any, tx *gorm.DB) (called bool) {
if db.Statement.Schema.AfterUpdate {
if i, ok := value.(AfterUpdateInterface); ok {
called = true
db.AddError(i.AfterUpdate(tx))
}
}
if db.Statement.Schema.AfterSave {
if i, ok := value.(AfterSaveInterface); ok {
called = true
db.AddError(i.AfterSave(tx))
}
}
return called
})
}
}
// ConvertToAssignments convert to update assignments
func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
var (
selectColumns, restricted = stmt.SelectAndOmitColumns(false, true)
assignValue func(field *schema.Field, value any)
)
switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
assignValue = func(field *schema.Field, value any) {
for i := 0; i < stmt.ReflectValue.Len(); i++ {
if stmt.ReflectValue.CanAddr() {
field.Set(stmt.Context, stmt.ReflectValue.Index(i), value)
}
}
}
case reflect.Struct:
assignValue = func(field *schema.Field, value any) {
if stmt.ReflectValue.CanAddr() {
field.Set(stmt.Context, stmt.ReflectValue, value)
}
}
default:
assignValue = func(field *schema.Field, value any) {
}
}
updatingValue := reflect.ValueOf(stmt.Dest)
for updatingValue.Kind() == reflect.Ptr {
updatingValue = updatingValue.Elem()
}
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
if size := stmt.ReflectValue.Len(); size > 0 {
var isZero bool
for i := 0; i < size; i++ {
for _, field := range stmt.Schema.PrimaryFields {
_, isZero = field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i))
if !isZero {
break
}
}
}
if !isZero {
_, primaryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields)
column, values := schema.ToQueryValues("", stmt.Schema.PrimaryFieldDBNames, primaryValues)
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
}
}
case reflect.Struct:
for _, field := range stmt.Schema.PrimaryFields {
if value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
}
}
}
}
switch value := updatingValue.Interface().(type) {
case map[string]any:
set = make([]clause.Assignment, 0, len(value))
keys := make([]string, 0, len(value))
for k := range value {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
kv := value[k]
if _, ok := kv.(*gorm.DB); ok {
kv = []any{kv}
}
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(k); field != nil {
if field.DBName != "" {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: kv})
assignValue(field, value[k])
}
} else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) {
assignValue(field, value[k])
}
continue
}
}
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: kv})
}
}
if !stmt.SkipHooks && stmt.Schema != nil {
for _, dbName := range stmt.Schema.DBNames {
field := stmt.Schema.LookUpField(dbName)
if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil {
if v, ok := selectColumns[field.DBName]; (ok && v) || !ok {
now := stmt.DB.NowFunc()
assignValue(field, now)
if field.AutoUpdateTime == schema.UnixNanosecond {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()})
} else if field.AutoUpdateTime == schema.UnixMillisecond {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixMilli()})
} else if field.AutoUpdateTime == schema.UnixSecond {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()})
} else {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now})
}
}
}
}
}
default:
updatingSchema := stmt.Schema
var isDiffSchema bool
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
// different schema
updatingStmt := &gorm.Statement{DB: stmt.DB}
if err := updatingStmt.Parse(stmt.Dest); err == nil {
updatingSchema = updatingStmt.Schema
isDiffSchema = true
}
}
switch updatingValue.Kind() {
case reflect.Struct:
set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName))
for _, dbName := range stmt.Schema.DBNames {
if field := updatingSchema.LookUpField(dbName); field != nil {
if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) {
value, isZero := field.ValueOf(stmt.Context, updatingValue)
if !stmt.SkipHooks && field.AutoUpdateTime > 0 {
if field.AutoUpdateTime == schema.UnixNanosecond {
value = stmt.DB.NowFunc().UnixNano()
} else if field.AutoUpdateTime == schema.UnixMillisecond {
value = stmt.DB.NowFunc().UnixMilli()
} else if field.AutoUpdateTime == schema.UnixSecond {
value = stmt.DB.NowFunc().Unix()
} else {
value = stmt.DB.NowFunc()
}
isZero = false
}
if (ok || !isZero) && field.Updatable {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value})
assignField := field
if isDiffSchema {
if originField := stmt.Schema.LookUpField(dbName); originField != nil {
assignField = originField
}
}
assignValue(assignField, value)
}
}
} else {
if value, isZero := field.ValueOf(stmt.Context, updatingValue); !isZero {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
}
}
}
}
default:
stmt.AddError(gorm.ErrInvalidData)
}
}
return
}
================================================
FILE: gossh/gorm/callbacks.go
================================================
package gorm
import (
"context"
"errors"
"fmt"
"reflect"
"sort"
"time"
"gossh/gorm/schema"
"gossh/gorm/utils"
)
func initializeCallbacks(db *DB) *callbacks {
return &callbacks{
processors: map[string]*processor{
"create": {db: db},
"query": {db: db},
"update": {db: db},
"delete": {db: db},
"row": {db: db},
"raw": {db: db},
},
}
}
// callbacks gorm callbacks manager
type callbacks struct {
processors map[string]*processor
}
type processor struct {
db *DB
Clauses []string
fns []func(*DB)
callbacks []*callback
}
type callback struct {
name string
before string
after string
remove bool
replace bool
match func(*DB) bool
handler func(*DB)
processor *processor
}
func (cs *callbacks) Create() *processor {
return cs.processors["create"]
}
func (cs *callbacks) Query() *processor {
return cs.processors["query"]
}
func (cs *callbacks) Update() *processor {
return cs.processors["update"]
}
func (cs *callbacks) Delete() *processor {
return cs.processors["delete"]
}
func (cs *callbacks) Row() *processor {
return cs.processors["row"]
}
func (cs *callbacks) Raw() *processor {
return cs.processors["raw"]
}
func (p *processor) Execute(db *DB) *DB {
// call scopes
for len(db.Statement.scopes) > 0 {
db = db.executeScopes()
}
var (
curTime = time.Now()
stmt = db.Statement
resetBuildClauses bool
)
if len(stmt.BuildClauses) == 0 {
stmt.BuildClauses = p.Clauses
resetBuildClauses = true
}
if optimizer, ok := db.Statement.Dest.(StatementModifier); ok {
optimizer.ModifyStatement(stmt)
}
// assign model values
if stmt.Model == nil {
stmt.Model = stmt.Dest
} else if stmt.Dest == nil {
stmt.Dest = stmt.Model
}
// parse model values
if stmt.Model != nil {
if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.TableExpr == nil && stmt.SQL.Len() == 0)) {
if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" && stmt.TableExpr == nil {
db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err))
} else {
db.AddError(err)
}
}
}
// assign stmt.ReflectValue
if stmt.Dest != nil {
stmt.ReflectValue = reflect.ValueOf(stmt.Dest)
for stmt.ReflectValue.Kind() == reflect.Ptr {
if stmt.ReflectValue.IsNil() && stmt.ReflectValue.CanAddr() {
stmt.ReflectValue.Set(reflect.New(stmt.ReflectValue.Type().Elem()))
}
stmt.ReflectValue = stmt.ReflectValue.Elem()
}
if !stmt.ReflectValue.IsValid() {
db.AddError(ErrInvalidValue)
}
}
for _, f := range p.fns {
f(db)
}
if stmt.SQL.Len() > 0 {
db.Logger.Trace(stmt.Context, curTime, func() (string, int64) {
sql, vars := stmt.SQL.String(), stmt.Vars
if filter, ok := db.Logger.(ParamsFilter); ok {
sql, vars = filter.ParamsFilter(stmt.Context, stmt.SQL.String(), stmt.Vars...)
}
return db.Dialector.Explain(sql, vars...), db.RowsAffected
}, db.Error)
}
if !stmt.DB.DryRun {
stmt.SQL.Reset()
stmt.Vars = nil
}
if resetBuildClauses {
stmt.BuildClauses = nil
}
return db
}
func (p *processor) Get(name string) func(*DB) {
for i := len(p.callbacks) - 1; i >= 0; i-- {
if v := p.callbacks[i]; v.name == name && !v.remove {
return v.handler
}
}
return nil
}
func (p *processor) Before(name string) *callback {
return &callback{before: name, processor: p}
}
func (p *processor) After(name string) *callback {
return &callback{after: name, processor: p}
}
func (p *processor) Match(fc func(*DB) bool) *callback {
return &callback{match: fc, processor: p}
}
func (p *processor) Register(name string, fn func(*DB)) error {
return (&callback{processor: p}).Register(name, fn)
}
func (p *processor) Remove(name string) error {
return (&callback{processor: p}).Remove(name)
}
func (p *processor) Replace(name string, fn func(*DB)) error {
return (&callback{processor: p}).Replace(name, fn)
}
func (p *processor) compile() (err error) {
var callbacks []*callback
removedMap := map[string]bool{}
for _, callback := range p.callbacks {
if callback.match == nil || callback.match(p.db) {
callbacks = append(callbacks, callback)
}
if callback.remove {
removedMap[callback.name] = true
}
}
if len(removedMap) > 0 {
callbacks = removeCallbacks(callbacks, removedMap)
}
p.callbacks = callbacks
if p.fns, err = sortCallbacks(p.callbacks); err != nil {
p.db.Logger.Error(context.Background(), "Got error when compile callbacks, got %v", err)
}
return
}
func (c *callback) Before(name string) *callback {
c.before = name
return c
}
func (c *callback) After(name string) *callback {
c.after = name
return c
}
func (c *callback) Register(name string, fn func(*DB)) error {
c.name = name
c.handler = fn
c.processor.callbacks = append(c.processor.callbacks, c)
return c.processor.compile()
}
func (c *callback) Remove(name string) error {
c.processor.db.Logger.Warn(context.Background(), "removing callback `%s` from %s\n", name, utils.FileWithLineNum())
c.name = name
c.remove = true
c.processor.callbacks = append(c.processor.callbacks, c)
return c.processor.compile()
}
func (c *callback) Replace(name string, fn func(*DB)) error {
c.processor.db.Logger.Info(context.Background(), "replacing callback `%s` from %s\n", name, utils.FileWithLineNum())
c.name = name
c.handler = fn
c.replace = true
c.processor.callbacks = append(c.processor.callbacks, c)
return c.processor.compile()
}
// getRIndex get right index from string slice
func getRIndex(strs []string, str string) int {
for i := len(strs) - 1; i >= 0; i-- {
if strs[i] == str {
return i
}
}
return -1
}
func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
var (
names, sorted []string
sortCallback func(*callback) error
)
sort.SliceStable(cs, func(i, j int) bool {
if cs[j].before == "*" && cs[i].before != "*" {
return true
}
if cs[j].after == "*" && cs[i].after != "*" {
return true
}
return false
})
for _, c := range cs {
// show warning message the callback name already exists
if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove {
c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%s` from %s\n", c.name, utils.FileWithLineNum())
}
names = append(names, c.name)
}
sortCallback = func(c *callback) error {
if c.before != "" { // if defined before callback
if c.before == "*" && len(sorted) > 0 {
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
sorted = append([]string{c.name}, sorted...)
}
} else if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 {
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
// if before callback already sorted, append current callback just after it
sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...)
} else if curIdx > sortedIdx {
return fmt.Errorf("conflicting callback %s with before %s", c.name, c.before)
}
} else if idx := getRIndex(names, c.before); idx != -1 {
// if before callback exists
cs[idx].after = c.name
}
}
if c.after != "" { // if defined after callback
if c.after == "*" && len(sorted) > 0 {
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
sorted = append(sorted, c.name)
}
} else if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 {
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
// if after callback sorted, append current callback to last
sorted = append(sorted, c.name)
} else if curIdx < sortedIdx {
return fmt.Errorf("conflicting callback %s with before %s", c.name, c.after)
}
} else if idx := getRIndex(names, c.after); idx != -1 {
// if after callback exists but haven't sorted
// set after callback's before callback to current callback
after := cs[idx]
if after.before == "" {
after.before = c.name
}
if err := sortCallback(after); err != nil {
return err
}
if err := sortCallback(c); err != nil {
return err
}
}
}
// if current callback haven't been sorted, append it to last
if getRIndex(sorted, c.name) == -1 {
sorted = append(sorted, c.name)
}
return nil
}
for _, c := range cs {
if err = sortCallback(c); err != nil {
return
}
}
for _, name := range sorted {
if idx := getRIndex(names, name); !cs[idx].remove {
fns = append(fns, cs[idx].handler)
}
}
return
}
func removeCallbacks(cs []*callback, nameMap map[string]bool) []*callback {
callbacks := make([]*callback, 0, len(cs))
for _, callback := range cs {
if nameMap[callback.name] {
continue
}
callbacks = append(callbacks, callback)
}
return callbacks
}
================================================
FILE: gossh/gorm/chainable_api.go
================================================
package gorm
import (
"fmt"
"regexp"
"strings"
"gossh/gorm/clause"
"gossh/gorm/utils"
)
// Model specify the model you would like to run db operations
//
// // update all users's name to `hello`
// db.Model(&User{}).Update("name", "hello")
// // if user's primary key is non-blank, will use it as condition, then will only update that user's name to `hello`
// db.Model(&user).Update("name", "hello")
func (db *DB) Model(value any) (tx *DB) {
tx = db.getInstance()
tx.Statement.Model = value
return
}
// Clauses Add clauses
//
// This supports both standard clauses (clause.OrderBy, clause.Limit, clause.Where) and more
// advanced techniques like specifying lock strength and optimizer hints. See the
// [docs] for more depth.
//
// // add a simple limit clause
// db.Clauses(clause.Limit{Limit: 1}).Find(&User{})
// // tell the optimizer to use the `idx_user_name` index
// db.Clauses(hints.UseIndex("idx_user_name")).Find(&User{})
// // specify the lock strength to UPDATE
// db.Clauses(clause.Locking{Strength: "UPDATE"}).Find(&users)
//
// [docs]: https://gorm.io/docs/sql_builder.html#Clauses
func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) {
tx = db.getInstance()
var whereConds []any
for _, cond := range conds {
if c, ok := cond.(clause.Interface); ok {
tx.Statement.AddClause(c)
} else if optimizer, ok := cond.(StatementModifier); ok {
optimizer.ModifyStatement(tx.Statement)
} else {
whereConds = append(whereConds, cond)
}
}
if len(whereConds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(whereConds[0], whereConds[1:]...)})
}
return
}
var tableRegexp = regexp.MustCompile(`(?i)(?:.+? AS (\w+)\s*(?:$|,)|^\w+\s+(\w+)$)`)
// Table specify the table you would like to run db operations
//
// // Get a user
// db.Table("users").Take(&result)
func (db *DB) Table(name string, args ...any) (tx *DB) {
tx = db.getInstance()
if strings.Contains(name, " ") || strings.Contains(name, "`") || len(args) > 0 {
tx.Statement.TableExpr = &clause.Expr{SQL: name, Vars: args}
if results := tableRegexp.FindStringSubmatch(name); len(results) == 3 {
if results[1] != "" {
tx.Statement.Table = results[1]
} else {
tx.Statement.Table = results[2]
}
}
} else if tables := strings.Split(name, "."); len(tables) == 2 {
tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
tx.Statement.Table = tables[1]
} else if name != "" {
tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
tx.Statement.Table = name
} else {
tx.Statement.TableExpr = nil
tx.Statement.Table = ""
}
return
}
// Distinct specify distinct fields that you want querying
//
// // Select distinct names of users
// db.Distinct("name").Find(&results)
// // Select distinct name/age pairs from users
// db.Distinct("name", "age").Find(&results)
func (db *DB) Distinct(args ...any) (tx *DB) {
tx = db.getInstance()
tx.Statement.Distinct = true
if len(args) > 0 {
tx = tx.Select(args[0], args[1:]...)
}
return
}
// Select specify fields that you want when querying, creating, updating
//
// Use Select when you only want a subset of the fields. By default, GORM will select all fields.
// Select accepts both string arguments and arrays.
//
// // Select name and age of user using multiple arguments
// db.Select("name", "age").Find(&users)
// // Select name and age of user using an array
// db.Select([]string{"name", "age"}).Find(&users)
func (db *DB) Select(query any, args ...any) (tx *DB) {
tx = db.getInstance()
switch v := query.(type) {
case []string:
tx.Statement.Selects = v
for _, arg := range args {
switch arg := arg.(type) {
case string:
tx.Statement.Selects = append(tx.Statement.Selects, arg)
case []string:
tx.Statement.Selects = append(tx.Statement.Selects, arg...)
default:
tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args))
return
}
}
if clause, ok := tx.Statement.Clauses["SELECT"]; ok {
clause.Expression = nil
tx.Statement.Clauses["SELECT"] = clause
}
case string:
if strings.Count(v, "?") >= len(args) && len(args) > 0 {
tx.Statement.AddClause(clause.Select{
Distinct: db.Statement.Distinct,
Expression: clause.Expr{SQL: v, Vars: args},
})
} else if strings.Count(v, "@") > 0 && len(args) > 0 {
tx.Statement.AddClause(clause.Select{
Distinct: db.Statement.Distinct,
Expression: clause.NamedExpr{SQL: v, Vars: args},
})
} else {
tx.Statement.Selects = []string{v}
for _, arg := range args {
switch arg := arg.(type) {
case string:
tx.Statement.Selects = append(tx.Statement.Selects, arg)
case []string:
tx.Statement.Selects = append(tx.Statement.Selects, arg...)
default:
tx.Statement.AddClause(clause.Select{
Distinct: db.Statement.Distinct,
Expression: clause.Expr{SQL: v, Vars: args},
})
return
}
}
if clause, ok := tx.Statement.Clauses["SELECT"]; ok {
clause.Expression = nil
tx.Statement.Clauses["SELECT"] = clause
}
}
default:
tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args))
}
return
}
// Omit specify fields that you want to ignore when creating, updating and querying
func (db *DB) Omit(columns ...string) (tx *DB) {
tx = db.getInstance()
if len(columns) == 1 && strings.ContainsRune(columns[0], ',') {
tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsValidDBNameChar)
} else {
tx.Statement.Omits = columns
}
return
}
// Where add conditions
//
// See the [docs] for details on the various formats that where clauses can take. By default, where clauses chain with AND.
//
// // Find the first user with name jinzhu
// db.Where("name = ?", "jinzhu").First(&user)
// // Find the first user with name jinzhu and age 20
// db.Where(&User{Name: "jinzhu", Age: 20}).First(&user)
// // Find the first user with name jinzhu and age not equal to 20
// db.Where("name = ?", "jinzhu").Where("age <> ?", "20").First(&user)
//
// [docs]: https://gorm.io/docs/query.html#Conditions
func (db *DB) Where(query any, args ...any) (tx *DB) {
tx = db.getInstance()
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: conds})
}
return
}
// Not add NOT conditions
//
// Not works similarly to where, and has the same syntax.
//
// // Find the first user with name not equal to jinzhu
// db.Not("name = ?", "jinzhu").First(&user)
func (db *DB) Not(query any, args ...any) (tx *DB) {
tx = db.getInstance()
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Not(conds...)}})
}
return
}
// Or add OR conditions
//
// Or is used to chain together queries with an OR.
//
// // Find the first user with name equal to jinzhu or john
// db.Where("name = ?", "jinzhu").Or("name = ?", "john").First(&user)
func (db *DB) Or(query any, args ...any) (tx *DB) {
tx = db.getInstance()
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(clause.And(conds...))}})
}
return
}
// Joins specify Joins conditions
//
// db.Joins("Account").Find(&user)
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{}))
func (db *DB) Joins(query string, args ...any) (tx *DB) {
return joins(db, clause.LeftJoin, query, args...)
}
// InnerJoins specify inner joins conditions
// db.InnerJoins("Account").Find(&user)
func (db *DB) InnerJoins(query string, args ...any) (tx *DB) {
return joins(db, clause.InnerJoin, query, args...)
}
func joins(db *DB, joinType clause.JoinType, query string, args ...any) (tx *DB) {
tx = db.getInstance()
if len(args) == 1 {
if db, ok := args[0].(*DB); ok {
j := join{
Name: query, Conds: args, Selects: db.Statement.Selects,
Omits: db.Statement.Omits, JoinType: joinType,
}
if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
j.On = &where
}
tx.Statement.Joins = append(tx.Statement.Joins, j)
return
}
}
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, JoinType: joinType})
return
}
// Group specify the group method on the find
//
// // Select the sum age of users with given names
// db.Model(&User{}).Select("name, sum(age) as total").Group("name").Find(&results)
func (db *DB) Group(name string) (tx *DB) {
tx = db.getInstance()
fields := strings.FieldsFunc(name, utils.IsValidDBNameChar)
tx.Statement.AddClause(clause.GroupBy{
Columns: []clause.Column{{Name: name, Raw: len(fields) != 1}},
})
return
}
// Having specify HAVING conditions for GROUP BY
//
// // Select the sum age of users with name jinzhu
// db.Model(&User{}).Select("name, sum(age) as total").Group("name").Having("name = ?", "jinzhu").Find(&result)
func (db *DB) Having(query any, args ...any) (tx *DB) {
tx = db.getInstance()
tx.Statement.AddClause(clause.GroupBy{
Having: tx.Statement.BuildCondition(query, args...),
})
return
}
// Order specify order when retrieving records from database
//
// db.Order("name DESC")
// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true})
func (db *DB) Order(value any) (tx *DB) {
tx = db.getInstance()
switch v := value.(type) {
case clause.OrderByColumn:
tx.Statement.AddClause(clause.OrderBy{
Columns: []clause.OrderByColumn{v},
})
case string:
if v != "" {
tx.Statement.AddClause(clause.OrderBy{
Columns: []clause.OrderByColumn{{
Column: clause.Column{Name: v, Raw: true},
}},
})
}
}
return
}
// Limit specify the number of records to be retrieved
//
// Limit conditions can be cancelled by using `Limit(-1)`.
//
// // retrieve 3 users
// db.Limit(3).Find(&users)
// // retrieve 3 users into users1, and all users into users2
// db.Limit(3).Find(&users1).Limit(-1).Find(&users2)
func (db *DB) Limit(limit int) (tx *DB) {
tx = db.getInstance()
tx.Statement.AddClause(clause.Limit{Limit: &limit})
return
}
// Offset specify the number of records to skip before starting to return the records
//
// Offset conditions can be cancelled by using `Offset(-1)`.
//
// // select the third user
// db.Offset(2).First(&user)
// // select the first user by cancelling an earlier chained offset
// db.Offset(5).Offset(-1).First(&user)
func (db *DB) Offset(offset int) (tx *DB) {
tx = db.getInstance()
tx.Statement.AddClause(clause.Limit{Offset: offset})
return
}
// Scopes pass current database connection to arguments `func(DB) DB`, which could be used to add conditions dynamically
//
// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
// return db.Where("amount > ?", 1000)
// }
//
// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB {
// return func (db *gorm.DB) *gorm.DB {
// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
// }
// }
//
// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) {
tx = db.getInstance()
tx.Statement.scopes = append(tx.Statement.scopes, funcs...)
return tx
}
func (db *DB) executeScopes() (tx *DB) {
scopes := db.Statement.scopes
db.Statement.scopes = nil
for _, scope := range scopes {
db = scope(db)
}
return db
}
// Preload preload associations with given conditions
//
// // get all users, and preload all non-cancelled orders
// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
func (db *DB) Preload(query string, args ...any) (tx *DB) {
tx = db.getInstance()
if tx.Statement.Preloads == nil {
tx.Statement.Preloads = map[string][]any{}
}
tx.Statement.Preloads[query] = args
return
}
// Attrs provide attributes used in [FirstOrCreate] or [FirstOrInit]
//
// Attrs only adds attributes if the record is not found.
//
// // assign an email if the record is not found
// db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
//
// // assign an email if the record is not found, otherwise ignore provided email
// db.Where(User{Name: "jinzhu"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "jinzhu", Age: 20}
//
// [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate
// [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit
func (db *DB) Attrs(attrs ...any) (tx *DB) {
tx = db.getInstance()
tx.Statement.attrs = attrs
return
}
// Assign provide attributes used in [FirstOrCreate] or [FirstOrInit]
//
// Assign adds attributes even if the record is found. If using FirstOrCreate, this means that
// records will be updated even if they are found.
//
// // assign an email regardless of if the record is not found
// db.Where(User{Name: "non_existing"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
//
// // assign email regardless of if record is found
// db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
//
// [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate
// [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit
func (db *DB) Assign(attrs ...any) (tx *DB) {
tx = db.getInstance()
tx.Statement.assigns = attrs
return
}
func (db *DB) Unscoped() (tx *DB) {
tx = db.getInstance()
tx.Statement.Unscoped = true
return
}
func (db *DB) Raw(sql string, values ...any) (tx *DB) {
tx = db.getInstance()
tx.Statement.SQL = strings.Builder{}
if strings.Contains(sql, "@") {
clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement)
} else {
clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
}
return
}
================================================
FILE: gossh/gorm/clause/clause.go
================================================
package clause
// Interface clause interface
type Interface interface {
Name() string
Build(Builder)
MergeClause(*Clause)
}
// ClauseBuilder clause builder, allows to customize how to build clause
type ClauseBuilder func(Clause, Builder)
type Writer interface {
WriteByte(byte) error
WriteString(string) (int, error)
}
// Builder builder interface
type Builder interface {
Writer
WriteQuoted(field any)
AddVar(Writer, ...any)
AddError(error) error
}
// Clause
type Clause struct {
Name string // WHERE
BeforeExpression Expression
AfterNameExpression Expression
AfterExpression Expression
Expression Expression
Builder ClauseBuilder
}
// Build build clause
func (c Clause) Build(builder Builder) {
if c.Builder != nil {
c.Builder(c, builder)
} else if c.Expression != nil {
if c.BeforeExpression != nil {
c.BeforeExpression.Build(builder)
builder.WriteByte(' ')
}
if c.Name != "" {
builder.WriteString(c.Name)
builder.WriteByte(' ')
}
if c.AfterNameExpression != nil {
c.AfterNameExpression.Build(builder)
builder.WriteByte(' ')
}
c.Expression.Build(builder)
if c.AfterExpression != nil {
builder.WriteByte(' ')
c.AfterExpression.Build(builder)
}
}
}
const (
PrimaryKey string = "~~~py~~~" // primary key
CurrentTable string = "~~~ct~~~" // current table
Associations string = "~~~as~~~" // associations
)
var (
currentTable = Table{Name: CurrentTable}
PrimaryColumn = Column{Table: CurrentTable, Name: PrimaryKey}
)
// Column quote with name
type Column struct {
Table string
Name string
Alias string
Raw bool
}
// Table quote with name
type Table struct {
Name string
Alias string
Raw bool
}
================================================
FILE: gossh/gorm/clause/delete.go
================================================
package clause
type Delete struct {
Modifier string
}
func (d Delete) Name() string {
return "DELETE"
}
func (d Delete) Build(builder Builder) {
builder.WriteString("DELETE")
if d.Modifier != "" {
builder.WriteByte(' ')
builder.WriteString(d.Modifier)
}
}
func (d Delete) MergeClause(clause *Clause) {
clause.Name = ""
clause.Expression = d
}
================================================
FILE: gossh/gorm/clause/expression.go
================================================
package clause
import (
"database/sql"
"database/sql/driver"
"go/ast"
"reflect"
)
// Expression expression interface
type Expression interface {
Build(builder Builder)
}
// NegationExpressionBuilder negation expression builder
type NegationExpressionBuilder interface {
NegationBuild(builder Builder)
}
// Expr raw expression
type Expr struct {
SQL string
Vars []any
WithoutParentheses bool
}
// Build build raw expression
func (expr Expr) Build(builder Builder) {
var (
afterParenthesis bool
idx int
)
for _, v := range []byte(expr.SQL) {
if v == '?' && len(expr.Vars) > idx {
if afterParenthesis || expr.WithoutParentheses {
if _, ok := expr.Vars[idx].(driver.Valuer); ok {
builder.AddVar(builder, expr.Vars[idx])
} else {
switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() {
case reflect.Slice, reflect.Array:
if rv.Len() == 0 {
builder.AddVar(builder, nil)
} else {
for i := 0; i < rv.Len(); i++ {
if i > 0 {
builder.WriteByte(',')
}
builder.AddVar(builder, rv.Index(i).Interface())
}
}
default:
builder.AddVar(builder, expr.Vars[idx])
}
}
} else {
builder.AddVar(builder, expr.Vars[idx])
}
idx++
} else {
if v == '(' {
afterParenthesis = true
} else {
afterParenthesis = false
}
builder.WriteByte(v)
}
}
if idx < len(expr.Vars) {
for _, v := range expr.Vars[idx:] {
builder.AddVar(builder, sql.NamedArg{Value: v})
}
}
}
// NamedExpr raw expression for named expr
type NamedExpr struct {
SQL string
Vars []any
}
// Build build raw expression
func (expr NamedExpr) Build(builder Builder) {
var (
idx int
inName bool
afterParenthesis bool
namedMap = make(map[string]any, len(expr.Vars))
)
for _, v := range expr.Vars {
switch value := v.(type) {
case sql.NamedArg:
namedMap[value.Name] = value.Value
case map[string]any:
for k, v := range value {
namedMap[k] = v
}
default:
var appendFieldsToMap func(reflect.Value)
appendFieldsToMap = func(reflectValue reflect.Value) {
reflectValue = reflect.Indirect(reflectValue)
switch reflectValue.Kind() {
case reflect.Struct:
modelType := reflectValue.Type()
for i := 0; i < modelType.NumField(); i++ {
if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
namedMap[fieldStruct.Name] = reflectValue.Field(i).Interface()
if fieldStruct.Anonymous {
appendFieldsToMap(reflectValue.Field(i))
}
}
}
}
}
appendFieldsToMap(reflect.ValueOf(value))
}
}
name := make([]byte, 0, 10)
for _, v := range []byte(expr.SQL) {
if v == '@' && !inName {
inName = true
name = name[:0]
} else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\r' || v == '\n' || v == ';' {
if inName {
if nv, ok := namedMap[string(name)]; ok {
builder.AddVar(builder, nv)
} else {
builder.WriteByte('@')
builder.WriteString(string(name))
}
inName = false
}
afterParenthesis = false
builder.WriteByte(v)
} else if v == '?' && len(expr.Vars) > idx {
if afterParenthesis {
if _, ok := expr.Vars[idx].(driver.Valuer); ok {
builder.AddVar(builder, expr.Vars[idx])
} else {
switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() {
case reflect.Slice, reflect.Array:
if rv.Len() == 0 {
builder.AddVar(builder, nil)
} else {
for i := 0; i < rv.Len(); i++ {
if i > 0 {
builder.WriteByte(',')
}
builder.AddVar(builder, rv.Index(i).Interface())
}
}
default:
builder.AddVar(builder, expr.Vars[idx])
}
}
} else {
builder.AddVar(builder, expr.Vars[idx])
}
idx++
} else if inName {
name = append(name, v)
} else {
if v == '(' {
afterParenthesis = true
} else {
afterParenthesis = false
}
builder.WriteByte(v)
}
}
if inName {
if nv, ok := namedMap[string(name)]; ok {
builder.AddVar(builder, nv)
} else {
builder.WriteByte('@')
builder.WriteString(string(name))
}
}
}
// IN Whether a value is within a set of values
type IN struct {
Column any
Values []any
}
func (in IN) Build(builder Builder) {
builder.WriteQuoted(in.Column)
switch len(in.Values) {
case 0:
builder.WriteString(" IN (NULL)")
case 1:
if _, ok := in.Values[0].([]any); !ok {
builder.WriteString(" = ")
builder.AddVar(builder, in.Values[0])
break
}
fallthrough
default:
builder.WriteString(" IN (")
builder.AddVar(builder, in.Values...)
builder.WriteByte(')')
}
}
func (in IN) NegationBuild(builder Builder) {
builder.WriteQuoted(in.Column)
switch len(in.Values) {
case 0:
builder.WriteString(" IS NOT NULL")
case 1:
if _, ok := in.Values[0].([]any); !ok {
builder.WriteString(" <> ")
builder.AddVar(builder, in.Values[0])
break
}
fallthrough
default:
builder.WriteString(" NOT IN (")
builder.AddVar(builder, in.Values...)
builder.WriteByte(')')
}
}
// Eq equal to for where
type Eq struct {
Column any
Value any
}
func (eq Eq) Build(builder Builder) {
builder.WriteQuoted(eq.Column)
switch eq.Value.(type) {
case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []any:
rv := reflect.ValueOf(eq.Value)
if rv.Len() == 0 {
builder.WriteString(" IN (NULL)")
} else {
builder.WriteString(" IN (")
for i := 0; i < rv.Len(); i++ {
if i > 0 {
builder.WriteByte(',')
}
builder.AddVar(builder, rv.Index(i).Interface())
}
builder.WriteByte(')')
}
default:
if eqNil(eq.Value) {
builder.WriteString(" IS NULL")
} else {
builder.WriteString(" = ")
builder.AddVar(builder, eq.Value)
}
}
}
func (eq Eq) NegationBuild(builder Builder) {
Neq(eq).Build(builder)
}
// Neq not equal to for where
type Neq Eq
func (neq Neq) Build(builder Builder) {
builder.WriteQuoted(neq.Column)
switch neq.Value.(type) {
case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []any:
builder.WriteString(" NOT IN (")
rv := reflect.ValueOf(neq.Value)
for i := 0; i < rv.Len(); i++ {
if i > 0 {
builder.WriteByte(',')
}
builder.AddVar(builder, rv.Index(i).Interface())
}
builder.WriteByte(')')
default:
if eqNil(neq.Value) {
builder.WriteString(" IS NOT NULL")
} else {
builder.WriteString(" <> ")
builder.AddVar(builder, neq.Value)
}
}
}
func (neq Neq) NegationBuild(builder Builder) {
Eq(neq).Build(builder)
}
// Gt greater than for where
type Gt Eq
func (gt Gt) Build(builder Builder) {
builder.WriteQuoted(gt.Column)
builder.WriteString(" > ")
builder.AddVar(builder, gt.Value)
}
func (gt Gt) NegationBuild(builder Builder) {
Lte(gt).Build(builder)
}
// Gte greater than or equal to for where
type Gte Eq
func (gte Gte) Build(builder Builder) {
builder.WriteQuoted(gte.Column)
builder.WriteString(" >= ")
builder.AddVar(builder, gte.Value)
}
func (gte Gte) NegationBuild(builder Builder) {
Lt(gte).Build(builder)
}
// Lt less than for where
type Lt Eq
func (lt Lt) Build(builder Builder) {
builder.WriteQuoted(lt.Column)
builder.WriteString(" < ")
builder.AddVar(builder, lt.Value)
}
func (lt Lt) NegationBuild(builder Builder) {
Gte(lt).Build(builder)
}
// Lte less than or equal to for where
type Lte Eq
func (lte Lte) Build(builder Builder) {
builder.WriteQuoted(lte.Column)
builder.WriteString(" <= ")
builder.AddVar(builder, lte.Value)
}
func (lte Lte) NegationBuild(builder Builder) {
Gt(lte).Build(builder)
}
// Like whether string matches regular expression
type Like Eq
func (like Like) Build(builder Builder) {
builder.WriteQuoted(like.Column)
builder.WriteString(" LIKE ")
builder.AddVar(builder, like.Value)
}
func (like Like) NegationBuild(builder Builder) {
builder.WriteQuoted(like.Column)
builder.WriteString(" NOT LIKE ")
builder.AddVar(builder, like.Value)
}
func eqNil(value any) bool {
if valuer, ok := value.(driver.Valuer); ok && !eqNilReflect(valuer) {
value, _ = valuer.Value()
}
return value == nil || eqNilReflect(value)
}
func eqNilReflect(value any) bool {
reflectValue := reflect.ValueOf(value)
return reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil()
}
================================================
FILE: gossh/gorm/clause/from.go
================================================
package clause
// From from clause
type From struct {
Tables []Table
Joins []Join
}
// Name from clause name
func (from From) Name() string {
return "FROM"
}
// Build build from clause
func (from From) Build(builder Builder) {
if len(from.Tables) > 0 {
for idx, table := range from.Tables {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(table)
}
} else {
builder.WriteQuoted(currentTable)
}
for _, join := range from.Joins {
builder.WriteByte(' ')
join.Build(builder)
}
}
// MergeClause merge from clause
func (from From) MergeClause(clause *Clause) {
clause.Expression = from
}
================================================
FILE: gossh/gorm/clause/group_by.go
================================================
package clause
// GroupBy group by clause
type GroupBy struct {
Columns []Column
Having []Expression
}
// Name from clause name
func (groupBy GroupBy) Name() string {
return "GROUP BY"
}
// Build build group by clause
func (groupBy GroupBy) Build(builder Builder) {
for idx, column := range groupBy.Columns {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(column)
}
if len(groupBy.Having) > 0 {
builder.WriteString(" HAVING ")
Where{Exprs: groupBy.Having}.Build(builder)
}
}
// MergeClause merge group by clause
func (groupBy GroupBy) MergeClause(clause *Clause) {
if v, ok := clause.Expression.(GroupBy); ok {
copiedColumns := make([]Column, len(v.Columns))
copy(copiedColumns, v.Columns)
groupBy.Columns = append(copiedColumns, groupBy.Columns...)
copiedHaving := make([]Expression, len(v.Having))
copy(copiedHaving, v.Having)
groupBy.Having = append(copiedHaving, groupBy.Having...)
}
clause.Expression = groupBy
if len(groupBy.Columns) == 0 {
clause.Name = ""
} else {
clause.Name = groupBy.Name()
}
}
================================================
FILE: gossh/gorm/clause/insert.go
================================================
package clause
type Insert struct {
Table Table
Modifier string
}
// Name insert clause name
func (insert Insert) Name() string {
return "INSERT"
}
// Build build insert clause
func (insert Insert) Build(builder Builder) {
if insert.Modifier != "" {
builder.WriteString(insert.Modifier)
builder.WriteByte(' ')
}
builder.WriteString("INTO ")
if insert.Table.Name == "" {
builder.WriteQuoted(currentTable)
} else {
builder.WriteQuoted(insert.Table)
}
}
// MergeClause merge insert clause
func (insert Insert) MergeClause(clause *Clause) {
if v, ok := clause.Expression.(Insert); ok {
if insert.Modifier == "" {
insert.Modifier = v.Modifier
}
if insert.Table.Name == "" {
insert.Table = v.Table
}
}
clause.Expression = insert
}
================================================
FILE: gossh/gorm/clause/joins.go
================================================
package clause
type JoinType string
const (
CrossJoin JoinType = "CROSS"
InnerJoin JoinType = "INNER"
LeftJoin JoinType = "LEFT"
RightJoin JoinType = "RIGHT"
)
// Join clause for from
type Join struct {
Type JoinType
Table Table
ON Where
Using []string
Expression Expression
}
func (join Join) Build(builder Builder) {
if join.Expression != nil {
join.Expression.Build(builder)
} else {
if join.Type != "" {
builder.WriteString(string(join.Type))
builder.WriteByte(' ')
}
builder.WriteString("JOIN ")
builder.WriteQuoted(join.Table)
if len(join.ON.Exprs) > 0 {
builder.WriteString(" ON ")
join.ON.Build(builder)
} else if len(join.Using) > 0 {
builder.WriteString(" USING (")
for idx, c := range join.Using {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(c)
}
builder.WriteByte(')')
}
}
}
================================================
FILE: gossh/gorm/clause/limit.go
================================================
package clause
// Limit limit clause
type Limit struct {
Limit *int
Offset int
}
// Name where clause name
func (limit Limit) Name() string {
return "LIMIT"
}
// Build build where clause
func (limit Limit) Build(builder Builder) {
if limit.Limit != nil && *limit.Limit >= 0 {
builder.WriteString("LIMIT ")
builder.AddVar(builder, *limit.Limit)
}
if limit.Offset > 0 {
if limit.Limit != nil && *limit.Limit >= 0 {
builder.WriteByte(' ')
}
builder.WriteString("OFFSET ")
builder.AddVar(builder, limit.Offset)
}
}
// MergeClause merge order by clauses
func (limit Limit) MergeClause(clause *Clause) {
clause.Name = ""
if v, ok := clause.Expression.(Limit); ok {
if (limit.Limit == nil || *limit.Limit == 0) && v.Limit != nil {
limit.Limit = v.Limit
}
if limit.Offset == 0 && v.Offset > 0 {
limit.Offset = v.Offset
} else if limit.Offset < 0 {
limit.Offset = 0
}
}
clause.Expression = limit
}
================================================
FILE: gossh/gorm/clause/locking.go
================================================
package clause
const (
LockingStrengthUpdate = "UPDATE"
LockingStrengthShare = "SHARE"
LockingOptionsSkipLocked = "SKIP LOCKED"
LockingOptionsNoWait = "NOWAIT"
)
type Locking struct {
Strength string
Table Table
Options string
}
// Name where clause name
func (locking Locking) Name() string {
return "FOR"
}
// Build build where clause
func (locking Locking) Build(builder Builder) {
builder.WriteString(locking.Strength)
if locking.Table.Name != "" {
builder.WriteString(" OF ")
builder.WriteQuoted(locking.Table)
}
if locking.Options != "" {
builder.WriteByte(' ')
builder.WriteString(locking.Options)
}
}
// MergeClause merge order by clauses
func (locking Locking) MergeClause(clause *Clause) {
clause.Expression = locking
}
================================================
FILE: gossh/gorm/clause/on_conflict.go
================================================
package clause
type OnConflict struct {
Columns []Column
Where Where
TargetWhere Where
OnConstraint string
DoNothing bool
DoUpdates Set
UpdateAll bool
}
func (OnConflict) Name() string {
return "ON CONFLICT"
}
// Build build onConflict clause
func (onConflict OnConflict) Build(builder Builder) {
if onConflict.OnConstraint != "" {
builder.WriteString("ON CONSTRAINT ")
builder.WriteString(onConflict.OnConstraint)
builder.WriteByte(' ')
} else {
if len(onConflict.Columns) > 0 {
builder.WriteByte('(')
for idx, column := range onConflict.Columns {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(column)
}
builder.WriteString(`) `)
}
if len(onConflict.TargetWhere.Exprs) > 0 {
builder.WriteString(" WHERE ")
onConflict.TargetWhere.Build(builder)
builder.WriteByte(' ')
}
}
if onConflict.DoNothing {
builder.WriteString("DO NOTHING")
} else {
builder.WriteString("DO UPDATE SET ")
onConflict.DoUpdates.Build(builder)
}
if len(onConflict.Where.Exprs) > 0 {
builder.WriteString(" WHERE ")
onConflict.Where.Build(builder)
builder.WriteByte(' ')
}
}
// MergeClause merge onConflict clauses
func (onConflict OnConflict) MergeClause(clause *Clause) {
clause.Expression = onConflict
}
================================================
FILE: gossh/gorm/clause/order_by.go
================================================
package clause
type OrderByColumn struct {
Column Column
Desc bool
Reorder bool
}
type OrderBy struct {
Columns []OrderByColumn
Expression Expression
}
// Name where clause name
func (orderBy OrderBy) Name() string {
return "ORDER BY"
}
// Build build where clause
func (orderBy OrderBy) Build(builder Builder) {
if orderBy.Expression != nil {
orderBy.Expression.Build(builder)
} else {
for idx, column := range orderBy.Columns {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(column.Column)
if column.Desc {
builder.WriteString(" DESC")
}
}
}
}
// MergeClause merge order by clauses
func (orderBy OrderBy) MergeClause(clause *Clause) {
if v, ok := clause.Expression.(OrderBy); ok {
for i := len(orderBy.Columns) - 1; i >= 0; i-- {
if orderBy.Columns[i].Reorder {
orderBy.Columns = orderBy.Columns[i:]
clause.Expression = orderBy
return
}
}
copiedColumns := make([]OrderByColumn, len(v.Columns))
copy(copiedColumns, v.Columns)
orderBy.Columns = append(copiedColumns, orderBy.Columns...)
}
clause.Expression = orderBy
}
================================================
FILE: gossh/gorm/clause/returning.go
================================================
package clause
type Returning struct {
Columns []Column
}
// Name where clause name
func (returning Returning) Name() string {
return "RETURNING"
}
// Build build where clause
func (returning Returning) Build(builder Builder) {
if len(returning.Columns) > 0 {
for idx, column := range returning.Columns {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(column)
}
} else {
builder.WriteByte('*')
}
}
// MergeClause merge order by clauses
func (returning Returning) MergeClause(clause *Clause) {
if v, ok := clause.Expression.(Returning); ok {
returning.Columns = append(v.Columns, returning.Columns...)
}
clause.Expression = returning
}
================================================
FILE: gossh/gorm/clause/select.go
================================================
package clause
// Select select attrs when querying, updating, creating
type Select struct {
Distinct bool
Columns []Column
Expression Expression
}
func (s Select) Name() string {
return "SELECT"
}
func (s Select) Build(builder Builder) {
if len(s.Columns) > 0 {
if s.Distinct {
builder.WriteString("DISTINCT ")
}
for idx, column := range s.Columns {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(column)
}
} else {
builder.WriteByte('*')
}
}
func (s Select) MergeClause(clause *Clause) {
if s.Expression != nil {
if s.Distinct {
if expr, ok := s.Expression.(Expr); ok {
expr.SQL = "DISTINCT " + expr.SQL
clause.Expression = expr
return
}
}
clause.Expression = s.Expression
} else {
clause.Expression = s
}
}
// CommaExpression represents a group of expressions separated by commas.
type CommaExpression struct {
Exprs []Expression
}
func (comma CommaExpression) Build(builder Builder) {
for idx, expr := range comma.Exprs {
if idx > 0 {
_, _ = builder.WriteString(", ")
}
expr.Build(builder)
}
}
================================================
FILE: gossh/gorm/clause/set.go
================================================
package clause
import "sort"
type Set []Assignment
type Assignment struct {
Column Column
Value any
}
func (set Set) Name() string {
return "SET"
}
func (set Set) Build(builder Builder) {
if len(set) > 0 {
for idx, assignment := range set {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(assignment.Column)
builder.WriteByte('=')
builder.AddVar(builder, assignment.Value)
}
} else {
builder.WriteQuoted(Column{Name: PrimaryKey})
builder.WriteByte('=')
builder.WriteQuoted(Column{Name: PrimaryKey})
}
}
// MergeClause merge assignments clauses
func (set Set) MergeClause(clause *Clause) {
copiedAssignments := make([]Assignment, len(set))
copy(copiedAssignments, set)
clause.Expression = Set(copiedAssignments)
}
func Assignments(values map[string]any) Set {
keys := make([]string, 0, len(values))
for key := range values {
keys = append(keys, key)
}
sort.Strings(keys)
assignments := make([]Assignment, len(keys))
for idx, key := range keys {
assignments[idx] = Assignment{Column: Column{Name: key}, Value: values[key]}
}
return assignments
}
func AssignmentColumns(values []string) Set {
assignments := make([]Assignment, len(values))
for idx, value := range values {
assignments[idx] = Assignment{Column: Column{Name: value}, Value: Column{Table: "excluded", Name: value}}
}
return assignments
}
================================================
FILE: gossh/gorm/clause/update.go
================================================
package clause
type Update struct {
Modifier string
Table Table
}
// Name update clause name
func (update Update) Name() string {
return "UPDATE"
}
// Build build update clause
func (update Update) Build(builder Builder) {
if update.Modifier != "" {
builder.WriteString(update.Modifier)
builder.WriteByte(' ')
}
if update.Table.Name == "" {
builder.WriteQuoted(currentTable)
} else {
builder.WriteQuoted(update.Table)
}
}
// MergeClause merge update clause
func (update Update) MergeClause(clause *Clause) {
if v, ok := clause.Expression.(Update); ok {
if update.Modifier == "" {
update.Modifier = v.Modifier
}
if update.Table.Name == "" {
update.Table = v.Table
}
}
clause.Expression = update
}
================================================
FILE: gossh/gorm/clause/values.go
================================================
package clause
type Values struct {
Columns []Column
Values [][]any
}
// Name from clause name
func (Values) Name() string {
return "VALUES"
}
// Build build from clause
func (values Values) Build(builder Builder) {
if len(values.Columns) > 0 {
builder.WriteByte('(')
for idx, column := range values.Columns {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(column)
}
builder.WriteByte(')')
builder.WriteString(" VALUES ")
for idx, value := range values.Values {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteByte('(')
builder.AddVar(builder, value...)
builder.WriteByte(')')
}
} else {
builder.WriteString("DEFAULT VALUES")
}
}
// MergeClause merge values clauses
func (values Values) MergeClause(clause *Clause) {
clause.Name = ""
clause.Expression = values
}
================================================
FILE: gossh/gorm/clause/where.go
================================================
package clause
import (
"strings"
)
const (
AndWithSpace = " AND "
OrWithSpace = " OR "
)
// Where where clause
type Where struct {
Exprs []Expression
}
// Name where clause name
func (where Where) Name() string {
return "WHERE"
}
// Build build where clause
func (where Where) Build(builder Builder) {
if len(where.Exprs) == 1 {
if andCondition, ok := where.Exprs[0].(AndConditions); ok {
where.Exprs = andCondition.Exprs
}
}
// Switch position if the first query expression is a single Or condition
for idx, expr := range where.Exprs {
if v, ok := expr.(OrConditions); !ok || len(v.Exprs) > 1 {
if idx != 0 {
where.Exprs[0], where.Exprs[idx] = where.Exprs[idx], where.Exprs[0]
}
break
}
}
buildExprs(where.Exprs, builder, AndWithSpace)
}
func buildExprs(exprs []Expression, builder Builder, joinCond string) {
wrapInParentheses := false
for idx, expr := range exprs {
if idx > 0 {
if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 {
builder.WriteString(OrWithSpace)
} else {
builder.WriteString(joinCond)
}
}
if len(exprs) > 1 {
switch v := expr.(type) {
case OrConditions:
if len(v.Exprs) == 1 {
if e, ok := v.Exprs[0].(Expr); ok {
sql := strings.ToUpper(e.SQL)
wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
}
}
case AndConditions:
if len(v.Exprs) == 1 {
if e, ok := v.Exprs[0].(Expr); ok {
sql := strings.ToUpper(e.SQL)
wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
}
}
case Expr:
sql := strings.ToUpper(v.SQL)
wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
case NamedExpr:
sql := strings.ToUpper(v.SQL)
wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
}
}
if wrapInParentheses {
builder.WriteByte('(')
expr.Build(builder)
builder.WriteByte(')')
wrapInParentheses = false
} else {
expr.Build(builder)
}
}
}
// MergeClause merge where clauses
func (where Where) MergeClause(clause *Clause) {
if w, ok := clause.Expression.(Where); ok {
exprs := make([]Expression, len(w.Exprs)+len(where.Exprs))
copy(exprs, w.Exprs)
copy(exprs[len(w.Exprs):], where.Exprs)
where.Exprs = exprs
}
clause.Expression = where
}
func And(exprs ...Expression) Expression {
if len(exprs) == 0 {
return nil
}
if len(exprs) == 1 {
if _, ok := exprs[0].(OrConditions); !ok {
return exprs[0]
}
}
return AndConditions{Exprs: exprs}
}
type AndConditions struct {
Exprs []Expression
}
func (and AndConditions) Build(builder Builder) {
if len(and.Exprs) > 1 {
builder.WriteByte('(')
buildExprs(and.Exprs, builder, AndWithSpace)
builder.WriteByte(')')
} else {
buildExprs(and.Exprs, builder, AndWithSpace)
}
}
func Or(exprs ...Expression) Expression {
if len(exprs) == 0 {
return nil
}
return OrConditions{Exprs: exprs}
}
type OrConditions struct {
Exprs []Expression
}
func (or OrConditions) Build(builder Builder) {
if len(or.Exprs) > 1 {
builder.WriteByte('(')
buildExprs(or.Exprs, builder, OrWithSpace)
builder.WriteByte(')')
} else {
buildExprs(or.Exprs, builder, OrWithSpace)
}
}
func Not(exprs ...Expression) Expression {
if len(exprs) == 0 {
return nil
}
if len(exprs) == 1 {
if andCondition, ok := exprs[0].(AndConditions); ok {
exprs = andCondition.Exprs
}
}
return NotConditions{Exprs: exprs}
}
type NotConditions struct {
Exprs []Expression
}
func (not NotConditions) Build(builder Builder) {
anyNegationBuilder := false
for _, c := range not.Exprs {
if _, ok := c.(NegationExpressionBuilder); ok {
anyNegationBuilder = true
break
}
}
if anyNegationBuilder {
if len(not.Exprs) > 1 {
builder.WriteByte('(')
}
for idx, c := range not.Exprs {
if idx > 0 {
builder.WriteString(AndWithSpace)
}
if negationBuilder, ok := c.(NegationExpressionBuilder); ok {
negationBuilder.NegationBuild(builder)
} else {
builder.WriteString("NOT ")
e, wrapInParentheses := c.(Expr)
if wrapInParentheses {
sql := strings.ToUpper(e.SQL)
if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses {
builder.WriteByte('(')
}
}
c.Build(builder)
if wrapInParentheses {
builder.WriteByte(')')
}
}
}
if len(not.Exprs) > 1 {
builder.WriteByte(')')
}
} else {
builder.WriteString("NOT ")
if len(not.Exprs) > 1 {
builder.WriteByte('(')
}
for idx, c := range not.Exprs {
if idx > 0 {
builder.WriteString(AndWithSpace)
}
e, wrapInParentheses := c.(Expr)
if wrapInParentheses {
sql := strings.ToUpper(e.SQL)
if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses {
builder.WriteByte('(')
}
}
c.Build(builder)
if wrapInParentheses {
builder.WriteByte(')')
}
}
if len(not.Exprs) > 1 {
builder.WriteByte(')')
}
}
}
================================================
FILE: gossh/gorm/clause/with.go
================================================
package clause
type With struct{}
================================================
FILE: gossh/gorm/driver/mysql/error_translator.go
================================================
package mysql
import (
"gossh/gorm"
"gossh/mysql"
)
// The error codes to map mysql errors to gorm errors, here is the mysql error codes reference https://dev.mysql.com/doc/mysql-errors/8.0/en/server-error-reference.html.
var errCodes = map[uint16]error{
1062: gorm.ErrDuplicatedKey,
1451: gorm.ErrForeignKeyViolated,
1452: gorm.ErrForeignKeyViolated,
}
func (dialector Dialector) Translate(err error) error {
if mysqlErr, ok := err.(*mysql.MySQLError); ok {
if translatedErr, found := errCodes[mysqlErr.Number]; found {
return translatedErr
}
return mysqlErr
}
return err
}
================================================
FILE: gossh/gorm/driver/mysql/migrator.go
================================================
package mysql
import (
"database/sql"
"fmt"
"strconv"
"strings"
"gossh/gorm"
"gossh/gorm/clause"
"gossh/gorm/migrator"
"gossh/gorm/schema"
)
const indexSql = `
SELECT
TABLE_NAME,
COLUMN_NAME,
INDEX_NAME,
NON_UNIQUE
FROM
information_schema.STATISTICS
WHERE
TABLE_SCHEMA = ?
AND TABLE_NAME = ?
ORDER BY
INDEX_NAME,
SEQ_IN_INDEX`
var typeAliasMap = map[string][]string{
"bool": {"tinyint"},
"tinyint": {"bool"},
}
type Migrator struct {
migrator.Migrator
Dialector
}
func (m Migrator) FullDataTypeOf(field *schema.Field) clause.Expr {
expr := m.Migrator.FullDataTypeOf(field)
if value, ok := field.TagSettings["COMMENT"]; ok {
expr.SQL += " COMMENT " + m.Dialector.Explain("?", value)
}
return expr
}
// MigrateColumnUnique migrate column's UNIQUE constraint.
// In MySQL, ColumnType's Unique is affected by UniqueIndex, so we have to take care of the UniqueIndex.
func (m Migrator) MigrateColumnUnique(value any, field *schema.Field, columnType gorm.ColumnType) error {
unique, ok := columnType.Unique()
if !ok || field.PrimaryKey {
return nil // skip primary key
}
queryTx, execTx := m.GetQueryAndExecTx()
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
// We're currently only receiving boolean values on `Unique` tag,
// so the UniqueConstraint name is fixed
constraint := m.DB.NamingStrategy.UniqueName(stmt.Table, field.DBName)
if unique {
// Clean up redundant unique indexes
indexes, _ := queryTx.Migrator().GetIndexes(value)
for _, index := range indexes {
if uni, ok := index.Unique(); !ok || !uni {
continue
}
if columns := index.Columns(); len(columns) != 1 || columns[0] != field.DBName {
continue
}
if name := index.Name(); name == constraint || name == field.UniqueIndex {
continue
}
if err := execTx.Migrator().DropIndex(value, index.Name()); err != nil {
return err
}
}
hasConstraint := queryTx.Migrator().HasConstraint(value, constraint)
switch {
case field.Unique && !hasConstraint:
if field.Unique {
if err := execTx.Migrator().CreateConstraint(value, constraint); err != nil {
return err
}
}
// field isn't Unique but ColumnType's Unique is reported by UniqueConstraint.
case !field.Unique && hasConstraint:
if err := execTx.Migrator().DropConstraint(value, constraint); err != nil {
return err
}
if field.UniqueIndex != "" {
if err := execTx.Migrator().CreateIndex(value, field.UniqueIndex); err != nil {
return err
}
}
}
if field.UniqueIndex != "" && !queryTx.Migrator().HasIndex(value, field.UniqueIndex) {
if err := execTx.Migrator().CreateIndex(value, field.UniqueIndex); err != nil {
return err
}
}
} else {
if field.Unique {
if err := execTx.Migrator().CreateConstraint(value, constraint); err != nil {
return err
}
}
if field.UniqueIndex != "" {
if err := execTx.Migrator().CreateIndex(value, field.UniqueIndex); err != nil {
return err
}
}
}
return nil
})
}
func (m Migrator) AddColumn(value any, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
// avoid using the same name field
f := stmt.Schema.LookUpField(name)
if f == nil {
return fmt.Errorf("failed to look up field with name: %s", name)
}
if !f.IgnoreMigration {
fieldType := m.FullDataTypeOf(f)
columnName := clause.Column{Name: f.DBName}
values := []any{m.CurrentTable(stmt), columnName, fieldType}
var alterSql strings.Builder
alterSql.WriteString("ALTER TABLE ? ADD ? ?")
if f.PrimaryKey || strings.Contains(strings.ToLower(fieldType.SQL), "auto_increment") {
alterSql.WriteString(", ADD PRIMARY KEY (?)")
values = append(values, columnName)
}
return m.DB.Exec(alterSql.String(), values...).Error
}
return nil
})
}
func (m Migrator) AlterColumn(value any, field string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(field); field != nil {
fullDataType := m.FullDataTypeOf(field)
if m.Dialector.DontSupportRenameColumnUnique {
fullDataType.SQL = strings.Replace(fullDataType.SQL, " UNIQUE ", " ", 1)
}
return m.DB.Exec(
"ALTER TABLE ? MODIFY COLUMN ? ?",
clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, fullDataType,
).Error
}
}
return fmt.Errorf("failed to look up field with name: %s", field)
})
}
func (m Migrator) TiDBVersion() (isTiDB bool, major, minor, patch int, err error) {
// TiDB version string looks like:
// "5.7.25-TiDB-v6.5.0" or "5.7.25-TiDB-v6.4.0-serverless"
tidbVersionArray := strings.Split(m.Dialector.ServerVersion, "-")
if len(tidbVersionArray) < 3 || tidbVersionArray[1] != "TiDB" {
// It isn't TiDB
return
}
rawVersion := strings.TrimPrefix(tidbVersionArray[2], "v")
realVersionArray := strings.Split(rawVersion, ".")
if major, err = strconv.Atoi(realVersionArray[0]); err != nil {
err = fmt.Errorf("failed to parse the version of TiDB, the major version is: %s", realVersionArray[0])
return
}
if minor, err = strconv.Atoi(realVersionArray[1]); err != nil {
err = fmt.Errorf("failed to parse the version of TiDB, the minor version is: %s", realVersionArray[1])
return
}
if patch, err = strconv.Atoi(realVersionArray[2]); err != nil {
err = fmt.Errorf("failed to parse the version of TiDB, the patch version is: %s", realVersionArray[2])
return
}
isTiDB = true
return
}
func (m Migrator) RenameColumn(value any, oldName, newName string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if !m.Dialector.DontSupportRenameColumn {
return m.Migrator.RenameColumn(value, oldName, newName)
}
var field *schema.Field
if stmt.Schema != nil {
if f := stmt.Schema.LookUpField(oldName); f != nil {
oldName = f.DBName
field = f
}
if f := stmt.Schema.LookUpField(newName); f != nil {
newName = f.DBName
field = f
}
}
if field != nil {
return m.DB.Exec(
"ALTER TABLE ? CHANGE ? ? ?",
clause.Table{Name: stmt.Table}, clause.Column{Name: oldName},
clause.Column{Name: newName}, m.FullDataTypeOf(field),
).Error
}
return fmt.Errorf("failed to look up field with name: %s", newName)
})
}
func (m Migrator) DropConstraint(value any, name string) error {
if !m.Dialector.Config.DontSupportDropConstraint {
return m.Migrator.DropConstraint(value, name)
}
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
if constraint != nil {
name = constraint.GetName()
switch constraint.(type) {
case *schema.Constraint:
return m.DB.Exec("ALTER TABLE ? DROP FOREIGN KEY ?", clause.Table{Name: table}, clause.Column{Name: name}).Error
case *schema.CheckConstraint:
return m.DB.Exec("ALTER TABLE ? DROP CHECK ?", clause.Table{Name: table}, clause.Column{Name: name}).Error
}
}
if m.HasIndex(value, name) {
return m.DB.Exec("ALTER TABLE ? DROP INDEX ?", clause.Table{Name: table}, clause.Column{Name: name}).Error
}
return nil
})
}
func (m Migrator) RenameIndex(value any, oldName, newName string) error {
if !m.Dialector.DontSupportRenameIndex {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
return m.DB.Exec(
"ALTER TABLE ? RENAME INDEX ? TO ?",
clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName},
).Error
})
}
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
err := m.DropIndex(value, oldName)
if err != nil {
return err
}
if stmt.Schema != nil {
if idx := stmt.Schema.LookIndex(newName); idx == nil {
if idx = stmt.Schema.LookIndex(oldName); idx != nil {
opts := m.BuildIndexOptions(idx.Fields, stmt)
values := []any{clause.Column{Name: newName}, clause.Table{Name: stmt.Table}, opts}
createIndexSQL := "CREATE "
if idx.Class != "" {
createIndexSQL += idx.Class + " "
}
createIndexSQL += "INDEX ? ON ??"
if idx.Type != "" {
createIndexSQL += " USING " + idx.Type
}
return m.DB.Exec(createIndexSQL, values...).Error
}
}
}
return m.CreateIndex(value, newName)
})
}
func (m Migrator) DropTable(values ...any) error {
values = m.ReorderModels(values, false)
return m.DB.Connection(func(tx *gorm.DB) error {
tx.Exec("SET FOREIGN_KEY_CHECKS = 0;")
for i := len(values) - 1; i >= 0; i-- {
if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error {
return tx.Exec("DROP TABLE IF EXISTS ? CASCADE", clause.Table{Name: stmt.Table}).Error
}); err != nil {
return err
}
}
return tx.Exec("SET FOREIGN_KEY_CHECKS = 1;").Error
})
}
// ColumnTypes column types return columnTypes,error
func (m Migrator) ColumnTypes(value any) ([]gorm.ColumnType, error) {
columnTypes := make([]gorm.ColumnType, 0)
err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
var (
currentDatabase, table = m.CurrentSchema(stmt, stmt.Table)
columnTypeSQL = "SELECT column_name, column_default, is_nullable = 'YES', data_type, character_maximum_length, column_type, column_key, extra, column_comment, numeric_precision, numeric_scale "
rows, err = m.DB.Session(&gorm.Session{}).Table(table).Limit(1).Rows()
)
if err != nil {
return err
}
rawColumnTypes, err := rows.ColumnTypes()
if err != nil {
return err
}
if err := rows.Close(); err != nil {
return err
}
if !m.DisableDatetimePrecision {
columnTypeSQL += ", datetime_precision "
}
columnTypeSQL += "FROM information_schema.columns WHERE table_schema = ? AND table_name = ? ORDER BY ORDINAL_POSITION"
columns, rowErr := m.DB.Table(table).Raw(columnTypeSQL, currentDatabase, table).Rows()
if rowErr != nil {
return rowErr
}
defer columns.Close()
for columns.Next() {
var (
column migrator.ColumnType
datetimePrecision sql.NullInt64
extraValue sql.NullString
columnKey sql.NullString
values = []any{
&column.NameValue, &column.DefaultValueValue, &column.NullableValue, &column.DataTypeValue, &column.LengthValue, &column.ColumnTypeValue, &columnKey, &extraValue, &column.CommentValue, &column.DecimalSizeValue, &column.ScaleValue,
}
)
if !m.DisableDatetimePrecision {
values = append(values, &datetimePrecision)
}
if scanErr := columns.Scan(values...); scanErr != nil {
return scanErr
}
column.PrimaryKeyValue = sql.NullBool{Bool: false, Valid: true}
column.UniqueValue = sql.NullBool{Bool: false, Valid: true}
switch columnKey.String {
case "PRI":
column.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true}
case "UNI":
column.UniqueValue = sql.NullBool{Bool: true, Valid: true}
}
if strings.Contains(extraValue.String, "auto_increment") {
column.AutoIncrementValue = sql.NullBool{Bool: true, Valid: true}
}
// only trim paired single-quotes
s := column.DefaultValueValue.String
for (len(s) >= 3 && s[0] == '\'' && s[len(s)-1] == '\'' && s[len(s)-2] != '\\') ||
(len(s) == 2 && s == "''") {
s = s[1 : len(s)-1]
}
column.DefaultValueValue.String = s
if m.Dialector.DontSupportNullAsDefaultValue {
// rewrite mariadb default value like other version
if column.DefaultValueValue.Valid && column.DefaultValueValue.String == "NULL" {
column.DefaultValueValue.Valid = false
column.DefaultValueValue.String = ""
}
}
if datetimePrecision.Valid {
column.DecimalSizeValue = datetimePrecision
}
for _, c := range rawColumnTypes {
if c.Name() == column.NameValue.String {
column.SQLColumnType = c
break
}
}
columnTypes = append(columnTypes, column)
}
return nil
})
return columnTypes, err
}
func (m Migrator) CurrentDatabase() (name string) {
baseName := m.Migrator.CurrentDatabase()
m.DB.Raw(
"SELECT SCHEMA_NAME from Information_schema.SCHEMATA where SCHEMA_NAME LIKE ? ORDER BY SCHEMA_NAME=? DESC,SCHEMA_NAME limit 1",
baseName+"%", baseName).Scan(&name)
return
}
func (m Migrator) GetTables() (tableList []string, err error) {
err = m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()).
Scan(&tableList).Error
return
}
func (m Migrator) GetIndexes(value any) ([]gorm.Index, error) {
indexes := make([]gorm.Index, 0)
err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
result := make([]*Index, 0)
schema, table := m.CurrentSchema(stmt, stmt.Table)
scanErr := m.DB.Table(table).Raw(indexSql, schema, table).Scan(&result).Error
if scanErr != nil {
return scanErr
}
indexMap, indexNames := groupByIndexName(result)
for _, name := range indexNames {
idx := indexMap[name]
if len(idx) == 0 {
continue
}
tempIdx := &migrator.Index{
TableName: idx[0].TableName,
NameValue: idx[0].IndexName,
PrimaryKeyValue: sql.NullBool{
Bool: idx[0].IndexName == "PRIMARY",
Valid: true,
},
UniqueValue: sql.NullBool{
Bool: idx[0].NonUnique == 0,
Valid: true,
},
}
for _, x := range idx {
tempIdx.ColumnList = append(tempIdx.ColumnList, x.ColumnName)
}
indexes = append(indexes, tempIdx)
}
return nil
})
return indexes, err
}
// Index table index info
type Index struct {
TableName string `gorm:"column:TABLE_NAME"`
ColumnName string `gorm:"column:COLUMN_NAME"`
IndexName string `gorm:"column:INDEX_NAME"`
NonUnique int32 `gorm:"column:NON_UNIQUE"`
}
func groupByIndexName(indexList []*Index) (map[string][]*Index, []string) {
columnIndexMap := make(map[string][]*Index, len(indexList))
indexNames := make([]string, 0, len(indexList))
for _, idx := range indexList {
if _, ok := columnIndexMap[idx.IndexName]; !ok {
indexNames = append(indexNames, idx.IndexName)
}
columnIndexMap[idx.IndexName] = append(columnIndexMap[idx.IndexName], idx)
}
return columnIndexMap, indexNames
}
func (m Migrator) CurrentSchema(stmt *gorm.Statement, table string) (string, string) {
if tables := strings.Split(table, `.`); len(tables) == 2 {
return tables[0], tables[1]
}
m.DB = m.DB.Table(table)
return m.CurrentDatabase(), table
}
func (m Migrator) GetTypeAliases(databaseTypeName string) []string {
return typeAliasMap[databaseTypeName]
}
// TableType table type return tableType,error
func (m Migrator) TableType(value any) (tableType gorm.TableType, err error) {
var table migrator.TableType
err = m.RunWithValue(value, func(stmt *gorm.Statement) error {
var (
values = []any{
&table.SchemaValue, &table.NameValue, &table.TypeValue, &table.CommentValue,
}
currentDatabase, tableName = m.CurrentSchema(stmt, stmt.Table)
tableTypeSQL = "SELECT table_schema, table_name, table_type, table_comment FROM information_schema.tables WHERE table_schema = ? AND table_name = ?"
)
row := m.DB.Table(tableName).Raw(tableTypeSQL, currentDatabase, tableName).Row()
if scanErr := row.Scan(values...); scanErr != nil {
return scanErr
}
return nil
})
return table, err
}
================================================
FILE: gossh/gorm/driver/mysql/mysql.go
================================================
package mysql
import (
"context"
"database/sql"
"fmt"
"math"
"regexp"
"strconv"
"strings"
"time"
"gossh/mysql"
"gossh/gorm"
"gossh/gorm/callbacks"
"gossh/gorm/clause"
"gossh/gorm/logger"
"gossh/gorm/migrator"
"gossh/gorm/schema"
"gossh/gorm/utils"
)
const (
AutoRandomTag = "auto_random()" // Treated as an auto_random field for tidb
)
type Config struct {
DriverName string
ServerVersion string
DSN string
DSNConfig *mysql.Config
Conn gorm.ConnPool
SkipInitializeWithVersion bool
DefaultStringSize uint
DefaultDatetimePrecision *int
DisableWithReturning bool
DisableDatetimePrecision bool
DontSupportRenameIndex bool
DontSupportRenameColumn bool
DontSupportForShareClause bool
DontSupportNullAsDefaultValue bool
DontSupportRenameColumnUnique bool
// As of MySQL 8.0.19, ALTER TABLE permits more general (and SQL standard) syntax
// for dropping and altering existing constraints of any type.
// see https://dev.mysql.com/doc/refman/8.0/en/alter-table.html
DontSupportDropConstraint bool
}
type Dialector struct {
*Config
}
var (
// CreateClauses create clauses
CreateClauses = []string{"INSERT", "VALUES", "ON CONFLICT"}
// QueryClauses query clauses
QueryClauses = []string{}
// UpdateClauses update clauses
UpdateClauses = []string{"UPDATE", "SET", "WHERE", "ORDER BY", "LIMIT"}
// DeleteClauses delete clauses
DeleteClauses = []string{"DELETE", "FROM", "WHERE", "ORDER BY", "LIMIT"}
defaultDatetimePrecision = 3
)
func Open(dsn string) gorm.Dialector {
dsnConf, _ := mysql.ParseDSN(dsn)
return &Dialector{Config: &Config{DSN: dsn, DSNConfig: dsnConf}}
}
func New(config Config) gorm.Dialector {
switch {
case config.DSN == "" && config.DSNConfig != nil:
config.DSN = config.DSNConfig.FormatDSN()
case config.DSN != "" && config.DSNConfig == nil:
config.DSNConfig, _ = mysql.ParseDSN(config.DSN)
}
return &Dialector{Config: &config}
}
func (dialector Dialector) Name() string {
return "mysql"
}
// NowFunc return now func
func (dialector Dialector) NowFunc(n int) func() time.Time {
return func() time.Time {
round := time.Second / time.Duration(math.Pow10(n))
return time.Now().Round(round)
}
}
func (dialector Dialector) Apply(config *gorm.Config) error {
if config.NowFunc != nil {
return nil
}
if dialector.DefaultDatetimePrecision == nil {
dialector.DefaultDatetimePrecision = &defaultDatetimePrecision
}
// while maintaining the readability of the code, separate the business logic from
// the general part and leave it to the function to do it here.
config.NowFunc = dialector.NowFunc(*dialector.DefaultDatetimePrecision)
return nil
}
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
if dialector.DriverName == "" {
dialector.DriverName = "mysql"
}
if dialector.DefaultDatetimePrecision == nil {
dialector.DefaultDatetimePrecision = &defaultDatetimePrecision
}
if dialector.Conn != nil {
db.ConnPool = dialector.Conn
} else {
db.ConnPool, err = sql.Open(dialector.DriverName, dialector.DSN)
if err != nil {
return err
}
}
withReturning := false
if !dialector.Config.SkipInitializeWithVersion {
err = db.ConnPool.QueryRowContext(context.Background(), "SELECT VERSION()").Scan(&dialector.ServerVersion)
if err != nil {
return err
}
if strings.Contains(dialector.ServerVersion, "MariaDB") {
dialector.Config.DontSupportRenameIndex = true
dialector.Config.DontSupportRenameColumn = true
dialector.Config.DontSupportForShareClause = true
dialector.Config.DontSupportNullAsDefaultValue = true
withReturning = checkVersion(dialector.ServerVersion, "10.5")
} else if strings.HasPrefix(dialector.ServerVersion, "5.6.") {
dialector.Config.DontSupportRenameIndex = true
dialector.Config.DontSupportRenameColumn = true
dialector.Config.DontSupportForShareClause = true
dialector.Config.DontSupportDropConstraint = true
} else if strings.HasPrefix(dialector.ServerVersion, "5.7.") {
dialector.Config.DontSupportRenameColumn = true
dialector.Config.DontSupportForShareClause = true
dialector.Config.DontSupportDropConstraint = true
} else if strings.HasPrefix(dialector.ServerVersion, "5.") {
dialector.Config.DisableDatetimePrecision = true
dialector.Config.DontSupportRenameIndex = true
dialector.Config.DontSupportRenameColumn = true
dialector.Config.DontSupportForShareClause = true
dialector.Config.DontSupportDropConstraint = true
}
if strings.Contains(dialector.ServerVersion, "TiDB") {
dialector.Config.DontSupportRenameColumnUnique = true
}
}
// register callbacks
callbackConfig := &callbacks.Config{
CreateClauses: CreateClauses,
QueryClauses: QueryClauses,
UpdateClauses: UpdateClauses,
DeleteClauses: DeleteClauses,
}
if !dialector.Config.DisableWithReturning && withReturning {
if !utils.Contains(callbackConfig.CreateClauses, "RETURNING") {
callbackConfig.CreateClauses = append(callbackConfig.CreateClauses, "RETURNING")
}
if !utils.Contains(callbackConfig.UpdateClauses, "RETURNING") {
callbackConfig.UpdateClauses = append(callbackConfig.UpdateClauses, "RETURNING")
}
if !utils.Contains(callbackConfig.DeleteClauses, "RETURNING") {
callbackConfig.DeleteClauses = append(callbackConfig.DeleteClauses, "RETURNING")
}
}
callbacks.RegisterDefaultCallbacks(db, callbackConfig)
for k, v := range dialector.ClauseBuilders() {
db.ClauseBuilders[k] = v
}
return
}
const (
// ClauseOnConflict for clause.ClauseBuilder ON CONFLICT key
ClauseOnConflict = "ON CONFLICT"
// ClauseValues for clause.ClauseBuilder VALUES key
ClauseValues = "VALUES"
// ClauseFor for clause.ClauseBuilder FOR key
ClauseFor = "FOR"
)
func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
clauseBuilders := map[string]clause.ClauseBuilder{
ClauseOnConflict: func(c clause.Clause, builder clause.Builder) {
onConflict, ok := c.Expression.(clause.OnConflict)
if !ok {
c.Build(builder)
return
}
builder.WriteString("ON DUPLICATE KEY UPDATE ")
if len(onConflict.DoUpdates) == 0 {
if s := builder.(*gorm.Statement).Schema; s != nil {
var column clause.Column
onConflict.DoNothing = false
if s.PrioritizedPrimaryField != nil {
column = clause.Column{Name: s.PrioritizedPrimaryField.DBName}
} else if len(s.DBNames) > 0 {
column = clause.Column{Name: s.DBNames[0]}
}
if column.Name != "" {
onConflict.DoUpdates = []clause.Assignment{{Column: column, Value: column}}
}
builder.(*gorm.Statement).AddClause(onConflict)
}
}
for idx, assignment := range onConflict.DoUpdates {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(assignment.Column)
builder.WriteByte('=')
if column, ok := assignment.Value.(clause.Column); ok && column.Table == "excluded" {
column.Table = ""
builder.WriteString("VALUES(")
builder.WriteQuoted(column)
builder.WriteByte(')')
} else {
builder.AddVar(builder, assignment.Value)
}
}
},
ClauseValues: func(c clause.Clause, builder clause.Builder) {
if values, ok := c.Expression.(clause.Values); ok && len(values.Columns) == 0 {
builder.WriteString("VALUES()")
return
}
c.Build(builder)
},
}
if dialector.Config.DontSupportForShareClause {
clauseBuilders[ClauseFor] = func(c clause.Clause, builder clause.Builder) {
if values, ok := c.Expression.(clause.Locking); ok && strings.EqualFold(values.Strength, "SHARE") {
builder.WriteString("LOCK IN SHARE MODE")
return
}
c.Build(builder)
}
}
return clauseBuilders
}
func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression {
return clause.Expr{SQL: "DEFAULT"}
}
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
return Migrator{
Migrator: migrator.Migrator{
Config: migrator.Config{
DB: db,
Dialector: dialector,
},
},
Dialector: dialector,
}
}
func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v any) {
writer.WriteByte('?')
}
func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
var (
underQuoted, selfQuoted bool
continuousBacktick int8
shiftDelimiter int8
)
for _, v := range []byte(str) {
switch v {
case '`':
continuousBacktick++
if continuousBacktick == 2 {
writer.WriteString("``")
continuousBacktick = 0
}
case '.':
if continuousBacktick > 0 || !selfQuoted {
shiftDelimiter = 0
underQuoted = false
continuousBacktick = 0
writer.WriteByte('`')
}
writer.WriteByte(v)
continue
default:
if shiftDelimiter-continuousBacktick <= 0 && !underQuoted {
writer.WriteByte('`')
underQuoted = true
if selfQuoted = continuousBacktick > 0; selfQuoted {
continuousBacktick -= 1
}
}
for ; continuousBacktick > 0; continuousBacktick -= 1 {
writer.WriteString("``")
}
writer.WriteByte(v)
}
shiftDelimiter++
}
if continuousBacktick > 0 && !selfQuoted {
writer.WriteString("``")
}
writer.WriteByte('`')
}
type localTimeInterface interface {
In(loc *time.Location) time.Time
}
func (dialector Dialector) Explain(sql string, vars ...any) string {
if dialector.DSNConfig != nil && dialector.DSNConfig.Loc != nil {
for i, v := range vars {
if p, ok := v.(localTimeInterface); ok {
func(i int, t localTimeInterface) {
defer func() {
recover()
}()
vars[i] = t.In(dialector.DSNConfig.Loc)
}(i, p)
}
}
}
return logger.ExplainSQL(sql, nil, `'`, vars...)
}
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
switch field.DataType {
case schema.Bool:
return "boolean"
case schema.Int, schema.Uint:
return dialector.getSchemaIntAndUnitType(field)
case schema.Float:
return dialector.getSchemaFloatType(field)
case schema.String:
return dialector.getSchemaStringType(field)
case schema.Time:
return dialector.getSchemaTimeType(field)
case schema.Bytes:
return dialector.getSchemaBytesType(field)
default:
return dialector.getSchemaCustomType(field)
}
}
func (dialector Dialector) getSchemaFloatType(field *schema.Field) string {
if field.Precision > 0 {
return fmt.Sprintf("decimal(%d, %d)", field.Precision, field.Scale)
}
if field.Size <= 32 {
return "float"
}
return "double"
}
func (dialector Dialector) getSchemaStringType(field *schema.Field) string {
size := field.Size
if size == 0 {
if dialector.DefaultStringSize > 0 {
size = int(dialector.DefaultStringSize)
} else {
hasIndex := field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUE"] != ""
// TEXT, GEOMETRY or JSON column can't have a default value
if field.PrimaryKey || field.HasDefaultValue || hasIndex {
size = 191 // utf8mb4
}
}
}
if size >= 65536 && size <= int(math.Pow(2, 24)) {
return "mediumtext"
}
if size > int(math.Pow(2, 24)) || size <= 0 {
return "longtext"
}
return fmt.Sprintf("varchar(%d)", size)
}
func (dialector Dialector) getSchemaTimeType(field *schema.Field) string {
if !dialector.DisableDatetimePrecision && field.Precision == 0 && field.TagSettings["PRECISION"] == "" {
field.Precision = *dialector.DefaultDatetimePrecision
}
var precision string
if field.Precision > 0 {
precision = fmt.Sprintf("(%d)", field.Precision)
}
if field.NotNull || field.PrimaryKey {
return "datetime" + precision
}
return "datetime" + precision + " NULL"
}
func (dialector Dialector) getSchemaBytesType(field *schema.Field) string {
if field.Size > 0 && field.Size < 65536 {
return fmt.Sprintf("varbinary(%d)", field.Size)
}
if field.Size >= 65536 && field.Size <= int(math.Pow(2, 24)) {
return "mediumblob"
}
return "longblob"
}
// autoRandomType
// field.DataType MUST be `schema.Int` or `schema.Uint`
// Judgement logic:
// 1. Is PrimaryKey;
// 2. Has default value;
// 3. Default value is "auto_random()";
// 4. IGNORE the field.Size, it MUST be bigint;
// 5. CLEAR the default tag, and return true;
// 6. Otherwise, return false.
func autoRandomType(field *schema.Field) (bool, string) {
if field.PrimaryKey && field.HasDefaultValue &&
strings.ToLower(strings.TrimSpace(field.DefaultValue)) == AutoRandomTag {
field.DefaultValue = ""
sqlType := "bigint"
if field.DataType == schema.Uint {
sqlType += " unsigned"
}
sqlType += " auto_random"
return true, sqlType
}
return false, ""
}
func (dialector Dialector) getSchemaIntAndUnitType(field *schema.Field) string {
if autoRandom, typeString := autoRandomType(field); autoRandom {
return typeString
}
constraint := func(sqlType string) string {
if field.DataType == schema.Uint {
sqlType += " unsigned"
}
if field.AutoIncrement {
sqlType += " AUTO_INCREMENT"
}
return sqlType
}
switch {
case field.Size <= 8:
return constraint("tinyint")
case field.Size <= 16:
return constraint("smallint")
case field.Size <= 24:
return constraint("mediumint")
case field.Size <= 32:
return constraint("int")
default:
return constraint("bigint")
}
}
func (dialector Dialector) getSchemaCustomType(field *schema.Field) string {
sqlType := string(field.DataType)
if field.AutoIncrement && !strings.Contains(strings.ToLower(sqlType), " auto_increment") {
sqlType += " AUTO_INCREMENT"
}
return sqlType
}
func (dialector Dialector) SavePoint(tx *gorm.DB, name string) error {
return tx.Exec("SAVEPOINT " + name).Error
}
func (dialector Dialector) RollbackTo(tx *gorm.DB, name string) error {
return tx.Exec("ROLLBACK TO SAVEPOINT " + name).Error
}
// checkVersion newer or equal returns true, old returns false
func checkVersion(newVersion, oldVersion string) bool {
if newVersion == oldVersion {
return true
}
var (
versionTrimmerRegexp = regexp.MustCompile(`^(\d+).*$`)
newVersions = strings.Split(newVersion, ".")
oldVersions = strings.Split(oldVersion, ".")
)
for idx, nv := range newVersions {
if len(oldVersions) <= idx {
return true
}
nvi, _ := strconv.Atoi(versionTrimmerRegexp.ReplaceAllString(nv, "$1"))
ovi, _ := strconv.Atoi(versionTrimmerRegexp.ReplaceAllString(oldVersions[idx], "$1"))
if nvi == ovi {
continue
}
return nvi > ovi
}
return false
}
================================================
FILE: gossh/gorm/driver/pgsql/error_translator.go
================================================
package pgsql
import (
"encoding/json"
"gossh/gorm"
)
var errCodes = map[string]error{
"23505": gorm.ErrDuplicatedKey,
"23503": gorm.ErrForeignKeyViolated,
"42703": gorm.ErrInvalidField,
}
type ErrMessage struct {
Code string
Severity string
Message string
}
// Translate it will translate the error to native gorm errors.
// Since currently gorm supporting both pgx and pg drivers, only checking for pgx PgError types is not enough for translating errors, so we have additional error json marshal fallback.
func (dialector Dialector) Translate(err error) error {
parsedErr, marshalErr := json.Marshal(err)
if marshalErr != nil {
return err
}
var errMsg ErrMessage
unmarshalErr := json.Unmarshal(parsedErr, &errMsg)
if unmarshalErr != nil {
return err
}
if translatedErr, found := errCodes[errMsg.Code]; found {
return translatedErr
}
return err
}
================================================
FILE: gossh/gorm/driver/pgsql/migrator.go
================================================
package pgsql
import (
"database/sql"
"fmt"
"regexp"
"strings"
"gossh/gorm"
"gossh/gorm/clause"
"gossh/gorm/migrator"
"gossh/gorm/schema"
)
// See https://stackoverflow.com/questions/2204058/list-columns-with-indexes-in-postgresql
// Here are some changes:
// - use `LEFT JOIN` instead of `CROSS JOIN`
// - exclude indexes used to support constraints (they are auto-generated)
const indexSql = `
SELECT
ct.relname AS table_name,
ci.relname AS index_name,
i.indisunique AS non_unique,
i.indisprimary AS primary,
a.attname AS column_name
FROM
pg_index i
LEFT JOIN pg_class ct ON ct.oid = i.indrelid
LEFT JOIN pg_class ci ON ci.oid = i.indexrelid
LEFT JOIN pg_attribute a ON a.attrelid = ct.oid
LEFT JOIN pg_constraint con ON con.conindid = i.indexrelid
WHERE
a.attnum = ANY(i.indkey)
AND con.oid IS NULL
AND ct.relkind = 'r'
AND ct.relname = ?
`
var typeAliasMap = map[string][]string{
"int2": {"smallint"},
"int4": {"integer"},
"int8": {"bigint"},
"smallint": {"int2"},
"integer": {"int4"},
"bigint": {"int8"},
"decimal": {"numeric"},
"numeric": {"decimal"},
"timestamptz": {"timestamp with time zone"},
"timestamp with time zone": {"timestamptz"},
"bool": {"boolean"},
"boolean": {"bool"},
}
type Migrator struct {
migrator.Migrator
}
// select querys ignore dryrun
func (m Migrator) queryRaw(sql string, values ...any) (tx *gorm.DB) {
queryTx := m.DB
if m.DB.DryRun {
queryTx = m.DB.Session(&gorm.Session{})
queryTx.DryRun = false
}
return queryTx.Raw(sql, values...)
}
func (m Migrator) CurrentDatabase() (name string) {
m.queryRaw("SELECT CURRENT_DATABASE()").Scan(&name)
return
}
func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []any) {
for _, opt := range opts {
str := stmt.Quote(opt.DBName)
if opt.Expression != "" {
str = opt.Expression
}
if opt.Collate != "" {
str += " COLLATE " + opt.Collate
}
if opt.Sort != "" {
str += " " + opt.Sort
}
results = append(results, clause.Expr{SQL: str})
}
return
}
func (m Migrator) HasIndex(value any, name string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
}
currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table)
return m.queryRaw(
"SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ? AND schemaname = ?", curTable, name, currentSchema,
).Scan(&count).Error
})
return count > 0
}
func (m Migrator) CreateIndex(value any, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if idx := stmt.Schema.LookIndex(name); idx != nil {
opts := m.BuildIndexOptions(idx.Fields, stmt)
values := []any{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts}
createIndexSQL := "CREATE "
if idx.Class != "" {
createIndexSQL += idx.Class + " "
}
createIndexSQL += "INDEX "
if strings.TrimSpace(strings.ToUpper(idx.Option)) == "CONCURRENTLY" {
createIndexSQL += "CONCURRENTLY "
}
createIndexSQL += "IF NOT EXISTS ? ON ?"
if idx.Type != "" {
createIndexSQL += " USING " + idx.Type + "(?)"
} else {
createIndexSQL += " ?"
}
if idx.Where != "" {
createIndexSQL += " WHERE " + idx.Where
}
return m.DB.Exec(createIndexSQL, values...).Error
}
}
return fmt.Errorf("failed to create index with name %v", name)
})
}
func (m Migrator) RenameIndex(value any, oldName, newName string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
return m.DB.Exec(
"ALTER INDEX ? RENAME TO ?",
clause.Column{Name: oldName}, clause.Column{Name: newName},
).Error
})
}
func (m Migrator) DropIndex(value any, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
}
return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error
})
}
func (m Migrator) GetTables() (tableList []string, err error) {
currentSchema, _ := m.CurrentSchema(m.DB.Statement, "")
return tableList, m.queryRaw("SELECT table_name FROM information_schema.tables WHERE table_schema = ? AND table_type = ?", currentSchema, "BASE TABLE").Scan(&tableList).Error
}
func (m Migrator) CreateTable(values ...any) (err error) {
if err = m.Migrator.CreateTable(values...); err != nil {
return
}
for _, value := range m.ReorderModels(values, false) {
if err = m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
for _, fieldName := range stmt.Schema.DBNames {
field := stmt.Schema.FieldsByDBName[fieldName]
if field.Comment != "" {
if err := m.DB.Exec(
"COMMENT ON COLUMN ?.? IS ?",
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, gorm.Expr(m.Migrator.Dialector.Explain("$1", field.Comment)),
).Error; err != nil {
return err
}
}
}
}
return nil
}); err != nil {
return
}
}
return
}
func (m Migrator) HasTable(value any) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table)
return m.queryRaw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentSchema, curTable, "BASE TABLE").Scan(&count).Error
})
return count > 0
}
func (m Migrator) DropTable(values ...any) error {
values = m.ReorderModels(values, false)
tx := m.DB.Session(&gorm.Session{})
for i := len(values) - 1; i >= 0; i-- {
if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error {
return tx.Exec("DROP TABLE IF EXISTS ? CASCADE", m.CurrentTable(stmt)).Error
}); err != nil {
return err
}
}
return nil
}
func (m Migrator) AddColumn(value any, field string) error {
if err := m.Migrator.AddColumn(value, field); err != nil {
return err
}
m.resetPreparedStmts()
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(field); field != nil {
if field.Comment != "" {
if err := m.DB.Exec(
"COMMENT ON COLUMN ?.? IS ?",
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, gorm.Expr(m.Migrator.Dialector.Explain("$1", field.Comment)),
).Error; err != nil {
return err
}
}
}
}
return nil
})
}
func (m Migrator) HasColumn(value any, field string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
name := field
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(field); field != nil {
name = field.DBName
}
}
currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table)
return m.queryRaw(
"SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?",
currentSchema, curTable, name,
).Scan(&count).Error
})
return count > 0
}
func (m Migrator) MigrateColumn(value any, field *schema.Field, columnType gorm.ColumnType) error {
// skip primary field
if !field.PrimaryKey {
if err := m.Migrator.MigrateColumn(value, field, columnType); err != nil {
return err
}
}
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
var description string
currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table)
values := []any{currentSchema, curTable, field.DBName, stmt.Table, currentSchema}
checkSQL := "SELECT description FROM pg_catalog.pg_description "
checkSQL += "WHERE objsubid = (SELECT ordinal_position FROM information_schema.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?) "
checkSQL += "AND objoid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = ? AND relnamespace = "
checkSQL += "(SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = ?))"
m.queryRaw(checkSQL, values...).Scan(&description)
comment := strings.Trim(field.Comment, "'")
comment = strings.Trim(comment, `"`)
if field.Comment != "" && comment != description {
if err := m.DB.Exec(
"COMMENT ON COLUMN ?.? IS ?",
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, gorm.Expr(m.Migrator.Dialector.Explain("$1", field.Comment)),
).Error; err != nil {
return err
}
}
return nil
})
}
// AlterColumn alter value's `field` column' type based on schema definition
func (m Migrator) AlterColumn(value any, field string) error {
err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(field); field != nil {
var (
columnTypes, _ = m.DB.Migrator().ColumnTypes(value)
fieldColumnType *migrator.ColumnType
)
for _, columnType := range columnTypes {
if columnType.Name() == field.DBName {
fieldColumnType, _ = columnType.(*migrator.ColumnType)
}
}
fileType := clause.Expr{SQL: m.DataTypeOf(field)}
// check for typeName and SQL name
isSameType := true
if fieldColumnType.DatabaseTypeName() != fileType.SQL {
isSameType = false
// if different, also check for aliases
aliases := m.GetTypeAliases(fieldColumnType.DatabaseTypeName())
for _, alias := range aliases {
if strings.HasPrefix(fileType.SQL, alias) {
isSameType = true
break
}
}
}
// not same, migrate
if !isSameType {
filedColumnAutoIncrement, _ := fieldColumnType.AutoIncrement()
if field.AutoIncrement && filedColumnAutoIncrement { // update
serialDatabaseType, _ := getSerialDatabaseType(fileType.SQL)
if t, _ := fieldColumnType.ColumnType(); t != serialDatabaseType {
if err := m.UpdateSequence(m.DB, stmt, field, serialDatabaseType); err != nil {
return err
}
}
} else if field.AutoIncrement && !filedColumnAutoIncrement { // create
serialDatabaseType, _ := getSerialDatabaseType(fileType.SQL)
if err := m.CreateSequence(m.DB, stmt, field, serialDatabaseType); err != nil {
return err
}
} else if !field.AutoIncrement && filedColumnAutoIncrement { // delete
if err := m.DeleteSequence(m.DB, stmt, field, fileType); err != nil {
return err
}
} else {
if err := m.modifyColumn(stmt, field, fileType, fieldColumnType); err != nil {
return err
}
}
}
if null, _ := fieldColumnType.Nullable(); null == field.NotNull {
if field.NotNull {
if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? SET NOT NULL", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil {
return err
}
} else {
if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP NOT NULL", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil {
return err
}
}
}
if v, ok := fieldColumnType.DefaultValue(); (field.DefaultValueInterface == nil && ok) || v != field.DefaultValue {
if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") {
if field.DefaultValueInterface != nil {
defaultStmt := &gorm.Statement{Vars: []any{field.DefaultValueInterface}}
m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface)
if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface)}).Error; err != nil {
return err
}
} else if field.DefaultValue != "(-)" {
if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DefaultValue}).Error; err != nil {
return err
}
} else {
if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DefaultValue}).Error; err != nil {
return err
}
}
}
}
return nil
}
}
return fmt.Errorf("failed to look up field with name: %s", field)
})
if err != nil {
return err
}
m.resetPreparedStmts()
return nil
}
func (m Migrator) modifyColumn(stmt *gorm.Statement, field *schema.Field, targetType clause.Expr, existingColumn *migrator.ColumnType) error {
alterSQL := "ALTER TABLE ? ALTER COLUMN ? TYPE ? USING ?::?"
isUncastableDefaultValue := false
if targetType.SQL == "boolean" {
switch existingColumn.DatabaseTypeName() {
case "int2", "int8", "numeric":
alterSQL = "ALTER TABLE ? ALTER COLUMN ? TYPE ? USING ?::int::?"
}
isUncastableDefaultValue = true
}
if dv, _ := existingColumn.DefaultValue(); dv != "" && isUncastableDefaultValue {
if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil {
return err
}
}
if err := m.DB.Exec(alterSQL, m.CurrentTable(stmt), clause.Column{Name: field.DBName}, targetType, clause.Column{Name: field.DBName}, targetType).Error; err != nil {
return err
}
return nil
}
func (m Migrator) HasConstraint(value any, name string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
if constraint != nil {
name = constraint.GetName()
}
currentSchema, curTable := m.CurrentSchema(stmt, table)
return m.queryRaw(
"SELECT count(*) FROM INFORMATION_SCHEMA.table_constraints WHERE table_schema = ? AND table_name = ? AND constraint_name = ?",
currentSchema, curTable, name,
).Scan(&count).Error
})
return count > 0
}
func (m Migrator) ColumnTypes(value any) (columnTypes []gorm.ColumnType, err error) {
columnTypes = make([]gorm.ColumnType, 0)
err = m.RunWithValue(value, func(stmt *gorm.Statement) error {
var (
currentDatabase = m.DB.Migrator().CurrentDatabase()
currentSchema, table = m.CurrentSchema(stmt, stmt.Table)
columns, err = m.queryRaw(
"SELECT c.column_name, c.is_nullable = 'YES', c.udt_name, c.character_maximum_length, c.numeric_precision, c.numeric_precision_radix, c.numeric_scale, c.datetime_precision, 8 * typlen, c.column_default, pd.description, c.identity_increment FROM information_schema.columns AS c JOIN pg_type AS pgt ON c.udt_name = pgt.typname LEFT JOIN pg_catalog.pg_description as pd ON pd.objsubid = c.ordinal_position AND pd.objoid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = c.table_name AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = c.table_schema)) where table_catalog = ? AND table_schema = ? AND table_name = ?",
currentDatabase, currentSchema, table).Rows()
)
if err != nil {
return err
}
for columns.Next() {
var (
column = &migrator.ColumnType{
PrimaryKeyValue: sql.NullBool{Valid: true},
UniqueValue: sql.NullBool{Valid: true},
}
datetimePrecision sql.NullInt64
radixValue sql.NullInt64
typeLenValue sql.NullInt64
identityIncrement sql.NullString
)
err = columns.Scan(
&column.NameValue, &column.NullableValue, &column.DataTypeValue, &column.LengthValue, &column.DecimalSizeValue,
&radixValue, &column.ScaleValue, &datetimePrecision, &typeLenValue, &column.DefaultValueValue, &column.CommentValue, &identityIncrement,
)
if err != nil {
return err
}
if typeLenValue.Valid && typeLenValue.Int64 > 0 {
column.LengthValue = typeLenValue
}
if (strings.HasPrefix(column.DefaultValueValue.String, "nextval('") &&
strings.HasSuffix(column.DefaultValueValue.String, "seq'::regclass)")) || (identityIncrement.Valid && identityIncrement.String != "") {
column.AutoIncrementValue = sql.NullBool{Bool: true, Valid: true}
column.DefaultValueValue = sql.NullString{}
}
if column.DefaultValueValue.Valid {
column.DefaultValueValue.String = parseDefaultValueValue(column.DefaultValueValue.String)
}
if datetimePrecision.Valid {
column.DecimalSizeValue = datetimePrecision
}
columnTypes = append(columnTypes, column)
}
columns.Close()
// assign sql column type
{
rows, rowsErr := m.GetRows(currentSchema, table)
if rowsErr != nil {
return rowsErr
}
rawColumnTypes, err := rows.ColumnTypes()
if err != nil {
return err
}
for _, columnType := range columnTypes {
for _, c := range rawColumnTypes {
if c.Name() == columnType.Name() {
columnType.(*migrator.ColumnType).SQLColumnType = c
break
}
}
}
rows.Close()
}
// check primary, unique field
{
columnTypeRows, err := m.queryRaw("SELECT constraint_name FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_catalog, table_name, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ? AND constraint_type = ?", currentDatabase, currentSchema, table, "UNIQUE").Rows()
if err != nil {
return err
}
uniqueContraints := map[string]int{}
for columnTypeRows.Next() {
var constraintName string
columnTypeRows.Scan(&constraintName)
uniqueContraints[constraintName]++
}
columnTypeRows.Close()
columnTypeRows, err = m.queryRaw("SELECT c.column_name, constraint_name, constraint_type FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_catalog, table_name, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ?", currentDatabase, currentSchema, table).Rows()
if err != nil {
return err
}
for columnTypeRows.Next() {
var name, constraintName, columnType string
columnTypeRows.Scan(&name, &constraintName, &columnType)
for _, c := range columnTypes {
mc := c.(*migrator.ColumnType)
if mc.NameValue.String == name {
switch columnType {
case "PRIMARY KEY":
mc.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true}
case "UNIQUE":
if uniqueContraints[constraintName] == 1 {
mc.UniqueValue = sql.NullBool{Bool: true, Valid: true}
}
}
break
}
}
}
columnTypeRows.Close()
}
// check column type
{
dataTypeRows, err := m.queryRaw(`SELECT a.attname as column_name, format_type(a.atttypid, a.atttypmod) AS data_type
FROM pg_attribute a JOIN pg_class b ON a.attrelid = b.oid AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = ?)
WHERE a.attnum > 0 -- hide internal columns
AND NOT a.attisdropped -- hide deleted columns
AND b.relname = ?`, currentSchema, table).Rows()
if err != nil {
return err
}
for dataTypeRows.Next() {
var name, dataType string
dataTypeRows.Scan(&name, &dataType)
for _, c := range columnTypes {
mc := c.(*migrator.ColumnType)
if mc.NameValue.String == name {
mc.ColumnTypeValue = sql.NullString{String: dataType, Valid: true}
// Handle array type: _text -> text[] , _int4 -> integer[]
// Not support array size limits and array size limits because:
// https://www.postgresql.org/docs/current/arrays.html#ARRAYS-DECLARATION
if strings.HasPrefix(mc.DataTypeValue.String, "_") {
mc.DataTypeValue = sql.NullString{String: dataType, Valid: true}
}
break
}
}
}
dataTypeRows.Close()
}
return err
})
return
}
func (m Migrator) GetRows(currentSchema any, table any) (*sql.Rows, error) {
name := table.(string)
if _, ok := currentSchema.(string); ok {
name = fmt.Sprintf("%v.%v", currentSchema, table)
}
return m.DB.Session(&gorm.Session{}).Table(name).Limit(1).Scopes(func(d *gorm.DB) *gorm.DB {
dialector, _ := m.Dialector.(Dialector)
// use simple protocol
if !m.DB.PrepareStmt && (dialector.Config != nil && (dialector.Config.DriverName == "" || dialector.Config.DriverName == "pgx")) {
//huang d.Statement.Vars = append([]interface{}{pgx.QueryExecModeSimpleProtocol}, d.Statement.Vars...)
}
return d
}).Rows()
}
func (m Migrator) CurrentSchema(stmt *gorm.Statement, table string) (any, any) {
if strings.Contains(table, ".") {
if tables := strings.Split(table, `.`); len(tables) == 2 {
return tables[0], tables[1]
}
}
if stmt.TableExpr != nil {
if tables := strings.Split(stmt.TableExpr.SQL, `"."`); len(tables) == 2 {
return strings.TrimPrefix(tables[0], `"`), table
}
}
return clause.Expr{SQL: "CURRENT_SCHEMA()"}, table
}
func (m Migrator) CreateSequence(tx *gorm.DB, stmt *gorm.Statement, field *schema.Field,
serialDatabaseType string) (err error) {
_, table := m.CurrentSchema(stmt, stmt.Table)
tableName := table.(string)
sequenceName := strings.Join([]string{tableName, field.DBName, "seq"}, "_")
if err = tx.Exec(`CREATE SEQUENCE IF NOT EXISTS ? AS ?`, clause.Expr{SQL: sequenceName},
clause.Expr{SQL: serialDatabaseType}).Error; err != nil {
return err
}
if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT nextval('?')",
clause.Expr{SQL: tableName}, clause.Expr{SQL: field.DBName}, clause.Expr{SQL: sequenceName}).Error; err != nil {
return err
}
if err := tx.Exec("ALTER SEQUENCE ? OWNED BY ?.?",
clause.Expr{SQL: sequenceName}, clause.Expr{SQL: tableName}, clause.Expr{SQL: field.DBName}).Error; err != nil {
return err
}
return
}
func (m Migrator) UpdateSequence(tx *gorm.DB, stmt *gorm.Statement, field *schema.Field,
serialDatabaseType string) (err error) {
sequenceName, err := m.getColumnSequenceName(tx, stmt, field)
if err != nil {
return err
}
if err = tx.Exec(`ALTER SEQUENCE IF EXISTS ? AS ?`, clause.Expr{SQL: sequenceName}, clause.Expr{SQL: serialDatabaseType}).Error; err != nil {
return err
}
if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? TYPE ?",
m.CurrentTable(stmt), clause.Expr{SQL: field.DBName}, clause.Expr{SQL: serialDatabaseType}).Error; err != nil {
return err
}
return
}
func (m Migrator) DeleteSequence(tx *gorm.DB, stmt *gorm.Statement, field *schema.Field,
fileType clause.Expr) (err error) {
sequenceName, err := m.getColumnSequenceName(tx, stmt, field)
if err != nil {
return err
}
if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? TYPE ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType).Error; err != nil {
return err
}
if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT",
m.CurrentTable(stmt), clause.Expr{SQL: field.DBName}).Error; err != nil {
return err
}
if err = tx.Exec(`DROP SEQUENCE IF EXISTS ?`, clause.Expr{SQL: sequenceName}).Error; err != nil {
return err
}
return
}
func (m Migrator) getColumnSequenceName(tx *gorm.DB, stmt *gorm.Statement, field *schema.Field) (
sequenceName string, err error) {
_, table := m.CurrentSchema(stmt, stmt.Table)
// DefaultValueValue is reset by ColumnTypes, search again.
var columnDefault string
err = tx.Raw(
`SELECT column_default FROM information_schema.columns WHERE table_name = ? AND column_name = ?`,
table, field.DBName).Scan(&columnDefault).Error
if err != nil {
return
}
sequenceName = strings.TrimSuffix(
strings.TrimPrefix(columnDefault, `nextval('`),
`'::regclass)`,
)
return
}
func (m Migrator) GetIndexes(value any) ([]gorm.Index, error) {
indexes := make([]gorm.Index, 0)
err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
result := make([]*Index, 0)
scanErr := m.queryRaw(indexSql, stmt.Table).Scan(&result).Error
if scanErr != nil {
return scanErr
}
indexMap := groupByIndexName(result)
for _, idx := range indexMap {
tempIdx := &migrator.Index{
TableName: idx[0].TableName,
NameValue: idx[0].IndexName,
PrimaryKeyValue: sql.NullBool{
Bool: idx[0].Primary,
Valid: true,
},
UniqueValue: sql.NullBool{
Bool: idx[0].NonUnique,
Valid: true,
},
}
for _, x := range idx {
tempIdx.ColumnList = append(tempIdx.ColumnList, x.ColumnName)
}
indexes = append(indexes, tempIdx)
}
return nil
})
return indexes, err
}
// Index table index info
type Index struct {
TableName string `gorm:"column:table_name"`
ColumnName string `gorm:"column:column_name"`
IndexName string `gorm:"column:index_name"`
NonUnique bool `gorm:"column:non_unique"`
Primary bool `gorm:"column:primary"`
}
func groupByIndexName(indexList []*Index) map[string][]*Index {
columnIndexMap := make(map[string][]*Index, len(indexList))
for _, idx := range indexList {
columnIndexMap[idx.IndexName] = append(columnIndexMap[idx.IndexName], idx)
}
return columnIndexMap
}
func (m Migrator) GetTypeAliases(databaseTypeName string) []string {
return typeAliasMap[databaseTypeName]
}
// should reset prepared stmts when table changed
func (m Migrator) resetPreparedStmts() {
if m.DB.PrepareStmt {
if pdb, ok := m.DB.ConnPool.(*gorm.PreparedStmtDB); ok {
pdb.Reset()
}
}
}
func (m Migrator) DropColumn(dst any, field string) error {
if err := m.Migrator.DropColumn(dst, field); err != nil {
return err
}
m.resetPreparedStmts()
return nil
}
func (m Migrator) RenameColumn(dst any, oldName, field string) error {
if err := m.Migrator.RenameColumn(dst, oldName, field); err != nil {
return err
}
m.resetPreparedStmts()
return nil
}
func parseDefaultValueValue(defaultValue string) string {
value := regexp.MustCompile(`^(.*?)(?:::.*)?$`).ReplaceAllString(defaultValue, "$1")
return strings.Trim(value, "'")
}
================================================
FILE: gossh/gorm/driver/pgsql/postgres.go
================================================
package pgsql
import (
"database/sql"
"fmt"
"regexp"
"strconv"
"strings"
"gossh/gorm"
"gossh/gorm/callbacks"
"gossh/gorm/clause"
"gossh/gorm/logger"
"gossh/gorm/migrator"
"gossh/gorm/schema"
)
type Dialector struct {
*Config
}
type Config struct {
DriverName string
DSN string
WithoutQuotingCheck bool
PreferSimpleProtocol bool
WithoutReturning bool
Conn gorm.ConnPool
}
var (
timeZoneMatcher = regexp.MustCompile("(time_zone|TimeZone)=(.*?)($|&| )")
defaultIdentifierLength = 63 //maximum identifier length for postgres
)
func Open(dsn string) gorm.Dialector {
return &Dialector{&Config{DSN: dsn, DriverName: "postgres"}}
}
func New(config Config) gorm.Dialector {
return &Dialector{Config: &config}
}
func (dialector Dialector) Name() string {
return "postgres"
}
func (dialector Dialector) Apply(config *gorm.Config) error {
if config.NamingStrategy == nil {
config.NamingStrategy = schema.NamingStrategy{
IdentifierMaxLength: defaultIdentifierLength,
}
return nil
}
switch v := config.NamingStrategy.(type) {
case *schema.NamingStrategy:
if v.IdentifierMaxLength <= 0 {
v.IdentifierMaxLength = defaultIdentifierLength
}
case schema.NamingStrategy:
if v.IdentifierMaxLength <= 0 {
v.IdentifierMaxLength = defaultIdentifierLength
config.NamingStrategy = v
}
}
return nil
}
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
callbackConfig := &callbacks.Config{
CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT"},
UpdateClauses: []string{"UPDATE", "SET", "FROM", "WHERE"},
DeleteClauses: []string{"DELETE", "FROM", "WHERE"},
}
// register callbacks
if !dialector.WithoutReturning {
callbackConfig.CreateClauses = append(callbackConfig.CreateClauses, "RETURNING")
callbackConfig.UpdateClauses = append(callbackConfig.UpdateClauses, "RETURNING")
callbackConfig.DeleteClauses = append(callbackConfig.DeleteClauses, "RETURNING")
}
callbacks.RegisterDefaultCallbacks(db, callbackConfig)
if dialector.Conn != nil {
db.ConnPool = dialector.Conn
} else /*dialector.DriverName != ""*/ {
db.ConnPool, err = sql.Open(dialector.DriverName, dialector.Config.DSN)
}
return
}
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
return Migrator{migrator.Migrator{Config: migrator.Config{
DB: db,
Dialector: dialector,
CreateIndexAfterCreateTable: true,
}}}
}
func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression {
return clause.Expr{SQL: "DEFAULT"}
}
func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v any) {
writer.WriteByte('$')
index := 0
varLen := len(stmt.Vars)
// huang
//if varLen > 0 {
//switch stmt.Vars[0].(type) {
//
//case pgx.QueryExecMode:
//index++
//}
//}
writer.WriteString(strconv.Itoa(varLen - index))
}
func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
if dialector.WithoutQuotingCheck {
writer.WriteString(str)
return
}
var (
underQuoted, selfQuoted bool
continuousBacktick int8
shiftDelimiter int8
)
for _, v := range []byte(str) {
switch v {
case '"':
continuousBacktick++
if continuousBacktick == 2 {
writer.WriteString(`""`)
continuousBacktick = 0
}
case '.':
if continuousBacktick > 0 || !selfQuoted {
shiftDelimiter = 0
underQuoted = false
continuousBacktick = 0
writer.WriteByte('"')
}
writer.WriteByte(v)
continue
default:
if shiftDelimiter-continuousBacktick <= 0 && !underQuoted {
writer.WriteByte('"')
underQuoted = true
if selfQuoted = continuousBacktick > 0; selfQuoted {
continuousBacktick -= 1
}
}
for ; continuousBacktick > 0; continuousBacktick -= 1 {
writer.WriteString(`""`)
}
writer.WriteByte(v)
}
shiftDelimiter++
}
if continuousBacktick > 0 && !selfQuoted {
writer.WriteString(`""`)
}
writer.WriteByte('"')
}
var numericPlaceholder = regexp.MustCompile(`\$(\d+)`)
func (dialector Dialector) Explain(sql string, vars ...any) string {
return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...)
}
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
switch field.DataType {
case schema.Bool:
return "boolean"
case schema.Int, schema.Uint:
size := field.Size
if field.DataType == schema.Uint {
size++
}
if field.AutoIncrement {
switch {
case size <= 16:
return "smallserial"
case size <= 32:
return "serial"
default:
return "bigserial"
}
} else {
switch {
case size <= 16:
return "smallint"
case size <= 32:
return "integer"
default:
return "bigint"
}
}
case schema.Float:
if field.Precision > 0 {
if field.Scale > 0 {
return fmt.Sprintf("numeric(%d, %d)", field.Precision, field.Scale)
}
return fmt.Sprintf("numeric(%d)", field.Precision)
}
return "decimal"
case schema.String:
if field.Size > 0 {
return fmt.Sprintf("varchar(%d)", field.Size)
}
return "text"
case schema.Time:
if field.Precision > 0 {
return fmt.Sprintf("timestamptz(%d)", field.Precision)
}
return "timestamptz"
case schema.Bytes:
return "bytea"
default:
return dialector.getSchemaCustomType(field)
}
}
func (dialector Dialector) getSchemaCustomType(field *schema.Field) string {
sqlType := string(field.DataType)
if field.AutoIncrement && !strings.Contains(strings.ToLower(sqlType), "serial") {
size := field.Size
if field.GORMDataType == schema.Uint {
size++
}
switch {
case size <= 16:
sqlType = "smallserial"
case size <= 32:
sqlType = "serial"
default:
sqlType = "bigserial"
}
}
return sqlType
}
func (dialector Dialector) SavePoint(tx *gorm.DB, name string) error {
tx.Exec("SAVEPOINT " + name)
return nil
}
func (dialector Dialector) RollbackTo(tx *gorm.DB, name string) error {
tx.Exec("ROLLBACK TO SAVEPOINT " + name)
return nil
}
func getSerialDatabaseType(s string) (dbType string, ok bool) {
switch s {
case "smallserial":
return "smallint", true
case "serial":
return "integer", true
case "bigserial":
return "bigint", true
default:
return "", false
}
}
================================================
FILE: gossh/gorm/errors.go
================================================
package gorm
import (
"errors"
"gossh/gorm/logger"
)
var (
// ErrRecordNotFound record not found error
ErrRecordNotFound = logger.ErrRecordNotFound
// ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback`
ErrInvalidTransaction = errors.New("invalid transaction")
// ErrNotImplemented not implemented
ErrNotImplemented = errors.New("not implemented")
// ErrMissingWhereClause missing where clause
ErrMissingWhereClause = errors.New("WHERE conditions required")
// ErrUnsupportedRelation unsupported relations
ErrUnsupportedRelation = errors.New("unsupported relations")
// ErrPrimaryKeyRequired primary keys required
ErrPrimaryKeyRequired = errors.New("primary key required")
// ErrModelValueRequired model value required
ErrModelValueRequired = errors.New("model value required")
// ErrModelAccessibleFieldsRequired model accessible fields required
ErrModelAccessibleFieldsRequired = errors.New("model accessible fields required")
// ErrSubQueryRequired sub query required
ErrSubQueryRequired = errors.New("sub query required")
// ErrInvalidData unsupported data
ErrInvalidData = errors.New("unsupported data")
// ErrUnsupportedDriver unsupported driver
ErrUnsupportedDriver = errors.New("unsupported driver")
// ErrRegistered registered
ErrRegistered = errors.New("registered")
// ErrInvalidField invalid field
ErrInvalidField = errors.New("invalid field")
// ErrEmptySlice empty slice found
ErrEmptySlice = errors.New("empty slice found")
// ErrDryRunModeUnsupported dry run mode unsupported
ErrDryRunModeUnsupported = errors.New("dry run mode unsupported")
// ErrInvalidDB invalid db
ErrInvalidDB = errors.New("invalid db")
// ErrInvalidValue invalid value
ErrInvalidValue = errors.New("invalid value, should be pointer to struct or slice")
// ErrInvalidValueOfLength invalid values do not match length
ErrInvalidValueOfLength = errors.New("invalid association values, length doesn't match")
// ErrPreloadNotAllowed preload is not allowed when count is used
ErrPreloadNotAllowed = errors.New("preload is not allowed when count is used")
// ErrDuplicatedKey occurs when there is a unique key constraint violation
ErrDuplicatedKey = errors.New("duplicated key not allowed")
// ErrForeignKeyViolated occurs when there is a foreign key constraint violation
ErrForeignKeyViolated = errors.New("violates foreign key constraint")
)
================================================
FILE: gossh/gorm/finisher_api.go
================================================
package gorm
import (
"database/sql"
"errors"
"fmt"
"reflect"
"strings"
"gossh/gorm/clause"
"gossh/gorm/logger"
"gossh/gorm/schema"
"gossh/gorm/utils"
)
// Create inserts value, returning the inserted data's primary key in value's id
func (db *DB) Create(value any) (tx *DB) {
if db.CreateBatchSize > 0 {
return db.CreateInBatches(value, db.CreateBatchSize)
}
tx = db.getInstance()
tx.Statement.Dest = value
return tx.callbacks.Create().Execute(tx)
}
// CreateInBatches inserts value in batches of batchSize
func (db *DB) CreateInBatches(value any, batchSize int) (tx *DB) {
reflectValue := reflect.Indirect(reflect.ValueOf(value))
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
var rowsAffected int64
tx = db.getInstance()
// the reflection length judgment of the optimized value
reflectLen := reflectValue.Len()
callFc := func(tx *DB) error {
for i := 0; i < reflectLen; i += batchSize {
ends := i + batchSize
if ends > reflectLen {
ends = reflectLen
}
subtx := tx.getInstance()
subtx.Statement.Dest = reflectValue.Slice(i, ends).Interface()
subtx.callbacks.Create().Execute(subtx)
if subtx.Error != nil {
return subtx.Error
}
rowsAffected += subtx.RowsAffected
}
return nil
}
if tx.SkipDefaultTransaction || reflectLen <= batchSize {
tx.AddError(callFc(tx.Session(&Session{})))
} else {
tx.AddError(tx.Transaction(callFc))
}
tx.RowsAffected = rowsAffected
default:
tx = db.getInstance()
tx.Statement.Dest = value
tx = tx.callbacks.Create().Execute(tx)
}
return
}
// Save updates value in database. If value doesn't contain a matching primary key, value is inserted.
func (db *DB) Save(value any) (tx *DB) {
tx = db.getInstance()
tx.Statement.Dest = value
reflectValue := reflect.Indirect(reflect.ValueOf(value))
for reflectValue.Kind() == reflect.Ptr || reflectValue.Kind() == reflect.Interface {
reflectValue = reflect.Indirect(reflectValue)
}
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok {
tx = tx.Clauses(clause.OnConflict{UpdateAll: true})
}
tx = tx.callbacks.Create().Execute(tx.Set("gorm:update_track_time", true))
case reflect.Struct:
if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
for _, pf := range tx.Statement.Schema.PrimaryFields {
if _, isZero := pf.ValueOf(tx.Statement.Context, reflectValue); isZero {
return tx.callbacks.Create().Execute(tx)
}
}
}
fallthrough
default:
selectedUpdate := len(tx.Statement.Selects) != 0
// when updating, use all fields including those zero-value fields
if !selectedUpdate {
tx.Statement.Selects = append(tx.Statement.Selects, "*")
}
updateTx := tx.callbacks.Update().Execute(tx.Session(&Session{Initialized: true}))
if updateTx.Error == nil && updateTx.RowsAffected == 0 && !updateTx.DryRun && !selectedUpdate {
return tx.Session(&Session{SkipHooks: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(value)
}
return updateTx
}
return
}
// First finds the first record ordered by primary key, matching given conditions conds
func (db *DB) First(dest any, conds ...any) (tx *DB) {
tx = db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
})
if len(conds) > 0 {
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: exprs})
}
}
tx.Statement.RaiseErrorOnNotFound = true
tx.Statement.Dest = dest
return tx.callbacks.Query().Execute(tx)
}
// Take finds the first record returned by the database in no specified order, matching given conditions conds
func (db *DB) Take(dest any, conds ...any) (tx *DB) {
tx = db.Limit(1)
if len(conds) > 0 {
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: exprs})
}
}
tx.Statement.RaiseErrorOnNotFound = true
tx.Statement.Dest = dest
return tx.callbacks.Query().Execute(tx)
}
// Last finds the last record ordered by primary key, matching given conditions conds
func (db *DB) Last(dest any, conds ...any) (tx *DB) {
tx = db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
Desc: true,
})
if len(conds) > 0 {
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: exprs})
}
}
tx.Statement.RaiseErrorOnNotFound = true
tx.Statement.Dest = dest
return tx.callbacks.Query().Execute(tx)
}
// Find finds all records matching given conditions conds
func (db *DB) Find(dest any, conds ...any) (tx *DB) {
tx = db.getInstance()
if len(conds) > 0 {
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: exprs})
}
}
tx.Statement.Dest = dest
return tx.callbacks.Query().Execute(tx)
}
// FindInBatches finds all records in batches of batchSize
func (db *DB) FindInBatches(dest any, batchSize int, fc func(tx *DB, batch int) error) *DB {
var (
tx = db.Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
}).Session(&Session{})
queryDB = tx
rowsAffected int64
batch int
)
// user specified offset or limit
var totalSize int
if c, ok := tx.Statement.Clauses["LIMIT"]; ok {
if limit, ok := c.Expression.(clause.Limit); ok {
if limit.Limit != nil {
totalSize = *limit.Limit
}
if totalSize > 0 && batchSize > totalSize {
batchSize = totalSize
}
// reset to offset to 0 in next batch
tx = tx.Offset(-1).Session(&Session{})
}
}
for {
result := queryDB.Limit(batchSize).Find(dest)
rowsAffected += result.RowsAffected
batch++
if result.Error == nil && result.RowsAffected != 0 {
fcTx := result.Session(&Session{NewDB: true})
fcTx.RowsAffected = result.RowsAffected
tx.AddError(fc(fcTx, batch))
} else if result.Error != nil {
tx.AddError(result.Error)
}
if tx.Error != nil || int(result.RowsAffected) < batchSize {
break
}
if totalSize > 0 {
if totalSize <= int(rowsAffected) {
break
}
if totalSize/batchSize == batch {
batchSize = totalSize % batchSize
}
}
// Optimize for-break
resultsValue := reflect.Indirect(reflect.ValueOf(dest))
if result.Statement.Schema.PrioritizedPrimaryField == nil {
tx.AddError(ErrPrimaryKeyRequired)
break
}
primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1))
if zero {
tx.AddError(ErrPrimaryKeyRequired)
break
}
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
}
tx.RowsAffected = rowsAffected
return tx
}
func (db *DB) assignInterfacesToValue(values ...any) {
for _, value := range values {
switch v := value.(type) {
case []clause.Expression:
for _, expr := range v {
if eq, ok := expr.(clause.Eq); ok {
switch column := eq.Column.(type) {
case string:
if field := db.Statement.Schema.LookUpField(column); field != nil {
db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, eq.Value))
}
case clause.Column:
if field := db.Statement.Schema.LookUpField(column.Name); field != nil {
db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, eq.Value))
}
}
} else if andCond, ok := expr.(clause.AndConditions); ok {
db.assignInterfacesToValue(andCond.Exprs)
}
}
case clause.Expression, map[string]string, map[any]any, map[string]any:
if exprs := db.Statement.BuildCondition(value); len(exprs) > 0 {
db.assignInterfacesToValue(exprs)
}
default:
if s, err := schema.Parse(value, db.cacheStore, db.NamingStrategy); err == nil {
reflectValue := reflect.Indirect(reflect.ValueOf(value))
switch reflectValue.Kind() {
case reflect.Struct:
for _, f := range s.Fields {
if f.Readable {
if v, isZero := f.ValueOf(db.Statement.Context, reflectValue); !isZero {
if field := db.Statement.Schema.LookUpField(f.Name); field != nil {
db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, v))
}
}
}
}
}
} else if len(values) > 0 {
if exprs := db.Statement.BuildCondition(values[0], values[1:]...); len(exprs) > 0 {
db.assignInterfacesToValue(exprs)
}
return
}
}
}
}
// FirstOrInit finds the first matching record, otherwise if not found initializes a new instance with given conds.
// Each conds must be a struct or map.
//
// FirstOrInit never modifies the database. It is often used with Assign and Attrs.
//
// // assign an email if the record is not found
// db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
//
// // assign email regardless of if record is found
// db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
func (db *DB) FirstOrInit(dest any, conds ...any) (tx *DB) {
queryTx := db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
})
if tx = queryTx.Find(dest, conds...); tx.RowsAffected == 0 {
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok {
tx.assignInterfacesToValue(where.Exprs)
}
}
// initialize with attrs, conds
if len(tx.Statement.attrs) > 0 {
tx.assignInterfacesToValue(tx.Statement.attrs...)
}
}
// initialize with attrs, conds
if len(tx.Statement.assigns) > 0 {
tx.assignInterfacesToValue(tx.Statement.assigns...)
}
return
}
// FirstOrCreate finds the first matching record, otherwise if not found creates a new instance with given conds.
// Each conds must be a struct or map.
//
// Using FirstOrCreate in conjunction with Assign will result in an update to the database even if the record exists.
//
// // assign an email if the record is not found
// result := db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrCreate(&user)
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
// // result.RowsAffected -> 1
//
// // assign email regardless of if record is found
// result := db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrCreate(&user)
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
// // result.RowsAffected -> 1
func (db *DB) FirstOrCreate(dest any, conds ...any) (tx *DB) {
tx = db.getInstance()
queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
})
result := queryTx.Find(dest, conds...)
if result.Error != nil {
tx.Error = result.Error
return tx
}
if result.RowsAffected == 0 {
if c, ok := result.Statement.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok {
result.assignInterfacesToValue(where.Exprs)
}
}
// initialize with attrs, conds
if len(db.Statement.attrs) > 0 {
result.assignInterfacesToValue(db.Statement.attrs...)
}
// initialize with attrs, conds
if len(db.Statement.assigns) > 0 {
result.assignInterfacesToValue(db.Statement.assigns...)
}
return tx.Create(dest)
} else if len(db.Statement.assigns) > 0 {
exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...)
assigns := map[string]any{}
for i := 0; i < len(exprs); i++ {
expr := exprs[i]
if eq, ok := expr.(clause.AndConditions); ok {
exprs = append(exprs, eq.Exprs...)
} else if eq, ok := expr.(clause.Eq); ok {
switch column := eq.Column.(type) {
case string:
assigns[column] = eq.Value
case clause.Column:
assigns[column.Name] = eq.Value
}
}
}
return tx.Model(dest).Updates(assigns)
}
return tx
}
// Update updates column with value using callbacks. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields
func (db *DB) Update(column string, value any) (tx *DB) {
tx = db.getInstance()
tx.Statement.Dest = map[string]any{column: value}
return tx.callbacks.Update().Execute(tx)
}
// Updates updates attributes using callbacks. values must be a struct or map. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields
func (db *DB) Updates(values any) (tx *DB) {
tx = db.getInstance()
tx.Statement.Dest = values
return tx.callbacks.Update().Execute(tx)
}
func (db *DB) UpdateColumn(column string, value any) (tx *DB) {
tx = db.getInstance()
tx.Statement.Dest = map[string]any{column: value}
tx.Statement.SkipHooks = true
return tx.callbacks.Update().Execute(tx)
}
func (db *DB) UpdateColumns(values any) (tx *DB) {
tx = db.getInstance()
tx.Statement.Dest = values
tx.Statement.SkipHooks = true
return tx.callbacks.Update().Execute(tx)
}
// Delete deletes value matching given conditions. If value contains primary key it is included in the conditions. If
// value includes a deleted_at field, then Delete performs a soft delete instead by setting deleted_at with the current
// time if null.
func (db *DB) Delete(value any, conds ...any) (tx *DB) {
tx = db.getInstance()
if len(conds) > 0 {
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: exprs})
}
}
tx.Statement.Dest = value
return tx.callbacks.Delete().Execute(tx)
}
func (db *DB) Count(count *int64) (tx *DB) {
tx = db.getInstance()
if tx.Statement.Model == nil {
tx.Statement.Model = tx.Statement.Dest
defer func() {
tx.Statement.Model = nil
}()
}
if selectClause, ok := db.Statement.Clauses["SELECT"]; ok {
defer func() {
tx.Statement.Clauses["SELECT"] = selectClause
}()
} else {
defer delete(tx.Statement.Clauses, "SELECT")
}
if len(tx.Statement.Selects) == 0 {
tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(*)"}})
} else if !strings.HasPrefix(strings.TrimSpace(strings.ToLower(tx.Statement.Selects[0])), "count(") {
expr := clause.Expr{SQL: "count(*)"}
if len(tx.Statement.Selects) == 1 {
dbName := tx.Statement.Selects[0]
fields := strings.FieldsFunc(dbName, utils.IsValidDBNameChar)
if len(fields) == 1 || (len(fields) == 3 && (strings.ToUpper(fields[1]) == "AS" || fields[1] == ".")) {
if tx.Statement.Parse(tx.Statement.Model) == nil {
if f := tx.Statement.Schema.LookUpField(dbName); f != nil {
dbName = f.DBName
}
}
if tx.Statement.Distinct {
expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []any{clause.Column{Name: dbName}}}
} else if dbName != "*" {
expr = clause.Expr{SQL: "COUNT(?)", Vars: []any{clause.Column{Name: dbName}}}
}
}
}
tx.Statement.AddClause(clause.Select{Expression: expr})
}
if orderByClause, ok := db.Statement.Clauses["ORDER BY"]; ok {
if _, ok := db.Statement.Clauses["GROUP BY"]; !ok {
delete(tx.Statement.Clauses, "ORDER BY")
defer func() {
tx.Statement.Clauses["ORDER BY"] = orderByClause
}()
}
}
tx.Statement.Dest = count
tx = tx.callbacks.Query().Execute(tx)
if _, ok := db.Statement.Clauses["GROUP BY"]; ok || tx.RowsAffected != 1 {
*count = tx.RowsAffected
}
return
}
func (db *DB) Row() *sql.Row {
tx := db.getInstance().Set("rows", false)
tx = tx.callbacks.Row().Execute(tx)
row, ok := tx.Statement.Dest.(*sql.Row)
if !ok && tx.DryRun {
db.Logger.Error(tx.Statement.Context, ErrDryRunModeUnsupported.Error())
}
return row
}
func (db *DB) Rows() (*sql.Rows, error) {
tx := db.getInstance().Set("rows", true)
tx = tx.callbacks.Row().Execute(tx)
rows, ok := tx.Statement.Dest.(*sql.Rows)
if !ok && tx.DryRun && tx.Error == nil {
tx.Error = ErrDryRunModeUnsupported
}
return rows, tx.Error
}
// Scan scans selected value to the struct dest
func (db *DB) Scan(dest any) (tx *DB) {
config := *db.Config
currentLogger, newLogger := config.Logger, logger.Recorder.New()
config.Logger = newLogger
tx = db.getInstance()
tx.Config = &config
if rows, err := tx.Rows(); err == nil {
if rows.Next() {
tx.ScanRows(rows, dest)
} else {
tx.RowsAffected = 0
tx.AddError(rows.Err())
}
tx.AddError(rows.Close())
}
currentLogger.Trace(tx.Statement.Context, newLogger.BeginAt, func() (string, int64) {
return newLogger.SQL, tx.RowsAffected
}, tx.Error)
tx.Logger = currentLogger
return
}
// Pluck queries a single column from a model, returning in the slice dest. E.g.:
//
// var ages []int64
// db.Model(&users).Pluck("age", &ages)
func (db *DB) Pluck(column string, dest any) (tx *DB) {
tx = db.getInstance()
if tx.Statement.Model != nil {
if tx.Statement.Parse(tx.Statement.Model) == nil {
if f := tx.Statement.Schema.LookUpField(column); f != nil {
column = f.DBName
}
}
}
if len(tx.Statement.Selects) != 1 {
fields := strings.FieldsFunc(column, utils.IsValidDBNameChar)
tx.Statement.AddClauseIfNotExists(clause.Select{
Distinct: tx.Statement.Distinct,
Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}},
})
}
tx.Statement.Dest = dest
return tx.callbacks.Query().Execute(tx)
}
func (db *DB) ScanRows(rows *sql.Rows, dest any) error {
tx := db.getInstance()
if err := tx.Statement.Parse(dest); !errors.Is(err, schema.ErrUnsupportedDataType) {
tx.AddError(err)
}
tx.Statement.Dest = dest
tx.Statement.ReflectValue = reflect.ValueOf(dest)
for tx.Statement.ReflectValue.Kind() == reflect.Ptr {
elem := tx.Statement.ReflectValue.Elem()
if !elem.IsValid() {
elem = reflect.New(tx.Statement.ReflectValue.Type().Elem())
tx.Statement.ReflectValue.Set(elem)
}
tx.Statement.ReflectValue = elem
}
Scan(rows, tx, ScanInitialized)
return tx.Error
}
// Connection uses a db connection to execute an arbitrary number of commands in fc. When finished, the connection is
// returned to the connection pool.
func (db *DB) Connection(fc func(tx *DB) error) (err error) {
if db.Error != nil {
return db.Error
}
tx := db.getInstance()
sqlDB, err := tx.DB()
if err != nil {
return
}
conn, err := sqlDB.Conn(tx.Statement.Context)
if err != nil {
return
}
defer conn.Close()
tx.Statement.ConnPool = conn
return fc(tx)
}
// Transaction start a transaction as a block, return error will rollback, otherwise to commit. Transaction executes an
// arbitrary number of commands in fc within a transaction. On success the changes are committed; if an error occurs
// they are rolled back.
func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
panicked := true
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
// nested transaction
if !db.DisableNestedTransaction {
err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error
if err != nil {
return
}
defer func() {
// Make sure to rollback when panic, Block error or Commit error
if panicked || err != nil {
db.RollbackTo(fmt.Sprintf("sp%p", fc))
}
}()
}
err = fc(db.Session(&Session{NewDB: db.clone == 1}))
} else {
tx := db.Begin(opts...)
if tx.Error != nil {
return tx.Error
}
defer func() {
// Make sure to rollback when panic, Block error or Commit error
if panicked || err != nil {
tx.Rollback()
}
}()
if err = fc(tx); err == nil {
panicked = false
return tx.Commit().Error
}
}
panicked = false
return
}
// Begin begins a transaction with any transaction options opts
func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
var (
// clone statement
tx = db.getInstance().Session(&Session{Context: db.Statement.Context, NewDB: db.clone == 1})
opt *sql.TxOptions
err error
)
if len(opts) > 0 {
opt = opts[0]
}
switch beginner := tx.Statement.ConnPool.(type) {
case TxBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
case ConnPoolBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
default:
err = ErrInvalidTransaction
}
if err != nil {
tx.AddError(err)
}
return tx
}
// Commit commits the changes in a transaction
func (db *DB) Commit() *DB {
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() {
db.AddError(committer.Commit())
} else {
db.AddError(ErrInvalidTransaction)
}
return db
}
// Rollback rollbacks the changes in a transaction
func (db *DB) Rollback() *DB {
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
if !reflect.ValueOf(committer).IsNil() {
db.AddError(committer.Rollback())
}
} else {
db.AddError(ErrInvalidTransaction)
}
return db
}
func (db *DB) SavePoint(name string) *DB {
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
// close prepared statement, because SavePoint not support prepared statement.
// e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html
var (
preparedStmtTx *PreparedStmtTX
isPreparedStmtTx bool
)
// close prepared statement, because SavePoint not support prepared statement.
if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx {
db.Statement.ConnPool = preparedStmtTx.Tx
}
db.AddError(savePointer.SavePoint(db, name))
// restore prepared statement
if isPreparedStmtTx {
db.Statement.ConnPool = preparedStmtTx
}
} else {
db.AddError(ErrUnsupportedDriver)
}
return db
}
func (db *DB) RollbackTo(name string) *DB {
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
// close prepared statement, because RollbackTo not support prepared statement.
// e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html
var (
preparedStmtTx *PreparedStmtTX
isPreparedStmtTx bool
)
// close prepared statement, because SavePoint not support prepared statement.
if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx {
db.Statement.ConnPool = preparedStmtTx.Tx
}
db.AddError(savePointer.RollbackTo(db, name))
// restore prepared statement
if isPreparedStmtTx {
db.Statement.ConnPool = preparedStmtTx
}
} else {
db.AddError(ErrUnsupportedDriver)
}
return db
}
// Exec executes raw sql
func (db *DB) Exec(sql string, values ...any) (tx *DB) {
tx = db.getInstance()
tx.Statement.SQL = strings.Builder{}
if strings.Contains(sql, "@") {
clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement)
} else {
clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
}
return tx.callbacks.Raw().Execute(tx)
}
================================================
FILE: gossh/gorm/gorm.go
================================================
package gorm
import (
"context"
"database/sql"
"fmt"
"reflect"
"sort"
"sync"
"time"
"gossh/gorm/clause"
"gossh/gorm/logger"
"gossh/gorm/schema"
)
// for Config.cacheStore store PreparedStmtDB key
const preparedStmtDBKey = "preparedStmt"
// Config GORM config
type Config struct {
// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
// You can disable it by setting `SkipDefaultTransaction` to true
SkipDefaultTransaction bool
// NamingStrategy tables, columns naming strategy
NamingStrategy schema.Namer
// FullSaveAssociations full save associations
FullSaveAssociations bool
// Logger
Logger logger.Interface
// NowFunc the function to be used when creating a new timestamp
NowFunc func() time.Time
// DryRun generate sql without execute
DryRun bool
// PrepareStmt executes the given query in cached statement
PrepareStmt bool
// DisableAutomaticPing
DisableAutomaticPing bool
// DisableForeignKeyConstraintWhenMigrating
DisableForeignKeyConstraintWhenMigrating bool
// IgnoreRelationshipsWhenMigrating
IgnoreRelationshipsWhenMigrating bool
// DisableNestedTransaction disable nested transaction
DisableNestedTransaction bool
// AllowGlobalUpdate allow global update
AllowGlobalUpdate bool
// QueryFields executes the SQL query with all fields of the table
QueryFields bool
// CreateBatchSize default create batch size
CreateBatchSize int
// TranslateError enabling error translation
TranslateError bool
// ClauseBuilders clause builder
ClauseBuilders map[string]clause.ClauseBuilder
// ConnPool db conn pool
ConnPool ConnPool
// Dialector database dialector
Dialector
// Plugins registered plugins
Plugins map[string]Plugin
callbacks *callbacks
cacheStore *sync.Map
}
// Apply update config to new config
func (c *Config) Apply(config *Config) error {
if config != c {
*config = *c
}
return nil
}
// AfterInitialize initialize plugins after db connected
func (c *Config) AfterInitialize(db *DB) error {
if db != nil {
for _, plugin := range c.Plugins {
if err := plugin.Initialize(db); err != nil {
return err
}
}
}
return nil
}
// Option gorm option interface
type Option interface {
Apply(*Config) error
AfterInitialize(*DB) error
}
// DB GORM DB definition
type DB struct {
*Config
Error error
RowsAffected int64
Statement *Statement
clone int
}
// Session session config when create session with Session() method
type Session struct {
DryRun bool
PrepareStmt bool
NewDB bool
Initialized bool
SkipHooks bool
SkipDefaultTransaction bool
DisableNestedTransaction bool
AllowGlobalUpdate bool
FullSaveAssociations bool
QueryFields bool
Context context.Context
Logger logger.Interface
NowFunc func() time.Time
CreateBatchSize int
}
// Open initialize db session based on dialector
func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
config := &Config{}
sort.Slice(opts, func(i, j int) bool {
_, isConfig := opts[i].(*Config)
_, isConfig2 := opts[j].(*Config)
return isConfig && !isConfig2
})
for _, opt := range opts {
if opt != nil {
if applyErr := opt.Apply(config); applyErr != nil {
return nil, applyErr
}
defer func(opt Option) {
if errr := opt.AfterInitialize(db); errr != nil {
err = errr
}
}(opt)
}
}
if d, ok := dialector.(interface{ Apply(*Config) error }); ok {
if err = d.Apply(config); err != nil {
return
}
}
if config.NamingStrategy == nil {
config.NamingStrategy = schema.NamingStrategy{IdentifierMaxLength: 64} // Default Identifier length is 64
}
if config.Logger == nil {
config.Logger = logger.Default
}
if config.NowFunc == nil {
config.NowFunc = func() time.Time { return time.Now().Local() }
}
if dialector != nil {
config.Dialector = dialector
}
if config.Plugins == nil {
config.Plugins = map[string]Plugin{}
}
if config.cacheStore == nil {
config.cacheStore = &sync.Map{}
}
db = &DB{Config: config, clone: 1}
db.callbacks = initializeCallbacks(db)
if config.ClauseBuilders == nil {
config.ClauseBuilders = map[string]clause.ClauseBuilder{}
}
if config.Dialector != nil {
err = config.Dialector.Initialize(db)
if err != nil {
if db, _ := db.DB(); db != nil {
_ = db.Close()
}
}
}
if config.PrepareStmt {
preparedStmt := NewPreparedStmtDB(db.ConnPool)
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
db.ConnPool = preparedStmt
}
db.Statement = &Statement{
DB: db,
ConnPool: db.ConnPool,
Context: context.Background(),
Clauses: map[string]clause.Clause{},
}
if err == nil && !config.DisableAutomaticPing {
if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok {
err = pinger.Ping()
}
}
if err != nil {
config.Logger.Error(context.Background(), "failed to initialize database, got error %v", err)
}
return
}
// Session create new db session
func (db *DB) Session(config *Session) *DB {
var (
txConfig = *db.Config
tx = &DB{
Config: &txConfig,
Statement: db.Statement,
Error: db.Error,
clone: 1,
}
)
if config.CreateBatchSize > 0 {
tx.Config.CreateBatchSize = config.CreateBatchSize
}
if config.SkipDefaultTransaction {
tx.Config.SkipDefaultTransaction = true
}
if config.AllowGlobalUpdate {
txConfig.AllowGlobalUpdate = true
}
if config.FullSaveAssociations {
txConfig.FullSaveAssociations = true
}
if config.Context != nil || config.PrepareStmt || config.SkipHooks {
tx.Statement = tx.Statement.clone()
tx.Statement.DB = tx
}
if config.Context != nil {
tx.Statement.Context = config.Context
}
if config.PrepareStmt {
var preparedStmt *PreparedStmtDB
if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok {
preparedStmt = v.(*PreparedStmtDB)
} else {
preparedStmt = NewPreparedStmtDB(db.ConnPool)
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
}
switch t := tx.Statement.ConnPool.(type) {
case Tx:
tx.Statement.ConnPool = &PreparedStmtTX{
Tx: t,
PreparedStmtDB: preparedStmt,
}
default:
tx.Statement.ConnPool = &PreparedStmtDB{
ConnPool: db.Config.ConnPool,
Mux: preparedStmt.Mux,
Stmts: preparedStmt.Stmts,
}
}
txConfig.ConnPool = tx.Statement.ConnPool
txConfig.PrepareStmt = true
}
if config.SkipHooks {
tx.Statement.SkipHooks = true
}
if config.DisableNestedTransaction {
txConfig.DisableNestedTransaction = true
}
if !config.NewDB {
tx.clone = 2
}
if config.DryRun {
tx.Config.DryRun = true
}
if config.QueryFields {
tx.Config.QueryFields = true
}
if config.Logger != nil {
tx.Config.Logger = config.Logger
}
if config.NowFunc != nil {
tx.Config.NowFunc = config.NowFunc
}
if config.Initialized {
tx = tx.getInstance()
}
return tx
}
// WithContext change current instance db's context to ctx
func (db *DB) WithContext(ctx context.Context) *DB {
return db.Session(&Session{Context: ctx})
}
// Debug start debug mode
func (db *DB) Debug() (tx *DB) {
tx = db.getInstance()
return tx.Session(&Session{
Logger: db.Logger.LogMode(logger.Info),
})
}
// Set store value with key into current db instance's context
func (db *DB) Set(key string, value any) *DB {
tx := db.getInstance()
tx.Statement.Settings.Store(key, value)
return tx
}
// Get get value with key from current db instance's context
func (db *DB) Get(key string) (any, bool) {
return db.Statement.Settings.Load(key)
}
// InstanceSet store value with key into current db instance's context
func (db *DB) InstanceSet(key string, value any) *DB {
tx := db.getInstance()
tx.Statement.Settings.Store(fmt.Sprintf("%p", tx.Statement)+key, value)
return tx
}
// InstanceGet get value with key from current db instance's context
func (db *DB) InstanceGet(key string) (any, bool) {
return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key)
}
// Callback returns callback manager
func (db *DB) Callback() *callbacks {
return db.callbacks
}
// AddError add error to db
func (db *DB) AddError(err error) error {
if err != nil {
if db.Config.TranslateError {
if errTranslator, ok := db.Dialector.(ErrorTranslator); ok {
err = errTranslator.Translate(err)
}
}
if db.Error == nil {
db.Error = err
} else {
db.Error = fmt.Errorf("%v; %w", db.Error, err)
}
}
return db.Error
}
// DB returns `*sql.DB`
func (db *DB) DB() (*sql.DB, error) {
connPool := db.ConnPool
if db.Statement != nil && db.Statement.ConnPool != nil {
connPool = db.Statement.ConnPool
}
if tx, ok := connPool.(*sql.Tx); ok && tx != nil {
return (*sql.DB)(reflect.ValueOf(tx).Elem().FieldByName("db").UnsafePointer()), nil
}
if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
if sqldb, err := dbConnector.GetDBConn(); sqldb != nil || err != nil {
return sqldb, err
}
}
if sqldb, ok := connPool.(*sql.DB); ok && sqldb != nil {
return sqldb, nil
}
return nil, ErrInvalidDB
}
func (db *DB) getInstance() *DB {
if db.clone > 0 {
tx := &DB{Config: db.Config, Error: db.Error}
if db.clone == 1 {
// clone with new statement
tx.Statement = &Statement{
DB: tx,
ConnPool: db.Statement.ConnPool,
Context: db.Statement.Context,
Clauses: map[string]clause.Clause{},
Vars: make([]any, 0, 8),
SkipHooks: db.Statement.SkipHooks,
}
} else {
// with clone statement
tx.Statement = db.Statement.clone()
tx.Statement.DB = tx
}
return tx
}
return db
}
// Expr returns clause.Expr, which can be used to pass SQL expression as params
func Expr(expr string, args ...any) clause.Expr {
return clause.Expr{SQL: expr, Vars: args}
}
// SetupJoinTable setup join table schema
func (db *DB) SetupJoinTable(model any, field string, joinTable any) error {
var (
tx = db.getInstance()
stmt = tx.Statement
modelSchema, joinSchema *schema.Schema
)
err := stmt.Parse(model)
if err != nil {
return err
}
modelSchema = stmt.Schema
err = stmt.Parse(joinTable)
if err != nil {
return err
}
joinSchema = stmt.Schema
relation, ok := modelSchema.Relationships.Relations[field]
isRelation := ok && relation.JoinTable != nil
if !isRelation {
return fmt.Errorf("failed to find relation: %s", field)
}
for _, ref := range relation.References {
f := joinSchema.LookUpField(ref.ForeignKey.DBName)
if f == nil {
return fmt.Errorf("missing field %s for join table", ref.ForeignKey.DBName)
}
f.DataType = ref.ForeignKey.DataType
f.GORMDataType = ref.ForeignKey.GORMDataType
if f.Size == 0 {
f.Size = ref.ForeignKey.Size
}
ref.ForeignKey = f
}
for name, rel := range relation.JoinTable.Relationships.Relations {
if _, ok := joinSchema.Relationships.Relations[name]; !ok {
rel.Schema = joinSchema
joinSchema.Relationships.Relations[name] = rel
}
}
relation.JoinTable = joinSchema
return nil
}
// Use use plugin
func (db *DB) Use(plugin Plugin) error {
name := plugin.Name()
if _, ok := db.Plugins[name]; ok {
return ErrRegistered
}
if err := plugin.Initialize(db); err != nil {
return err
}
db.Plugins[name] = plugin
return nil
}
// ToSQL for generate SQL string.
//
// db.ToSQL(func(tx *gorm.DB) *gorm.DB {
// return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20})
// .Limit(10).Offset(5)
// .Order("name ASC")
// .First(&User{})
// })
func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string {
tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}))
stmt := tx.Statement
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
}
================================================
FILE: gossh/gorm/inflection/inflections.go
================================================
/*
Package inflection pluralizes and singularizes English nouns.
inflection.Plural("person") => "people"
inflection.Plural("Person") => "People"
inflection.Plural("PERSON") => "PEOPLE"
inflection.Singular("people") => "person"
inflection.Singular("People") => "Person"
inflection.Singular("PEOPLE") => "PERSON"
inflection.Plural("FancyPerson") => "FancydPeople"
inflection.Singular("FancyPeople") => "FancydPerson"
Standard rules are from Rails's ActiveSupport (https://github.com/rails/rails/blob/master/activesupport/lib/active_support/inflections.rb)
If you want to register more rules, follow:
inflection.AddUncountable("fish")
inflection.AddIrregular("person", "people")
inflection.AddPlural("(bu)s$", "${1}ses") # "bus" => "buses" / "BUS" => "BUSES" / "Bus" => "Buses"
inflection.AddSingular("(bus)(es)?$", "${1}") # "buses" => "bus" / "Buses" => "Bus" / "BUSES" => "BUS"
*/
package inflection
import (
"regexp"
"strings"
)
type inflection struct {
regexp *regexp.Regexp
replace string
}
// Regular is a regexp find replace inflection
type Regular struct {
find string
replace string
}
// Irregular is a hard replace inflection,
// containing both singular and plural forms
type Irregular struct {
singular string
plural string
}
// RegularSlice is a slice of Regular inflections
type RegularSlice []Regular
// IrregularSlice is a slice of Irregular inflections
type IrregularSlice []Irregular
var pluralInflections = RegularSlice{
{"([a-z])$", "${1}s"},
{"s$", "s"},
{"^(ax|test)is$", "${1}es"},
{"(octop|vir)us$", "${1}i"},
{"(octop|vir)i$", "${1}i"},
{"(alias|status)$", "${1}es"},
{"(bu)s$", "${1}ses"},
{"(buffal|tomat)o$", "${1}oes"},
{"([ti])um$", "${1}a"},
{"([ti])a$", "${1}a"},
{"sis$", "ses"},
{"(?:([^f])fe|([lr])f)$", "${1}${2}ves"},
{"(hive)$", "${1}s"},
{"([^aeiouy]|qu)y$", "${1}ies"},
{"(x|ch|ss|sh)$", "${1}es"},
{"(matr|vert|ind)(?:ix|ex)$", "${1}ices"},
{"^(m|l)ouse$", "${1}ice"},
{"^(m|l)ice$", "${1}ice"},
{"^(ox)$", "${1}en"},
{"^(oxen)$", "${1}"},
{"(quiz)$", "${1}zes"},
}
var singularInflections = RegularSlice{
{"s$", ""},
{"(ss)$", "${1}"},
{"(n)ews$", "${1}ews"},
{"([ti])a$", "${1}um"},
{"((a)naly|(b)a|(d)iagno|(p)arenthe|(p)rogno|(s)ynop|(t)he)(sis|ses)$", "${1}sis"},
{"(^analy)(sis|ses)$", "${1}sis"},
{"([^f])ves$", "${1}fe"},
{"(hive)s$", "${1}"},
{"(tive)s$", "${1}"},
{"([lr])ves$", "${1}f"},
{"([^aeiouy]|qu)ies$", "${1}y"},
{"(s)eries$", "${1}eries"},
{"(m)ovies$", "${1}ovie"},
{"(c)ookies$", "${1}ookie"},
{"(x|ch|ss|sh)es$", "${1}"},
{"^(m|l)ice$", "${1}ouse"},
{"(bus)(es)?$", "${1}"},
{"(o)es$", "${1}"},
{"(shoe)s$", "${1}"},
{"(cris|test)(is|es)$", "${1}is"},
{"^(a)x[ie]s$", "${1}xis"},
{"(octop|vir)(us|i)$", "${1}us"},
{"(alias|status)(es)?$", "${1}"},
{"^(ox)en", "${1}"},
{"(vert|ind)ices$", "${1}ex"},
{"(matr)ices$", "${1}ix"},
{"(quiz)zes$", "${1}"},
{"(database)s$", "${1}"},
}
var irregularInflections = IrregularSlice{
{"person", "people"},
{"man", "men"},
{"child", "children"},
{"sex", "sexes"},
{"move", "moves"},
{"mombie", "mombies"},
}
var uncountableInflections = []string{"equipment", "information", "rice", "money", "species", "series", "fish", "sheep", "jeans", "police"}
var compiledPluralMaps []inflection
var compiledSingularMaps []inflection
func compile() {
compiledPluralMaps = []inflection{}
compiledSingularMaps = []inflection{}
for _, uncountable := range uncountableInflections {
inf := inflection{
regexp: regexp.MustCompile("^(?i)(" + uncountable + ")$"),
replace: "${1}",
}
compiledPluralMaps = append(compiledPluralMaps, inf)
compiledSingularMaps = append(compiledSingularMaps, inf)
}
for _, value := range irregularInflections {
infs := []inflection{
inflection{regexp: regexp.MustCompile(strings.ToUpper(value.singular) + "$"), replace: strings.ToUpper(value.plural)},
inflection{regexp: regexp.MustCompile(strings.Title(value.singular) + "$"), replace: strings.Title(value.plural)},
inflection{regexp: regexp.MustCompile(value.singular + "$"), replace: value.plural},
}
compiledPluralMaps = append(compiledPluralMaps, infs...)
}
for _, value := range irregularInflections {
infs := []inflection{
inflection{regexp: regexp.MustCompile(strings.ToUpper(value.plural) + "$"), replace: strings.ToUpper(value.singular)},
inflection{regexp: regexp.MustCompile(strings.Title(value.plural) + "$"), replace: strings.Title(value.singular)},
inflection{regexp: regexp.MustCompile(value.plural + "$"), replace: value.singular},
}
compiledSingularMaps = append(compiledSingularMaps, infs...)
}
for i := len(pluralInflections) - 1; i >= 0; i-- {
value := pluralInflections[i]
infs := []inflection{
inflection{regexp: regexp.MustCompile(strings.ToUpper(value.find)), replace: strings.ToUpper(value.replace)},
inflection{regexp: regexp.MustCompile(value.find), replace: value.replace},
inflection{regexp: regexp.MustCompile("(?i)" + value.find), replace: value.replace},
}
compiledPluralMaps = append(compiledPluralMaps, infs...)
}
for i := len(singularInflections) - 1; i >= 0; i-- {
value := singularInflections[i]
infs := []inflection{
inflection{regexp: regexp.MustCompile(strings.ToUpper(value.find)), replace: strings.ToUpper(value.replace)},
inflection{regexp: regexp.MustCompile(value.find), replace: value.replace},
inflection{regexp: regexp.MustCompile("(?i)" + value.find), replace: value.replace},
}
compiledSingularMaps = append(compiledSingularMaps, infs...)
}
}
func init() {
compile()
}
// AddPlural adds a plural inflection
func AddPlural(find, replace string) {
pluralInflections = append(pluralInflections, Regular{find, replace})
compile()
}
// AddSingular adds a singular inflection
func AddSingular(find, replace string) {
singularInflections = append(singularInflections, Regular{find, replace})
compile()
}
// AddIrregular adds an irregular inflection
func AddIrregular(singular, plural string) {
irregularInflections = append(irregularInflections, Irregular{singular, plural})
compile()
}
// AddUncountable adds an uncountable inflection
func AddUncountable(values ...string) {
uncountableInflections = append(uncountableInflections, values...)
compile()
}
// GetPlural retrieves the plural inflection values
func GetPlural() RegularSlice {
plurals := make(RegularSlice, len(pluralInflections))
copy(plurals, pluralInflections)
return plurals
}
// GetSingular retrieves the singular inflection values
func GetSingular() RegularSlice {
singulars := make(RegularSlice, len(singularInflections))
copy(singulars, singularInflections)
return singulars
}
// GetIrregular retrieves the irregular inflection values
func GetIrregular() IrregularSlice {
irregular := make(IrregularSlice, len(irregularInflections))
copy(irregular, irregularInflections)
return irregular
}
// GetUncountable retrieves the uncountable inflection values
func GetUncountable() []string {
uncountables := make([]string, len(uncountableInflections))
copy(uncountables, uncountableInflections)
return uncountables
}
// SetPlural sets the plural inflections slice
func SetPlural(inflections RegularSlice) {
pluralInflections = inflections
compile()
}
// SetSingular sets the singular inflections slice
func SetSingular(inflections RegularSlice) {
singularInflections = inflections
compile()
}
// SetIrregular sets the irregular inflections slice
func SetIrregular(inflections IrregularSlice) {
irregularInflections = inflections
compile()
}
// SetUncountable sets the uncountable inflections slice
func SetUncountable(inflections []string) {
uncountableInflections = inflections
compile()
}
// Plural converts a word to its plural form
func Plural(str string) string {
for _, inflection := range compiledPluralMaps {
if inflection.regexp.MatchString(str) {
return inflection.regexp.ReplaceAllString(str, inflection.replace)
}
}
return str
}
// Singular converts a word to its singular form
func Singular(str string) string {
for _, inflection := range compiledSingularMaps {
if inflection.regexp.MatchString(str) {
return inflection.regexp.ReplaceAllString(str, inflection.replace)
}
}
return str
}
================================================
FILE: gossh/gorm/interfaces.go
================================================
package gorm
import (
"context"
"database/sql"
"gossh/gorm/clause"
"gossh/gorm/schema"
)
// Dialector GORM database dialector
type Dialector interface {
Name() string
Initialize(*DB) error
Migrator(db *DB) Migrator
DataTypeOf(*schema.Field) string
DefaultValueOf(*schema.Field) clause.Expression
BindVarTo(writer clause.Writer, stmt *Statement, v any)
QuoteTo(clause.Writer, string)
Explain(sql string, vars ...any) string
}
// Plugin GORM plugin interface
type Plugin interface {
Name() string
Initialize(*DB) error
}
type ParamsFilter interface {
ParamsFilter(ctx context.Context, sql string, params ...any) (string, []any)
}
// ConnPool db conns pool interface
type ConnPool interface {
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
}
// SavePointerDialectorInterface save pointer interface
type SavePointerDialectorInterface interface {
SavePoint(tx *DB, name string) error
RollbackTo(tx *DB, name string) error
}
// TxBeginner tx beginner
type TxBeginner interface {
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
}
// ConnPoolBeginner conn pool beginner
type ConnPoolBeginner interface {
BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error)
}
// TxCommitter tx committer
type TxCommitter interface {
Commit() error
Rollback() error
}
// Tx sql.Tx interface
type Tx interface {
ConnPool
TxCommitter
StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt
}
// Valuer gorm valuer interface
type Valuer interface {
GormValue(context.Context, *DB) clause.Expr
}
// GetDBConnector SQL db connector
type GetDBConnector interface {
GetDBConn() (*sql.DB, error)
}
// Rows rows interface
type Rows interface {
Columns() ([]string, error)
ColumnTypes() ([]*sql.ColumnType, error)
Next() bool
Scan(dest ...any) error
Err() error
Close() error
}
type ErrorTranslator interface {
Translate(err error) error
}
================================================
FILE: gossh/gorm/logger/logger.go
================================================
package logger
import (
"context"
"errors"
"fmt"
"io"
"log"
"os"
"time"
"gossh/gorm/utils"
)
// ErrRecordNotFound record not found error
var ErrRecordNotFound = errors.New("record not found")
// Colors
const (
Reset = "\033[0m"
Red = "\033[31m"
Green = "\033[32m"
Yellow = "\033[33m"
Blue = "\033[34m"
Magenta = "\033[35m"
Cyan = "\033[36m"
White = "\033[37m"
BlueBold = "\033[34;1m"
MagentaBold = "\033[35;1m"
RedBold = "\033[31;1m"
YellowBold = "\033[33;1m"
)
// LogLevel log level
type LogLevel int
const (
// Silent silent log level
Silent LogLevel = iota + 1
// Error error log level
Error
// Warn warn log level
Warn
// Info info log level
Info
)
// Writer log writer interface
type Writer interface {
Printf(string, ...any)
}
// Config logger config
type Config struct {
SlowThreshold time.Duration
Colorful bool
IgnoreRecordNotFoundError bool
ParameterizedQueries bool
LogLevel LogLevel
}
// Interface logger interface
type Interface interface {
LogMode(LogLevel) Interface
Info(context.Context, string, ...any)
Warn(context.Context, string, ...any)
Error(context.Context, string, ...any)
Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error)
}
var (
// Discard logger will print any log to io.Discard
Discard = New(log.New(io.Discard, "", log.LstdFlags), Config{})
// Default Default logger
Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
SlowThreshold: 200 * time.Millisecond,
LogLevel: Warn,
IgnoreRecordNotFoundError: false,
Colorful: true,
})
// Recorder logger records running SQL into a recorder instance
Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()}
)
// New initialize logger
func New(writer Writer, config Config) Interface {
var (
infoStr = "%s\n[info] "
warnStr = "%s\n[warn] "
errStr = "%s\n[error] "
traceStr = "%s\n[%.3fms] [rows:%v] %s"
traceWarnStr = "%s %s\n[%.3fms] [rows:%v] %s"
traceErrStr = "%s %s\n[%.3fms] [rows:%v] %s"
)
if config.Colorful {
infoStr = Green + "%s\n" + Reset + Green + "[info] " + Reset
warnStr = BlueBold + "%s\n" + Reset + Magenta + "[warn] " + Reset
errStr = Magenta + "%s\n" + Reset + Red + "[error] " + Reset
traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s"
traceWarnStr = Green + "%s " + Yellow + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%v]" + Magenta + " %s" + Reset
traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s"
}
return &logger{
Writer: writer,
Config: config,
infoStr: infoStr,
warnStr: warnStr,
errStr: errStr,
traceStr: traceStr,
traceWarnStr: traceWarnStr,
traceErrStr: traceErrStr,
}
}
type logger struct {
Writer
Config
infoStr, warnStr, errStr string
traceStr, traceErrStr, traceWarnStr string
}
// LogMode log mode
func (l *logger) LogMode(level LogLevel) Interface {
newlogger := *l
newlogger.LogLevel = level
return &newlogger
}
// Info print info
func (l *logger) Info(ctx context.Context, msg string, data ...any) {
if l.LogLevel >= Info {
l.Printf(l.infoStr+msg, append([]any{utils.FileWithLineNum()}, data...)...)
}
}
// Warn print warn messages
func (l *logger) Warn(ctx context.Context, msg string, data ...any) {
if l.LogLevel >= Warn {
l.Printf(l.warnStr+msg, append([]any{utils.FileWithLineNum()}, data...)...)
}
}
// Error print error messages
func (l *logger) Error(ctx context.Context, msg string, data ...any) {
if l.LogLevel >= Error {
l.Printf(l.errStr+msg, append([]any{utils.FileWithLineNum()}, data...)...)
}
}
// Trace print sql message
//
//nolint:cyclop
func (l *logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
if l.LogLevel <= Silent {
return
}
elapsed := time.Since(begin)
switch {
case err != nil && l.LogLevel >= Error && (!errors.Is(err, ErrRecordNotFound) || !l.IgnoreRecordNotFoundError):
sql, rows := fc()
if rows == -1 {
l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, "-", sql)
} else {
l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql)
}
case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= Warn:
sql, rows := fc()
slowLog := fmt.Sprintf("SLOW SQL >= %v", l.SlowThreshold)
if rows == -1 {
l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, "-", sql)
} else {
l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, rows, sql)
}
case l.LogLevel == Info:
sql, rows := fc()
if rows == -1 {
l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql)
} else {
l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql)
}
}
}
// ParamsFilter filter params
func (l *logger) ParamsFilter(ctx context.Context, sql string, params ...any) (string, []any) {
if l.Config.ParameterizedQueries {
return sql, nil
}
return sql, params
}
type traceRecorder struct {
Interface
BeginAt time.Time
SQL string
RowsAffected int64
Err error
}
// New trace recorder
func (l *traceRecorder) New() *traceRecorder {
return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()}
}
// Trace implement logger interface
func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
l.BeginAt = begin
l.SQL, l.RowsAffected = fc()
l.Err = err
}
================================================
FILE: gossh/gorm/logger/sql.go
================================================
package logger
import (
"database/sql/driver"
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
"time"
"unicode"
"gossh/gorm/utils"
)
const (
tmFmtWithMS = "2006-01-02 15:04:05.999"
tmFmtZero = "0000-00-00 00:00:00"
nullStr = "NULL"
)
func isPrintable(s string) bool {
for _, r := range s {
if !unicode.IsPrint(r) {
return false
}
}
return true
}
// A list of Go types that should be converted to SQL primitives
var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
// RegEx matches only numeric values
var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`)
func isNumeric(k reflect.Kind) bool {
switch k {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return true
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return true
case reflect.Float32, reflect.Float64:
return true
default:
return false
}
}
// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...any) string {
var (
convertParams func(any, int)
vars = make([]string, len(avars))
)
convertParams = func(v any, idx int) {
switch v := v.(type) {
case bool:
vars[idx] = strconv.FormatBool(v)
case time.Time:
if v.IsZero() {
vars[idx] = escaper + tmFmtZero + escaper
} else {
vars[idx] = escaper + v.Format(tmFmtWithMS) + escaper
}
case *time.Time:
if v != nil {
if v.IsZero() {
vars[idx] = escaper + tmFmtZero + escaper
} else {
vars[idx] = escaper + v.Format(tmFmtWithMS) + escaper
}
} else {
vars[idx] = nullStr
}
case driver.Valuer:
reflectValue := reflect.ValueOf(v)
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
r, _ := v.Value()
convertParams(r, idx)
} else {
vars[idx] = nullStr
}
case fmt.Stringer:
reflectValue := reflect.ValueOf(v)
switch reflectValue.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
vars[idx] = fmt.Sprintf("%d", reflectValue.Interface())
case reflect.Float32, reflect.Float64:
vars[idx] = fmt.Sprintf("%.6f", reflectValue.Interface())
case reflect.Bool:
vars[idx] = fmt.Sprintf("%t", reflectValue.Interface())
case reflect.String:
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper
default:
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper
} else {
vars[idx] = nullStr
}
}
case []byte:
if s := string(v); isPrintable(s) {
vars[idx] = escaper + strings.ReplaceAll(s, escaper, escaper+escaper) + escaper
} else {
vars[idx] = escaper + "" + escaper
}
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
vars[idx] = utils.ToString(v)
case float32:
vars[idx] = strconv.FormatFloat(float64(v), 'f', -1, 32)
case float64:
vars[idx] = strconv.FormatFloat(v, 'f', -1, 64)
case string:
vars[idx] = escaper + strings.ReplaceAll(v, escaper, escaper+escaper) + escaper
default:
rv := reflect.ValueOf(v)
if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() {
vars[idx] = nullStr
} else if valuer, ok := v.(driver.Valuer); ok {
v, _ = valuer.Value()
convertParams(v, idx)
} else if rv.Kind() == reflect.Ptr && !rv.IsZero() {
convertParams(reflect.Indirect(rv).Interface(), idx)
} else if isNumeric(rv.Kind()) {
if rv.CanInt() || rv.CanUint() {
vars[idx] = fmt.Sprintf("%d", rv.Interface())
} else {
vars[idx] = fmt.Sprintf("%.6f", rv.Interface())
}
} else {
for _, t := range convertibleTypes {
if rv.Type().ConvertibleTo(t) {
convertParams(rv.Convert(t).Interface(), idx)
return
}
}
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, escaper+escaper) + escaper
}
}
}
for idx, v := range avars {
convertParams(v, idx)
}
if numericPlaceholder == nil {
var idx int
var newSQL strings.Builder
for _, v := range []byte(sql) {
if v == '?' {
if len(vars) > idx {
newSQL.WriteString(vars[idx])
idx++
continue
}
}
newSQL.WriteByte(v)
}
sql = newSQL.String()
} else {
sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$")
sql = numericPlaceholderRe.ReplaceAllStringFunc(sql, func(v string) string {
num := v[1 : len(v)-1]
n, _ := strconv.Atoi(num)
// position var start from 1 ($1, $2)
n -= 1
if n >= 0 && n <= len(vars)-1 {
return vars[n]
}
return v
})
}
return sql
}
================================================
FILE: gossh/gorm/migrator/column_type.go
================================================
package migrator
import (
"database/sql"
"reflect"
)
// ColumnType column type implements ColumnType interface
type ColumnType struct {
SQLColumnType *sql.ColumnType
NameValue sql.NullString
DataTypeValue sql.NullString
ColumnTypeValue sql.NullString
PrimaryKeyValue sql.NullBool
UniqueValue sql.NullBool
AutoIncrementValue sql.NullBool
LengthValue sql.NullInt64
DecimalSizeValue sql.NullInt64
ScaleValue sql.NullInt64
NullableValue sql.NullBool
ScanTypeValue reflect.Type
CommentValue sql.NullString
DefaultValueValue sql.NullString
}
// Name returns the name or alias of the column.
func (ct ColumnType) Name() string {
if ct.NameValue.Valid {
return ct.NameValue.String
}
return ct.SQLColumnType.Name()
}
// DatabaseTypeName returns the database system name of the column type. If an empty
// string is returned, then the driver type name is not supported.
// Consult your driver documentation for a list of driver data types. Length specifiers
// are not included.
// Common type names include "VARCHAR", "TEXT", "NVARCHAR", "DECIMAL", "BOOL",
// "INT", and "BIGINT".
func (ct ColumnType) DatabaseTypeName() string {
if ct.DataTypeValue.Valid {
return ct.DataTypeValue.String
}
return ct.SQLColumnType.DatabaseTypeName()
}
// ColumnType returns the database type of the column. like `varchar(16)`
func (ct ColumnType) ColumnType() (columnType string, ok bool) {
return ct.ColumnTypeValue.String, ct.ColumnTypeValue.Valid
}
// PrimaryKey returns the column is primary key or not.
func (ct ColumnType) PrimaryKey() (isPrimaryKey bool, ok bool) {
return ct.PrimaryKeyValue.Bool, ct.PrimaryKeyValue.Valid
}
// AutoIncrement returns the column is auto increment or not.
func (ct ColumnType) AutoIncrement() (isAutoIncrement bool, ok bool) {
return ct.AutoIncrementValue.Bool, ct.AutoIncrementValue.Valid
}
// Length returns the column type length for variable length column types
func (ct ColumnType) Length() (length int64, ok bool) {
if ct.LengthValue.Valid {
return ct.LengthValue.Int64, true
}
return ct.SQLColumnType.Length()
}
// DecimalSize returns the scale and precision of a decimal type.
func (ct ColumnType) DecimalSize() (precision int64, scale int64, ok bool) {
if ct.DecimalSizeValue.Valid {
return ct.DecimalSizeValue.Int64, ct.ScaleValue.Int64, true
}
return ct.SQLColumnType.DecimalSize()
}
// Nullable reports whether the column may be null.
func (ct ColumnType) Nullable() (nullable bool, ok bool) {
if ct.NullableValue.Valid {
return ct.NullableValue.Bool, true
}
return ct.SQLColumnType.Nullable()
}
// Unique reports whether the column may be unique.
func (ct ColumnType) Unique() (unique bool, ok bool) {
return ct.UniqueValue.Bool, ct.UniqueValue.Valid
}
// ScanType returns a Go type suitable for scanning into using Rows.Scan.
func (ct ColumnType) ScanType() reflect.Type {
if ct.ScanTypeValue != nil {
return ct.ScanTypeValue
}
return ct.SQLColumnType.ScanType()
}
// Comment returns the comment of current column.
func (ct ColumnType) Comment() (value string, ok bool) {
return ct.CommentValue.String, ct.CommentValue.Valid
}
// DefaultValue returns the default value of current column.
func (ct ColumnType) DefaultValue() (value string, ok bool) {
return ct.DefaultValueValue.String, ct.DefaultValueValue.Valid
}
================================================
FILE: gossh/gorm/migrator/index.go
================================================
package migrator
import "database/sql"
// Index implements gorm.Index interface
type Index struct {
TableName string
NameValue string
ColumnList []string
PrimaryKeyValue sql.NullBool
UniqueValue sql.NullBool
OptionValue string
}
// Table return the table name of the index.
func (idx Index) Table() string {
return idx.TableName
}
// Name return the name of the index.
func (idx Index) Name() string {
return idx.NameValue
}
// Columns return the columns of the index
func (idx Index) Columns() []string {
return idx.ColumnList
}
// PrimaryKey returns the index is primary key or not.
func (idx Index) PrimaryKey() (isPrimaryKey bool, ok bool) {
return idx.PrimaryKeyValue.Bool, idx.PrimaryKeyValue.Valid
}
// Unique returns whether the index is unique or not.
func (idx Index) Unique() (unique bool, ok bool) {
return idx.UniqueValue.Bool, idx.UniqueValue.Valid
}
// Option return the optional attribute of the index
func (idx Index) Option() string {
return idx.OptionValue
}
================================================
FILE: gossh/gorm/migrator/migrator.go
================================================
package migrator
import (
"context"
"database/sql"
"errors"
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
"time"
"gossh/gorm"
"gossh/gorm/clause"
"gossh/gorm/logger"
"gossh/gorm/schema"
)
// This regular expression seeks to find a sequence of digits (\d+) among zero or more non-digit characters (\D*),
// with a possible trailing non-digit character (\D?).
// For example, values that can pass this regular expression are:
// - "123"
// - "abc456"
// -"%$#@789"
var regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`)
// TODO:? Create const vars for raw sql queries ?
var _ gorm.Migrator = (*Migrator)(nil)
// Migrator m struct
type Migrator struct {
Config
}
// Config schema config
type Config struct {
CreateIndexAfterCreateTable bool
DB *gorm.DB
gorm.Dialector
}
type printSQLLogger struct {
logger.Interface
}
func (l *printSQLLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
sql, _ := fc()
fmt.Println(sql + ";")
l.Interface.Trace(ctx, begin, fc, err)
}
// GormDataTypeInterface gorm data type interface
type GormDataTypeInterface interface {
GormDBDataType(*gorm.DB, *schema.Field) string
}
// RunWithValue run migration with statement value
func (m Migrator) RunWithValue(value any, fc func(*gorm.Statement) error) error {
stmt := &gorm.Statement{DB: m.DB}
if m.DB.Statement != nil {
stmt.Table = m.DB.Statement.Table
stmt.TableExpr = m.DB.Statement.TableExpr
}
if table, ok := value.(string); ok {
stmt.Table = table
} else if err := stmt.ParseWithSpecialTableName(value, stmt.Table); err != nil {
return err
}
return fc(stmt)
}
// DataTypeOf return field's db data type
func (m Migrator) DataTypeOf(field *schema.Field) string {
fieldValue := reflect.New(field.IndirectFieldType)
if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok {
if dataType := dataTyper.GormDBDataType(m.DB, field); dataType != "" {
return dataType
}
}
return m.Dialector.DataTypeOf(field)
}
// FullDataTypeOf returns field's db full data type
func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
expr.SQL = m.DataTypeOf(field)
if field.NotNull {
expr.SQL += " NOT NULL"
}
if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") {
if field.DefaultValueInterface != nil {
defaultStmt := &gorm.Statement{Vars: []any{field.DefaultValueInterface}}
m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface)
expr.SQL += " DEFAULT " + m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface)
} else if field.DefaultValue != "(-)" {
expr.SQL += " DEFAULT " + field.DefaultValue
}
}
return
}
func (m Migrator) GetQueryAndExecTx() (queryTx, execTx *gorm.DB) {
queryTx = m.DB.Session(&gorm.Session{})
execTx = queryTx
if m.DB.DryRun {
queryTx.DryRun = false
execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{Interface: m.DB.Logger}})
}
return queryTx, execTx
}
// AutoMigrate auto migrate values
func (m Migrator) AutoMigrate(values ...any) error {
for _, value := range m.ReorderModels(values, true) {
queryTx, execTx := m.GetQueryAndExecTx()
if !queryTx.Migrator().HasTable(value) {
if err := execTx.Migrator().CreateTable(value); err != nil {
return err
}
} else {
if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
columnTypes, err := queryTx.Migrator().ColumnTypes(value)
if err != nil {
return err
}
var (
parseIndexes = stmt.Schema.ParseIndexes()
parseCheckConstraints = stmt.Schema.ParseCheckConstraints()
)
for _, dbName := range stmt.Schema.DBNames {
var foundColumn gorm.ColumnType
for _, columnType := range columnTypes {
if columnType.Name() == dbName {
foundColumn = columnType
break
}
}
if foundColumn == nil {
// not found, add column
if err = execTx.Migrator().AddColumn(value, dbName); err != nil {
return err
}
} else {
// found, smartly migrate
field := stmt.Schema.FieldsByDBName[dbName]
if err = execTx.Migrator().MigrateColumn(value, field, foundColumn); err != nil {
return err
}
}
}
if !m.DB.DisableForeignKeyConstraintWhenMigrating && !m.DB.IgnoreRelationshipsWhenMigrating {
for _, rel := range stmt.Schema.Relationships.Relations {
if rel.Field.IgnoreMigration {
continue
}
if constraint := rel.ParseConstraint(); constraint != nil &&
constraint.Schema == stmt.Schema && !queryTx.Migrator().HasConstraint(value, constraint.Name) {
if err := execTx.Migrator().CreateConstraint(value, constraint.Name); err != nil {
return err
}
}
}
}
for _, chk := range parseCheckConstraints {
if !queryTx.Migrator().HasConstraint(value, chk.Name) {
if err := execTx.Migrator().CreateConstraint(value, chk.Name); err != nil {
return err
}
}
}
for _, idx := range parseIndexes {
if !queryTx.Migrator().HasIndex(value, idx.Name) {
if err := execTx.Migrator().CreateIndex(value, idx.Name); err != nil {
return err
}
}
}
return nil
}); err != nil {
return err
}
}
}
return nil
}
// GetTables returns tables
func (m Migrator) GetTables() (tableList []string, err error) {
err = m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()).
Scan(&tableList).Error
return
}
// CreateTable create table in database for values
func (m Migrator) CreateTable(values ...any) error {
for _, value := range m.ReorderModels(values, false) {
tx := m.DB.Session(&gorm.Session{})
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
var (
createTableSQL = "CREATE TABLE ? ("
values = []any{m.CurrentTable(stmt)}
hasPrimaryKeyInDataType bool
)
for _, dbName := range stmt.Schema.DBNames {
field := stmt.Schema.FieldsByDBName[dbName]
if !field.IgnoreMigration {
createTableSQL += "? ?"
hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(m.DataTypeOf(field)), "PRIMARY KEY")
values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field))
createTableSQL += ","
}
}
if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 {
createTableSQL += "PRIMARY KEY ?,"
primaryKeys := make([]any, 0, len(stmt.Schema.PrimaryFields))
for _, field := range stmt.Schema.PrimaryFields {
primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName})
}
values = append(values, primaryKeys)
}
for _, idx := range stmt.Schema.ParseIndexes() {
if m.CreateIndexAfterCreateTable {
defer func(value any, name string) {
if err == nil {
err = tx.Migrator().CreateIndex(value, name)
}
}(value, idx.Name)
} else {
if idx.Class != "" {
createTableSQL += idx.Class + " "
}
createTableSQL += "INDEX ? ?"
if idx.Comment != "" {
createTableSQL += fmt.Sprintf(" COMMENT '%s'", idx.Comment)
}
if idx.Option != "" {
createTableSQL += " " + idx.Option
}
createTableSQL += ","
values = append(values, clause.Column{Name: idx.Name}, tx.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt))
}
}
if !m.DB.DisableForeignKeyConstraintWhenMigrating && !m.DB.IgnoreRelationshipsWhenMigrating {
for _, rel := range stmt.Schema.Relationships.Relations {
if rel.Field.IgnoreMigration {
continue
}
if constraint := rel.ParseConstraint(); constraint != nil {
if constraint.Schema == stmt.Schema {
sql, vars := constraint.Build()
createTableSQL += sql + ","
values = append(values, vars...)
}
}
}
}
for _, uni := range stmt.Schema.ParseUniqueConstraints() {
createTableSQL += "CONSTRAINT ? UNIQUE (?),"
values = append(values, clause.Column{Name: uni.Name}, clause.Expr{SQL: stmt.Quote(uni.Field.DBName)})
}
for _, chk := range stmt.Schema.ParseCheckConstraints() {
createTableSQL += "CONSTRAINT ? CHECK (?),"
values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint})
}
createTableSQL = strings.TrimSuffix(createTableSQL, ",")
createTableSQL += ")"
if tableOption, ok := m.DB.Get("gorm:table_options"); ok {
createTableSQL += fmt.Sprint(tableOption)
}
err = tx.Exec(createTableSQL, values...).Error
return err
}); err != nil {
return err
}
}
return nil
}
// DropTable drop table for values
func (m Migrator) DropTable(values ...any) error {
values = m.ReorderModels(values, false)
for i := len(values) - 1; i >= 0; i-- {
tx := m.DB.Session(&gorm.Session{})
if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error {
return tx.Exec("DROP TABLE IF EXISTS ?", m.CurrentTable(stmt)).Error
}); err != nil {
return err
}
}
return nil
}
// HasTable returns table exists or not for value, value could be a struct or string
func (m Migrator) HasTable(value any) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := m.DB.Migrator().CurrentDatabase()
return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentDatabase, stmt.Table, "BASE TABLE").Row().Scan(&count)
})
return count > 0
}
// RenameTable rename table from oldName to newName
func (m Migrator) RenameTable(oldName, newName any) error {
var oldTable, newTable any
if v, ok := oldName.(string); ok {
oldTable = clause.Table{Name: v}
} else {
stmt := &gorm.Statement{DB: m.DB}
if err := stmt.Parse(oldName); err == nil {
oldTable = m.CurrentTable(stmt)
} else {
return err
}
}
if v, ok := newName.(string); ok {
newTable = clause.Table{Name: v}
} else {
stmt := &gorm.Statement{DB: m.DB}
if err := stmt.Parse(newName); err == nil {
newTable = m.CurrentTable(stmt)
} else {
return err
}
}
return m.DB.Exec("ALTER TABLE ? RENAME TO ?", oldTable, newTable).Error
}
// AddColumn create `name` column for value
func (m Migrator) AddColumn(value any, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
// avoid using the same name field
f := stmt.Schema.LookUpField(name)
if f == nil {
return fmt.Errorf("failed to look up field with name: %s", name)
}
if !f.IgnoreMigration {
return m.DB.Exec(
"ALTER TABLE ? ADD ? ?",
m.CurrentTable(stmt), clause.Column{Name: f.DBName}, m.DB.Migrator().FullDataTypeOf(f),
).Error
}
return nil
})
}
// DropColumn drop value's `name` column
func (m Migrator) DropColumn(value any, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(name); field != nil {
name = field.DBName
}
return m.DB.Exec(
"ALTER TABLE ? DROP COLUMN ?", m.CurrentTable(stmt), clause.Column{Name: name},
).Error
})
}
// AlterColumn alter value's `field` column' type based on schema definition
func (m Migrator) AlterColumn(value any, field string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(field); field != nil {
fileType := m.FullDataTypeOf(field)
return m.DB.Exec(
"ALTER TABLE ? ALTER COLUMN ? TYPE ?",
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType,
).Error
}
return fmt.Errorf("failed to look up field with name: %s", field)
})
}
// HasColumn check has column `field` for value or not
func (m Migrator) HasColumn(value any, field string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := m.DB.Migrator().CurrentDatabase()
name := field
if field := stmt.Schema.LookUpField(field); field != nil {
name = field.DBName
}
return m.DB.Raw(
"SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?",
currentDatabase, stmt.Table, name,
).Row().Scan(&count)
})
return count > 0
}
// RenameColumn rename value's field name from oldName to newName
func (m Migrator) RenameColumn(value any, oldName, newName string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(oldName); field != nil {
oldName = field.DBName
}
if field := stmt.Schema.LookUpField(newName); field != nil {
newName = field.DBName
}
return m.DB.Exec(
"ALTER TABLE ? RENAME COLUMN ? TO ?",
m.CurrentTable(stmt), clause.Column{Name: oldName}, clause.Column{Name: newName},
).Error
})
}
// MigrateColumn migrate column
func (m Migrator) MigrateColumn(value any, field *schema.Field, columnType gorm.ColumnType) error {
if field.IgnoreMigration {
return nil
}
// found, smart migrate
fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL))
realDataType := strings.ToLower(columnType.DatabaseTypeName())
var (
alterColumn bool
isSameType = fullDataType == realDataType
)
if !field.PrimaryKey {
// check type
if !strings.HasPrefix(fullDataType, realDataType) {
// check type aliases
aliases := m.DB.Migrator().GetTypeAliases(realDataType)
for _, alias := range aliases {
if strings.HasPrefix(fullDataType, alias) {
isSameType = true
break
}
}
if !isSameType {
alterColumn = true
}
}
}
if !isSameType {
// check size
if length, ok := columnType.Length(); length != int64(field.Size) {
if length > 0 && field.Size > 0 {
alterColumn = true
} else {
// has size in data type and not equal
// Since the following code is frequently called in the for loop, reg optimization is needed here
matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1)
if !field.PrimaryKey &&
(len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) && ok) {
alterColumn = true
}
}
}
// check precision
if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision {
if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) {
alterColumn = true
}
}
}
// check nullable
if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull {
// not primary key & database is nullable
if !field.PrimaryKey && nullable {
alterColumn = true
}
}
// check default value
if !field.PrimaryKey {
currentDefaultNotNull := field.HasDefaultValue && (field.DefaultValueInterface != nil || !strings.EqualFold(field.DefaultValue, "NULL"))
dv, dvNotNull := columnType.DefaultValue()
if dvNotNull && !currentDefaultNotNull {
// default value -> null
alterColumn = true
} else if !dvNotNull && currentDefaultNotNull {
// null -> default value
alterColumn = true
} else if currentDefaultNotNull || dvNotNull {
switch field.GORMDataType {
case schema.Time:
if !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()")) {
alterColumn = true
}
case schema.Bool:
v1, _ := strconv.ParseBool(dv)
v2, _ := strconv.ParseBool(field.DefaultValue)
alterColumn = v1 != v2
default:
alterColumn = dv != field.DefaultValue
}
}
}
// check comment
if comment, ok := columnType.Comment(); ok && comment != field.Comment {
// not primary key
if !field.PrimaryKey {
alterColumn = true
}
}
if alterColumn {
if err := m.DB.Migrator().AlterColumn(value, field.DBName); err != nil {
return err
}
}
if err := m.DB.Migrator().MigrateColumnUnique(value, field, columnType); err != nil {
return err
}
return nil
}
func (m Migrator) MigrateColumnUnique(value any, field *schema.Field, columnType gorm.ColumnType) error {
unique, ok := columnType.Unique()
if !ok || field.PrimaryKey {
return nil // skip primary key
}
// By default, ColumnType's Unique is not affected by UniqueIndex, so we don't care about UniqueIndex.
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
// We're currently only receiving boolean values on `Unique` tag,
// so the UniqueConstraint name is fixed
constraint := m.DB.NamingStrategy.UniqueName(stmt.Table, field.DBName)
if unique && !field.Unique {
return m.DB.Migrator().DropConstraint(value, constraint)
}
if !unique && field.Unique {
return m.DB.Migrator().CreateConstraint(value, constraint)
}
return nil
})
}
// ColumnTypes return columnTypes []gorm.ColumnType and execErr error
func (m Migrator) ColumnTypes(value any) ([]gorm.ColumnType, error) {
columnTypes := make([]gorm.ColumnType, 0)
execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows()
if err != nil {
return err
}
defer func() {
err = rows.Close()
}()
var rawColumnTypes []*sql.ColumnType
rawColumnTypes, err = rows.ColumnTypes()
if err != nil {
return err
}
for _, c := range rawColumnTypes {
columnTypes = append(columnTypes, ColumnType{SQLColumnType: c})
}
return
})
return columnTypes, execErr
}
// CreateView create view from Query in gorm.ViewOption.
// Query in gorm.ViewOption is a [subquery]
//
// // CREATE VIEW `user_view` AS SELECT * FROM `users` WHERE age > 20
// q := DB.Model(&User{}).Where("age > ?", 20)
// DB.Debug().Migrator().CreateView("user_view", gorm.ViewOption{Query: q})
//
// // CREATE OR REPLACE VIEW `users_view` AS SELECT * FROM `users` WITH CHECK OPTION
// q := DB.Model(&User{})
// DB.Debug().Migrator().CreateView("user_view", gorm.ViewOption{Query: q, Replace: true, CheckOption: "WITH CHECK OPTION"})
//
// [subquery]: https://gorm.io/docs/advanced_query.html#SubQuery
func (m Migrator) CreateView(name string, option gorm.ViewOption) error {
if option.Query == nil {
return gorm.ErrSubQueryRequired
}
sql := new(strings.Builder)
sql.WriteString("CREATE ")
if option.Replace {
sql.WriteString("OR REPLACE ")
}
sql.WriteString("VIEW ")
m.QuoteTo(sql, name)
sql.WriteString(" AS ")
m.DB.Statement.AddVar(sql, option.Query)
if option.CheckOption != "" {
sql.WriteString(" ")
sql.WriteString(option.CheckOption)
}
return m.DB.Exec(m.Explain(sql.String(), m.DB.Statement.Vars...)).Error
}
// DropView drop view
func (m Migrator) DropView(name string) error {
return m.DB.Exec("DROP VIEW IF EXISTS ?", clause.Table{Name: name}).Error
}
// GuessConstraintAndTable guess statement's constraint and it's table based on name
//
// Deprecated: use GuessConstraintInterfaceAndTable instead.
func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (*schema.Constraint, *schema.CheckConstraint, string) {
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
switch c := constraint.(type) {
case *schema.Constraint:
return c, nil, table
case *schema.CheckConstraint:
return nil, c, table
default:
return nil, nil, table
}
}
// GuessConstraintInterfaceAndTable guess statement's constraint and it's table based on name
// nolint:cyclop
func (m Migrator) GuessConstraintInterfaceAndTable(stmt *gorm.Statement, name string) (_ schema.ConstraintInterface, table string) {
if stmt.Schema == nil {
return nil, stmt.Table
}
checkConstraints := stmt.Schema.ParseCheckConstraints()
if chk, ok := checkConstraints[name]; ok {
return &chk, stmt.Table
}
uniqueConstraints := stmt.Schema.ParseUniqueConstraints()
if uni, ok := uniqueConstraints[name]; ok {
return &uni, stmt.Table
}
getTable := func(rel *schema.Relationship) string {
switch rel.Type {
case schema.HasOne, schema.HasMany:
return rel.FieldSchema.Table
case schema.Many2Many:
return rel.JoinTable.Table
}
return stmt.Table
}
for _, rel := range stmt.Schema.Relationships.Relations {
if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name {
return constraint, getTable(rel)
}
}
if field := stmt.Schema.LookUpField(name); field != nil {
for k := range checkConstraints {
if checkConstraints[k].Field == field {
v := checkConstraints[k]
return &v, stmt.Table
}
}
for k := range uniqueConstraints {
if uniqueConstraints[k].Field == field {
v := uniqueConstraints[k]
return &v, stmt.Table
}
}
for _, rel := range stmt.Schema.Relationships.Relations {
if constraint := rel.ParseConstraint(); constraint != nil && rel.Field == field {
return constraint, getTable(rel)
}
}
}
return nil, stmt.Schema.Table
}
// CreateConstraint create constraint
func (m Migrator) CreateConstraint(value any, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
if constraint != nil {
vars := []any{clause.Table{Name: table}}
if stmt.TableExpr != nil {
vars[0] = stmt.TableExpr
}
sql, values := constraint.Build()
return m.DB.Exec("ALTER TABLE ? ADD "+sql, append(vars, values...)...).Error
}
return nil
})
}
// DropConstraint drop constraint
func (m Migrator) DropConstraint(value any, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
if constraint != nil {
name = constraint.GetName()
}
return m.DB.Exec("ALTER TABLE ? DROP CONSTRAINT ?", clause.Table{Name: table}, clause.Column{Name: name}).Error
})
}
// HasConstraint check has constraint or not
func (m Migrator) HasConstraint(value any, name string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := m.DB.Migrator().CurrentDatabase()
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
if constraint != nil {
name = constraint.GetName()
}
return m.DB.Raw(
"SELECT count(*) FROM INFORMATION_SCHEMA.table_constraints WHERE constraint_schema = ? AND table_name = ? AND constraint_name = ?",
currentDatabase, table, name,
).Row().Scan(&count)
})
return count > 0
}
// BuildIndexOptions build index options
func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []any) {
for _, opt := range opts {
str := stmt.Quote(opt.DBName)
if opt.Expression != "" {
str = opt.Expression
} else if opt.Length > 0 {
str += fmt.Sprintf("(%d)", opt.Length)
}
if opt.Collate != "" {
str += " COLLATE " + opt.Collate
}
if opt.Sort != "" {
str += " " + opt.Sort
}
results = append(results, clause.Expr{SQL: str})
}
return
}
// BuildIndexOptionsInterface build index options interface
type BuildIndexOptionsInterface interface {
BuildIndexOptions([]schema.IndexOption, *gorm.Statement) []any
}
// CreateIndex create index `name`
func (m Migrator) CreateIndex(value any, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if idx := stmt.Schema.LookIndex(name); idx != nil {
opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)
values := []any{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts}
createIndexSQL := "CREATE "
if idx.Class != "" {
createIndexSQL += idx.Class + " "
}
createIndexSQL += "INDEX ? ON ??"
if idx.Type != "" {
createIndexSQL += " USING " + idx.Type
}
if idx.Comment != "" {
createIndexSQL += fmt.Sprintf(" COMMENT '%s'", idx.Comment)
}
if idx.Option != "" {
createIndexSQL += " " + idx.Option
}
return m.DB.Exec(createIndexSQL, values...).Error
}
return fmt.Errorf("failed to create index with name %s", name)
})
}
// DropIndex drop index `name`
func (m Migrator) DropIndex(value any, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, m.CurrentTable(stmt)).Error
})
}
// HasIndex check has index `name` or not
func (m Migrator) HasIndex(value any, name string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := m.DB.Migrator().CurrentDatabase()
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
return m.DB.Raw(
"SELECT count(*) FROM information_schema.statistics WHERE table_schema = ? AND table_name = ? AND index_name = ?",
currentDatabase, stmt.Table, name,
).Row().Scan(&count)
})
return count > 0
}
// RenameIndex rename index from oldName to newName
func (m Migrator) RenameIndex(value any, oldName, newName string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
return m.DB.Exec(
"ALTER TABLE ? RENAME INDEX ? TO ?",
m.CurrentTable(stmt), clause.Column{Name: oldName}, clause.Column{Name: newName},
).Error
})
}
// CurrentDatabase returns current database name
func (m Migrator) CurrentDatabase() (name string) {
m.DB.Raw("SELECT DATABASE()").Row().Scan(&name)
return
}
// ReorderModels reorder models according to constraint dependencies
func (m Migrator) ReorderModels(values []any, autoAdd bool) (results []any) {
type Dependency struct {
*gorm.Statement
Depends []*schema.Schema
}
var (
modelNames, orderedModelNames []string
orderedModelNamesMap = map[string]bool{}
parsedSchemas = map[*schema.Schema]bool{}
valuesMap = map[string]Dependency{}
insertIntoOrderedList func(name string)
parseDependence func(value any, addToList bool)
)
parseDependence = func(value any, addToList bool) {
dep := Dependency{
Statement: &gorm.Statement{DB: m.DB, Dest: value},
}
beDependedOn := map[*schema.Schema]bool{}
// support for special table name
if err := dep.ParseWithSpecialTableName(value, m.DB.Statement.Table); err != nil {
m.DB.Logger.Error(context.Background(), "failed to parse value %#v, got error %v", value, err)
}
if _, ok := parsedSchemas[dep.Statement.Schema]; ok {
return
}
parsedSchemas[dep.Statement.Schema] = true
if !m.DB.IgnoreRelationshipsWhenMigrating {
for _, rel := range dep.Schema.Relationships.Relations {
if rel.Field.IgnoreMigration {
continue
}
if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema {
dep.Depends = append(dep.Depends, c.ReferenceSchema)
}
if rel.Type == schema.HasOne || rel.Type == schema.HasMany {
beDependedOn[rel.FieldSchema] = true
}
if rel.JoinTable != nil {
// append join value
defer func(rel *schema.Relationship, joinValue any) {
if !beDependedOn[rel.FieldSchema] {
dep.Depends = append(dep.Depends, rel.FieldSchema)
} else {
fieldValue := reflect.New(rel.FieldSchema.ModelType).Interface()
parseDependence(fieldValue, autoAdd)
}
parseDependence(joinValue, autoAdd)
}(rel, reflect.New(rel.JoinTable.ModelType).Interface())
}
}
}
valuesMap[dep.Schema.Table] = dep
if addToList {
modelNames = append(modelNames, dep.Schema.Table)
}
}
insertIntoOrderedList = func(name string) {
if _, ok := orderedModelNamesMap[name]; ok {
return // avoid loop
}
orderedModelNamesMap[name] = true
if autoAdd {
dep := valuesMap[name]
for _, d := range dep.Depends {
if _, ok := valuesMap[d.Table]; ok {
insertIntoOrderedList(d.Table)
} else {
parseDependence(reflect.New(d.ModelType).Interface(), autoAdd)
insertIntoOrderedList(d.Table)
}
}
}
orderedModelNames = append(orderedModelNames, name)
}
for _, value := range values {
if v, ok := value.(string); ok {
results = append(results, v)
} else {
parseDependence(value, true)
}
}
for _, name := range modelNames {
insertIntoOrderedList(name)
}
for _, name := range orderedModelNames {
results = append(results, valuesMap[name].Statement.Dest)
}
return
}
// CurrentTable returns current statement's table expression
func (m Migrator) CurrentTable(stmt *gorm.Statement) any {
if stmt.TableExpr != nil {
return *stmt.TableExpr
}
return clause.Table{Name: stmt.Table}
}
// GetIndexes return Indexes []gorm.Index and execErr error
func (m Migrator) GetIndexes(dst any) ([]gorm.Index, error) {
return nil, errors.New("not support")
}
// GetTypeAliases return database type aliases
func (m Migrator) GetTypeAliases(databaseTypeName string) []string {
return nil
}
// TableType return tableType gorm.TableType and execErr error
func (m Migrator) TableType(dst any) (gorm.TableType, error) {
return nil, errors.New("not support")
}
================================================
FILE: gossh/gorm/migrator/table_type.go
================================================
package migrator
import (
"database/sql"
)
// TableType table type implements TableType interface
type TableType struct {
SchemaValue string
NameValue string
TypeValue string
CommentValue sql.NullString
}
// Schema returns the schema of the table.
func (ct TableType) Schema() string {
return ct.SchemaValue
}
// Name returns the name of the table.
func (ct TableType) Name() string {
return ct.NameValue
}
// Type returns the type of the table.
func (ct TableType) Type() string {
return ct.TypeValue
}
// Comment returns the comment of current table.
func (ct TableType) Comment() (comment string, ok bool) {
return ct.CommentValue.String, ct.CommentValue.Valid
}
================================================
FILE: gossh/gorm/migrator.go
================================================
package gorm
import (
"reflect"
"gossh/gorm/clause"
"gossh/gorm/schema"
)
// Migrator returns migrator
func (db *DB) Migrator() Migrator {
tx := db.getInstance()
// apply scopes to migrator
for len(tx.Statement.scopes) > 0 {
tx = tx.executeScopes()
}
return tx.Dialector.Migrator(tx.Session(&Session{}))
}
// AutoMigrate run auto migration for given models
func (db *DB) AutoMigrate(dst ...any) error {
return db.Migrator().AutoMigrate(dst...)
}
// ViewOption view option
type ViewOption struct {
Replace bool // If true, exec `CREATE`. If false, exec `CREATE OR REPLACE`
CheckOption string // optional. e.g. `WITH [ CASCADED | LOCAL ] CHECK OPTION`
Query *DB // required subquery.
}
// ColumnType column type interface
type ColumnType interface {
Name() string
DatabaseTypeName() string // varchar
ColumnType() (columnType string, ok bool) // varchar(64)
PrimaryKey() (isPrimaryKey bool, ok bool)
AutoIncrement() (isAutoIncrement bool, ok bool)
Length() (length int64, ok bool)
DecimalSize() (precision int64, scale int64, ok bool)
Nullable() (nullable bool, ok bool)
Unique() (unique bool, ok bool)
ScanType() reflect.Type
Comment() (value string, ok bool)
DefaultValue() (value string, ok bool)
}
type Index interface {
Table() string
Name() string
Columns() []string
PrimaryKey() (isPrimaryKey bool, ok bool)
Unique() (unique bool, ok bool)
Option() string
}
// TableType table type interface
type TableType interface {
Schema() string
Name() string
Type() string
Comment() (comment string, ok bool)
}
// Migrator migrator interface
type Migrator interface {
// AutoMigrate
AutoMigrate(dst ...any) error
// Database
CurrentDatabase() string
FullDataTypeOf(*schema.Field) clause.Expr
GetTypeAliases(databaseTypeName string) []string
// Tables
CreateTable(dst ...any) error
DropTable(dst ...any) error
HasTable(dst any) bool
RenameTable(oldName, newName any) error
GetTables() (tableList []string, err error)
TableType(dst any) (TableType, error)
// Columns
AddColumn(dst any, field string) error
DropColumn(dst any, field string) error
AlterColumn(dst any, field string) error
MigrateColumn(dst any, field *schema.Field, columnType ColumnType) error
// MigrateColumnUnique migrate column's UNIQUE constraint, it's part of MigrateColumn.
MigrateColumnUnique(dst any, field *schema.Field, columnType ColumnType) error
HasColumn(dst any, field string) bool
RenameColumn(dst any, oldName, field string) error
ColumnTypes(dst any) ([]ColumnType, error)
// Views
CreateView(name string, option ViewOption) error
DropView(name string) error
// Constraints
CreateConstraint(dst any, name string) error
DropConstraint(dst any, name string) error
HasConstraint(dst any, name string) bool
// Indexes
CreateIndex(dst any, name string) error
DropIndex(dst any, name string) error
HasIndex(dst any, name string) bool
RenameIndex(dst any, oldName, newName string) error
GetIndexes(dst any) ([]Index, error)
}
================================================
FILE: gossh/gorm/model.go
================================================
package gorm
import "time"
// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt
// It may be embedded into your model or you may build your own model without it
//
// type User struct {
// gorm.Model
// }
type Model struct {
ID uint `gorm:"primarykey"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt DeletedAt `gorm:"index"`
}
================================================
FILE: gossh/gorm/now/main.go
================================================
// Package now is a time toolkit for golang.
//
// More details README here: https://gossh/gorm/now
//
// import "gossh/gorm/now"
//
// now.BeginningOfMinute() // 2013-11-18 17:51:00 Mon
// now.BeginningOfDay() // 2013-11-18 00:00:00 Mon
// now.EndOfDay() // 2013-11-18 23:59:59.999999999 Mon
package now
import "time"
// WeekStartDay set week start day, default is sunday
var WeekStartDay = time.Sunday
// TimeFormats default time formats will be parsed as
var TimeFormats = []string{
"2006", "2006-1", "2006-1-2", "2006-1-2 15", "2006-1-2 15:4", "2006-1-2 15:4:5", "1-2",
"15:4:5", "15:4", "15",
"15:4:5 Jan 2, 2006 MST", "2006-01-02 15:04:05.999999999 -0700 MST", "2006-01-02T15:04:05Z0700", "2006-01-02T15:04:05Z07",
"2006.1.2", "2006.1.2 15:04:05", "2006.01.02", "2006.01.02 15:04:05", "2006.01.02 15:04:05.999999999",
"1/2/2006", "1/2/2006 15:4:5", "2006/01/02", "20060102", "2006/01/02 15:04:05",
time.ANSIC, time.UnixDate, time.RubyDate, time.RFC822, time.RFC822Z, time.RFC850,
time.RFC1123, time.RFC1123Z, time.RFC3339, time.RFC3339Nano,
time.Kitchen, time.Stamp, time.StampMilli, time.StampMicro, time.StampNano,
}
// Config configuration for now package
type Config struct {
WeekStartDay time.Weekday
TimeLocation *time.Location
TimeFormats []string
}
// DefaultConfig default config
var DefaultConfig *Config
// New initialize Now based on configuration
func (config *Config) With(t time.Time) *Now {
return &Now{Time: t, Config: config}
}
// Parse parse string to time based on configuration
func (config *Config) Parse(strs ...string) (time.Time, error) {
if config.TimeLocation == nil {
return config.With(time.Now()).Parse(strs...)
} else {
return config.With(time.Now().In(config.TimeLocation)).Parse(strs...)
}
}
// MustParse must parse string to time or will panic
func (config *Config) MustParse(strs ...string) time.Time {
if config.TimeLocation == nil {
return config.With(time.Now()).MustParse(strs...)
} else {
return config.With(time.Now().In(config.TimeLocation)).MustParse(strs...)
}
}
// Now now struct
type Now struct {
time.Time
*Config
}
// With initialize Now with time
func With(t time.Time) *Now {
config := DefaultConfig
if config == nil {
config = &Config{
WeekStartDay: WeekStartDay,
TimeFormats: TimeFormats,
}
}
return &Now{Time: t, Config: config}
}
// New initialize Now with time
func New(t time.Time) *Now {
return With(t)
}
// BeginningOfMinute beginning of minute
func BeginningOfMinute() time.Time {
return With(time.Now()).BeginningOfMinute()
}
// BeginningOfHour beginning of hour
func BeginningOfHour() time.Time {
return With(time.Now()).BeginningOfHour()
}
// BeginningOfDay beginning of day
func BeginningOfDay() time.Time {
return With(time.Now()).BeginningOfDay()
}
// BeginningOfWeek beginning of week
func BeginningOfWeek() time.Time {
return With(time.Now()).BeginningOfWeek()
}
// BeginningOfMonth beginning of month
func BeginningOfMonth() time.Time {
return With(time.Now()).BeginningOfMonth()
}
// BeginningOfQuarter beginning of quarter
func BeginningOfQuarter() time.Time {
return With(time.Now()).BeginningOfQuarter()
}
// BeginningOfYear beginning of year
func BeginningOfYear() time.Time {
return With(time.Now()).BeginningOfYear()
}
// EndOfMinute end of minute
func EndOfMinute() time.Time {
return With(time.Now()).EndOfMinute()
}
// EndOfHour end of hour
func EndOfHour() time.Time {
return With(time.Now()).EndOfHour()
}
// EndOfDay end of day
func EndOfDay() time.Time {
return With(time.Now()).EndOfDay()
}
// EndOfWeek end of week
func EndOfWeek() time.Time {
return With(time.Now()).EndOfWeek()
}
// EndOfMonth end of month
func EndOfMonth() time.Time {
return With(time.Now()).EndOfMonth()
}
// EndOfQuarter end of quarter
func EndOfQuarter() time.Time {
return With(time.Now()).EndOfQuarter()
}
// EndOfYear end of year
func EndOfYear() time.Time {
return With(time.Now()).EndOfYear()
}
// Monday monday
func Monday(strs ...string) time.Time {
return With(time.Now()).Monday(strs...)
}
// Sunday sunday
func Sunday(strs ...string) time.Time {
return With(time.Now()).Sunday(strs...)
}
// EndOfSunday end of sunday
func EndOfSunday() time.Time {
return With(time.Now()).EndOfSunday()
}
// Quarter returns the yearly quarter
func Quarter() uint {
return With(time.Now()).Quarter()
}
// Parse parse string to time
func Parse(strs ...string) (time.Time, error) {
return With(time.Now()).Parse(strs...)
}
// ParseInLocation parse string to time in location
func ParseInLocation(loc *time.Location, strs ...string) (time.Time, error) {
return With(time.Now().In(loc)).Parse(strs...)
}
// MustParse must parse string to time or will panic
func MustParse(strs ...string) time.Time {
return With(time.Now()).MustParse(strs...)
}
// MustParseInLocation must parse string to time in location or will panic
func MustParseInLocation(loc *time.Location, strs ...string) time.Time {
return With(time.Now().In(loc)).MustParse(strs...)
}
// Between check now between the begin, end time or not
func Between(time1, time2 string) bool {
return With(time.Now()).Between(time1, time2)
}
================================================
FILE: gossh/gorm/now/now.go
================================================
package now
import (
"errors"
"regexp"
"time"
)
// BeginningOfMinute beginning of minute
func (now *Now) BeginningOfMinute() time.Time {
return now.Truncate(time.Minute)
}
// BeginningOfHour beginning of hour
func (now *Now) BeginningOfHour() time.Time {
y, m, d := now.Date()
return time.Date(y, m, d, now.Time.Hour(), 0, 0, 0, now.Time.Location())
}
// BeginningOfDay beginning of day
func (now *Now) BeginningOfDay() time.Time {
y, m, d := now.Date()
return time.Date(y, m, d, 0, 0, 0, 0, now.Time.Location())
}
// BeginningOfWeek beginning of week
func (now *Now) BeginningOfWeek() time.Time {
t := now.BeginningOfDay()
weekday := int(t.Weekday())
if now.WeekStartDay != time.Sunday {
weekStartDayInt := int(now.WeekStartDay)
if weekday < weekStartDayInt {
weekday = weekday + 7 - weekStartDayInt
} else {
weekday = weekday - weekStartDayInt
}
}
return t.AddDate(0, 0, -weekday)
}
// BeginningOfMonth beginning of month
func (now *Now) BeginningOfMonth() time.Time {
y, m, _ := now.Date()
return time.Date(y, m, 1, 0, 0, 0, 0, now.Location())
}
// BeginningOfQuarter beginning of quarter
func (now *Now) BeginningOfQuarter() time.Time {
month := now.BeginningOfMonth()
offset := (int(month.Month()) - 1) % 3
return month.AddDate(0, -offset, 0)
}
// BeginningOfHalf beginning of half year
func (now *Now) BeginningOfHalf() time.Time {
month := now.BeginningOfMonth()
offset := (int(month.Month()) - 1) % 6
return month.AddDate(0, -offset, 0)
}
// BeginningOfYear BeginningOfYear beginning of year
func (now *Now) BeginningOfYear() time.Time {
y, _, _ := now.Date()
return time.Date(y, time.January, 1, 0, 0, 0, 0, now.Location())
}
// EndOfMinute end of minute
func (now *Now) EndOfMinute() time.Time {
return now.BeginningOfMinute().Add(time.Minute - time.Nanosecond)
}
// EndOfHour end of hour
func (now *Now) EndOfHour() time.Time {
return now.BeginningOfHour().Add(time.Hour - time.Nanosecond)
}
// EndOfDay end of day
func (now *Now) EndOfDay() time.Time {
y, m, d := now.Date()
return time.Date(y, m, d, 23, 59, 59, int(time.Second-time.Nanosecond), now.Location())
}
// EndOfWeek end of week
func (now *Now) EndOfWeek() time.Time {
return now.BeginningOfWeek().AddDate(0, 0, 7).Add(-time.Nanosecond)
}
// EndOfMonth end of month
func (now *Now) EndOfMonth() time.Time {
return now.BeginningOfMonth().AddDate(0, 1, 0).Add(-time.Nanosecond)
}
// EndOfQuarter end of quarter
func (now *Now) EndOfQuarter() time.Time {
return now.BeginningOfQuarter().AddDate(0, 3, 0).Add(-time.Nanosecond)
}
// EndOfHalf end of half year
func (now *Now) EndOfHalf() time.Time {
return now.BeginningOfHalf().AddDate(0, 6, 0).Add(-time.Nanosecond)
}
// EndOfYear end of year
func (now *Now) EndOfYear() time.Time {
return now.BeginningOfYear().AddDate(1, 0, 0).Add(-time.Nanosecond)
}
// Monday monday
/*
func (now *Now) Monday() time.Time {
t := now.BeginningOfDay()
weekday := int(t.Weekday())
if weekday == 0 {
weekday = 7
}
return t.AddDate(0, 0, -weekday+1)
}
*/
func (now *Now) Monday(strs ...string) time.Time {
var parseTime time.Time
var err error
if len(strs) > 0 {
parseTime, err = now.Parse(strs...)
if err != nil {
panic(err)
}
} else {
parseTime = now.BeginningOfDay()
}
weekday := int(parseTime.Weekday())
if weekday == 0 {
weekday = 7
}
return parseTime.AddDate(0, 0, -weekday+1)
}
func (now *Now) Sunday(strs ...string) time.Time {
var parseTime time.Time
var err error
if len(strs) > 0 {
parseTime, err = now.Parse(strs...)
if err != nil {
panic(err)
}
} else {
parseTime = now.BeginningOfDay()
}
weekday := int(parseTime.Weekday())
if weekday == 0 {
weekday = 7
}
return parseTime.AddDate(0, 0, (7 - weekday))
}
// EndOfSunday end of sunday
func (now *Now) EndOfSunday() time.Time {
return New(now.Sunday()).EndOfDay()
}
// Quarter returns the yearly quarter
func (now *Now) Quarter() uint {
return (uint(now.Month())-1)/3 + 1
}
func (now *Now) parseWithFormat(str string, location *time.Location) (t time.Time, err error) {
for _, format := range now.TimeFormats {
t, err = time.ParseInLocation(format, str, location)
if err == nil {
return
}
}
err = errors.New("Can't parse string as time: " + str)
return
}
var hasTimeRegexp = regexp.MustCompile(`(\s+|^\s*|T)\d{1,2}((:\d{1,2})*|((:\d{1,2}){2}\.(\d{3}|\d{6}|\d{9})))(\s*$|[Z+-])`) // match 15:04:05, 15:04:05.000, 15:04:05.000000 15, 2017-01-01 15:04, 2021-07-20T00:59:10Z, 2021-07-20T00:59:10+08:00, 2021-07-20T00:00:10-07:00 etc
var onlyTimeRegexp = regexp.MustCompile(`^\s*\d{1,2}((:\d{1,2})*|((:\d{1,2}){2}\.(\d{3}|\d{6}|\d{9})))\s*$`) // match 15:04:05, 15, 15:04:05.000, 15:04:05.000000, etc
// Parse parse string to time
func (now *Now) Parse(strs ...string) (t time.Time, err error) {
var (
setCurrentTime bool
parseTime []int
currentLocation = now.Location()
onlyTimeInStr = true
currentTime = formatTimeToList(now.Time)
)
for _, str := range strs {
hasTimeInStr := hasTimeRegexp.MatchString(str) // match 15:04:05, 15
onlyTimeInStr = hasTimeInStr && onlyTimeInStr && onlyTimeRegexp.MatchString(str)
if t, err = now.parseWithFormat(str, currentLocation); err == nil {
location := t.Location()
parseTime = formatTimeToList(t)
for i, v := range parseTime {
// Don't reset hour, minute, second if current time str including time
if hasTimeInStr && i <= 3 {
continue
}
// If value is zero, replace it with current time
if v == 0 {
if setCurrentTime {
parseTime[i] = currentTime[i]
}
} else {
setCurrentTime = true
}
// if current time only includes time, should change day, month to current time
if onlyTimeInStr {
if i == 4 || i == 5 {
parseTime[i] = currentTime[i]
continue
}
}
}
t = time.Date(parseTime[6], time.Month(parseTime[5]), parseTime[4], parseTime[3], parseTime[2], parseTime[1], parseTime[0], location)
currentTime = formatTimeToList(t)
}
}
return
}
// MustParse must parse string to time or it will panic
func (now *Now) MustParse(strs ...string) (t time.Time) {
t, err := now.Parse(strs...)
if err != nil {
panic(err)
}
return t
}
// Between check time between the begin, end time or not
func (now *Now) Between(begin, end string) bool {
beginTime := now.MustParse(begin)
endTime := now.MustParse(end)
return now.After(beginTime) && now.Before(endTime)
}
================================================
FILE: gossh/gorm/now/time.go
================================================
package now
import "time"
func formatTimeToList(t time.Time) []int {
hour, min, sec := t.Clock()
year, month, day := t.Date()
return []int{t.Nanosecond(), sec, min, hour, day, int(month), year}
}
================================================
FILE: gossh/gorm/prepare_stmt.go
================================================
package gorm
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"reflect"
"sync"
)
type Stmt struct {
*sql.Stmt
Transaction bool
prepared chan struct{}
prepareErr error
}
type PreparedStmtDB struct {
Stmts map[string]*Stmt
PreparedSQL []string
Mux *sync.RWMutex
ConnPool
}
func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB {
return &PreparedStmtDB{
ConnPool: connPool,
Stmts: make(map[string]*Stmt),
Mux: &sync.RWMutex{},
PreparedSQL: make([]string, 0, 100),
}
}
func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
if sqldb, ok := db.ConnPool.(*sql.DB); ok {
return sqldb, nil
}
if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil {
return dbConnector.GetDBConn()
}
return nil, ErrInvalidDB
}
func (db *PreparedStmtDB) Close() {
db.Mux.Lock()
defer db.Mux.Unlock()
for _, query := range db.PreparedSQL {
if stmt, ok := db.Stmts[query]; ok {
delete(db.Stmts, query)
go stmt.Close()
}
}
}
func (sdb *PreparedStmtDB) Reset() {
sdb.Mux.Lock()
defer sdb.Mux.Unlock()
for _, stmt := range sdb.Stmts {
go stmt.Close()
}
sdb.PreparedSQL = make([]string, 0, 100)
sdb.Stmts = make(map[string]*Stmt)
}
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
db.Mux.RLock()
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
db.Mux.RUnlock()
// wait for other goroutines prepared
<-stmt.prepared
if stmt.prepareErr != nil {
return Stmt{}, stmt.prepareErr
}
return *stmt, nil
}
db.Mux.RUnlock()
db.Mux.Lock()
// double check
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
db.Mux.Unlock()
// wait for other goroutines prepared
<-stmt.prepared
if stmt.prepareErr != nil {
return Stmt{}, stmt.prepareErr
}
return *stmt, nil
}
// cache preparing stmt first
cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})}
db.Stmts[query] = &cacheStmt
db.Mux.Unlock()
// prepare completed
defer close(cacheStmt.prepared)
// Reason why cannot lock conn.PrepareContext
// suppose the maxopen is 1, g1 is creating record and g2 is querying record.
// 1. g1 begin tx, g1 is requeue because of waiting for the system call, now `db.ConnPool` db.numOpen == 1.
// 2. g2 select lock `conn.PrepareContext(ctx, query)`, now db.numOpen == db.maxOpen , wait for release.
// 3. g1 tx exec insert, wait for unlock `conn.PrepareContext(ctx, query)` to finish tx and release.
stmt, err := conn.PrepareContext(ctx, query)
if err != nil {
cacheStmt.prepareErr = err
db.Mux.Lock()
delete(db.Stmts, query)
db.Mux.Unlock()
return Stmt{}, err
}
db.Mux.Lock()
cacheStmt.Stmt = stmt
db.PreparedSQL = append(db.PreparedSQL, query)
db.Mux.Unlock()
return cacheStmt, nil
}
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
if beginner, ok := db.ConnPool.(TxBeginner); ok {
tx, err := beginner.BeginTx(ctx, opt)
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
}
beginner, ok := db.ConnPool.(ConnPoolBeginner)
if !ok {
return nil, ErrInvalidTransaction
}
connPool, err := beginner.BeginTx(ctx, opt)
if err != nil {
return nil, err
}
if tx, ok := connPool.(Tx); ok {
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, nil
}
return nil, ErrInvalidTransaction
}
func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...any) (result sql.Result, err error) {
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
if err == nil {
result, err = stmt.ExecContext(ctx, args...)
if errors.Is(err, driver.ErrBadConn) {
db.Mux.Lock()
defer db.Mux.Unlock()
go stmt.Close()
delete(db.Stmts, query)
}
}
return result, err
}
func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...any) (rows *sql.Rows, err error) {
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
if err == nil {
rows, err = stmt.QueryContext(ctx, args...)
if errors.Is(err, driver.ErrBadConn) {
db.Mux.Lock()
defer db.Mux.Unlock()
go stmt.Close()
delete(db.Stmts, query)
}
}
return rows, err
}
func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
if err == nil {
return stmt.QueryRowContext(ctx, args...)
}
return &sql.Row{}
}
type PreparedStmtTX struct {
Tx
PreparedStmtDB *PreparedStmtDB
}
func (db *PreparedStmtTX) GetDBConn() (*sql.DB, error) {
return db.PreparedStmtDB.GetDBConn()
}
func (tx *PreparedStmtTX) Commit() error {
if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
return tx.Tx.Commit()
}
return ErrInvalidTransaction
}
func (tx *PreparedStmtTX) Rollback() error {
if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
return tx.Tx.Rollback()
}
return ErrInvalidTransaction
}
func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...any) (result sql.Result, err error) {
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
if err == nil {
result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...)
if errors.Is(err, driver.ErrBadConn) {
tx.PreparedStmtDB.Mux.Lock()
defer tx.PreparedStmtDB.Mux.Unlock()
go stmt.Close()
delete(tx.PreparedStmtDB.Stmts, query)
}
}
return result, err
}
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...any) (rows *sql.Rows, err error) {
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
if err == nil {
rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...)
if errors.Is(err, driver.ErrBadConn) {
tx.PreparedStmtDB.Mux.Lock()
defer tx.PreparedStmtDB.Mux.Unlock()
go stmt.Close()
delete(tx.PreparedStmtDB.Stmts, query)
}
}
return rows, err
}
func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
if err == nil {
return tx.Tx.StmtContext(ctx, stmt.Stmt).QueryRowContext(ctx, args...)
}
return &sql.Row{}
}
================================================
FILE: gossh/gorm/scan.go
================================================
package gorm
import (
"database/sql"
"database/sql/driver"
"reflect"
"time"
"gossh/gorm/schema"
"gossh/gorm/utils"
)
// prepareValues prepare values slice
func prepareValues(values []any, db *DB, columnTypes []*sql.ColumnType, columns []string) {
if db.Statement.Schema != nil {
for idx, name := range columns {
if field := db.Statement.Schema.LookUpField(name); field != nil {
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
continue
}
values[idx] = new(any)
}
} else if len(columnTypes) > 0 {
for idx, columnType := range columnTypes {
if columnType.ScanType() != nil {
values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface()
} else {
values[idx] = new(any)
}
}
} else {
for idx := range columns {
values[idx] = new(any)
}
}
}
func scanIntoMap(mapValue map[string]any, values []any, columns []string) {
for idx, column := range columns {
if reflectValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(values[idx]))); reflectValue.IsValid() {
mapValue[column] = reflectValue.Interface()
if valuer, ok := mapValue[column].(driver.Valuer); ok {
mapValue[column], _ = valuer.Value()
} else if b, ok := mapValue[column].(sql.RawBytes); ok {
mapValue[column] = string(b)
}
} else {
mapValue[column] = nil
}
}
}
func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []any, fields []*schema.Field, joinFields [][]*schema.Field) {
for idx, field := range fields {
if field != nil {
values[idx] = field.NewValuePool.Get()
} else if len(fields) == 1 {
if reflectValue.CanAddr() {
values[idx] = reflectValue.Addr().Interface()
} else {
values[idx] = reflectValue.Interface()
}
}
}
db.RowsAffected++
db.AddError(rows.Scan(values...))
joinedNestedSchemaMap := make(map[string]any)
for idx, field := range fields {
if field == nil {
continue
}
if len(joinFields) == 0 || len(joinFields[idx]) == 0 {
db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx]))
} else { // joinFields count is larger than 2 when using join
var isNilPtrValue bool
var relValue reflect.Value
// does not contain raw dbname
nestedJoinSchemas := joinFields[idx][:len(joinFields[idx])-1]
// current reflect value
currentReflectValue := reflectValue
fullRels := make([]string, 0, len(nestedJoinSchemas))
for _, joinSchema := range nestedJoinSchemas {
fullRels = append(fullRels, joinSchema.Name)
relValue = joinSchema.ReflectValueOf(db.Statement.Context, currentReflectValue)
if relValue.Kind() == reflect.Ptr {
fullRelsName := utils.JoinNestedRelationNames(fullRels)
// same nested structure
if _, ok := joinedNestedSchemaMap[fullRelsName]; !ok {
if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
isNilPtrValue = true
break
}
relValue.Set(reflect.New(relValue.Type().Elem()))
joinedNestedSchemaMap[fullRelsName] = nil
}
}
currentReflectValue = relValue
}
if !isNilPtrValue { // ignore if value is nil
f := joinFields[idx][len(joinFields[idx])-1]
db.AddError(f.Set(db.Statement.Context, relValue, values[idx]))
}
}
// release data to pool
field.NewValuePool.Put(values[idx])
}
}
// ScanMode scan data mode
type ScanMode uint8
// scan modes
const (
ScanInitialized ScanMode = 1 << 0 // 1
ScanUpdate ScanMode = 1 << 1 // 2
ScanOnConflictDoNothing ScanMode = 1 << 2 // 4
)
// Scan scan rows into db statement
func Scan(rows Rows, db *DB, mode ScanMode) {
var (
columns, _ = rows.Columns()
values = make([]any, len(columns))
initialized = mode&ScanInitialized != 0
update = mode&ScanUpdate != 0
onConflictDonothing = mode&ScanOnConflictDoNothing != 0
)
db.RowsAffected = 0
switch dest := db.Statement.Dest.(type) {
case map[string]any, *map[string]any:
if initialized || rows.Next() {
columnTypes, _ := rows.ColumnTypes()
prepareValues(values, db, columnTypes, columns)
db.RowsAffected++
db.AddError(rows.Scan(values...))
mapValue, ok := dest.(map[string]any)
if !ok {
if v, ok := dest.(*map[string]any); ok {
if *v == nil {
*v = map[string]any{}
}
mapValue = *v
}
}
scanIntoMap(mapValue, values, columns)
}
case *[]map[string]any:
columnTypes, _ := rows.ColumnTypes()
for initialized || rows.Next() {
prepareValues(values, db, columnTypes, columns)
initialized = false
db.RowsAffected++
db.AddError(rows.Scan(values...))
mapValue := map[string]any{}
scanIntoMap(mapValue, values, columns)
*dest = append(*dest, mapValue)
}
case *int, *int8, *int16, *int32, *int64,
*uint, *uint8, *uint16, *uint32, *uint64, *uintptr,
*float32, *float64,
*bool, *string, *time.Time,
*sql.NullInt32, *sql.NullInt64, *sql.NullFloat64,
*sql.NullBool, *sql.NullString, *sql.NullTime:
for initialized || rows.Next() {
initialized = false
db.RowsAffected++
db.AddError(rows.Scan(dest))
}
default:
var (
fields = make([]*schema.Field, len(columns))
joinFields [][]*schema.Field
sch = db.Statement.Schema
reflectValue = db.Statement.ReflectValue
)
if reflectValue.Kind() == reflect.Interface {
reflectValue = reflectValue.Elem()
}
reflectValueType := reflectValue.Type()
switch reflectValueType.Kind() {
case reflect.Array, reflect.Slice:
reflectValueType = reflectValueType.Elem()
}
isPtr := reflectValueType.Kind() == reflect.Ptr
if isPtr {
reflectValueType = reflectValueType.Elem()
}
if sch != nil {
if reflectValueType != sch.ModelType && reflectValueType.Kind() == reflect.Struct {
sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
}
if len(columns) == 1 {
// Is Pluck
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner
reflectValueType.Kind() != reflect.Struct || // is not struct
sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
sch = nil
}
}
// Not Pluck
if sch != nil {
matchedFieldCount := make(map[string]int, len(columns))
for idx, column := range columns {
if field := sch.LookUpField(column); field != nil && field.Readable {
fields[idx] = field
if count, ok := matchedFieldCount[column]; ok {
// handle duplicate fields
for _, selectField := range sch.Fields {
if selectField.DBName == column && selectField.Readable {
if count == 0 {
matchedFieldCount[column]++
fields[idx] = selectField
break
}
count--
}
}
} else {
matchedFieldCount[column] = 1
}
} else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
subNameCount := len(names)
// nested relation fields
relFields := make([]*schema.Field, 0, subNameCount-1)
relFields = append(relFields, rel.Field)
for _, name := range names[1 : subNameCount-1] {
rel = rel.FieldSchema.Relationships.Relations[name]
relFields = append(relFields, rel.Field)
}
// lastest name is raw dbname
dbName := names[subNameCount-1]
if field := rel.FieldSchema.LookUpField(dbName); field != nil && field.Readable {
fields[idx] = field
if len(joinFields) == 0 {
joinFields = make([][]*schema.Field, len(columns))
}
relFields = append(relFields, field)
joinFields[idx] = relFields
continue
}
}
values[idx] = &sql.RawBytes{}
} else {
values[idx] = &sql.RawBytes{}
}
}
}
}
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
var (
elem reflect.Value
isArrayKind = reflectValue.Kind() == reflect.Array
)
if !update || reflectValue.Len() == 0 {
update = false
if isArrayKind {
db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type()))
} else {
// if the slice cap is externally initialized, the externally initialized slice is directly used here
if reflectValue.Cap() == 0 {
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
} else {
reflectValue.SetLen(0)
db.Statement.ReflectValue.Set(reflectValue)
}
}
}
for initialized || rows.Next() {
BEGIN:
initialized = false
if update {
if int(db.RowsAffected) >= reflectValue.Len() {
return
}
elem = reflectValue.Index(int(db.RowsAffected))
if onConflictDonothing {
for _, field := range fields {
if _, ok := field.ValueOf(db.Statement.Context, elem); !ok {
db.RowsAffected++
goto BEGIN
}
}
}
} else {
elem = reflect.New(reflectValueType)
}
db.scanIntoStruct(rows, elem, values, fields, joinFields)
if !update {
if !isPtr {
elem = elem.Elem()
}
if isArrayKind {
if reflectValue.Len() >= int(db.RowsAffected) {
reflectValue.Index(int(db.RowsAffected - 1)).Set(elem)
}
} else {
reflectValue = reflect.Append(reflectValue, elem)
}
}
}
if !update {
db.Statement.ReflectValue.Set(reflectValue)
}
case reflect.Struct, reflect.Ptr:
if initialized || rows.Next() {
db.scanIntoStruct(rows, reflectValue, values, fields, joinFields)
}
default:
db.AddError(rows.Scan(dest))
}
}
if err := rows.Err(); err != nil && err != db.Error {
db.AddError(err)
}
if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound && db.Error == nil {
db.AddError(ErrRecordNotFound)
}
}
================================================
FILE: gossh/gorm/schema/constraint.go
================================================
package schema
import (
"regexp"
"strings"
"gossh/gorm/clause"
)
// reg match english letters and midline
var regEnLetterAndMidline = regexp.MustCompile(`^[\w-]+$`)
type CheckConstraint struct {
Name string
Constraint string // length(phone) >= 10
*Field
}
func (chk *CheckConstraint) GetName() string { return chk.Name }
func (chk *CheckConstraint) Build() (sql string, vars []any) {
return "CONSTRAINT ? CHECK (?)", []any{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}}
}
// ParseCheckConstraints parse schema check constraints
func (schema *Schema) ParseCheckConstraints() map[string]CheckConstraint {
checks := map[string]CheckConstraint{}
for _, field := range schema.FieldsByDBName {
if chk := field.TagSettings["CHECK"]; chk != "" {
names := strings.Split(chk, ",")
if len(names) > 1 && regEnLetterAndMidline.MatchString(names[0]) {
checks[names[0]] = CheckConstraint{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field}
} else {
if names[0] == "" {
chk = strings.Join(names[1:], ",")
}
name := schema.namer.CheckerName(schema.Table, field.DBName)
checks[name] = CheckConstraint{Name: name, Constraint: chk, Field: field}
}
}
}
return checks
}
type UniqueConstraint struct {
Name string
Field *Field
}
func (uni *UniqueConstraint) GetName() string { return uni.Name }
func (uni *UniqueConstraint) Build() (sql string, vars []any) {
return "CONSTRAINT ? UNIQUE (?)", []any{clause.Column{Name: uni.Name}, clause.Column{Name: uni.Field.DBName}}
}
// ParseUniqueConstraints parse schema unique constraints
func (schema *Schema) ParseUniqueConstraints() map[string]UniqueConstraint {
uniques := make(map[string]UniqueConstraint)
for _, field := range schema.Fields {
if field.Unique {
name := schema.namer.UniqueName(schema.Table, field.DBName)
uniques[name] = UniqueConstraint{Name: name, Field: field}
}
}
return uniques
}
================================================
FILE: gossh/gorm/schema/field.go
================================================
package schema
import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"reflect"
"strconv"
"strings"
"sync"
"time"
"gossh/gorm/clause"
"gossh/gorm/now"
"gossh/gorm/utils"
)
// special types' reflect type
var (
TimeReflectType = reflect.TypeOf(time.Time{})
TimePtrReflectType = reflect.TypeOf(&time.Time{})
ByteReflectType = reflect.TypeOf(uint8(0))
)
type (
// DataType GORM data type
DataType string
// TimeType GORM time type
TimeType int64
)
// GORM time types
const (
UnixTime TimeType = 1
UnixSecond TimeType = 2
UnixMillisecond TimeType = 3
UnixNanosecond TimeType = 4
)
// GORM fields types
const (
Bool DataType = "bool"
Int DataType = "int"
Uint DataType = "uint"
Float DataType = "float"
String DataType = "string"
Time DataType = "time"
Bytes DataType = "bytes"
)
const DefaultAutoIncrementIncrement int64 = 1
// Field is the representation of model schema's field
type Field struct {
Name string
DBName string
BindNames []string
DataType DataType
GORMDataType DataType
PrimaryKey bool
AutoIncrement bool
AutoIncrementIncrement int64
Creatable bool
Updatable bool
Readable bool
AutoCreateTime TimeType
AutoUpdateTime TimeType
HasDefaultValue bool
DefaultValue string
DefaultValueInterface any
NotNull bool
Unique bool
Comment string
Size int
Precision int
Scale int
IgnoreMigration bool
FieldType reflect.Type
IndirectFieldType reflect.Type
StructField reflect.StructField
Tag reflect.StructTag
TagSettings map[string]string
Schema *Schema
EmbeddedSchema *Schema
OwnerSchema *Schema
ReflectValueOf func(context.Context, reflect.Value) reflect.Value
ValueOf func(context.Context, reflect.Value) (value any, zero bool)
Set func(context.Context, reflect.Value, any) error
Serializer SerializerInterface
NewValuePool FieldNewValuePool
// In some db (e.g. MySQL), Unique and UniqueIndex are indistinguishable.
// When a column has a (not Mul) UniqueIndex, Migrator always reports its gorm.ColumnType is Unique.
// It causes field unnecessarily migration.
// Therefore, we need to record the UniqueIndex on this column (exclude Mul UniqueIndex) for MigrateColumnUnique.
UniqueIndex string
}
func (field *Field) BindName() string {
return strings.Join(field.BindNames, ".")
}
// ParseField parses reflect.StructField to Field
func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
var (
err error
tagSetting = ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";")
)
field := &Field{
Name: fieldStruct.Name,
DBName: tagSetting["COLUMN"],
BindNames: []string{fieldStruct.Name},
FieldType: fieldStruct.Type,
IndirectFieldType: fieldStruct.Type,
StructField: fieldStruct,
Tag: fieldStruct.Tag,
TagSettings: tagSetting,
Schema: schema,
Creatable: true,
Updatable: true,
Readable: true,
PrimaryKey: utils.CheckTruth(tagSetting["PRIMARYKEY"], tagSetting["PRIMARY_KEY"]),
AutoIncrement: utils.CheckTruth(tagSetting["AUTOINCREMENT"]),
HasDefaultValue: utils.CheckTruth(tagSetting["AUTOINCREMENT"]),
NotNull: utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]),
Unique: utils.CheckTruth(tagSetting["UNIQUE"]),
Comment: tagSetting["COMMENT"],
AutoIncrementIncrement: DefaultAutoIncrementIncrement,
}
for field.IndirectFieldType.Kind() == reflect.Ptr {
field.IndirectFieldType = field.IndirectFieldType.Elem()
}
fieldValue := reflect.New(field.IndirectFieldType)
// if field is valuer, used its value or first field as data type
valuer, isValuer := fieldValue.Interface().(driver.Valuer)
if isValuer {
if _, ok := fieldValue.Interface().(GormDataTypeInterface); !ok {
if v, err := valuer.Value(); reflect.ValueOf(v).IsValid() && err == nil {
fieldValue = reflect.ValueOf(v)
}
// Use the field struct's first field type as data type, e.g: use `string` for sql.NullString
var getRealFieldValue func(reflect.Value)
getRealFieldValue = func(v reflect.Value) {
var (
rv = reflect.Indirect(v)
rvType = rv.Type()
)
if rv.Kind() == reflect.Struct && !rvType.ConvertibleTo(TimeReflectType) {
for i := 0; i < rvType.NumField(); i++ {
for key, value := range ParseTagSetting(rvType.Field(i).Tag.Get("gorm"), ";") {
if _, ok := field.TagSettings[key]; !ok {
field.TagSettings[key] = value
}
}
}
for i := 0; i < rvType.NumField(); i++ {
newFieldType := rvType.Field(i).Type
for newFieldType.Kind() == reflect.Ptr {
newFieldType = newFieldType.Elem()
}
fieldValue = reflect.New(newFieldType)
if rvType != reflect.Indirect(fieldValue).Type() {
getRealFieldValue(fieldValue)
}
if fieldValue.IsValid() {
return
}
}
}
}
getRealFieldValue(fieldValue)
}
}
if v, isSerializer := fieldValue.Interface().(SerializerInterface); isSerializer {
field.DataType = String
field.Serializer = v
} else {
serializerName := field.TagSettings["JSON"]
if serializerName == "" {
serializerName = field.TagSettings["SERIALIZER"]
}
if serializerName != "" {
if serializer, ok := GetSerializer(serializerName); ok {
// Set default data type to string for serializer
field.DataType = String
field.Serializer = serializer
} else {
schema.err = fmt.Errorf("invalid serializer type %v", serializerName)
}
}
}
if num, ok := field.TagSettings["AUTOINCREMENTINCREMENT"]; ok {
field.AutoIncrementIncrement, _ = strconv.ParseInt(num, 10, 64)
}
if v, ok := field.TagSettings["DEFAULT"]; ok {
field.HasDefaultValue = true
field.DefaultValue = v
}
if num, ok := field.TagSettings["SIZE"]; ok {
if field.Size, err = strconv.Atoi(num); err != nil {
field.Size = -1
}
}
if p, ok := field.TagSettings["PRECISION"]; ok {
field.Precision, _ = strconv.Atoi(p)
}
if s, ok := field.TagSettings["SCALE"]; ok {
field.Scale, _ = strconv.Atoi(s)
}
// default value is function or null or blank (primary keys)
field.DefaultValue = strings.TrimSpace(field.DefaultValue)
skipParseDefaultValue := strings.Contains(field.DefaultValue, "(") &&
strings.Contains(field.DefaultValue, ")") || strings.ToLower(field.DefaultValue) == "null" || field.DefaultValue == ""
switch reflect.Indirect(fieldValue).Kind() {
case reflect.Bool:
field.DataType = Bool
if field.HasDefaultValue && !skipParseDefaultValue {
if field.DefaultValueInterface, err = strconv.ParseBool(field.DefaultValue); err != nil {
schema.err = fmt.Errorf("failed to parse %s as default value for bool, got error: %v", field.DefaultValue, err)
}
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
field.DataType = Int
if field.HasDefaultValue && !skipParseDefaultValue {
if field.DefaultValueInterface, err = strconv.ParseInt(field.DefaultValue, 0, 64); err != nil {
schema.err = fmt.Errorf("failed to parse %s as default value for int, got error: %v", field.DefaultValue, err)
}
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
field.DataType = Uint
if field.HasDefaultValue && !skipParseDefaultValue {
if field.DefaultValueInterface, err = strconv.ParseUint(field.DefaultValue, 0, 64); err != nil {
schema.err = fmt.Errorf("failed to parse %s as default value for uint, got error: %v", field.DefaultValue, err)
}
}
case reflect.Float32, reflect.Float64:
field.DataType = Float
if field.HasDefaultValue && !skipParseDefaultValue {
if field.DefaultValueInterface, err = strconv.ParseFloat(field.DefaultValue, 64); err != nil {
schema.err = fmt.Errorf("failed to parse %s as default value for float, got error: %v", field.DefaultValue, err)
}
}
case reflect.String:
field.DataType = String
if field.HasDefaultValue && !skipParseDefaultValue {
field.DefaultValue = strings.Trim(field.DefaultValue, "'")
field.DefaultValue = strings.Trim(field.DefaultValue, `"`)
field.DefaultValueInterface = field.DefaultValue
}
case reflect.Struct:
if _, ok := fieldValue.Interface().(*time.Time); ok {
field.DataType = Time
} else if fieldValue.Type().ConvertibleTo(TimeReflectType) {
field.DataType = Time
} else if fieldValue.Type().ConvertibleTo(TimePtrReflectType) {
field.DataType = Time
}
if field.HasDefaultValue && !skipParseDefaultValue && field.DataType == Time {
if t, err := now.Parse(field.DefaultValue); err == nil {
field.DefaultValueInterface = t
}
}
case reflect.Array, reflect.Slice:
if reflect.Indirect(fieldValue).Type().Elem() == ByteReflectType && field.DataType == "" {
field.DataType = Bytes
}
}
if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok {
field.DataType = DataType(dataTyper.GormDataType())
}
if v, ok := field.TagSettings["AUTOCREATETIME"]; (ok && utils.CheckTruth(v)) || (!ok && field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
if field.DataType == Time {
field.AutoCreateTime = UnixTime
} else if strings.ToUpper(v) == "NANO" {
field.AutoCreateTime = UnixNanosecond
} else if strings.ToUpper(v) == "MILLI" {
field.AutoCreateTime = UnixMillisecond
} else {
field.AutoCreateTime = UnixSecond
}
}
if v, ok := field.TagSettings["AUTOUPDATETIME"]; (ok && utils.CheckTruth(v)) || (!ok && field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
if field.DataType == Time {
field.AutoUpdateTime = UnixTime
} else if strings.ToUpper(v) == "NANO" {
field.AutoUpdateTime = UnixNanosecond
} else if strings.ToUpper(v) == "MILLI" {
field.AutoUpdateTime = UnixMillisecond
} else {
field.AutoUpdateTime = UnixSecond
}
}
if field.GORMDataType == "" {
field.GORMDataType = field.DataType
}
if val, ok := field.TagSettings["TYPE"]; ok {
switch DataType(strings.ToLower(val)) {
case Bool, Int, Uint, Float, String, Time, Bytes:
field.DataType = DataType(strings.ToLower(val))
default:
field.DataType = DataType(val)
}
}
if field.Size == 0 {
switch reflect.Indirect(fieldValue).Kind() {
case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64, reflect.Float64:
field.Size = 64
case reflect.Int8, reflect.Uint8:
field.Size = 8
case reflect.Int16, reflect.Uint16:
field.Size = 16
case reflect.Int32, reflect.Uint32, reflect.Float32:
field.Size = 32
}
}
// setup permission
if val, ok := field.TagSettings["-"]; ok {
val = strings.ToLower(strings.TrimSpace(val))
switch val {
case "-":
field.Creatable = false
field.Updatable = false
field.Readable = false
field.DataType = ""
case "all":
field.Creatable = false
field.Updatable = false
field.Readable = false
field.DataType = ""
field.IgnoreMigration = true
case "migration":
field.IgnoreMigration = true
}
}
if v, ok := field.TagSettings["->"]; ok {
field.Creatable = false
field.Updatable = false
if strings.ToLower(v) == "false" {
field.Readable = false
} else {
field.Readable = true
}
}
if v, ok := field.TagSettings["<-"]; ok {
field.Creatable = true
field.Updatable = true
if v != "<-" {
if !strings.Contains(v, "create") {
field.Creatable = false
}
if !strings.Contains(v, "update") {
field.Updatable = false
}
}
}
// Normal anonymous field or having `EMBEDDED` tag
if _, ok := field.TagSettings["EMBEDDED"]; ok || (field.GORMDataType != Time && field.GORMDataType != Bytes && !isValuer &&
fieldStruct.Anonymous && (field.Creatable || field.Updatable || field.Readable)) {
kind := reflect.Indirect(fieldValue).Kind()
switch kind {
case reflect.Struct:
var err error
field.Creatable = false
field.Updatable = false
field.Readable = false
cacheStore := &sync.Map{}
cacheStore.Store(embeddedCacheKey, true)
if field.EmbeddedSchema, err = getOrParse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}); err != nil {
schema.err = err
}
for _, ef := range field.EmbeddedSchema.Fields {
ef.Schema = schema
ef.OwnerSchema = field.EmbeddedSchema
ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...)
// index is negative means is pointer
if field.FieldType.Kind() == reflect.Struct {
ef.StructField.Index = append([]int{fieldStruct.Index[0]}, ef.StructField.Index...)
} else {
ef.StructField.Index = append([]int{-fieldStruct.Index[0] - 1}, ef.StructField.Index...)
}
if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok && ef.DBName != "" {
ef.DBName = prefix + ef.DBName
}
if ef.PrimaryKey {
if !utils.CheckTruth(ef.TagSettings["PRIMARYKEY"], ef.TagSettings["PRIMARY_KEY"]) {
ef.PrimaryKey = false
if val, ok := ef.TagSettings["AUTOINCREMENT"]; !ok || !utils.CheckTruth(val) {
ef.AutoIncrement = false
}
if !ef.AutoIncrement && ef.DefaultValue == "" {
ef.HasDefaultValue = false
}
}
}
for k, v := range field.TagSettings {
ef.TagSettings[k] = v
}
}
case reflect.Invalid, reflect.Uintptr, reflect.Array, reflect.Chan, reflect.Func, reflect.Interface,
reflect.Map, reflect.Ptr, reflect.Slice, reflect.UnsafePointer, reflect.Complex64, reflect.Complex128:
schema.err = fmt.Errorf("invalid embedded struct for %s's field %s, should be struct, but got %v", field.Schema.Name, field.Name, field.FieldType)
}
}
return field
}
// create valuer, setter when parse struct
func (field *Field) setupValuerAndSetter() {
// Setup NewValuePool
field.setupNewValuePool()
// ValueOf returns field's value and if it is zero
fieldIndex := field.StructField.Index[0]
switch {
case len(field.StructField.Index) == 1 && fieldIndex > 0:
field.ValueOf = func(ctx context.Context, value reflect.Value) (any, bool) {
fieldValue := reflect.Indirect(value).Field(fieldIndex)
return fieldValue.Interface(), fieldValue.IsZero()
}
default:
field.ValueOf = func(ctx context.Context, v reflect.Value) (any, bool) {
v = reflect.Indirect(v)
for _, fieldIdx := range field.StructField.Index {
if fieldIdx >= 0 {
v = v.Field(fieldIdx)
} else {
v = v.Field(-fieldIdx - 1)
if !v.IsNil() {
v = v.Elem()
} else {
return nil, true
}
}
}
fv, zero := v.Interface(), v.IsZero()
return fv, zero
}
}
if field.Serializer != nil {
oldValuerOf := field.ValueOf
field.ValueOf = func(ctx context.Context, v reflect.Value) (any, bool) {
value, zero := oldValuerOf(ctx, v)
s, ok := value.(SerializerValuerInterface)
if !ok {
s = field.Serializer
}
return &serializer{
Field: field,
SerializeValuer: s,
Destination: v,
Context: ctx,
fieldValue: value,
}, zero
}
}
// ReflectValueOf returns field's reflect value
switch {
case len(field.StructField.Index) == 1 && fieldIndex > 0:
field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value {
return reflect.Indirect(value).Field(fieldIndex)
}
default:
field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value {
v = reflect.Indirect(v)
for idx, fieldIdx := range field.StructField.Index {
if fieldIdx >= 0 {
v = v.Field(fieldIdx)
} else {
v = v.Field(-fieldIdx - 1)
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
if idx < len(field.StructField.Index)-1 {
v = v.Elem()
}
}
}
return v
}
}
fallbackSetter := func(ctx context.Context, value reflect.Value, v any, setter func(context.Context, reflect.Value, any) error) (err error) {
if v == nil {
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else {
reflectV := reflect.ValueOf(v)
// Optimal value type acquisition for v
reflectValType := reflectV.Type()
if reflectValType.AssignableTo(field.FieldType) {
if reflectV.Kind() == reflect.Ptr && reflectV.Elem().Kind() == reflect.Ptr {
reflectV = reflect.Indirect(reflectV)
}
field.ReflectValueOf(ctx, value).Set(reflectV)
return
} else if reflectValType.ConvertibleTo(field.FieldType) {
field.ReflectValueOf(ctx, value).Set(reflectV.Convert(field.FieldType))
return
} else if field.FieldType.Kind() == reflect.Ptr {
fieldValue := field.ReflectValueOf(ctx, value)
fieldType := field.FieldType.Elem()
if reflectValType.AssignableTo(fieldType) {
if !fieldValue.IsValid() {
fieldValue = reflect.New(fieldType)
} else if fieldValue.IsNil() {
fieldValue.Set(reflect.New(fieldType))
}
fieldValue.Elem().Set(reflectV)
return
} else if reflectValType.ConvertibleTo(fieldType) {
if fieldValue.IsNil() {
fieldValue.Set(reflect.New(fieldType))
}
fieldValue.Elem().Set(reflectV.Convert(fieldType))
return
}
}
if reflectV.Kind() == reflect.Ptr {
if reflectV.IsNil() {
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else if reflectV.Type().Elem().AssignableTo(field.FieldType) {
field.ReflectValueOf(ctx, value).Set(reflectV.Elem())
return
} else {
err = setter(ctx, value, reflectV.Elem().Interface())
}
} else if valuer, ok := v.(driver.Valuer); ok {
if v, err = valuer.Value(); err == nil {
err = setter(ctx, value, v)
}
} else if _, ok := v.(clause.Expr); !ok {
return fmt.Errorf("failed to set value %#v to field %s", v, field.Name)
}
}
return
}
// Set
switch field.FieldType.Kind() {
case reflect.Bool:
field.Set = func(ctx context.Context, value reflect.Value, v any) error {
switch data := v.(type) {
case **bool:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetBool(**data)
}
case bool:
field.ReflectValueOf(ctx, value).SetBool(data)
case int64:
field.ReflectValueOf(ctx, value).SetBool(data > 0)
case string:
b, _ := strconv.ParseBool(data)
field.ReflectValueOf(ctx, value).SetBool(b)
default:
return fallbackSetter(ctx, value, v, field.Set)
}
return nil
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
field.Set = func(ctx context.Context, value reflect.Value, v any) (err error) {
switch data := v.(type) {
case **int64:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetInt(**data)
}
case **int:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetInt(int64(**data))
}
case **int8:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetInt(int64(**data))
}
case **int16:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetInt(int64(**data))
}
case **int32:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetInt(int64(**data))
}
case int64:
field.ReflectValueOf(ctx, value).SetInt(data)
case int:
field.ReflectValueOf(ctx, value).SetInt(int64(data))
case int8:
field.ReflectValueOf(ctx, value).SetInt(int64(data))
case int16:
field.ReflectValueOf(ctx, value).SetInt(int64(data))
case int32:
field.ReflectValueOf(ctx, value).SetInt(int64(data))
case uint:
field.ReflectValueOf(ctx, value).SetInt(int64(data))
case uint8:
field.ReflectValueOf(ctx, value).SetInt(int64(data))
case uint16:
field.ReflectValueOf(ctx, value).SetInt(int64(data))
case uint32:
field.ReflectValueOf(ctx, value).SetInt(int64(data))
case uint64:
field.ReflectValueOf(ctx, value).SetInt(int64(data))
case float32:
field.ReflectValueOf(ctx, value).SetInt(int64(data))
case float64:
field.ReflectValueOf(ctx, value).SetInt(int64(data))
case []byte:
return field.Set(ctx, value, string(data))
case string:
if i, err := strconv.ParseInt(data, 0, 64); err == nil {
field.ReflectValueOf(ctx, value).SetInt(i)
} else {
return err
}
case time.Time:
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano())
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
field.ReflectValueOf(ctx, value).SetInt(data.UnixMilli())
} else {
field.ReflectValueOf(ctx, value).SetInt(data.Unix())
}
case *time.Time:
if data != nil {
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano())
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
field.ReflectValueOf(ctx, value).SetInt(data.UnixMilli())
} else {
field.ReflectValueOf(ctx, value).SetInt(data.Unix())
}
} else {
field.ReflectValueOf(ctx, value).SetInt(0)
}
default:
return fallbackSetter(ctx, value, v, field.Set)
}
return err
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
field.Set = func(ctx context.Context, value reflect.Value, v any) (err error) {
switch data := v.(type) {
case **uint64:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetUint(**data)
}
case **uint:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetUint(uint64(**data))
}
case **uint8:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetUint(uint64(**data))
}
case **uint16:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetUint(uint64(**data))
}
case **uint32:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetUint(uint64(**data))
}
case uint64:
field.ReflectValueOf(ctx, value).SetUint(data)
case uint:
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case uint8:
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case uint16:
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case uint32:
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case int64:
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case int:
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case int8:
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case int16:
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case int32:
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case float32:
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case float64:
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case []byte:
return field.Set(ctx, value, string(data))
case time.Time:
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano()))
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixMilli()))
} else {
field.ReflectValueOf(ctx, value).SetUint(uint64(data.Unix()))
}
case string:
if i, err := strconv.ParseUint(data, 0, 64); err == nil {
field.ReflectValueOf(ctx, value).SetUint(i)
} else {
return err
}
default:
return fallbackSetter(ctx, value, v, field.Set)
}
return err
}
case reflect.Float32, reflect.Float64:
field.Set = func(ctx context.Context, value reflect.Value, v any) (err error) {
switch data := v.(type) {
case **float64:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetFloat(**data)
}
case **float32:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetFloat(float64(**data))
}
case float64:
field.ReflectValueOf(ctx, value).SetFloat(data)
case float32:
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case int64:
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case int:
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case int8:
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case int16:
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case int32:
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case uint:
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case uint8:
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case uint16:
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case uint32:
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case uint64:
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case []byte:
return field.Set(ctx, value, string(data))
case string:
if i, err := strconv.ParseFloat(data, 64); err == nil {
field.ReflectValueOf(ctx, value).SetFloat(i)
} else {
return err
}
default:
return fallbackSetter(ctx, value, v, field.Set)
}
return err
}
case reflect.String:
field.Set = func(ctx context.Context, value reflect.Value, v any) (err error) {
switch data := v.(type) {
case **string:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetString(**data)
}
case string:
field.ReflectValueOf(ctx, value).SetString(data)
case []byte:
field.ReflectValueOf(ctx, value).SetString(string(data))
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
field.ReflectValueOf(ctx, value).SetString(utils.ToString(data))
case float64, float32:
field.ReflectValueOf(ctx, value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data))
default:
return fallbackSetter(ctx, value, v, field.Set)
}
return err
}
default:
fieldValue := reflect.New(field.FieldType)
switch fieldValue.Elem().Interface().(type) {
case time.Time:
field.Set = func(ctx context.Context, value reflect.Value, v any) error {
switch data := v.(type) {
case **time.Time:
if data != nil && *data != nil {
field.Set(ctx, value, *data)
}
case time.Time:
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(v))
case *time.Time:
if data != nil {
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(data).Elem())
} else {
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(time.Time{}))
}
case string:
if t, err := now.Parse(data); err == nil {
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(t))
} else {
return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err)
}
default:
return fallbackSetter(ctx, value, v, field.Set)
}
return nil
}
case *time.Time:
field.Set = func(ctx context.Context, value reflect.Value, v any) error {
switch data := v.(type) {
case **time.Time:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(*data))
}
case time.Time:
fieldValue := field.ReflectValueOf(ctx, value)
if fieldValue.IsNil() {
fieldValue.Set(reflect.New(field.FieldType.Elem()))
}
fieldValue.Elem().Set(reflect.ValueOf(v))
case *time.Time:
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(v))
case string:
if t, err := now.Parse(data); err == nil {
fieldValue := field.ReflectValueOf(ctx, value)
if fieldValue.IsNil() {
if v == "" {
return nil
}
fieldValue.Set(reflect.New(field.FieldType.Elem()))
}
fieldValue.Elem().Set(reflect.ValueOf(t))
} else {
return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err)
}
default:
return fallbackSetter(ctx, value, v, field.Set)
}
return nil
}
default:
if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok {
// pointer scanner
field.Set = func(ctx context.Context, value reflect.Value, v any) (err error) {
reflectV := reflect.ValueOf(v)
if !reflectV.IsValid() {
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() {
return
} else if reflectV.Type().AssignableTo(field.FieldType) {
field.ReflectValueOf(ctx, value).Set(reflectV)
} else if reflectV.Kind() == reflect.Ptr {
return field.Set(ctx, value, reflectV.Elem().Interface())
} else {
fieldValue := field.ReflectValueOf(ctx, value)
if fieldValue.IsNil() {
fieldValue.Set(reflect.New(field.FieldType.Elem()))
}
if valuer, ok := v.(driver.Valuer); ok {
v, _ = valuer.Value()
}
err = fieldValue.Interface().(sql.Scanner).Scan(v)
}
return
}
} else if _, ok := fieldValue.Interface().(sql.Scanner); ok {
// struct scanner
field.Set = func(ctx context.Context, value reflect.Value, v any) (err error) {
reflectV := reflect.ValueOf(v)
if !reflectV.IsValid() {
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() {
return
} else if reflectV.Type().AssignableTo(field.FieldType) {
field.ReflectValueOf(ctx, value).Set(reflectV)
} else if reflectV.Kind() == reflect.Ptr {
return field.Set(ctx, value, reflectV.Elem().Interface())
} else {
if valuer, ok := v.(driver.Valuer); ok {
v, _ = valuer.Value()
}
err = field.ReflectValueOf(ctx, value).Addr().Interface().(sql.Scanner).Scan(v)
}
return
}
} else {
field.Set = func(ctx context.Context, value reflect.Value, v any) (err error) {
return fallbackSetter(ctx, value, v, field.Set)
}
}
}
}
if field.Serializer != nil {
var (
oldFieldSetter = field.Set
sameElemType bool
sameType = field.FieldType == reflect.ValueOf(field.Serializer).Type()
)
if reflect.ValueOf(field.Serializer).Kind() == reflect.Ptr {
sameElemType = field.FieldType == reflect.ValueOf(field.Serializer).Type().Elem()
}
serializerValue := reflect.Indirect(reflect.ValueOf(field.Serializer))
serializerType := serializerValue.Type()
field.Set = func(ctx context.Context, value reflect.Value, v any) (err error) {
if s, ok := v.(*serializer); ok {
if s.fieldValue != nil {
err = oldFieldSetter(ctx, value, s.fieldValue)
} else if err = s.Serializer.Scan(ctx, field, value, s.value); err == nil {
if sameElemType {
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer).Elem())
} else if sameType {
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer))
}
si := reflect.New(serializerType)
si.Elem().Set(serializerValue)
s.Serializer = si.Interface().(SerializerInterface)
}
} else {
err = oldFieldSetter(ctx, value, v)
}
return
}
}
}
func (field *Field) setupNewValuePool() {
if field.Serializer != nil {
serializerValue := reflect.Indirect(reflect.ValueOf(field.Serializer))
serializerType := serializerValue.Type()
field.NewValuePool = &sync.Pool{
New: func() any {
si := reflect.New(serializerType)
si.Elem().Set(serializerValue)
return &serializer{
Field: field,
Serializer: si.Interface().(SerializerInterface),
}
},
}
}
if field.NewValuePool == nil {
field.NewValuePool = poolInitializer(reflect.PtrTo(field.IndirectFieldType))
}
}
================================================
FILE: gossh/gorm/schema/index.go
================================================
package schema
import (
"fmt"
"sort"
"strconv"
"strings"
)
type Index struct {
Name string
Class string // UNIQUE | FULLTEXT | SPATIAL
Type string // btree, hash, gist, spgist, gin, and brin
Where string
Comment string
Option string // WITH PARSER parser_name
Fields []IndexOption // Note: IndexOption's Field maybe the same
}
type IndexOption struct {
*Field
Expression string
Sort string // DESC, ASC
Collate string
Length int
priority int
}
// ParseIndexes parse schema indexes
func (schema *Schema) ParseIndexes() map[string]Index {
indexes := map[string]Index{}
for _, field := range schema.Fields {
if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUEINDEX"] != "" {
fieldIndexes, err := parseFieldIndexes(field)
if err != nil {
schema.err = err
break
}
for _, index := range fieldIndexes {
idx := indexes[index.Name]
idx.Name = index.Name
if idx.Class == "" {
idx.Class = index.Class
}
if idx.Type == "" {
idx.Type = index.Type
}
if idx.Where == "" {
idx.Where = index.Where
}
if idx.Comment == "" {
idx.Comment = index.Comment
}
if idx.Option == "" {
idx.Option = index.Option
}
idx.Fields = append(idx.Fields, index.Fields...)
sort.Slice(idx.Fields, func(i, j int) bool {
return idx.Fields[i].priority < idx.Fields[j].priority
})
indexes[index.Name] = idx
}
}
}
for _, index := range indexes {
if index.Class == "UNIQUE" && len(index.Fields) == 1 {
index.Fields[0].Field.UniqueIndex = index.Name
}
}
return indexes
}
func (schema *Schema) LookIndex(name string) *Index {
if schema != nil {
indexes := schema.ParseIndexes()
for _, index := range indexes {
if index.Name == name {
return &index
}
for _, field := range index.Fields {
if field.Name == name {
return &index
}
}
}
}
return nil
}
func parseFieldIndexes(field *Field) (indexes []Index, err error) {
for _, value := range strings.Split(field.Tag.Get("gorm"), ";") {
if value != "" {
v := strings.Split(value, ":")
k := strings.TrimSpace(strings.ToUpper(v[0]))
if k == "INDEX" || k == "UNIQUEINDEX" {
var (
name string
tag = strings.Join(v[1:], ":")
idx = strings.Index(tag, ",")
tagSetting = strings.Join(strings.Split(tag, ",")[1:], ",")
settings = ParseTagSetting(tagSetting, ",")
length, _ = strconv.Atoi(settings["LENGTH"])
)
if idx == -1 {
idx = len(tag)
}
if idx != -1 {
name = tag[0:idx]
}
if name == "" {
subName := field.Name
const key = "COMPOSITE"
if composite, found := settings[key]; found {
if len(composite) == 0 || composite == key {
err = fmt.Errorf(
"The composite tag of %s.%s cannot be empty",
field.Schema.Name,
field.Name)
return
}
subName = composite
}
name = field.Schema.namer.IndexName(
field.Schema.Table, subName)
}
if (k == "UNIQUEINDEX") || settings["UNIQUE"] != "" {
settings["CLASS"] = "UNIQUE"
}
priority, err := strconv.Atoi(settings["PRIORITY"])
if err != nil {
priority = 10
}
indexes = append(indexes, Index{
Name: name,
Class: settings["CLASS"],
Type: settings["TYPE"],
Where: settings["WHERE"],
Comment: settings["COMMENT"],
Option: settings["OPTION"],
Fields: []IndexOption{{
Field: field,
Expression: settings["EXPRESSION"],
Sort: settings["SORT"],
Collate: settings["COLLATE"],
Length: length,
priority: priority,
}},
})
}
}
}
err = nil
return
}
================================================
FILE: gossh/gorm/schema/interfaces.go
================================================
package schema
import (
"gossh/gorm/clause"
)
// ConstraintInterface database constraint interface
type ConstraintInterface interface {
GetName() string
Build() (sql string, vars []any)
}
// GormDataTypeInterface gorm data type interface
type GormDataTypeInterface interface {
GormDataType() string
}
// FieldNewValuePool field new scan value pool
type FieldNewValuePool interface {
Get() any
Put(any)
}
// CreateClausesInterface create clauses interface
type CreateClausesInterface interface {
CreateClauses(*Field) []clause.Interface
}
// QueryClausesInterface query clauses interface
type QueryClausesInterface interface {
QueryClauses(*Field) []clause.Interface
}
// UpdateClausesInterface update clauses interface
type UpdateClausesInterface interface {
UpdateClauses(*Field) []clause.Interface
}
// DeleteClausesInterface delete clauses interface
type DeleteClausesInterface interface {
DeleteClauses(*Field) []clause.Interface
}
================================================
FILE: gossh/gorm/schema/naming.go
================================================
package schema
import (
"crypto/sha1"
"encoding/hex"
"regexp"
"strings"
"unicode/utf8"
"gossh/gorm/inflection"
)
// Namer namer interface
type Namer interface {
TableName(table string) string
SchemaName(table string) string
ColumnName(table, column string) string
JoinTableName(joinTable string) string
RelationshipFKName(Relationship) string
CheckerName(table, column string) string
IndexName(table, column string) string
UniqueName(table, column string) string
}
// Replacer replacer interface like strings.Replacer
type Replacer interface {
Replace(name string) string
}
var _ Namer = (*NamingStrategy)(nil)
// NamingStrategy tables, columns naming strategy
type NamingStrategy struct {
TablePrefix string
SingularTable bool
NameReplacer Replacer
NoLowerCase bool
IdentifierMaxLength int
}
// TableName convert string to table name
func (ns NamingStrategy) TableName(str string) string {
if ns.SingularTable {
return ns.TablePrefix + ns.toDBName(str)
}
return ns.TablePrefix + inflection.Plural(ns.toDBName(str))
}
// SchemaName generate schema name from table name, don't guarantee it is the reverse value of TableName
func (ns NamingStrategy) SchemaName(table string) string {
table = strings.TrimPrefix(table, ns.TablePrefix)
if ns.SingularTable {
return ns.toSchemaName(table)
}
return ns.toSchemaName(inflection.Singular(table))
}
// ColumnName convert string to column name
func (ns NamingStrategy) ColumnName(table, column string) string {
return ns.toDBName(column)
}
// JoinTableName convert string to join table name
func (ns NamingStrategy) JoinTableName(str string) string {
if !ns.NoLowerCase && strings.ToLower(str) == str {
return ns.TablePrefix + str
}
if ns.SingularTable {
return ns.TablePrefix + ns.toDBName(str)
}
return ns.TablePrefix + inflection.Plural(ns.toDBName(str))
}
// RelationshipFKName generate fk name for relation
func (ns NamingStrategy) RelationshipFKName(rel Relationship) string {
return ns.formatName("fk", rel.Schema.Table, ns.toDBName(rel.Name))
}
// CheckerName generate checker name
func (ns NamingStrategy) CheckerName(table, column string) string {
return ns.formatName("chk", table, column)
}
// IndexName generate index name
func (ns NamingStrategy) IndexName(table, column string) string {
return ns.formatName("idx", table, ns.toDBName(column))
}
// UniqueName generate unique constraint name
func (ns NamingStrategy) UniqueName(table, column string) string {
return ns.formatName("uni", table, ns.toDBName(column))
}
func (ns NamingStrategy) formatName(prefix, table, name string) string {
formattedName := strings.ReplaceAll(strings.Join([]string{
prefix, table, name,
}, "_"), ".", "_")
if ns.IdentifierMaxLength == 0 {
ns.IdentifierMaxLength = 64
}
if utf8.RuneCountInString(formattedName) > ns.IdentifierMaxLength {
h := sha1.New()
h.Write([]byte(formattedName))
bs := h.Sum(nil)
formattedName = formattedName[0:ns.IdentifierMaxLength-8] + hex.EncodeToString(bs)[:8]
}
return formattedName
}
var (
// https://github.com/golang/lint/blob/master/lint.go#L770
commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
commonInitialismsReplacer *strings.Replacer
)
func init() {
commonInitialismsForReplacer := make([]string, 0, len(commonInitialisms))
for _, initialism := range commonInitialisms {
commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism)))
}
commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...)
}
func (ns NamingStrategy) toDBName(name string) string {
if name == "" {
return ""
}
if ns.NameReplacer != nil {
tmpName := ns.NameReplacer.Replace(name)
if tmpName == "" {
return name
}
name = tmpName
}
if ns.NoLowerCase {
return name
}
var (
value = commonInitialismsReplacer.Replace(name)
buf strings.Builder
lastCase, nextCase, nextNumber bool // upper case == true
curCase = value[0] <= 'Z' && value[0] >= 'A'
)
for i, v := range value[:len(value)-1] {
nextCase = value[i+1] <= 'Z' && value[i+1] >= 'A'
nextNumber = value[i+1] >= '0' && value[i+1] <= '9'
if curCase {
if lastCase && (nextCase || nextNumber) {
buf.WriteRune(v + 32)
} else {
if i > 0 && value[i-1] != '_' && value[i+1] != '_' {
buf.WriteByte('_')
}
buf.WriteRune(v + 32)
}
} else {
buf.WriteRune(v)
}
lastCase = curCase
curCase = nextCase
}
if curCase {
if !lastCase && len(value) > 1 {
buf.WriteByte('_')
}
buf.WriteByte(value[len(value)-1] + 32)
} else {
buf.WriteByte(value[len(value)-1])
}
ret := buf.String()
return ret
}
func (ns NamingStrategy) toSchemaName(name string) string {
result := strings.ReplaceAll(strings.Title(strings.ReplaceAll(name, "_", " ")), " ", "")
for _, initialism := range commonInitialisms {
result = regexp.MustCompile(strings.Title(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1")
}
return result
}
================================================
FILE: gossh/gorm/schema/pool.go
================================================
package schema
import (
"reflect"
"sync"
)
// sync pools
var (
normalPool sync.Map
poolInitializer = func(reflectType reflect.Type) FieldNewValuePool {
v, _ := normalPool.LoadOrStore(reflectType, &sync.Pool{
New: func() any {
return reflect.New(reflectType).Interface()
},
})
return v.(FieldNewValuePool)
}
)
================================================
FILE: gossh/gorm/schema/relationship.go
================================================
package schema
import (
"context"
"fmt"
"reflect"
"strings"
"gossh/gorm/clause"
"gossh/gorm/inflection"
)
// RelationshipType relationship type
type RelationshipType string
const (
HasOne RelationshipType = "has_one" // HasOneRel has one relationship
HasMany RelationshipType = "has_many" // HasManyRel has many relationship
BelongsTo RelationshipType = "belongs_to" // BelongsToRel belongs to relationship
Many2Many RelationshipType = "many_to_many" // Many2ManyRel many to many relationship
has RelationshipType = "has"
)
type Relationships struct {
HasOne []*Relationship
BelongsTo []*Relationship
HasMany []*Relationship
Many2Many []*Relationship
Relations map[string]*Relationship
EmbeddedRelations map[string]*Relationships
}
type Relationship struct {
Name string
Type RelationshipType
Field *Field
Polymorphic *Polymorphic
References []*Reference
Schema *Schema
FieldSchema *Schema
JoinTable *Schema
foreignKeys, primaryKeys []string
}
type Polymorphic struct {
PolymorphicID *Field
PolymorphicType *Field
Value string
}
type Reference struct {
PrimaryKey *Field
PrimaryValue string
ForeignKey *Field
OwnPrimaryKey bool
}
func (schema *Schema) parseRelation(field *Field) *Relationship {
var (
err error
fieldValue = reflect.New(field.IndirectFieldType).Interface()
relation = &Relationship{
Name: field.Name,
Field: field,
Schema: schema,
foreignKeys: toColumns(field.TagSettings["FOREIGNKEY"]),
primaryKeys: toColumns(field.TagSettings["REFERENCES"]),
}
)
cacheStore := schema.cacheStore
if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil {
schema.err = err
return nil
}
if hasPolymorphicRelation(field.TagSettings) {
schema.buildPolymorphicRelation(relation, field)
} else if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
schema.buildMany2ManyRelation(relation, field, many2many)
} else if belongsTo := field.TagSettings["BELONGSTO"]; belongsTo != "" {
schema.guessRelation(relation, field, guessBelongs)
} else {
switch field.IndirectFieldType.Kind() {
case reflect.Struct:
schema.guessRelation(relation, field, guessGuess)
case reflect.Slice:
schema.guessRelation(relation, field, guessHas)
default:
schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema,
field.Name)
}
}
if relation.Type == has {
// don't add relations to embedded schema, which might be shared
if relation.FieldSchema != relation.Schema && relation.Polymorphic == nil && field.OwnerSchema == nil {
relation.FieldSchema.Relationships.Relations["_"+relation.Schema.Name+"_"+relation.Name] = relation
}
switch field.IndirectFieldType.Kind() {
case reflect.Struct:
relation.Type = HasOne
case reflect.Slice:
relation.Type = HasMany
}
}
if schema.err == nil {
schema.setRelation(relation)
switch relation.Type {
case HasOne:
schema.Relationships.HasOne = append(schema.Relationships.HasOne, relation)
case HasMany:
schema.Relationships.HasMany = append(schema.Relationships.HasMany, relation)
case BelongsTo:
schema.Relationships.BelongsTo = append(schema.Relationships.BelongsTo, relation)
case Many2Many:
schema.Relationships.Many2Many = append(schema.Relationships.Many2Many, relation)
}
}
return relation
}
// hasPolymorphicRelation check if has polymorphic relation
// 1. `POLYMORPHIC` tag
// 2. `POLYMORPHICTYPE` and `POLYMORPHICID` tag
func hasPolymorphicRelation(tagSettings map[string]string) bool {
if _, ok := tagSettings["POLYMORPHIC"]; ok {
return true
}
_, hasType := tagSettings["POLYMORPHICTYPE"]
_, hasId := tagSettings["POLYMORPHICID"]
return hasType && hasId
}
func (schema *Schema) setRelation(relation *Relationship) {
// set non-embedded relation
if rel := schema.Relationships.Relations[relation.Name]; rel != nil {
if len(rel.Field.BindNames) > 1 {
schema.Relationships.Relations[relation.Name] = relation
}
} else {
schema.Relationships.Relations[relation.Name] = relation
}
// set embedded relation
if len(relation.Field.BindNames) <= 1 {
return
}
relationships := &schema.Relationships
for i, name := range relation.Field.BindNames {
if i < len(relation.Field.BindNames)-1 {
if relationships.EmbeddedRelations == nil {
relationships.EmbeddedRelations = map[string]*Relationships{}
}
if r := relationships.EmbeddedRelations[name]; r == nil {
relationships.EmbeddedRelations[name] = &Relationships{}
}
relationships = relationships.EmbeddedRelations[name]
} else {
if relationships.Relations == nil {
relationships.Relations = map[string]*Relationship{}
}
relationships.Relations[relation.Name] = relation
}
}
}
// User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner`
//
// type User struct {
// Toys []Toy `gorm:"polymorphic:Owner;"`
// }
// type Pet struct {
// Toy Toy `gorm:"polymorphic:Owner;"`
// }
// type Toy struct {
// OwnerID int
// OwnerType string
// }
func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field) {
polymorphic := field.TagSettings["POLYMORPHIC"]
relation.Polymorphic = &Polymorphic{
Value: schema.Table,
}
var (
typeName = polymorphic + "Type"
typeId = polymorphic + "ID"
)
if value, ok := field.TagSettings["POLYMORPHICTYPE"]; ok {
typeName = strings.TrimSpace(value)
}
if value, ok := field.TagSettings["POLYMORPHICID"]; ok {
typeId = strings.TrimSpace(value)
}
relation.Polymorphic.PolymorphicType = relation.FieldSchema.FieldsByName[typeName]
relation.Polymorphic.PolymorphicID = relation.FieldSchema.FieldsByName[typeId]
if value, ok := field.TagSettings["POLYMORPHICVALUE"]; ok {
relation.Polymorphic.Value = strings.TrimSpace(value)
}
if relation.Polymorphic.PolymorphicType == nil {
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s",
relation.FieldSchema, schema, field.Name, polymorphic+"Type")
}
if relation.Polymorphic.PolymorphicID == nil {
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s",
relation.FieldSchema, schema, field.Name, polymorphic+"ID")
}
if schema.err == nil {
relation.References = append(relation.References, &Reference{
PrimaryValue: relation.Polymorphic.Value,
ForeignKey: relation.Polymorphic.PolymorphicType,
})
primaryKeyField := schema.PrioritizedPrimaryField
if len(relation.foreignKeys) > 0 {
if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 {
schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys,
schema, field.Name)
}
}
if primaryKeyField == nil {
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field",
relation.FieldSchema, schema, field.Name)
return
}
// use same data type for foreign keys
if copyableDataType(primaryKeyField.DataType) {
relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType
}
relation.Polymorphic.PolymorphicID.GORMDataType = primaryKeyField.GORMDataType
if relation.Polymorphic.PolymorphicID.Size == 0 {
relation.Polymorphic.PolymorphicID.Size = primaryKeyField.Size
}
relation.References = append(relation.References, &Reference{
PrimaryKey: primaryKeyField,
ForeignKey: relation.Polymorphic.PolymorphicID,
OwnPrimaryKey: true,
})
}
relation.Type = has
}
func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Field, many2many string) {
relation.Type = Many2Many
var (
err error
joinTableFields []reflect.StructField
fieldsMap = map[string]*Field{}
ownFieldsMap = map[string]*Field{} // fix self join many2many
referFieldsMap = map[string]*Field{}
joinForeignKeys = toColumns(field.TagSettings["JOINFOREIGNKEY"])
joinReferences = toColumns(field.TagSettings["JOINREFERENCES"])
)
ownForeignFields := schema.PrimaryFields
refForeignFields := relation.FieldSchema.PrimaryFields
if len(relation.foreignKeys) > 0 {
ownForeignFields = []*Field{}
for _, foreignKey := range relation.foreignKeys {
if field := schema.LookUpField(foreignKey); field != nil {
ownForeignFields = append(ownForeignFields, field)
} else {
schema.err = fmt.Errorf("invalid foreign key: %s", foreignKey)
return
}
}
}
if len(relation.primaryKeys) > 0 {
refForeignFields = []*Field{}
for _, foreignKey := range relation.primaryKeys {
if field := relation.FieldSchema.LookUpField(foreignKey); field != nil {
refForeignFields = append(refForeignFields, field)
} else {
schema.err = fmt.Errorf("invalid foreign key: %s", foreignKey)
return
}
}
}
for idx, ownField := range ownForeignFields {
joinFieldName := strings.Title(schema.Name) + ownField.Name
if len(joinForeignKeys) > idx {
joinFieldName = strings.Title(joinForeignKeys[idx])
}
ownFieldsMap[joinFieldName] = ownField
fieldsMap[joinFieldName] = ownField
joinTableFields = append(joinTableFields, reflect.StructField{
Name: joinFieldName,
PkgPath: ownField.StructField.PkgPath,
Type: ownField.StructField.Type,
Tag: removeSettingFromTag(appendSettingFromTag(ownField.StructField.Tag, "primaryKey"),
"column", "autoincrement", "index", "unique", "uniqueindex"),
})
}
for idx, relField := range refForeignFields {
joinFieldName := strings.Title(relation.FieldSchema.Name) + relField.Name
if _, ok := ownFieldsMap[joinFieldName]; ok {
if field.Name != relation.FieldSchema.Name {
joinFieldName = inflection.Singular(field.Name) + relField.Name
} else {
joinFieldName += "Reference"
}
}
if len(joinReferences) > idx {
joinFieldName = strings.Title(joinReferences[idx])
}
referFieldsMap[joinFieldName] = relField
if _, ok := fieldsMap[joinFieldName]; !ok {
fieldsMap[joinFieldName] = relField
joinTableFields = append(joinTableFields, reflect.StructField{
Name: joinFieldName,
PkgPath: relField.StructField.PkgPath,
Type: relField.StructField.Type,
Tag: removeSettingFromTag(appendSettingFromTag(relField.StructField.Tag, "primaryKey"),
"column", "autoincrement", "index", "unique", "uniqueindex"),
})
}
}
joinTableFields = append(joinTableFields, reflect.StructField{
Name: strings.Title(schema.Name) + field.Name,
Type: schema.ModelType,
Tag: `gorm:"-"`,
})
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore,
schema.namer); err != nil {
schema.err = err
}
relation.JoinTable.Name = many2many
relation.JoinTable.Table = schema.namer.JoinTableName(many2many)
relation.JoinTable.PrimaryFields = make([]*Field, 0, len(relation.JoinTable.Fields))
relName := relation.Schema.Name
relRefName := relation.FieldSchema.Name
if relName == relRefName {
relRefName = relation.Field.Name
}
if _, ok := relation.JoinTable.Relationships.Relations[relName]; !ok {
relation.JoinTable.Relationships.Relations[relName] = &Relationship{
Name: relName,
Type: BelongsTo,
Schema: relation.JoinTable,
FieldSchema: relation.Schema,
}
} else {
relation.JoinTable.Relationships.Relations[relName].References = []*Reference{}
}
if _, ok := relation.JoinTable.Relationships.Relations[relRefName]; !ok {
relation.JoinTable.Relationships.Relations[relRefName] = &Relationship{
Name: relRefName,
Type: BelongsTo,
Schema: relation.JoinTable,
FieldSchema: relation.FieldSchema,
}
} else {
relation.JoinTable.Relationships.Relations[relRefName].References = []*Reference{}
}
// build references
for _, f := range relation.JoinTable.Fields {
if f.Creatable || f.Readable || f.Updatable {
// use same data type for foreign keys
if copyableDataType(fieldsMap[f.Name].DataType) {
f.DataType = fieldsMap[f.Name].DataType
}
f.GORMDataType = fieldsMap[f.Name].GORMDataType
if f.Size == 0 {
f.Size = fieldsMap[f.Name].Size
}
relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f)
if of, ok := ownFieldsMap[f.Name]; ok {
joinRel := relation.JoinTable.Relationships.Relations[relName]
joinRel.Field = relation.Field
joinRel.References = append(joinRel.References, &Reference{
PrimaryKey: of,
ForeignKey: f,
})
relation.References = append(relation.References, &Reference{
PrimaryKey: of,
ForeignKey: f,
OwnPrimaryKey: true,
})
}
if rf, ok := referFieldsMap[f.Name]; ok {
joinRefRel := relation.JoinTable.Relationships.Relations[relRefName]
if joinRefRel.Field == nil {
joinRefRel.Field = relation.Field
}
joinRefRel.References = append(joinRefRel.References, &Reference{
PrimaryKey: rf,
ForeignKey: f,
})
relation.References = append(relation.References, &Reference{
PrimaryKey: rf,
ForeignKey: f,
})
}
}
}
}
type guessLevel int
const (
guessGuess guessLevel = iota
guessBelongs
guessEmbeddedBelongs
guessHas
guessEmbeddedHas
)
func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl guessLevel) {
var (
primaryFields, foreignFields []*Field
primarySchema, foreignSchema = schema, relation.FieldSchema
gl = cgl
)
if gl == guessGuess {
if field.Schema == relation.FieldSchema {
gl = guessBelongs
} else {
gl = guessHas
}
}
reguessOrErr := func() {
switch cgl {
case guessGuess:
schema.guessRelation(relation, field, guessBelongs)
case guessBelongs:
schema.guessRelation(relation, field, guessEmbeddedBelongs)
case guessEmbeddedBelongs:
schema.guessRelation(relation, field, guessHas)
case guessHas:
schema.guessRelation(relation, field, guessEmbeddedHas)
// case guessEmbeddedHas:
default:
schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface",
schema, field.Name)
}
}
switch gl {
case guessBelongs:
primarySchema, foreignSchema = relation.FieldSchema, schema
case guessEmbeddedBelongs:
if field.OwnerSchema == nil {
reguessOrErr()
return
}
primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema
case guessHas:
case guessEmbeddedHas:
if field.OwnerSchema == nil {
reguessOrErr()
return
}
primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
}
if len(relation.foreignKeys) > 0 {
for _, foreignKey := range relation.foreignKeys {
f := foreignSchema.LookUpField(foreignKey)
if f == nil {
reguessOrErr()
return
}
foreignFields = append(foreignFields, f)
}
} else {
primarySchemaName := primarySchema.Name
if primarySchemaName == "" {
primarySchemaName = relation.FieldSchema.Name
}
if len(relation.primaryKeys) > 0 {
for _, primaryKey := range relation.primaryKeys {
if f := primarySchema.LookUpField(primaryKey); f != nil {
primaryFields = append(primaryFields, f)
}
}
} else {
primaryFields = primarySchema.PrimaryFields
}
primaryFieldLoop:
for _, primaryField := range primaryFields {
lookUpName := primarySchemaName + primaryField.Name
if gl == guessBelongs {
lookUpName = field.Name + primaryField.Name
}
lookUpNames := []string{lookUpName}
if len(primaryFields) == 1 {
lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID",
strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table,
strings.TrimSuffix(lookUpName, primaryField.Name)+"ID"))
}
for _, name := range lookUpNames {
if f := foreignSchema.LookUpFieldByBindName(field.BindNames, name); f != nil {
foreignFields = append(foreignFields, f)
primaryFields = append(primaryFields, primaryField)
continue primaryFieldLoop
}
}
for _, name := range lookUpNames {
if f := foreignSchema.LookUpField(name); f != nil {
foreignFields = append(foreignFields, f)
primaryFields = append(primaryFields, primaryField)
continue primaryFieldLoop
}
}
}
}
switch {
case len(foreignFields) == 0:
reguessOrErr()
return
case len(relation.primaryKeys) > 0:
for idx, primaryKey := range relation.primaryKeys {
if f := primarySchema.LookUpField(primaryKey); f != nil {
if len(primaryFields) < idx+1 {
primaryFields = append(primaryFields, f)
} else if f != primaryFields[idx] {
reguessOrErr()
return
}
} else {
reguessOrErr()
return
}
}
case len(primaryFields) == 0:
if len(foreignFields) == 1 && primarySchema.PrioritizedPrimaryField != nil {
primaryFields = append(primaryFields, primarySchema.PrioritizedPrimaryField)
} else if len(primarySchema.PrimaryFields) == len(foreignFields) {
primaryFields = append(primaryFields, primarySchema.PrimaryFields...)
} else {
reguessOrErr()
return
}
}
// build references
for idx, foreignField := range foreignFields {
// use same data type for foreign keys
if copyableDataType(primaryFields[idx].DataType) {
foreignField.DataType = primaryFields[idx].DataType
}
foreignField.GORMDataType = primaryFields[idx].GORMDataType
if foreignField.Size == 0 {
foreignField.Size = primaryFields[idx].Size
}
relation.References = append(relation.References, &Reference{
PrimaryKey: primaryFields[idx],
ForeignKey: foreignField,
OwnPrimaryKey: (schema == primarySchema && gl == guessHas) || (field.OwnerSchema == primarySchema && gl == guessEmbeddedHas),
})
}
if gl == guessHas || gl == guessEmbeddedHas {
relation.Type = has
} else {
relation.Type = BelongsTo
}
}
// Constraint is ForeignKey Constraint
type Constraint struct {
Name string
Field *Field
Schema *Schema
ForeignKeys []*Field
ReferenceSchema *Schema
References []*Field
OnDelete string
OnUpdate string
}
func (constraint *Constraint) GetName() string { return constraint.Name }
func (constraint *Constraint) Build() (sql string, vars []any) {
sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
if constraint.OnDelete != "" {
sql += " ON DELETE " + constraint.OnDelete
}
if constraint.OnUpdate != "" {
sql += " ON UPDATE " + constraint.OnUpdate
}
foreignKeys := make([]any, 0, len(constraint.ForeignKeys))
for _, field := range constraint.ForeignKeys {
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
}
references := make([]any, 0, len(constraint.References))
for _, field := range constraint.References {
references = append(references, clause.Column{Name: field.DBName})
}
vars = append(vars, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
return
}
func (rel *Relationship) ParseConstraint() *Constraint {
str := rel.Field.TagSettings["CONSTRAINT"]
if str == "-" {
return nil
}
if rel.Type == BelongsTo {
for _, r := range rel.FieldSchema.Relationships.Relations {
if r != rel && r.FieldSchema == rel.Schema && len(rel.References) == len(r.References) {
matched := true
for idx, ref := range r.References {
if !(rel.References[idx].PrimaryKey == ref.PrimaryKey && rel.References[idx].ForeignKey == ref.ForeignKey &&
rel.References[idx].PrimaryValue == ref.PrimaryValue) {
matched = false
}
}
if matched {
return nil
}
}
}
}
var (
name string
idx = strings.Index(str, ",")
settings = ParseTagSetting(str, ",")
)
// optimize match english letters and midline
// The following code is basically called in for.
// In order to avoid the performance problems caused by repeated compilation of regular expressions,
// it only needs to be done once outside, so optimization is done here.
if idx != -1 && regEnLetterAndMidline.MatchString(str[0:idx]) {
name = str[0:idx]
} else {
name = rel.Schema.namer.RelationshipFKName(*rel)
}
constraint := Constraint{
Name: name,
Field: rel.Field,
OnUpdate: settings["ONUPDATE"],
OnDelete: settings["ONDELETE"],
}
for _, ref := range rel.References {
if ref.PrimaryKey != nil && (rel.JoinTable == nil || ref.OwnPrimaryKey) {
constraint.ForeignKeys = append(constraint.ForeignKeys, ref.ForeignKey)
constraint.References = append(constraint.References, ref.PrimaryKey)
if ref.OwnPrimaryKey {
constraint.Schema = ref.ForeignKey.Schema
constraint.ReferenceSchema = rel.Schema
} else {
constraint.Schema = rel.Schema
constraint.ReferenceSchema = ref.PrimaryKey.Schema
}
}
}
return &constraint
}
func (rel *Relationship) ToQueryConditions(ctx context.Context, reflectValue reflect.Value) (conds []clause.Expression) {
table := rel.FieldSchema.Table
foreignFields := []*Field{}
relForeignKeys := []string{}
if rel.JoinTable != nil {
table = rel.JoinTable.Table
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
foreignFields = append(foreignFields, ref.PrimaryKey)
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
} else if ref.PrimaryValue != "" {
conds = append(conds, clause.Eq{
Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName},
Value: ref.PrimaryValue,
})
} else {
conds = append(conds, clause.Eq{
Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName},
Value: clause.Column{Table: rel.FieldSchema.Table, Name: ref.PrimaryKey.DBName},
})
}
}
} else {
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
foreignFields = append(foreignFields, ref.PrimaryKey)
} else if ref.PrimaryValue != "" {
conds = append(conds, clause.Eq{
Column: clause.Column{Table: rel.FieldSchema.Table, Name: ref.ForeignKey.DBName},
Value: ref.PrimaryValue,
})
} else {
relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName)
foreignFields = append(foreignFields, ref.ForeignKey)
}
}
}
_, foreignValues := GetIdentityFieldValuesMap(ctx, reflectValue, foreignFields)
column, values := ToQueryValues(table, relForeignKeys, foreignValues)
conds = append(conds, clause.IN{Column: column, Values: values})
return
}
func copyableDataType(str DataType) bool {
for _, s := range []string{"auto_increment", "primary key"} {
if strings.Contains(strings.ToLower(string(str)), s) {
return false
}
}
return true
}
================================================
FILE: gossh/gorm/schema/schema.go
================================================
package schema
import (
"context"
"errors"
"fmt"
"go/ast"
"reflect"
"strings"
"sync"
"gossh/gorm/clause"
"gossh/gorm/logger"
)
type callbackType string
const (
callbackTypeBeforeCreate callbackType = "BeforeCreate"
callbackTypeBeforeUpdate callbackType = "BeforeUpdate"
callbackTypeAfterCreate callbackType = "AfterCreate"
callbackTypeAfterUpdate callbackType = "AfterUpdate"
callbackTypeBeforeSave callbackType = "BeforeSave"
callbackTypeAfterSave callbackType = "AfterSave"
callbackTypeBeforeDelete callbackType = "BeforeDelete"
callbackTypeAfterDelete callbackType = "AfterDelete"
callbackTypeAfterFind callbackType = "AfterFind"
)
// ErrUnsupportedDataType unsupported data type
var ErrUnsupportedDataType = errors.New("unsupported data type")
type Schema struct {
Name string
ModelType reflect.Type
Table string
PrioritizedPrimaryField *Field
DBNames []string
PrimaryFields []*Field
PrimaryFieldDBNames []string
Fields []*Field
FieldsByName map[string]*Field
FieldsByBindName map[string]*Field // embedded fields is 'Embed.Field'
FieldsByDBName map[string]*Field
FieldsWithDefaultDBValue []*Field // fields with default value assigned by database
Relationships Relationships
CreateClauses []clause.Interface
QueryClauses []clause.Interface
UpdateClauses []clause.Interface
DeleteClauses []clause.Interface
BeforeCreate, AfterCreate bool
BeforeUpdate, AfterUpdate bool
BeforeDelete, AfterDelete bool
BeforeSave, AfterSave bool
AfterFind bool
err error
initialized chan struct{}
namer Namer
cacheStore *sync.Map
}
func (schema Schema) String() string {
if schema.ModelType.Name() == "" {
return fmt.Sprintf("%s(%s)", schema.Name, schema.Table)
}
return fmt.Sprintf("%s.%s", schema.ModelType.PkgPath(), schema.ModelType.Name())
}
func (schema Schema) MakeSlice() reflect.Value {
slice := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(schema.ModelType)), 0, 20)
results := reflect.New(slice.Type())
results.Elem().Set(slice)
return results
}
func (schema Schema) LookUpField(name string) *Field {
if field, ok := schema.FieldsByDBName[name]; ok {
return field
}
if field, ok := schema.FieldsByName[name]; ok {
return field
}
return nil
}
// LookUpFieldByBindName looks for the closest field in the embedded struct.
//
// type Struct struct {
// Embedded struct {
// ID string // is selected by LookUpFieldByBindName([]string{"Embedded", "ID"}, "ID")
// }
// ID string // is selected by LookUpFieldByBindName([]string{"ID"}, "ID")
// }
func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Field {
if len(bindNames) == 0 {
return nil
}
for i := len(bindNames) - 1; i >= 0; i-- {
find := strings.Join(bindNames[:i], ".") + "." + name
if field, ok := schema.FieldsByBindName[find]; ok {
return field
}
}
return nil
}
type Tabler interface {
TableName() string
}
type TablerWithNamer interface {
TableName(Namer) string
}
// Parse get data type from dialector
func Parse(dest any, cacheStore *sync.Map, namer Namer) (*Schema, error) {
return ParseWithSpecialTableName(dest, cacheStore, namer, "")
}
// ParseWithSpecialTableName get data type from dialector with extra schema table
func ParseWithSpecialTableName(dest any, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error) {
if dest == nil {
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
}
value := reflect.ValueOf(dest)
if value.Kind() == reflect.Ptr && value.IsNil() {
value = reflect.New(value.Type().Elem())
}
modelType := reflect.Indirect(value).Type()
if modelType.Kind() == reflect.Interface {
modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type()
}
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
if modelType.Kind() != reflect.Struct {
if modelType.PkgPath() == "" {
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
}
return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
}
// Cache the Schema for performance,
// Use the modelType or modelType + schemaTable (if it present) as cache key.
var schemaCacheKey any
if specialTableName != "" {
schemaCacheKey = fmt.Sprintf("%p-%s", modelType, specialTableName)
} else {
schemaCacheKey = modelType
}
// Load exist schema cache, return if exists
if v, ok := cacheStore.Load(schemaCacheKey); ok {
s := v.(*Schema)
// Wait for the initialization of other goroutines to complete
<-s.initialized
return s, s.err
}
modelValue := reflect.New(modelType)
tableName := namer.TableName(modelType.Name())
if tabler, ok := modelValue.Interface().(Tabler); ok {
tableName = tabler.TableName()
}
if tabler, ok := modelValue.Interface().(TablerWithNamer); ok {
tableName = tabler.TableName(namer)
}
if en, ok := namer.(embeddedNamer); ok {
tableName = en.Table
}
if specialTableName != "" && specialTableName != tableName {
tableName = specialTableName
}
schema := &Schema{
Name: modelType.Name(),
ModelType: modelType,
Table: tableName,
FieldsByName: map[string]*Field{},
FieldsByBindName: map[string]*Field{},
FieldsByDBName: map[string]*Field{},
Relationships: Relationships{Relations: map[string]*Relationship{}},
cacheStore: cacheStore,
namer: namer,
initialized: make(chan struct{}),
}
// When the schema initialization is completed, the channel will be closed
defer close(schema.initialized)
// Load exist schema cache, return if exists
if v, ok := cacheStore.Load(schemaCacheKey); ok {
s := v.(*Schema)
// Wait for the initialization of other goroutines to complete
<-s.initialized
return s, s.err
}
for i := 0; i < modelType.NumField(); i++ {
if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
if field := schema.ParseField(fieldStruct); field.EmbeddedSchema != nil {
schema.Fields = append(schema.Fields, field.EmbeddedSchema.Fields...)
} else {
schema.Fields = append(schema.Fields, field)
}
}
}
for _, field := range schema.Fields {
if field.DBName == "" && field.DataType != "" {
field.DBName = namer.ColumnName(schema.Table, field.Name)
}
bindName := field.BindName()
if field.DBName != "" {
// nonexistence or shortest path or first appear prioritized if has permission
if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) {
if _, ok := schema.FieldsByDBName[field.DBName]; !ok {
schema.DBNames = append(schema.DBNames, field.DBName)
}
schema.FieldsByDBName[field.DBName] = field
schema.FieldsByName[field.Name] = field
schema.FieldsByBindName[bindName] = field
if v != nil && v.PrimaryKey {
for idx, f := range schema.PrimaryFields {
if f == v {
schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...)
}
}
}
if field.PrimaryKey {
schema.PrimaryFields = append(schema.PrimaryFields, field)
}
}
}
if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" {
schema.FieldsByName[field.Name] = field
}
if of, ok := schema.FieldsByBindName[bindName]; !ok || of.TagSettings["-"] == "-" {
schema.FieldsByBindName[bindName] = field
}
field.setupValuerAndSetter()
}
prioritizedPrimaryField := schema.LookUpField("id")
if prioritizedPrimaryField == nil {
prioritizedPrimaryField = schema.LookUpField("ID")
}
if prioritizedPrimaryField != nil {
if prioritizedPrimaryField.PrimaryKey {
schema.PrioritizedPrimaryField = prioritizedPrimaryField
} else if len(schema.PrimaryFields) == 0 {
prioritizedPrimaryField.PrimaryKey = true
schema.PrioritizedPrimaryField = prioritizedPrimaryField
schema.PrimaryFields = append(schema.PrimaryFields, prioritizedPrimaryField)
}
}
if schema.PrioritizedPrimaryField == nil {
if len(schema.PrimaryFields) == 1 {
schema.PrioritizedPrimaryField = schema.PrimaryFields[0]
} else if len(schema.PrimaryFields) > 1 {
// If there are multiple primary keys, the AUTOINCREMENT field is prioritized
for _, field := range schema.PrimaryFields {
if field.AutoIncrement {
schema.PrioritizedPrimaryField = field
break
}
}
}
}
for _, field := range schema.PrimaryFields {
schema.PrimaryFieldDBNames = append(schema.PrimaryFieldDBNames, field.DBName)
}
for _, field := range schema.Fields {
if field.DataType != "" && field.HasDefaultValue && field.DefaultValueInterface == nil {
schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
}
}
if field := schema.PrioritizedPrimaryField; field != nil {
switch field.GORMDataType {
case Int, Uint:
if _, ok := field.TagSettings["AUTOINCREMENT"]; !ok {
if !field.HasDefaultValue || field.DefaultValueInterface != nil {
schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
}
field.HasDefaultValue = true
field.AutoIncrement = true
}
}
}
callbackTypes := []callbackType{
callbackTypeBeforeCreate, callbackTypeAfterCreate,
callbackTypeBeforeUpdate, callbackTypeAfterUpdate,
callbackTypeBeforeSave, callbackTypeAfterSave,
callbackTypeBeforeDelete, callbackTypeAfterDelete,
callbackTypeAfterFind,
}
for _, cbName := range callbackTypes {
if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() {
switch methodValue.Type().String() {
case "func(*gorm.DB) error": // TODO hack
reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true)
default:
logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName)
}
}
}
// Cache the schema
if v, loaded := cacheStore.LoadOrStore(schemaCacheKey, schema); loaded {
s := v.(*Schema)
// Wait for the initialization of other goroutines to complete
<-s.initialized
return s, s.err
}
defer func() {
if schema.err != nil {
logger.Default.Error(context.Background(), schema.err.Error())
cacheStore.Delete(modelType)
}
}()
if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded {
for _, field := range schema.Fields {
if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) {
if schema.parseRelation(field); schema.err != nil {
return schema, schema.err
} else {
schema.FieldsByName[field.Name] = field
schema.FieldsByBindName[field.BindName()] = field
}
}
fieldValue := reflect.New(field.IndirectFieldType)
fieldInterface := fieldValue.Interface()
if fc, ok := fieldInterface.(CreateClausesInterface); ok {
field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...)
}
if fc, ok := fieldInterface.(QueryClausesInterface); ok {
field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...)
}
if fc, ok := fieldInterface.(UpdateClausesInterface); ok {
field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...)
}
if fc, ok := fieldInterface.(DeleteClausesInterface); ok {
field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...)
}
}
}
return schema, schema.err
}
// This unrolling is needed to show to the compiler the exact set of methods
// that can be used on the modelType.
// Prior to go1.22 any use of MethodByName would cause the linker to
// abandon dead code elimination for the entire binary.
// As of go1.22 the compiler supports one special case of a string constant
// being passed to MethodByName. For enterprise customers or those building
// large binaries, this gives a significant reduction in binary size.
// https://github.com/golang/go/issues/62257
func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect.Value {
switch cbType {
case callbackTypeBeforeCreate:
return modelType.MethodByName(string(callbackTypeBeforeCreate))
case callbackTypeAfterCreate:
return modelType.MethodByName(string(callbackTypeAfterCreate))
case callbackTypeBeforeUpdate:
return modelType.MethodByName(string(callbackTypeBeforeUpdate))
case callbackTypeAfterUpdate:
return modelType.MethodByName(string(callbackTypeAfterUpdate))
case callbackTypeBeforeSave:
return modelType.MethodByName(string(callbackTypeBeforeSave))
case callbackTypeAfterSave:
return modelType.MethodByName(string(callbackTypeAfterSave))
case callbackTypeBeforeDelete:
return modelType.MethodByName(string(callbackTypeBeforeDelete))
case callbackTypeAfterDelete:
return modelType.MethodByName(string(callbackTypeAfterDelete))
case callbackTypeAfterFind:
return modelType.MethodByName(string(callbackTypeAfterFind))
default:
return reflect.ValueOf(nil)
}
}
func getOrParse(dest any, cacheStore *sync.Map, namer Namer) (*Schema, error) {
modelType := reflect.ValueOf(dest).Type()
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
if modelType.Kind() != reflect.Struct {
if modelType.PkgPath() == "" {
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
}
return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
}
if v, ok := cacheStore.Load(modelType); ok {
return v.(*Schema), nil
}
return Parse(dest, cacheStore, namer)
}
================================================
FILE: gossh/gorm/schema/serializer.go
================================================
package schema
import (
"bytes"
"context"
"database/sql"
"database/sql/driver"
"encoding/gob"
"encoding/json"
"fmt"
"reflect"
"strings"
"sync"
"time"
)
var serializerMap = sync.Map{}
// RegisterSerializer register serializer
func RegisterSerializer(name string, serializer SerializerInterface) {
serializerMap.Store(strings.ToLower(name), serializer)
}
// GetSerializer get serializer
func GetSerializer(name string) (serializer SerializerInterface, ok bool) {
v, ok := serializerMap.Load(strings.ToLower(name))
if ok {
serializer, ok = v.(SerializerInterface)
}
return serializer, ok
}
func init() {
RegisterSerializer("json", JSONSerializer{})
RegisterSerializer("unixtime", UnixSecondSerializer{})
RegisterSerializer("gob", GobSerializer{})
}
// Serializer field value serializer
type serializer struct {
Field *Field
Serializer SerializerInterface
SerializeValuer SerializerValuerInterface
Destination reflect.Value
Context context.Context
value any
fieldValue any
}
// Scan implements sql.Scanner interface
func (s *serializer) Scan(value any) error {
s.value = value
return nil
}
// Value implements driver.Valuer interface
func (s serializer) Value() (driver.Value, error) {
return s.SerializeValuer.Value(s.Context, s.Field, s.Destination, s.fieldValue)
}
// SerializerInterface serializer interface
type SerializerInterface interface {
Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue any) error
SerializerValuerInterface
}
// SerializerValuerInterface serializer valuer interface
type SerializerValuerInterface interface {
Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue any) (any, error)
}
// JSONSerializer json serializer
type JSONSerializer struct{}
// Scan implements serializer interface
func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue any) (err error) {
fieldValue := reflect.New(field.FieldType)
if dbValue != nil {
var bytes []byte
switch v := dbValue.(type) {
case []byte:
bytes = v
case string:
bytes = []byte(v)
default:
return fmt.Errorf("failed to unmarshal JSONB value: %#v", dbValue)
}
if len(bytes) > 0 {
err = json.Unmarshal(bytes, fieldValue.Interface())
}
}
field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem())
return
}
// Value implements serializer interface
func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue any) (any, error) {
result, err := json.Marshal(fieldValue)
if string(result) == "null" {
if field.TagSettings["NOT NULL"] != "" {
return "", nil
}
return nil, err
}
return string(result), err
}
// UnixSecondSerializer json serializer
type UnixSecondSerializer struct{}
// Scan implements serializer interface
func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue any) (err error) {
t := sql.NullTime{}
if err = t.Scan(dbValue); err == nil && t.Valid {
err = field.Set(ctx, dst, t.Time.Unix())
}
return
}
// Value implements serializer interface
func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue any) (result any, err error) {
rv := reflect.ValueOf(fieldValue)
switch v := fieldValue.(type) {
case int64, int, uint, uint64, int32, uint32, int16, uint16:
result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC()
case *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16:
if rv.IsZero() {
return nil, nil
}
result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC()
default:
err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v)
}
return
}
// GobSerializer gob serializer
type GobSerializer struct{}
// Scan implements serializer interface
func (GobSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue any) (err error) {
fieldValue := reflect.New(field.FieldType)
if dbValue != nil {
var bytesValue []byte
switch v := dbValue.(type) {
case []byte:
bytesValue = v
default:
return fmt.Errorf("failed to unmarshal gob value: %#v", dbValue)
}
if len(bytesValue) > 0 {
decoder := gob.NewDecoder(bytes.NewBuffer(bytesValue))
err = decoder.Decode(fieldValue.Interface())
}
}
field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem())
return
}
// Value implements serializer interface
func (GobSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue any) (any, error) {
buf := new(bytes.Buffer)
err := gob.NewEncoder(buf).Encode(fieldValue)
return buf.Bytes(), err
}
================================================
FILE: gossh/gorm/schema/utils.go
================================================
package schema
import (
"context"
"fmt"
"reflect"
"regexp"
"strings"
"gossh/gorm/clause"
"gossh/gorm/utils"
)
var embeddedCacheKey = "embedded_cache_store"
func ParseTagSetting(str string, sep string) map[string]string {
settings := map[string]string{}
names := strings.Split(str, sep)
for i := 0; i < len(names); i++ {
j := i
if len(names[j]) > 0 {
for {
if names[j][len(names[j])-1] == '\\' {
i++
names[j] = names[j][0:len(names[j])-1] + sep + names[i]
names[i] = ""
} else {
break
}
}
}
values := strings.Split(names[j], ":")
k := strings.TrimSpace(strings.ToUpper(values[0]))
if len(values) >= 2 {
settings[k] = strings.Join(values[1:], ":")
} else if k != "" {
settings[k] = k
}
}
return settings
}
func toColumns(val string) (results []string) {
if val != "" {
for _, v := range strings.Split(val, ",") {
results = append(results, strings.TrimSpace(v))
}
}
return
}
func removeSettingFromTag(tag reflect.StructTag, names ...string) reflect.StructTag {
for _, name := range names {
tag = reflect.StructTag(regexp.MustCompile(`(?i)(gorm:.*?)(`+name+`(:.*?)?)(;|("))`).ReplaceAllString(string(tag), "${1}${5}"))
}
return tag
}
func appendSettingFromTag(tag reflect.StructTag, value string) reflect.StructTag {
t := tag.Get("gorm")
if strings.Contains(t, value) {
return tag
}
return reflect.StructTag(fmt.Sprintf(`gorm:"%s;%s"`, value, t))
}
// GetRelationsValues get relations's values from a reflect value
func GetRelationsValues(ctx context.Context, reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) {
for _, rel := range rels {
reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.FieldSchema.ModelType)), 0, 1)
appendToResults := func(value reflect.Value) {
if _, isZero := rel.Field.ValueOf(ctx, value); !isZero {
result := reflect.Indirect(rel.Field.ReflectValueOf(ctx, value))
switch result.Kind() {
case reflect.Struct:
reflectResults = reflect.Append(reflectResults, result.Addr())
case reflect.Slice, reflect.Array:
for i := 0; i < result.Len(); i++ {
if elem := result.Index(i); elem.Kind() == reflect.Ptr {
reflectResults = reflect.Append(reflectResults, elem)
} else {
reflectResults = reflect.Append(reflectResults, elem.Addr())
}
}
}
}
}
switch reflectValue.Kind() {
case reflect.Struct:
appendToResults(reflectValue)
case reflect.Slice:
for i := 0; i < reflectValue.Len(); i++ {
appendToResults(reflectValue.Index(i))
}
}
reflectValue = reflectResults
}
return
}
// GetIdentityFieldValuesMap get identity map from fields
func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value, fields []*Field) (map[string][]reflect.Value, [][]any) {
var (
results = [][]any{}
dataResults = map[string][]reflect.Value{}
loaded = map[any]bool{}
notZero, zero bool
)
if reflectValue.Kind() == reflect.Ptr ||
reflectValue.Kind() == reflect.Interface {
reflectValue = reflectValue.Elem()
}
switch reflectValue.Kind() {
case reflect.Struct:
results = [][]any{make([]any, len(fields))}
for idx, field := range fields {
results[0][idx], zero = field.ValueOf(ctx, reflectValue)
notZero = notZero || !zero
}
if !notZero {
return nil, nil
}
dataResults[utils.ToStringKey(results[0]...)] = []reflect.Value{reflectValue}
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
elem := reflectValue.Index(i)
elemKey := elem.Interface()
if elem.Kind() != reflect.Ptr && elem.CanAddr() {
elemKey = elem.Addr().Interface()
}
if _, ok := loaded[elemKey]; ok {
continue
}
loaded[elemKey] = true
fieldValues := make([]any, len(fields))
notZero = false
for idx, field := range fields {
fieldValues[idx], zero = field.ValueOf(ctx, elem)
notZero = notZero || !zero
}
if notZero {
dataKey := utils.ToStringKey(fieldValues...)
if _, ok := dataResults[dataKey]; !ok {
results = append(results, fieldValues)
dataResults[dataKey] = []reflect.Value{elem}
} else {
dataResults[dataKey] = append(dataResults[dataKey], elem)
}
}
}
}
return dataResults, results
}
// GetIdentityFieldValuesMapFromValues get identity map from fields
func GetIdentityFieldValuesMapFromValues(ctx context.Context, values []any, fields []*Field) (map[string][]reflect.Value, [][]any) {
resultsMap := map[string][]reflect.Value{}
results := [][]any{}
for _, v := range values {
rm, rs := GetIdentityFieldValuesMap(ctx, reflect.Indirect(reflect.ValueOf(v)), fields)
for k, v := range rm {
resultsMap[k] = append(resultsMap[k], v...)
}
results = append(results, rs...)
}
return resultsMap, results
}
// ToQueryValues to query values
func ToQueryValues(table string, foreignKeys []string, foreignValues [][]any) (any, []any) {
queryValues := make([]any, len(foreignValues))
if len(foreignKeys) == 1 {
for idx, r := range foreignValues {
queryValues[idx] = r[0]
}
return clause.Column{Table: table, Name: foreignKeys[0]}, queryValues
}
columns := make([]clause.Column, len(foreignKeys))
for idx, key := range foreignKeys {
columns[idx] = clause.Column{Table: table, Name: key}
}
for idx, r := range foreignValues {
queryValues[idx] = r
}
return columns, queryValues
}
type embeddedNamer struct {
Table string
Namer
}
================================================
FILE: gossh/gorm/soft_delete.go
================================================
package gorm
import (
"database/sql"
"database/sql/driver"
"encoding/json"
"reflect"
"gossh/gorm/clause"
"gossh/gorm/now"
"gossh/gorm/schema"
)
type DeletedAt sql.NullTime
// Scan implements the Scanner interface.
func (n *DeletedAt) Scan(value any) error {
return (*sql.NullTime)(n).Scan(value)
}
// Value implements the driver Valuer interface.
func (n DeletedAt) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Time, nil
}
func (n DeletedAt) MarshalJSON() ([]byte, error) {
if n.Valid {
return json.Marshal(n.Time)
}
return json.Marshal(nil)
}
func (n *DeletedAt) UnmarshalJSON(b []byte) error {
if string(b) == "null" {
n.Valid = false
return nil
}
err := json.Unmarshal(b, &n.Time)
if err == nil {
n.Valid = true
}
return err
}
func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface {
return []clause.Interface{SoftDeleteQueryClause{Field: f, ZeroValue: parseZeroValueTag(f)}}
}
func parseZeroValueTag(f *schema.Field) sql.NullString {
if v, ok := f.TagSettings["ZEROVALUE"]; ok {
if _, err := now.Parse(v); err == nil {
return sql.NullString{String: v, Valid: true}
}
}
return sql.NullString{Valid: false}
}
type SoftDeleteQueryClause struct {
ZeroValue sql.NullString
Field *schema.Field
}
func (sd SoftDeleteQueryClause) Name() string {
return ""
}
func (sd SoftDeleteQueryClause) Build(clause.Builder) {
}
func (sd SoftDeleteQueryClause) MergeClause(*clause.Clause) {
}
func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) {
if _, ok := stmt.Clauses["soft_delete_enabled"]; !ok && !stmt.Statement.Unscoped {
if c, ok := stmt.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok && len(where.Exprs) >= 1 {
for _, expr := range where.Exprs {
if orCond, ok := expr.(clause.OrConditions); ok && len(orCond.Exprs) == 1 {
where.Exprs = []clause.Expression{clause.And(where.Exprs...)}
c.Expression = where
stmt.Clauses["WHERE"] = c
break
}
}
}
}
stmt.AddClause(clause.Where{Exprs: []clause.Expression{
clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: sd.ZeroValue},
}})
stmt.Clauses["soft_delete_enabled"] = clause.Clause{}
}
}
func (DeletedAt) UpdateClauses(f *schema.Field) []clause.Interface {
return []clause.Interface{SoftDeleteUpdateClause{Field: f, ZeroValue: parseZeroValueTag(f)}}
}
type SoftDeleteUpdateClause struct {
ZeroValue sql.NullString
Field *schema.Field
}
func (sd SoftDeleteUpdateClause) Name() string {
return ""
}
func (sd SoftDeleteUpdateClause) Build(clause.Builder) {
}
func (sd SoftDeleteUpdateClause) MergeClause(*clause.Clause) {
}
func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) {
if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped {
SoftDeleteQueryClause(sd).ModifyStatement(stmt)
}
}
func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface {
return []clause.Interface{SoftDeleteDeleteClause{Field: f, ZeroValue: parseZeroValueTag(f)}}
}
type SoftDeleteDeleteClause struct {
ZeroValue sql.NullString
Field *schema.Field
}
func (sd SoftDeleteDeleteClause) Name() string {
return ""
}
func (sd SoftDeleteDeleteClause) Build(clause.Builder) {
}
func (sd SoftDeleteDeleteClause) MergeClause(*clause.Clause) {
}
func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped {
curTime := stmt.DB.NowFunc()
stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: curTime}})
stmt.SetColumn(sd.Field.DBName, curTime, true)
if stmt.Schema != nil {
_, queryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields)
column, values := schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
}
if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil {
_, queryValues = schema.GetIdentityFieldValuesMap(stmt.Context, reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields)
column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
}
}
}
SoftDeleteQueryClause(sd).ModifyStatement(stmt)
stmt.AddClauseIfNotExists(clause.Update{})
stmt.Build(stmt.DB.Callback().Update().Clauses...)
}
}
================================================
FILE: gossh/gorm/statement.go
================================================
package gorm
import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"reflect"
"regexp"
"sort"
"strconv"
"strings"
"sync"
"gossh/gorm/clause"
"gossh/gorm/logger"
"gossh/gorm/schema"
"gossh/gorm/utils"
)
// Statement statement
type Statement struct {
*DB
TableExpr *clause.Expr
Table string
Model any
Unscoped bool
Dest any
ReflectValue reflect.Value
Clauses map[string]clause.Clause
BuildClauses []string
Distinct bool
Selects []string // selected columns
Omits []string // omit columns
Joins []join
Preloads map[string][]any
Settings sync.Map
ConnPool ConnPool
Schema *schema.Schema
Context context.Context
RaiseErrorOnNotFound bool
SkipHooks bool
SQL strings.Builder
Vars []any
CurDestIndex int
attrs []any
assigns []any
scopes []func(*DB) *DB
}
type join struct {
Name string
Conds []any
On *clause.Where
Selects []string
Omits []string
JoinType clause.JoinType
}
// StatementModifier statement modifier interface
type StatementModifier interface {
ModifyStatement(*Statement)
}
// WriteString write string
func (stmt *Statement) WriteString(str string) (int, error) {
return stmt.SQL.WriteString(str)
}
// WriteByte write byte
func (stmt *Statement) WriteByte(c byte) error {
return stmt.SQL.WriteByte(c)
}
// WriteQuoted write quoted value
func (stmt *Statement) WriteQuoted(value any) {
stmt.QuoteTo(&stmt.SQL, value)
}
// QuoteTo write quoted value to writer
func (stmt *Statement) QuoteTo(writer clause.Writer, field any) {
write := func(raw bool, str string) {
if raw {
writer.WriteString(str)
} else {
stmt.DB.Dialector.QuoteTo(writer, str)
}
}
switch v := field.(type) {
case clause.Table:
if v.Name == clause.CurrentTable {
if stmt.TableExpr != nil {
stmt.TableExpr.Build(stmt)
} else {
write(v.Raw, stmt.Table)
}
} else {
write(v.Raw, v.Name)
}
if v.Alias != "" {
writer.WriteByte(' ')
write(v.Raw, v.Alias)
}
case clause.Column:
if v.Table != "" {
if v.Table == clause.CurrentTable {
write(v.Raw, stmt.Table)
} else {
write(v.Raw, v.Table)
}
writer.WriteByte('.')
}
if v.Name == clause.PrimaryKey {
if stmt.Schema == nil {
stmt.DB.AddError(ErrModelValueRequired)
} else if stmt.Schema.PrioritizedPrimaryField != nil {
write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName)
} else if len(stmt.Schema.DBNames) > 0 {
write(v.Raw, stmt.Schema.DBNames[0])
} else {
stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck
}
} else {
write(v.Raw, v.Name)
}
if v.Alias != "" {
writer.WriteString(" AS ")
write(v.Raw, v.Alias)
}
case []clause.Column:
writer.WriteByte('(')
for idx, d := range v {
if idx > 0 {
writer.WriteByte(',')
}
stmt.QuoteTo(writer, d)
}
writer.WriteByte(')')
case clause.Expr:
v.Build(stmt)
case string:
stmt.DB.Dialector.QuoteTo(writer, v)
case []string:
writer.WriteByte('(')
for idx, d := range v {
if idx > 0 {
writer.WriteByte(',')
}
stmt.DB.Dialector.QuoteTo(writer, d)
}
writer.WriteByte(')')
default:
stmt.DB.Dialector.QuoteTo(writer, fmt.Sprint(field))
}
}
// Quote returns quoted value
func (stmt *Statement) Quote(field any) string {
var builder strings.Builder
stmt.QuoteTo(&builder, field)
return builder.String()
}
// AddVar add var
func (stmt *Statement) AddVar(writer clause.Writer, vars ...any) {
for idx, v := range vars {
if idx > 0 {
writer.WriteByte(',')
}
switch v := v.(type) {
case sql.NamedArg:
stmt.Vars = append(stmt.Vars, v.Value)
case clause.Column, clause.Table:
stmt.QuoteTo(writer, v)
case Valuer:
reflectValue := reflect.ValueOf(v)
if reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil() {
stmt.AddVar(writer, nil)
} else {
stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB))
}
case clause.Interface:
c := clause.Clause{Name: v.Name()}
v.MergeClause(&c)
c.Build(stmt)
case clause.Expression:
v.Build(stmt)
case driver.Valuer:
stmt.Vars = append(stmt.Vars, v)
stmt.DB.Dialector.BindVarTo(writer, stmt, v)
case []byte:
stmt.Vars = append(stmt.Vars, v)
stmt.DB.Dialector.BindVarTo(writer, stmt, v)
case []any:
if len(v) > 0 {
writer.WriteByte('(')
stmt.AddVar(writer, v...)
writer.WriteByte(')')
} else {
writer.WriteString("(NULL)")
}
case *DB:
subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance()
if v.Statement.SQL.Len() > 0 {
var (
vars = subdb.Statement.Vars
sql = v.Statement.SQL.String()
)
subdb.Statement.Vars = make([]any, 0, len(vars))
for _, vv := range vars {
subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
bindvar := strings.Builder{}
v.Dialector.BindVarTo(&bindvar, subdb.Statement, vv)
sql = strings.Replace(sql, bindvar.String(), "?", 1)
}
subdb.Statement.SQL.Reset()
subdb.Statement.Vars = stmt.Vars
if strings.Contains(sql, "@") {
clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement)
} else {
clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement)
}
} else {
subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...)
subdb.callbacks.Query().Execute(subdb)
}
writer.WriteString(subdb.Statement.SQL.String())
stmt.Vars = subdb.Statement.Vars
default:
switch rv := reflect.ValueOf(v); rv.Kind() {
case reflect.Slice, reflect.Array:
if rv.Len() == 0 {
writer.WriteString("(NULL)")
} else if rv.Type().Elem() == reflect.TypeOf(uint8(0)) {
stmt.Vars = append(stmt.Vars, v)
stmt.DB.Dialector.BindVarTo(writer, stmt, v)
} else {
writer.WriteByte('(')
for i := 0; i < rv.Len(); i++ {
if i > 0 {
writer.WriteByte(',')
}
stmt.AddVar(writer, rv.Index(i).Interface())
}
writer.WriteByte(')')
}
default:
stmt.Vars = append(stmt.Vars, v)
stmt.DB.Dialector.BindVarTo(writer, stmt, v)
}
}
}
}
// AddClause add clause
func (stmt *Statement) AddClause(v clause.Interface) {
if optimizer, ok := v.(StatementModifier); ok {
optimizer.ModifyStatement(stmt)
} else {
name := v.Name()
c := stmt.Clauses[name]
c.Name = name
v.MergeClause(&c)
stmt.Clauses[name] = c
}
}
// AddClauseIfNotExists add clause if not exists
func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) {
if c, ok := stmt.Clauses[v.Name()]; !ok || c.Expression == nil {
stmt.AddClause(v)
}
}
// BuildCondition build condition
func (stmt *Statement) BuildCondition(query any, args ...any) []clause.Expression {
if s, ok := query.(string); ok {
// if it is a number, then treats it as primary key
if _, err := strconv.Atoi(s); err != nil {
if s == "" && len(args) == 0 {
return nil
}
if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) {
// looks like a where condition
return []clause.Expression{clause.Expr{SQL: s, Vars: args}}
}
if len(args) > 0 && strings.Contains(s, "@") {
// looks like a named query
return []clause.Expression{clause.NamedExpr{SQL: s, Vars: args}}
}
if strings.Contains(strings.TrimSpace(s), " ") {
// looks like a where condition
return []clause.Expression{clause.Expr{SQL: s, Vars: args}}
}
if len(args) == 1 {
return []clause.Expression{clause.Eq{Column: s, Value: args[0]}}
}
}
}
conds := make([]clause.Expression, 0, 4)
args = append([]any{query}, args...)
for idx, arg := range args {
if arg == nil {
continue
}
if valuer, ok := arg.(driver.Valuer); ok {
arg, _ = valuer.Value()
}
switch v := arg.(type) {
case clause.Expression:
conds = append(conds, v)
case *DB:
v.executeScopes()
if cs, ok := v.Statement.Clauses["WHERE"]; ok {
if where, ok := cs.Expression.(clause.Where); ok {
if len(where.Exprs) == 1 {
if orConds, ok := where.Exprs[0].(clause.OrConditions); ok {
where.Exprs[0] = clause.AndConditions(orConds)
}
}
conds = append(conds, clause.And(where.Exprs...))
} else if cs.Expression != nil {
conds = append(conds, cs.Expression)
}
}
case map[any]any:
for i, j := range v {
conds = append(conds, clause.Eq{Column: i, Value: j})
}
case map[string]string:
keys := make([]string, 0, len(v))
for i := range v {
keys = append(keys, i)
}
sort.Strings(keys)
for _, key := range keys {
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
}
case map[string]any:
keys := make([]string, 0, len(v))
for i := range v {
keys = append(keys, i)
}
sort.Strings(keys)
for _, key := range keys {
reflectValue := reflect.Indirect(reflect.ValueOf(v[key]))
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
if _, ok := v[key].(driver.Valuer); ok {
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
} else if _, ok := v[key].(Valuer); ok {
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
} else {
// optimize reflect value length
valueLen := reflectValue.Len()
values := make([]any, valueLen)
for i := 0; i < valueLen; i++ {
values[i] = reflectValue.Index(i).Interface()
}
conds = append(conds, clause.IN{Column: key, Values: values})
}
default:
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
}
}
default:
reflectValue := reflect.Indirect(reflect.ValueOf(arg))
for reflectValue.Kind() == reflect.Ptr {
reflectValue = reflectValue.Elem()
}
if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
selectedColumns := map[string]bool{}
if idx == 0 {
for _, v := range args[1:] {
if vs, ok := v.(string); ok {
selectedColumns[vs] = true
}
}
}
restricted := len(selectedColumns) != 0
switch reflectValue.Kind() {
case reflect.Struct:
for _, field := range s.Fields {
selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
if selected || (!restricted && field.Readable) {
if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected {
if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
} else if field.DataType != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
}
}
}
}
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
for _, field := range s.Fields {
selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
if selected || (!restricted && field.Readable) {
if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected {
if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
} else if field.DataType != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
}
}
}
}
}
}
if restricted {
break
}
} else if !reflectValue.IsValid() {
stmt.AddError(ErrInvalidData)
} else if len(conds) == 0 {
if len(args) == 1 {
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
// optimize reflect value length
valueLen := reflectValue.Len()
values := make([]any, valueLen)
for i := 0; i < valueLen; i++ {
values[i] = reflectValue.Index(i).Interface()
}
if len(values) > 0 {
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values})
return []clause.Expression{clause.And(conds...)}
}
return nil
}
}
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args})
}
}
}
if len(conds) > 0 {
return []clause.Expression{clause.And(conds...)}
}
return nil
}
// Build build sql with clauses names
func (stmt *Statement) Build(clauses ...string) {
var firstClauseWritten bool
for _, name := range clauses {
if c, ok := stmt.Clauses[name]; ok {
if firstClauseWritten {
stmt.WriteByte(' ')
}
firstClauseWritten = true
if b, ok := stmt.DB.ClauseBuilders[name]; ok {
b(c, stmt)
} else {
c.Build(stmt)
}
}
}
}
func (stmt *Statement) Parse(value any) (err error) {
return stmt.ParseWithSpecialTableName(value, "")
}
func (stmt *Statement) ParseWithSpecialTableName(value any, specialTableName string) (err error) {
if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, specialTableName); err == nil && stmt.Table == "" {
if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 {
stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)}
stmt.Table = tables[1]
return
}
stmt.Table = stmt.Schema.Table
}
return err
}
func (stmt *Statement) clone() *Statement {
newStmt := &Statement{
TableExpr: stmt.TableExpr,
Table: stmt.Table,
Model: stmt.Model,
Unscoped: stmt.Unscoped,
Dest: stmt.Dest,
ReflectValue: stmt.ReflectValue,
Clauses: map[string]clause.Clause{},
Distinct: stmt.Distinct,
Selects: stmt.Selects,
Omits: stmt.Omits,
Preloads: map[string][]any{},
ConnPool: stmt.ConnPool,
Schema: stmt.Schema,
Context: stmt.Context,
RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound,
SkipHooks: stmt.SkipHooks,
}
if stmt.SQL.Len() > 0 {
newStmt.SQL.WriteString(stmt.SQL.String())
newStmt.Vars = make([]any, 0, len(stmt.Vars))
newStmt.Vars = append(newStmt.Vars, stmt.Vars...)
}
for k, c := range stmt.Clauses {
newStmt.Clauses[k] = c
}
for k, p := range stmt.Preloads {
newStmt.Preloads[k] = p
}
if len(stmt.Joins) > 0 {
newStmt.Joins = make([]join, len(stmt.Joins))
copy(newStmt.Joins, stmt.Joins)
}
if len(stmt.scopes) > 0 {
newStmt.scopes = make([]func(*DB) *DB, len(stmt.scopes))
copy(newStmt.scopes, stmt.scopes)
}
stmt.Settings.Range(func(k, v any) bool {
newStmt.Settings.Store(k, v)
return true
})
return newStmt
}
// SetColumn set column's value
//
// stmt.SetColumn("Name", "jinzhu") // Hooks Method
// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method
func (stmt *Statement) SetColumn(name string, value any, fromCallbacks ...bool) {
if v, ok := stmt.Dest.(map[string]any); ok {
v[name] = value
} else if v, ok := stmt.Dest.([]map[string]any); ok {
for _, m := range v {
m[name] = value
}
} else if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(name); field != nil {
destValue := reflect.ValueOf(stmt.Dest)
for destValue.Kind() == reflect.Ptr {
destValue = destValue.Elem()
}
if stmt.ReflectValue != destValue {
if !destValue.CanAddr() {
destValueCanAddr := reflect.New(destValue.Type())
destValueCanAddr.Elem().Set(destValue)
stmt.Dest = destValueCanAddr.Interface()
destValue = destValueCanAddr.Elem()
}
switch destValue.Kind() {
case reflect.Struct:
stmt.AddError(field.Set(stmt.Context, destValue, value))
default:
stmt.AddError(ErrInvalidData)
}
}
switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
if len(fromCallbacks) > 0 {
for i := 0; i < stmt.ReflectValue.Len(); i++ {
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue.Index(i), value))
}
} else {
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue.Index(stmt.CurDestIndex), value))
}
case reflect.Struct:
if !stmt.ReflectValue.CanAddr() {
stmt.AddError(ErrInvalidValue)
return
}
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, value))
}
} else {
stmt.AddError(ErrInvalidField)
}
} else {
stmt.AddError(ErrInvalidField)
}
}
// Changed check model changed or not when updating
func (stmt *Statement) Changed(fields ...string) bool {
modelValue := stmt.ReflectValue
switch modelValue.Kind() {
case reflect.Slice, reflect.Array:
modelValue = stmt.ReflectValue.Index(stmt.CurDestIndex)
}
selectColumns, restricted := stmt.SelectAndOmitColumns(false, true)
changed := func(field *schema.Field) bool {
fieldValue, _ := field.ValueOf(stmt.Context, modelValue)
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if mv, mok := stmt.Dest.(map[string]any); mok {
if fv, ok := mv[field.Name]; ok {
return !utils.AssertEqual(fv, fieldValue)
} else if fv, ok := mv[field.DBName]; ok {
return !utils.AssertEqual(fv, fieldValue)
}
} else {
destValue := reflect.ValueOf(stmt.Dest)
for destValue.Kind() == reflect.Ptr {
destValue = destValue.Elem()
}
changedValue, zero := field.ValueOf(stmt.Context, destValue)
if v {
return !utils.AssertEqual(changedValue, fieldValue)
}
return !zero && !utils.AssertEqual(changedValue, fieldValue)
}
}
return false
}
if len(fields) == 0 {
for _, field := range stmt.Schema.FieldsByDBName {
if changed(field) {
return true
}
}
} else {
for _, name := range fields {
if field := stmt.Schema.LookUpField(name); field != nil {
if changed(field) {
return true
}
}
}
}
return false
}
var matchName = func() func(tableColumn string) (table, column string) {
nameMatcher := regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?(?:(\*)|\W?(\w+?)\W?)$`)
return func(tableColumn string) (table, column string) {
if matches := nameMatcher.FindStringSubmatch(tableColumn); len(matches) == 4 {
table = matches[1]
star := matches[2]
columnName := matches[3]
if star != "" {
return table, star
}
return table, columnName
}
return "", ""
}
}()
// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) {
results := map[string]bool{}
notRestricted := false
processColumn := func(column string, result bool) {
if stmt.Schema == nil {
results[column] = result
} else if column == "*" {
notRestricted = result
for _, dbName := range stmt.Schema.DBNames {
results[dbName] = result
}
} else if column == clause.Associations {
for _, rel := range stmt.Schema.Relationships.Relations {
results[rel.Name] = result
}
} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
results[field.DBName] = result
} else if table, col := matchName(column); col != "" && (table == stmt.Table || table == "") {
if col == "*" {
for _, dbName := range stmt.Schema.DBNames {
results[dbName] = result
}
} else {
results[col] = result
}
} else {
results[column] = result
}
}
// select columns
for _, column := range stmt.Selects {
processColumn(column, true)
}
// omit columns
for _, column := range stmt.Omits {
processColumn(column, false)
}
if stmt.Schema != nil {
for _, field := range stmt.Schema.FieldsByName {
name := field.DBName
if name == "" {
name = field.Name
}
if requireCreate && !field.Creatable {
results[name] = false
} else if requireUpdate && !field.Updatable {
results[name] = false
}
}
}
return results, !notRestricted && len(stmt.Selects) > 0
}
================================================
FILE: gossh/gorm/utils/tests/dummy_dialecter.go
================================================
package tests
import (
"gossh/gorm"
"gossh/gorm/callbacks"
"gossh/gorm/clause"
"gossh/gorm/logger"
"gossh/gorm/schema"
)
type DummyDialector struct {
TranslatedErr error
}
func (DummyDialector) Name() string {
return "dummy"
}
func (DummyDialector) Initialize(db *gorm.DB) error {
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"},
UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"},
DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"},
LastInsertIDReversed: true,
})
return nil
}
func (DummyDialector) DefaultValueOf(field *schema.Field) clause.Expression {
return clause.Expr{SQL: "DEFAULT"}
}
func (DummyDialector) Migrator(*gorm.DB) gorm.Migrator {
return nil
}
func (DummyDialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v any) {
writer.WriteByte('?')
}
func (DummyDialector) QuoteTo(writer clause.Writer, str string) {
var (
underQuoted, selfQuoted bool
continuousBacktick int8
shiftDelimiter int8
)
for _, v := range []byte(str) {
switch v {
case '`':
continuousBacktick++
if continuousBacktick == 2 {
writer.WriteString("``")
continuousBacktick = 0
}
case '.':
if continuousBacktick > 0 || !selfQuoted {
shiftDelimiter = 0
underQuoted = false
continuousBacktick = 0
writer.WriteByte('`')
}
writer.WriteByte(v)
continue
default:
if shiftDelimiter-continuousBacktick <= 0 && !underQuoted {
writer.WriteByte('`')
underQuoted = true
if selfQuoted = continuousBacktick > 0; selfQuoted {
continuousBacktick -= 1
}
}
for ; continuousBacktick > 0; continuousBacktick -= 1 {
writer.WriteString("``")
}
writer.WriteByte(v)
}
shiftDelimiter++
}
if continuousBacktick > 0 && !selfQuoted {
writer.WriteString("``")
}
writer.WriteByte('`')
}
func (DummyDialector) Explain(sql string, vars ...any) string {
return logger.ExplainSQL(sql, nil, `"`, vars...)
}
func (DummyDialector) DataTypeOf(*schema.Field) string {
return ""
}
func (d DummyDialector) Translate(err error) error {
return d.TranslatedErr
}
================================================
FILE: gossh/gorm/utils/tests/models.go
================================================
package tests
import (
"database/sql"
"time"
"gossh/gorm"
)
// User has one `Account` (has one), many `Pets` (has many) and `Toys` (has many - polymorphic)
// He works in a Company (belongs to), he has a Manager (belongs to - single-table), and also managed a Team (has many - single-table)
// He speaks many languages (many to many) and has many friends (many to many - single-table)
// His pet also has one Toy (has one - polymorphic)
// NamedPet is a reference to a named `Pet` (has one)
type User struct {
gorm.Model
Name string
Age uint
Birthday *time.Time
Account Account
Pets []*Pet
NamedPet *Pet
Toys []Toy `gorm:"polymorphic:Owner"`
Tools []Tools `gorm:"polymorphicType:Type;polymorphicId:CustomID"`
CompanyID *int
Company Company
ManagerID *uint
Manager *User
Team []User `gorm:"foreignkey:ManagerID"`
Languages []Language `gorm:"many2many:UserSpeak;"`
Friends []*User `gorm:"many2many:user_friends;"`
Active bool
}
type Account struct {
gorm.Model
UserID sql.NullInt64
Number string
}
type Pet struct {
gorm.Model
UserID *uint
Name string
Toy Toy `gorm:"polymorphic:Owner;"`
}
type Toy struct {
gorm.Model
Name string
OwnerID string
OwnerType string
}
type Tools struct {
gorm.Model
Name string
CustomID string
Type string
}
type Company struct {
ID int
Name string
}
type Language struct {
Code string `gorm:"primarykey"`
Name string
}
type Coupon struct {
ID int `gorm:"primarykey; size:255"`
AppliesToProduct []*CouponProduct `gorm:"foreignKey:CouponId;constraint:OnDelete:CASCADE"`
AmountOff uint32 `gorm:"column:amount_off"`
PercentOff float32 `gorm:"column:percent_off"`
}
type CouponProduct struct {
CouponId int `gorm:"primarykey;size:255"`
ProductId string `gorm:"primarykey;size:255"`
Desc string
}
type Order struct {
gorm.Model
Num string
Coupon *Coupon
CouponID string
}
type Parent struct {
gorm.Model
FavChildID uint
FavChild *Child
Children []*Child
}
type Child struct {
gorm.Model
Name string
ParentID *uint
Parent *Parent
}
================================================
FILE: gossh/gorm/utils/tests/utils.go
================================================
package tests
import (
"database/sql/driver"
"fmt"
"go/ast"
"reflect"
"testing"
"time"
"gossh/gorm/utils"
)
func AssertObjEqual(t *testing.T, r, e any, names ...string) {
for _, name := range names {
rv := reflect.Indirect(reflect.ValueOf(r))
ev := reflect.Indirect(reflect.ValueOf(e))
if rv.IsValid() != ev.IsValid() {
t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), r, e)
return
}
got := rv.FieldByName(name).Interface()
expect := ev.FieldByName(name).Interface()
t.Run(name, func(t *testing.T) {
AssertEqual(t, got, expect)
})
}
}
func AssertEqual(t *testing.T, got, expect any) {
if !reflect.DeepEqual(got, expect) {
isEqual := func() {
if curTime, ok := got.(time.Time); ok {
format := "2006-01-02T15:04:05Z07:00"
if curTime.Round(time.Second).UTC().Format(format) != expect.(time.Time).Round(time.Second).UTC().Format(format) && curTime.Truncate(time.Second).UTC().Format(format) != expect.(time.Time).Truncate(time.Second).UTC().Format(format) {
t.Errorf("%v: expect: %v, got %v after time round", utils.FileWithLineNum(), expect.(time.Time), curTime)
}
} else if fmt.Sprint(got) != fmt.Sprint(expect) {
t.Errorf("%v: expect: %#v, got %#v", utils.FileWithLineNum(), expect, got)
}
}
if fmt.Sprint(got) == fmt.Sprint(expect) {
return
}
if reflect.Indirect(reflect.ValueOf(got)).IsValid() != reflect.Indirect(reflect.ValueOf(expect)).IsValid() {
t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got)
return
}
if valuer, ok := got.(driver.Valuer); ok {
got, _ = valuer.Value()
}
if valuer, ok := expect.(driver.Valuer); ok {
expect, _ = valuer.Value()
}
if got != nil {
got = reflect.Indirect(reflect.ValueOf(got)).Interface()
}
if expect != nil {
expect = reflect.Indirect(reflect.ValueOf(expect)).Interface()
}
if reflect.ValueOf(got).IsValid() != reflect.ValueOf(expect).IsValid() {
t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got)
return
}
if reflect.ValueOf(got).Kind() == reflect.Slice {
if reflect.ValueOf(expect).Kind() == reflect.Slice {
if reflect.ValueOf(got).Len() == reflect.ValueOf(expect).Len() {
for i := 0; i < reflect.ValueOf(got).Len(); i++ {
name := fmt.Sprintf(reflect.ValueOf(got).Type().Name()+" #%v", i)
t.Run(name, func(t *testing.T) {
AssertEqual(t, reflect.ValueOf(got).Index(i).Interface(), reflect.ValueOf(expect).Index(i).Interface())
})
}
} else {
name := reflect.ValueOf(got).Type().Elem().Name()
t.Errorf("%v expects length: %v, got %v (expects: %+v, got %+v)", name, reflect.ValueOf(expect).Len(), reflect.ValueOf(got).Len(), expect, got)
}
return
}
}
if reflect.ValueOf(got).Kind() == reflect.Struct {
if reflect.ValueOf(expect).Kind() == reflect.Struct {
if reflect.ValueOf(got).NumField() == reflect.ValueOf(expect).NumField() {
exported := false
for i := 0; i < reflect.ValueOf(got).NumField(); i++ {
if fieldStruct := reflect.ValueOf(got).Type().Field(i); ast.IsExported(fieldStruct.Name) {
exported = true
field := reflect.ValueOf(got).Field(i)
t.Run(fieldStruct.Name, func(t *testing.T) {
AssertEqual(t, field.Interface(), reflect.ValueOf(expect).Field(i).Interface())
})
}
}
if exported {
return
}
}
}
}
if reflect.ValueOf(got).Type().ConvertibleTo(reflect.ValueOf(expect).Type()) {
got = reflect.ValueOf(got).Convert(reflect.ValueOf(expect).Type()).Interface()
isEqual()
} else if reflect.ValueOf(expect).Type().ConvertibleTo(reflect.ValueOf(got).Type()) {
expect = reflect.ValueOf(got).Convert(reflect.ValueOf(got).Type()).Interface()
isEqual()
} else {
t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got)
return
}
}
}
func Now() *time.Time {
now := time.Now()
return &now
}
================================================
FILE: gossh/gorm/utils/utils.go
================================================
package utils
import (
"database/sql/driver"
"fmt"
"path/filepath"
"reflect"
"runtime"
"strconv"
"strings"
"unicode"
)
var gormSourceDir string
func init() {
_, file, _, _ := runtime.Caller(0)
// compatible solution to get gorm source directory with various operating systems
gormSourceDir = sourceDir(file)
}
func sourceDir(file string) string {
dir := filepath.Dir(file)
dir = filepath.Dir(dir)
s := filepath.Dir(dir)
if filepath.Base(s) != "gorm.io" {
s = dir
}
return filepath.ToSlash(s) + "/"
}
// FileWithLineNum return the file name and line number of the current file
func FileWithLineNum() string {
// the second caller usually from gorm internal, so set i start from 2
for i := 2; i < 15; i++ {
_, file, line, ok := runtime.Caller(i)
if ok && (!strings.HasPrefix(file, gormSourceDir) || strings.HasSuffix(file, "_test.go")) &&
!strings.HasSuffix(file, ".gen.go") {
return file + ":" + strconv.FormatInt(int64(line), 10)
}
}
return ""
}
func IsValidDBNameChar(c rune) bool {
return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@'
}
// CheckTruth check string true or not
func CheckTruth(vals ...string) bool {
for _, val := range vals {
if val != "" && !strings.EqualFold(val, "false") {
return true
}
}
return false
}
func ToStringKey(values ...any) string {
results := make([]string, len(values))
for idx, value := range values {
if valuer, ok := value.(driver.Valuer); ok {
value, _ = valuer.Value()
}
switch v := value.(type) {
case string:
results[idx] = v
case []byte:
results[idx] = string(v)
case uint:
results[idx] = strconv.FormatUint(uint64(v), 10)
default:
results[idx] = "nil"
vv := reflect.ValueOf(v)
if vv.IsValid() && !vv.IsZero() {
results[idx] = fmt.Sprint(reflect.Indirect(vv).Interface())
}
}
}
return strings.Join(results, "_")
}
func Contains(elems []string, elem string) bool {
for _, e := range elems {
if elem == e {
return true
}
}
return false
}
func AssertEqual(x, y any) bool {
if reflect.DeepEqual(x, y) {
return true
}
if x == nil || y == nil {
return false
}
xval := reflect.ValueOf(x)
yval := reflect.ValueOf(y)
if xval.Kind() == reflect.Ptr && xval.IsNil() ||
yval.Kind() == reflect.Ptr && yval.IsNil() {
return false
}
if valuer, ok := x.(driver.Valuer); ok {
x, _ = valuer.Value()
}
if valuer, ok := y.(driver.Valuer); ok {
y, _ = valuer.Value()
}
return reflect.DeepEqual(x, y)
}
func ToString(value any) string {
switch v := value.(type) {
case string:
return v
case int:
return strconv.FormatInt(int64(v), 10)
case int8:
return strconv.FormatInt(int64(v), 10)
case int16:
return strconv.FormatInt(int64(v), 10)
case int32:
return strconv.FormatInt(int64(v), 10)
case int64:
return strconv.FormatInt(v, 10)
case uint:
return strconv.FormatUint(uint64(v), 10)
case uint8:
return strconv.FormatUint(uint64(v), 10)
case uint16:
return strconv.FormatUint(uint64(v), 10)
case uint32:
return strconv.FormatUint(uint64(v), 10)
case uint64:
return strconv.FormatUint(v, 10)
}
return ""
}
const nestedRelationSplit = "__"
// NestedRelationName nested relationships like `Manager__Company`
func NestedRelationName(prefix, name string) string {
return prefix + nestedRelationSplit + name
}
// SplitNestedRelationName Split nested relationships to `[]string{"Manager","Company"}`
func SplitNestedRelationName(name string) []string {
return strings.Split(name, nestedRelationSplit)
}
// JoinNestedRelationNames nested relationships like `Manager__Company`
func JoinNestedRelationNames(relationNames []string) string {
return strings.Join(relationNames, nestedRelationSplit)
}
================================================
FILE: gossh/main.go
================================================
package main
import (
"embed"
"errors"
"fmt"
"gossh/app/config"
"gossh/app/middleware"
"gossh/app/model"
"gossh/app/service"
"gossh/gin"
"io/fs"
"log/slog"
"net/http"
"os"
"path"
"path/filepath"
"strings"
)
// 使用go 1.16+ 新特性
//
//go:embed webroot
var dir embed.FS
// StaticFile 嵌入普通的静态资源
type StaticFile struct {
// 静态资源
embedFS embed.FS
// 设置embed文件到静态资源的相对路径,也就是embed注释里的路径
path string
}
// Open 静态资源被访问的核心逻辑
func (w StaticFile) Open(name string) (fs.File, error) {
if filepath.Separator != '/' && strings.ContainsRune(name, filepath.Separator) {
return nil, errors.New("http: invalid character in file path")
}
fullName := filepath.Join(w.path, filepath.FromSlash(path.Clean("/"+name)))
fullName = strings.ReplaceAll(fullName, `\`, `/`)
file, err := w.embedFS.Open(fullName)
return file, err
}
func init() {
config.InitConfig()
model.InitDatabase()
service.InitSessionClean()
service.InitSshServer()
fmt.Printf("WebBaseDir:[%s]\n", config.DefaultConfig.WebBaseDir)
}
func main() {
gin.SetMode(gin.ReleaseMode)
var engine = gin.Default()
engine.Use(middleware.DbCheck(), middleware.NetFilter())
engine.GET("/web_base_dir", func(c *gin.Context) { c.JSON(200, gin.H{"code": 0, "web_base_dir": config.DefaultConfig.WebBaseDir}) })
engine.NoRoute(func(c *gin.Context) {
c.Redirect(http.StatusMovedPermanently, config.DefaultConfig.WebBaseDir+"/app")
})
// 不需要认证的路由
var open = engine.Group(config.DefaultConfig.WebBaseDir)
open.StaticFS("/app", http.FS(StaticFile{embedFS: dir, path: "webroot"}))
open.POST("/api/login", service.UserLogin)
open.POST("/api/sys/db_conn_check", service.DbConnCheck)
open.GET("/api/sys/is_init", service.GetIsInit)
open.POST("/api/sys/init", service.SysInit)
// 需要认证的路由
var auth = engine.Group(config.DefaultConfig.WebBaseDir,
middleware.SysInit(),
middleware.JWTAuth(),
middleware.PremCheck(engine),
)
{ // SSH 连接配置
auth.GET("/api/conn_conf", service.ConfFindAll)
auth.GET("/api/conn_conf/:id", service.ConfFindByID)
auth.POST("/api/conn_conf", service.ConfCreate)
auth.PUT("/api/conn_conf", service.ConfUpdateById)
auth.DELETE("/api/conn_conf/:id", service.ConfDeleteById)
}
{ // 命令收藏
auth.GET("/api/cmd_note", service.CmdNoteFindAll)
auth.GET("/api/cmd_note/:id", service.CmdNoteFindByID)
auth.POST("/api/cmd_note", service.CmdNoteCreate)
auth.PUT("/api/cmd_note", service.CmdNoteUpdateById)
auth.DELETE("/api/cmd_note/:id", service.CmdNoteDeleteById)
}
{ // 策略配置
auth.GET("/api/policy_conf", service.PolicyConfFindAll)
auth.GET("/api/policy_conf/:id", service.PolicyConfFindByID)
auth.POST("/api/policy_conf", service.PolicyConfCreate)
auth.PUT("/api/policy_conf", service.PolicyConfUpdateById)
auth.DELETE("/api/policy_conf/:id", service.PolicyConfDeleteById)
}
{ // 访问控制
auth.GET("/api/net_filter", service.NetFilterFindAll)
auth.GET("/api/net_filter/:id", service.NetFilterFindByID)
auth.POST("/api/net_filter", service.NetFilterCreate)
auth.PUT("/api/net_filter", service.NetFilterUpdateById)
auth.DELETE("/api/net_filter/:id", service.NetFilterDeleteById)
}
{ // Web用户管理
auth.GET("/api/user", service.UserFindAll)
auth.GET("/api/user/:id", service.UserFindByID)
auth.POST("/api/user", service.UserCreate)
auth.PUT("/api/user", service.UserUpdateById)
auth.DELETE("/api/user/:id", service.UserDeleteById)
auth.PATCH("/api/user/check_name_exists", service.CheckUserNameExists)
auth.PATCH("/api/user/pwd", service.ModifyPasswd)
}
{ // SSHD用户管理
auth.GET("/api/sshd_user", service.SshdUserFindAll)
auth.GET("/api/sshd_user/:id", service.SshdUserFindByID)
auth.POST("/api/sshd_user", service.SshdUserCreate)
auth.PUT("/api/sshd_user", service.SshdUserUpdateById)
auth.DELETE("/api/sshd_user/:id", service.SshdUserDeleteById)
auth.PATCH("/api/sshd_user/check_name_exists", service.CheckSshdUserNameExists)
}
{ // SSHD证书管理
auth.GET("/api/sshd_cert", service.SshdCertFindAll)
auth.GET("/api/sshd_cert_text", service.GetSshdCertAuthorizedKeys)
auth.GET("/api/sshd_cert/:id", service.SshdCertFindByID)
auth.POST("/api/sshd_cert", service.SshdCertCreate)
auth.PUT("/api/sshd_cert", service.SshdCertUpdateById)
auth.DELETE("/api/sshd_cert/:id", service.SshdCertDeleteById)
auth.PATCH("/api/sshd_cert/check_name_exists", service.CheckSshdCertNameExists)
}
{ // 审计日志
auth.POST("/api/login_audit", service.LoginAuditSearch)
}
{ // SSH链接
auth.GET("/api/conn_manage/online_client", service.GetOnlineClient)
auth.PUT("/api/conn_manage/refresh_conn_time", service.RefreshConnTime)
auth.POST("/api/sftp/create_dir", service.SftpCreateDir)
auth.POST("/api/sftp/list", service.SftpList)
auth.GET("/api/sftp/download", service.SftpDownLoad)
auth.PUT("/api/sftp/upload", service.SftpUpload)
auth.DELETE("/api/sftp/delete", service.SftpDelete)
auth.GET("/api/ssh/conn", service.NewSshConn)
auth.PATCH("/api/ssh/conn", service.ResizeWindow)
auth.POST("/api/ssh/exec", service.ExecCommand)
auth.POST("/api/ssh/disconnect", service.Disconnect)
auth.POST("/api/ssh/create_session", service.CreateSessionId)
}
{ // 系统配置
auth.GET("/api/sys/config", service.GetRunConf)
auth.POST("/api/sys/config", service.SetRunConf)
}
address := fmt.Sprintf("%s:%s", config.DefaultConfig.Address, config.DefaultConfig.Port)
_, certErr := os.Open(config.DefaultConfig.CertFile)
_, keyErr := os.Open(config.DefaultConfig.KeyFile)
// 如果证书和私钥文件存在,就使用https协议,否则使用http协议
if certErr == nil && keyErr == nil {
slog.Info("https_server_start", "address", address)
err := engine.RunTLS(address, config.DefaultConfig.CertFile, config.DefaultConfig.KeyFile)
if err != nil {
slog.Error("RunServeTLSError:", "msg", err.Error())
os.Exit(1)
return
}
} else {
slog.Info("http_server_start", "address", address)
err := engine.Run(address)
if err != nil {
slog.Error("RunServeError:", "msg", err.Error())
os.Exit(1)
return
}
}
}
================================================
FILE: gossh/mysql/atomic_bool.go
================================================
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package.
//
// Copyright 2022 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
//go:build go1.19
// +build go1.19
package mysql
import "sync/atomic"
/******************************************************************************
* Sync utils *
******************************************************************************/
type atomicBool = atomic.Bool
================================================
FILE: gossh/mysql/atomic_bool_go118.go
================================================
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package.
//
// Copyright 2022 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
//go:build !go1.19
// +build !go1.19
package mysql
import "sync/atomic"
/******************************************************************************
* Sync utils *
******************************************************************************/
// atomicBool is an implementation of atomic.Bool for older version of Go.
// it is a wrapper around uint32 for usage as a boolean value with
// atomic access.
type atomicBool struct {
_ noCopy
value uint32
}
// Load returns whether the current boolean value is true
func (ab *atomicBool) Load() bool {
return atomic.LoadUint32(&ab.value) > 0
}
// Store sets the value of the bool regardless of the previous value
func (ab *atomicBool) Store(value bool) {
if value {
atomic.StoreUint32(&ab.value, 1)
} else {
atomic.StoreUint32(&ab.value, 0)
}
}
// Swap sets the value of the bool and returns the old value.
func (ab *atomicBool) Swap(value bool) bool {
if value {
return atomic.SwapUint32(&ab.value, 1) > 0
}
return atomic.SwapUint32(&ab.value, 0) > 0
}
================================================
FILE: gossh/mysql/auth.go
================================================
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"crypto/rand"
"crypto/rsa"
"crypto/sha1"
"crypto/sha256"
"crypto/x509"
"encoding/pem"
"fmt"
"sync"
)
// server pub keys registry
var (
serverPubKeyLock sync.RWMutex
serverPubKeyRegistry map[string]*rsa.PublicKey
)
// RegisterServerPubKey registers a server RSA public key which can be used to
// send data in a secure manner to the server without receiving the public key
// in a potentially insecure way from the server first.
// Registered keys can afterwards be used adding serverPubKey= to the DSN.
//
// Note: The provided rsa.PublicKey instance is exclusively owned by the driver
// after registering it and may not be modified.
//
// data, err := os.ReadFile("mykey.pem")
// if err != nil {
// log.Fatal(err)
// }
//
// block, _ := pem.Decode(data)
// if block == nil || block.Type != "PUBLIC KEY" {
// log.Fatal("failed to decode PEM block containing public key")
// }
//
// pub, err := x509.ParsePKIXPublicKey(block.Bytes)
// if err != nil {
// log.Fatal(err)
// }
//
// if rsaPubKey, ok := pub.(*rsa.PublicKey); ok {
// mysql.RegisterServerPubKey("mykey", rsaPubKey)
// } else {
// log.Fatal("not a RSA public key")
// }
func RegisterServerPubKey(name string, pubKey *rsa.PublicKey) {
serverPubKeyLock.Lock()
if serverPubKeyRegistry == nil {
serverPubKeyRegistry = make(map[string]*rsa.PublicKey)
}
serverPubKeyRegistry[name] = pubKey
serverPubKeyLock.Unlock()
}
// DeregisterServerPubKey removes the public key registered with the given name.
func DeregisterServerPubKey(name string) {
serverPubKeyLock.Lock()
if serverPubKeyRegistry != nil {
delete(serverPubKeyRegistry, name)
}
serverPubKeyLock.Unlock()
}
func getServerPubKey(name string) (pubKey *rsa.PublicKey) {
serverPubKeyLock.RLock()
if v, ok := serverPubKeyRegistry[name]; ok {
pubKey = v
}
serverPubKeyLock.RUnlock()
return
}
// Hash password using pre 4.1 (old password) method
// https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c
type myRnd struct {
seed1, seed2 uint32
}
const myRndMaxVal = 0x3FFFFFFF
// Pseudo random number generator
func newMyRnd(seed1, seed2 uint32) *myRnd {
return &myRnd{
seed1: seed1 % myRndMaxVal,
seed2: seed2 % myRndMaxVal,
}
}
// Tested to be equivalent to MariaDB's floating point variant
// http://play.golang.org/p/QHvhd4qved
// http://play.golang.org/p/RG0q4ElWDx
func (r *myRnd) NextByte() byte {
r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal
r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal
return byte(uint64(r.seed1) * 31 / myRndMaxVal)
}
// Generate binary hash from byte string using insecure pre 4.1 method
func pwHash(password []byte) (result [2]uint32) {
var add uint32 = 7
var tmp uint32
result[0] = 1345345333
result[1] = 0x12345671
for _, c := range password {
// skip spaces and tabs in password
if c == ' ' || c == '\t' {
continue
}
tmp = uint32(c)
result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8)
result[1] += (result[1] << 8) ^ result[0]
add += tmp
}
// Remove sign bit (1<<31)-1)
result[0] &= 0x7FFFFFFF
result[1] &= 0x7FFFFFFF
return
}
// Hash password using insecure pre 4.1 method
func scrambleOldPassword(scramble []byte, password string) []byte {
scramble = scramble[:8]
hashPw := pwHash([]byte(password))
hashSc := pwHash(scramble)
r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1])
var out [8]byte
for i := range out {
out[i] = r.NextByte() + 64
}
mask := r.NextByte()
for i := range out {
out[i] ^= mask
}
return out[:]
}
// Hash password using 4.1+ method (SHA1)
func scramblePassword(scramble []byte, password string) []byte {
if len(password) == 0 {
return nil
}
// stage1Hash = SHA1(password)
crypt := sha1.New()
crypt.Write([]byte(password))
stage1 := crypt.Sum(nil)
// scrambleHash = SHA1(scramble + SHA1(stage1Hash))
// inner Hash
crypt.Reset()
crypt.Write(stage1)
hash := crypt.Sum(nil)
// outer Hash
crypt.Reset()
crypt.Write(scramble)
crypt.Write(hash)
scramble = crypt.Sum(nil)
// token = scrambleHash XOR stage1Hash
for i := range scramble {
scramble[i] ^= stage1[i]
}
return scramble
}
// Hash password using MySQL 8+ method (SHA256)
func scrambleSHA256Password(scramble []byte, password string) []byte {
if len(password) == 0 {
return nil
}
// XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble))
crypt := sha256.New()
crypt.Write([]byte(password))
message1 := crypt.Sum(nil)
crypt.Reset()
crypt.Write(message1)
message1Hash := crypt.Sum(nil)
crypt.Reset()
crypt.Write(message1Hash)
crypt.Write(scramble)
message2 := crypt.Sum(nil)
for i := range message1 {
message1[i] ^= message2[i]
}
return message1
}
func encryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte, error) {
plain := make([]byte, len(password)+1)
copy(plain, password)
for i := range plain {
j := i % len(seed)
plain[i] ^= seed[j]
}
sha1 := sha1.New()
return rsa.EncryptOAEP(sha1, rand.Reader, pub, plain, nil)
}
func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) error {
enc, err := encryptPassword(mc.cfg.Passwd, seed, pub)
if err != nil {
return err
}
return mc.writeAuthSwitchPacket(enc)
}
func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) {
switch plugin {
case "caching_sha2_password":
authResp := scrambleSHA256Password(authData, mc.cfg.Passwd)
return authResp, nil
case "mysql_old_password":
if !mc.cfg.AllowOldPasswords {
return nil, ErrOldPassword
}
if len(mc.cfg.Passwd) == 0 {
return nil, nil
}
// Note: there are edge cases where this should work but doesn't;
// this is currently "wontfix":
// https://github.com/go-sql-driver/mysql/issues/184
authResp := append(scrambleOldPassword(authData[:8], mc.cfg.Passwd), 0)
return authResp, nil
case "mysql_clear_password":
if !mc.cfg.AllowCleartextPasswords {
return nil, ErrCleartextPassword
}
// http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html
// http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html
return append([]byte(mc.cfg.Passwd), 0), nil
case "mysql_native_password":
if !mc.cfg.AllowNativePasswords {
return nil, ErrNativePassword
}
// https://dev.mysql.com/doc/internals/en/secure-password-authentication.html
// Native password authentication only need and will need 20-byte challenge.
authResp := scramblePassword(authData[:20], mc.cfg.Passwd)
return authResp, nil
case "sha256_password":
if len(mc.cfg.Passwd) == 0 {
return []byte{0}, nil
}
// unlike caching_sha2_password, sha256_password does not accept
// cleartext password on unix transport.
if mc.cfg.TLS != nil {
// write cleartext auth packet
return append([]byte(mc.cfg.Passwd), 0), nil
}
pubKey := mc.cfg.pubKey
if pubKey == nil {
// request public key from server
return []byte{1}, nil
}
// encrypted password
enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey)
return enc, err
default:
mc.log("unknown auth plugin:", plugin)
return nil, ErrUnknownPlugin
}
}
func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
// Read Result Packet
authData, newPlugin, err := mc.readAuthResult()
if err != nil {
return err
}
// handle auth plugin switch, if requested
if newPlugin != "" {
// If CLIENT_PLUGIN_AUTH capability is not supported, no new cipher is
// sent and we have to keep using the cipher sent in the init packet.
if authData == nil {
authData = oldAuthData
} else {
// copy data from read buffer to owned slice
copy(oldAuthData, authData)
}
plugin = newPlugin
authResp, err := mc.auth(authData, plugin)
if err != nil {
return err
}
if err = mc.writeAuthSwitchPacket(authResp); err != nil {
return err
}
// Read Result Packet
authData, newPlugin, err = mc.readAuthResult()
if err != nil {
return err
}
// Do not allow to change the auth plugin more than once
if newPlugin != "" {
return ErrMalformPkt
}
}
switch plugin {
// https://dev.mysql.com/blog-archive/preparing-your-community-connector-for-mysql-8-part-2-sha256/
case "caching_sha2_password":
switch len(authData) {
case 0:
return nil // auth successful
case 1:
switch authData[0] {
case cachingSha2PasswordFastAuthSuccess:
if err = mc.resultUnchanged().readResultOK(); err == nil {
return nil // auth successful
}
case cachingSha2PasswordPerformFullAuthentication:
if mc.cfg.TLS != nil || mc.cfg.Net == "unix" {
// write cleartext auth packet
err = mc.writeAuthSwitchPacket(append([]byte(mc.cfg.Passwd), 0))
if err != nil {
return err
}
} else {
pubKey := mc.cfg.pubKey
if pubKey == nil {
// request public key from server
data, err := mc.buf.takeSmallBuffer(4 + 1)
if err != nil {
return err
}
data[4] = cachingSha2PasswordRequestPublicKey
err = mc.writePacket(data)
if err != nil {
return err
}
if data, err = mc.readPacket(); err != nil {
return err
}
if data[0] != iAuthMoreData {
return fmt.Errorf("unexpected resp from server for caching_sha2_password, perform full authentication")
}
// parse public key
block, rest := pem.Decode(data[1:])
if block == nil {
return fmt.Errorf("no pem data found, data: %s", rest)
}
pkix, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return err
}
pubKey = pkix.(*rsa.PublicKey)
}
// send encrypted password
err = mc.sendEncryptedPassword(oldAuthData, pubKey)
if err != nil {
return err
}
}
return mc.resultUnchanged().readResultOK()
default:
return ErrMalformPkt
}
default:
return ErrMalformPkt
}
case "sha256_password":
switch len(authData) {
case 0:
return nil // auth successful
default:
block, _ := pem.Decode(authData)
if block == nil {
return fmt.Errorf("no Pem data found, data: %s", authData)
}
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return err
}
// send encrypted password
err = mc.sendEncryptedPassword(oldAuthData, pub.(*rsa.PublicKey))
if err != nil {
return err
}
return mc.resultUnchanged().readResultOK()
}
default:
return nil // auth successful
}
return err
}
================================================
FILE: gossh/mysql/buffer.go
================================================
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"io"
"net"
"time"
)
const defaultBufSize = 4096
const maxCachedBufSize = 256 * 1024
// A buffer which is used for both reading and writing.
// This is possible since communication on each connection is synchronous.
// In other words, we can't write and read simultaneously on the same connection.
// The buffer is similar to bufio.Reader / Writer but zero-copy-ish
// Also highly optimized for this particular use case.
// This buffer is backed by two byte slices in a double-buffering scheme
type buffer struct {
buf []byte // buf is a byte buffer who's length and capacity are equal.
nc net.Conn
idx int
length int
timeout time.Duration
dbuf [2][]byte // dbuf is an array with the two byte slices that back this buffer
flipcnt uint // flipccnt is the current buffer counter for double-buffering
}
// newBuffer allocates and returns a new buffer.
func newBuffer(nc net.Conn) buffer {
fg := make([]byte, defaultBufSize)
return buffer{
buf: fg,
nc: nc,
dbuf: [2][]byte{fg, nil},
}
}
// flip replaces the active buffer with the background buffer
// this is a delayed flip that simply increases the buffer counter;
// the actual flip will be performed the next time we call `buffer.fill`
func (b *buffer) flip() {
b.flipcnt += 1
}
// fill reads into the buffer until at least _need_ bytes are in it
func (b *buffer) fill(need int) error {
n := b.length
// fill data into its double-buffering target: if we've called
// flip on this buffer, we'll be copying to the background buffer,
// and then filling it with network data; otherwise we'll just move
// the contents of the current buffer to the front before filling it
dest := b.dbuf[b.flipcnt&1]
// grow buffer if necessary to fit the whole packet.
if need > len(dest) {
// Round up to the next multiple of the default size
dest = make([]byte, ((need/defaultBufSize)+1)*defaultBufSize)
// if the allocated buffer is not too large, move it to backing storage
// to prevent extra allocations on applications that perform large reads
if len(dest) <= maxCachedBufSize {
b.dbuf[b.flipcnt&1] = dest
}
}
// if we're filling the fg buffer, move the existing data to the start of it.
// if we're filling the bg buffer, copy over the data
if n > 0 {
copy(dest[:n], b.buf[b.idx:])
}
b.buf = dest
b.idx = 0
for {
if b.timeout > 0 {
if err := b.nc.SetReadDeadline(time.Now().Add(b.timeout)); err != nil {
return err
}
}
nn, err := b.nc.Read(b.buf[n:])
n += nn
switch err {
case nil:
if n < need {
continue
}
b.length = n
return nil
case io.EOF:
if n >= need {
b.length = n
return nil
}
return io.ErrUnexpectedEOF
default:
return err
}
}
}
// returns next N bytes from buffer.
// The returned slice is only guaranteed to be valid until the next read
func (b *buffer) readNext(need int) ([]byte, error) {
if b.length < need {
// refill
if err := b.fill(need); err != nil {
return nil, err
}
}
offset := b.idx
b.idx += need
b.length -= need
return b.buf[offset:b.idx], nil
}
// takeBuffer returns a buffer with the requested size.
// If possible, a slice from the existing buffer is returned.
// Otherwise a bigger buffer is made.
// Only one buffer (total) can be used at a time.
func (b *buffer) takeBuffer(length int) ([]byte, error) {
if b.length > 0 {
return nil, ErrBusyBuffer
}
// test (cheap) general case first
if length <= cap(b.buf) {
return b.buf[:length], nil
}
if length < maxPacketSize {
b.buf = make([]byte, length)
return b.buf, nil
}
// buffer is larger than we want to store.
return make([]byte, length), nil
}
// takeSmallBuffer is shortcut which can be used if length is
// known to be smaller than defaultBufSize.
// Only one buffer (total) can be used at a time.
func (b *buffer) takeSmallBuffer(length int) ([]byte, error) {
if b.length > 0 {
return nil, ErrBusyBuffer
}
return b.buf[:length], nil
}
// takeCompleteBuffer returns the complete existing buffer.
// This can be used if the necessary buffer size is unknown.
// cap and len of the returned buffer will be equal.
// Only one buffer (total) can be used at a time.
func (b *buffer) takeCompleteBuffer() ([]byte, error) {
if b.length > 0 {
return nil, ErrBusyBuffer
}
return b.buf, nil
}
// store stores buf, an updated buffer, if its suitable to do so.
func (b *buffer) store(buf []byte) error {
if b.length > 0 {
return ErrBusyBuffer
} else if cap(buf) <= maxPacketSize && cap(buf) > cap(b.buf) {
b.buf = buf[:cap(buf)]
}
return nil
}
================================================
FILE: gossh/mysql/collations.go
================================================
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2014 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
const defaultCollation = "utf8mb4_general_ci"
const binaryCollationID = 63
// A list of available collations mapped to the internal ID.
// To update this map use the following MySQL query:
//
// SELECT COLLATION_NAME, ID FROM information_schema.COLLATIONS WHERE ID<256 ORDER BY ID
//
// Handshake packet have only 1 byte for collation_id. So we can't use collations with ID > 255.
//
// ucs2, utf16, and utf32 can't be used for connection charset.
// https://dev.mysql.com/doc/refman/5.7/en/charset-connection.html#charset-connection-impermissible-client-charset
// They are commented out to reduce this map.
var collations = map[string]byte{
"big5_chinese_ci": 1,
"latin2_czech_cs": 2,
"dec8_swedish_ci": 3,
"cp850_general_ci": 4,
"latin1_german1_ci": 5,
"hp8_english_ci": 6,
"koi8r_general_ci": 7,
"latin1_swedish_ci": 8,
"latin2_general_ci": 9,
"swe7_swedish_ci": 10,
"ascii_general_ci": 11,
"ujis_japanese_ci": 12,
"sjis_japanese_ci": 13,
"cp1251_bulgarian_ci": 14,
"latin1_danish_ci": 15,
"hebrew_general_ci": 16,
"tis620_thai_ci": 18,
"euckr_korean_ci": 19,
"latin7_estonian_cs": 20,
"latin2_hungarian_ci": 21,
"koi8u_general_ci": 22,
"cp1251_ukrainian_ci": 23,
"gb2312_chinese_ci": 24,
"greek_general_ci": 25,
"cp1250_general_ci": 26,
"latin2_croatian_ci": 27,
"gbk_chinese_ci": 28,
"cp1257_lithuanian_ci": 29,
"latin5_turkish_ci": 30,
"latin1_german2_ci": 31,
"armscii8_general_ci": 32,
"utf8_general_ci": 33,
"cp1250_czech_cs": 34,
//"ucs2_general_ci": 35,
"cp866_general_ci": 36,
"keybcs2_general_ci": 37,
"macce_general_ci": 38,
"macroman_general_ci": 39,
"cp852_general_ci": 40,
"latin7_general_ci": 41,
"latin7_general_cs": 42,
"macce_bin": 43,
"cp1250_croatian_ci": 44,
"utf8mb4_general_ci": 45,
"utf8mb4_bin": 46,
"latin1_bin": 47,
"latin1_general_ci": 48,
"latin1_general_cs": 49,
"cp1251_bin": 50,
"cp1251_general_ci": 51,
"cp1251_general_cs": 52,
"macroman_bin": 53,
//"utf16_general_ci": 54,
//"utf16_bin": 55,
//"utf16le_general_ci": 56,
"cp1256_general_ci": 57,
"cp1257_bin": 58,
"cp1257_general_ci": 59,
//"utf32_general_ci": 60,
//"utf32_bin": 61,
//"utf16le_bin": 62,
"binary": 63,
"armscii8_bin": 64,
"ascii_bin": 65,
"cp1250_bin": 66,
"cp1256_bin": 67,
"cp866_bin": 68,
"dec8_bin": 69,
"greek_bin": 70,
"hebrew_bin": 71,
"hp8_bin": 72,
"keybcs2_bin": 73,
"koi8r_bin": 74,
"koi8u_bin": 75,
"utf8_tolower_ci": 76,
"latin2_bin": 77,
"latin5_bin": 78,
"latin7_bin": 79,
"cp850_bin": 80,
"cp852_bin": 81,
"swe7_bin": 82,
"utf8_bin": 83,
"big5_bin": 84,
"euckr_bin": 85,
"gb2312_bin": 86,
"gbk_bin": 87,
"sjis_bin": 88,
"tis620_bin": 89,
//"ucs2_bin": 90,
"ujis_bin": 91,
"geostd8_general_ci": 92,
"geostd8_bin": 93,
"latin1_spanish_ci": 94,
"cp932_japanese_ci": 95,
"cp932_bin": 96,
"eucjpms_japanese_ci": 97,
"eucjpms_bin": 98,
"cp1250_polish_ci": 99,
//"utf16_unicode_ci": 101,
//"utf16_icelandic_ci": 102,
//"utf16_latvian_ci": 103,
//"utf16_romanian_ci": 104,
//"utf16_slovenian_ci": 105,
//"utf16_polish_ci": 106,
//"utf16_estonian_ci": 107,
//"utf16_spanish_ci": 108,
//"utf16_swedish_ci": 109,
//"utf16_turkish_ci": 110,
//"utf16_czech_ci": 111,
//"utf16_danish_ci": 112,
//"utf16_lithuanian_ci": 113,
//"utf16_slovak_ci": 114,
//"utf16_spanish2_ci": 115,
//"utf16_roman_ci": 116,
//"utf16_persian_ci": 117,
//"utf16_esperanto_ci": 118,
//"utf16_hungarian_ci": 119,
//"utf16_sinhala_ci": 120,
//"utf16_german2_ci": 121,
//"utf16_croatian_ci": 122,
//"utf16_unicode_520_ci": 123,
//"utf16_vietnamese_ci": 124,
//"ucs2_unicode_ci": 128,
//"ucs2_icelandic_ci": 129,
//"ucs2_latvian_ci": 130,
//"ucs2_romanian_ci": 131,
//"ucs2_slovenian_ci": 132,
//"ucs2_polish_ci": 133,
//"ucs2_estonian_ci": 134,
//"ucs2_spanish_ci": 135,
//"ucs2_swedish_ci": 136,
//"ucs2_turkish_ci": 137,
//"ucs2_czech_ci": 138,
//"ucs2_danish_ci": 139,
//"ucs2_lithuanian_ci": 140,
//"ucs2_slovak_ci": 141,
//"ucs2_spanish2_ci": 142,
//"ucs2_roman_ci": 143,
//"ucs2_persian_ci": 144,
//"ucs2_esperanto_ci": 145,
//"ucs2_hungarian_ci": 146,
//"ucs2_sinhala_ci": 147,
//"ucs2_german2_ci": 148,
//"ucs2_croatian_ci": 149,
//"ucs2_unicode_520_ci": 150,
//"ucs2_vietnamese_ci": 151,
//"ucs2_general_mysql500_ci": 159,
//"utf32_unicode_ci": 160,
//"utf32_icelandic_ci": 161,
//"utf32_latvian_ci": 162,
//"utf32_romanian_ci": 163,
//"utf32_slovenian_ci": 164,
//"utf32_polish_ci": 165,
//"utf32_estonian_ci": 166,
//"utf32_spanish_ci": 167,
//"utf32_swedish_ci": 168,
//"utf32_turkish_ci": 169,
//"utf32_czech_ci": 170,
//"utf32_danish_ci": 171,
//"utf32_lithuanian_ci": 172,
//"utf32_slovak_ci": 173,
//"utf32_spanish2_ci": 174,
//"utf32_roman_ci": 175,
//"utf32_persian_ci": 176,
//"utf32_esperanto_ci": 177,
//"utf32_hungarian_ci": 178,
//"utf32_sinhala_ci": 179,
//"utf32_german2_ci": 180,
//"utf32_croatian_ci": 181,
//"utf32_unicode_520_ci": 182,
//"utf32_vietnamese_ci": 183,
"utf8_unicode_ci": 192,
"utf8_icelandic_ci": 193,
"utf8_latvian_ci": 194,
"utf8_romanian_ci": 195,
"utf8_slovenian_ci": 196,
"utf8_polish_ci": 197,
"utf8_estonian_ci": 198,
"utf8_spanish_ci": 199,
"utf8_swedish_ci": 200,
"utf8_turkish_ci": 201,
"utf8_czech_ci": 202,
"utf8_danish_ci": 203,
"utf8_lithuanian_ci": 204,
"utf8_slovak_ci": 205,
"utf8_spanish2_ci": 206,
"utf8_roman_ci": 207,
"utf8_persian_ci": 208,
"utf8_esperanto_ci": 209,
"utf8_hungarian_ci": 210,
"utf8_sinhala_ci": 211,
"utf8_german2_ci": 212,
"utf8_croatian_ci": 213,
"utf8_unicode_520_ci": 214,
"utf8_vietnamese_ci": 215,
"utf8_general_mysql500_ci": 223,
"utf8mb4_unicode_ci": 224,
"utf8mb4_icelandic_ci": 225,
"utf8mb4_latvian_ci": 226,
"utf8mb4_romanian_ci": 227,
"utf8mb4_slovenian_ci": 228,
"utf8mb4_polish_ci": 229,
"utf8mb4_estonian_ci": 230,
"utf8mb4_spanish_ci": 231,
"utf8mb4_swedish_ci": 232,
"utf8mb4_turkish_ci": 233,
"utf8mb4_czech_ci": 234,
"utf8mb4_danish_ci": 235,
"utf8mb4_lithuanian_ci": 236,
"utf8mb4_slovak_ci": 237,
"utf8mb4_spanish2_ci": 238,
"utf8mb4_roman_ci": 239,
"utf8mb4_persian_ci": 240,
"utf8mb4_esperanto_ci": 241,
"utf8mb4_hungarian_ci": 242,
"utf8mb4_sinhala_ci": 243,
"utf8mb4_german2_ci": 244,
"utf8mb4_croatian_ci": 245,
"utf8mb4_unicode_520_ci": 246,
"utf8mb4_vietnamese_ci": 247,
"gb18030_chinese_ci": 248,
"gb18030_bin": 249,
"gb18030_unicode_520_ci": 250,
"utf8mb4_0900_ai_ci": 255,
}
// A denylist of collations which is unsafe to interpolate parameters.
// These multibyte encodings may contains 0x5c (`\`) in their trailing bytes.
var unsafeCollations = map[string]bool{
"big5_chinese_ci": true,
"sjis_japanese_ci": true,
"gbk_chinese_ci": true,
"big5_bin": true,
"gb2312_bin": true,
"gbk_bin": true,
"sjis_bin": true,
"cp932_japanese_ci": true,
"cp932_bin": true,
"gb18030_chinese_ci": true,
"gb18030_bin": true,
"gb18030_unicode_520_ci": true,
}
================================================
FILE: gossh/mysql/conncheck.go
================================================
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2019 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
//go:build linux || darwin || dragonfly || freebsd || netbsd || openbsd || solaris || illumos
// +build linux darwin dragonfly freebsd netbsd openbsd solaris illumos
package mysql
import (
"errors"
"io"
"net"
"syscall"
)
var errUnexpectedRead = errors.New("unexpected read from socket")
func connCheck(conn net.Conn) error {
var sysErr error
sysConn, ok := conn.(syscall.Conn)
if !ok {
return nil
}
rawConn, err := sysConn.SyscallConn()
if err != nil {
return err
}
err = rawConn.Read(func(fd uintptr) bool {
var buf [1]byte
n, err := syscall.Read(int(fd), buf[:])
switch {
case n == 0 && err == nil:
sysErr = io.EOF
case n > 0:
sysErr = errUnexpectedRead
case err == syscall.EAGAIN || err == syscall.EWOULDBLOCK:
sysErr = nil
default:
sysErr = err
}
return true
})
if err != nil {
return err
}
return sysErr
}
================================================
FILE: gossh/mysql/conncheck_dummy.go
================================================
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2019 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
//go:build !linux && !darwin && !dragonfly && !freebsd && !netbsd && !openbsd && !solaris && !illumos
// +build !linux,!darwin,!dragonfly,!freebsd,!netbsd,!openbsd,!solaris,!illumos
package mysql
import "net"
func connCheck(conn net.Conn) error {
return nil
}
================================================
FILE: gossh/mysql/connection.go
================================================
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"context"
"database/sql"
"database/sql/driver"
"encoding/json"
"io"
"net"
"strconv"
"strings"
"time"
)
type mysqlConn struct {
buf buffer
netConn net.Conn
rawConn net.Conn // underlying connection when netConn is TLS connection.
result mysqlResult // managed by clearResult() and handleOkPacket().
cfg *Config
connector *connector
maxAllowedPacket int
maxWriteSize int
writeTimeout time.Duration
flags clientFlag
status statusFlag
sequence uint8
parseTime bool
// for context support (Go 1.8+)
watching bool
watcher chan<- context.Context
closech chan struct{}
finished chan<- struct{}
canceled atomicError // set non-nil if conn is canceled
closed atomicBool // set when conn is closed, before closech is closed
}
// Helper function to call per-connection logger.
func (mc *mysqlConn) log(v ...any) {
mc.cfg.Logger.Print(v...)
}
// Handles parameters set in DSN after the connection is established
func (mc *mysqlConn) handleParams() (err error) {
var cmdSet strings.Builder
for param, val := range mc.cfg.Params {
switch param {
// Charset: character_set_connection, character_set_client, character_set_results
case "charset":
charsets := strings.Split(val, ",")
for _, cs := range charsets {
// ignore errors here - a charset may not exist
if mc.cfg.Collation != "" {
err = mc.exec("SET NAMES " + cs + " COLLATE " + mc.cfg.Collation)
} else {
err = mc.exec("SET NAMES " + cs)
}
if err == nil {
break
}
}
if err != nil {
return
}
// Other system vars accumulated in a single SET command
default:
if cmdSet.Len() == 0 {
// Heuristic: 29 chars for each other key=value to reduce reallocations
cmdSet.Grow(4 + len(param) + 3 + len(val) + 30*(len(mc.cfg.Params)-1))
cmdSet.WriteString("SET ")
} else {
cmdSet.WriteString(", ")
}
cmdSet.WriteString(param)
cmdSet.WriteString(" = ")
cmdSet.WriteString(val)
}
}
if cmdSet.Len() > 0 {
err = mc.exec(cmdSet.String())
if err != nil {
return
}
}
return
}
func (mc *mysqlConn) markBadConn(err error) error {
if mc == nil {
return err
}
if err != errBadConnNoWrite {
return err
}
return driver.ErrBadConn
}
func (mc *mysqlConn) Begin() (driver.Tx, error) {
return mc.begin(false)
}
func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
if mc.closed.Load() {
mc.log(ErrInvalidConn)
return nil, driver.ErrBadConn
}
var q string
if readOnly {
q = "START TRANSACTION READ ONLY"
} else {
q = "START TRANSACTION"
}
err := mc.exec(q)
if err == nil {
return &mysqlTx{mc}, err
}
return nil, mc.markBadConn(err)
}
func (mc *mysqlConn) Close() (err error) {
// Makes Close idempotent
if !mc.closed.Load() {
err = mc.writeCommandPacket(comQuit)
}
mc.cleanup()
mc.clearResult()
return
}
// Closes the network connection and unsets internal variables. Do not call this
// function after successfully authentication, call Close instead. This function
// is called before auth or on auth failure because MySQL will have already
// closed the network connection.
func (mc *mysqlConn) cleanup() {
if mc.closed.Swap(true) {
return
}
// Makes cleanup idempotent
close(mc.closech)
conn := mc.rawConn
if conn == nil {
return
}
if err := conn.Close(); err != nil {
mc.log(err)
}
// This function can be called from multiple goroutines.
// So we can not mc.clearResult() here.
// Caller should do it if they are in safe goroutine.
}
func (mc *mysqlConn) error() error {
if mc.closed.Load() {
if err := mc.canceled.Value(); err != nil {
return err
}
return ErrInvalidConn
}
return nil
}
func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
if mc.closed.Load() {
mc.log(ErrInvalidConn)
return nil, driver.ErrBadConn
}
// Send command
err := mc.writeCommandPacketStr(comStmtPrepare, query)
if err != nil {
// STMT_PREPARE is safe to retry. So we can return ErrBadConn here.
mc.log(err)
return nil, driver.ErrBadConn
}
stmt := &mysqlStmt{
mc: mc,
}
// Read Result
columnCount, err := stmt.readPrepareResultPacket()
if err == nil {
if stmt.paramCount > 0 {
if err = mc.readUntilEOF(); err != nil {
return nil, err
}
}
if columnCount > 0 {
err = mc.readUntilEOF()
}
}
return stmt, err
}
func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
// Number of ? should be same to len(args)
if strings.Count(query, "?") != len(args) {
return "", driver.ErrSkip
}
buf, err := mc.buf.takeCompleteBuffer()
if err != nil {
// can not take the buffer. Something must be wrong with the connection
mc.log(err)
return "", ErrInvalidConn
}
buf = buf[:0]
argPos := 0
for i := 0; i < len(query); i++ {
q := strings.IndexByte(query[i:], '?')
if q == -1 {
buf = append(buf, query[i:]...)
break
}
buf = append(buf, query[i:i+q]...)
i += q
arg := args[argPos]
argPos++
if arg == nil {
buf = append(buf, "NULL"...)
continue
}
switch v := arg.(type) {
case int64:
buf = strconv.AppendInt(buf, v, 10)
case uint64:
// Handle uint64 explicitly because our custom ConvertValue emits unsigned values
buf = strconv.AppendUint(buf, v, 10)
case float64:
buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
case bool:
if v {
buf = append(buf, '1')
} else {
buf = append(buf, '0')
}
case time.Time:
if v.IsZero() {
buf = append(buf, "'0000-00-00'"...)
} else {
buf = append(buf, '\'')
buf, err = appendDateTime(buf, v.In(mc.cfg.Loc), mc.cfg.timeTruncate)
if err != nil {
return "", err
}
buf = append(buf, '\'')
}
case json.RawMessage:
buf = append(buf, '\'')
if mc.status&statusNoBackslashEscapes == 0 {
buf = escapeBytesBackslash(buf, v)
} else {
buf = escapeBytesQuotes(buf, v)
}
buf = append(buf, '\'')
case []byte:
if v == nil {
buf = append(buf, "NULL"...)
} else {
buf = append(buf, "_binary'"...)
if mc.status&statusNoBackslashEscapes == 0 {
buf = escapeBytesBackslash(buf, v)
} else {
buf = escapeBytesQuotes(buf, v)
}
buf = append(buf, '\'')
}
case string:
buf = append(buf, '\'')
if mc.status&statusNoBackslashEscapes == 0 {
buf = escapeStringBackslash(buf, v)
} else {
buf = escapeStringQuotes(buf, v)
}
buf = append(buf, '\'')
default:
return "", driver.ErrSkip
}
if len(buf)+4 > mc.maxAllowedPacket {
return "", driver.ErrSkip
}
}
if argPos != len(args) {
return "", driver.ErrSkip
}
return string(buf), nil
}
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
if mc.closed.Load() {
mc.log(ErrInvalidConn)
return nil, driver.ErrBadConn
}
if len(args) != 0 {
if !mc.cfg.InterpolateParams {
return nil, driver.ErrSkip
}
// try to interpolate the parameters to save extra roundtrips for preparing and closing a statement
prepared, err := mc.interpolateParams(query, args)
if err != nil {
return nil, err
}
query = prepared
}
err := mc.exec(query)
if err == nil {
copied := mc.result
return &copied, err
}
return nil, mc.markBadConn(err)
}
// Internal function to execute commands
func (mc *mysqlConn) exec(query string) error {
handleOk := mc.clearResult()
// Send command
if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
return mc.markBadConn(err)
}
// Read Result
resLen, err := handleOk.readResultSetHeaderPacket()
if err != nil {
return err
}
if resLen > 0 {
// columns
if err := mc.readUntilEOF(); err != nil {
return err
}
// rows
if err := mc.readUntilEOF(); err != nil {
return err
}
}
return handleOk.discardResults()
}
func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
return mc.query(query, args)
}
func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
handleOk := mc.clearResult()
if mc.closed.Load() {
mc.log(ErrInvalidConn)
return nil, driver.ErrBadConn
}
if len(args) != 0 {
if !mc.cfg.InterpolateParams {
return nil, driver.ErrSkip
}
// try client-side prepare to reduce roundtrip
prepared, err := mc.interpolateParams(query, args)
if err != nil {
return nil, err
}
query = prepared
}
// Send command
err := mc.writeCommandPacketStr(comQuery, query)
if err == nil {
// Read Result
var resLen int
resLen, err = handleOk.readResultSetHeaderPacket()
if err == nil {
rows := new(textRows)
rows.mc = mc
if resLen == 0 {
rows.rs.done = true
switch err := rows.NextResultSet(); err {
case nil, io.EOF:
return rows, nil
default:
return nil, err
}
}
// Columns
rows.rs.columns, err = mc.readColumns(resLen)
return rows, err
}
}
return nil, mc.markBadConn(err)
}
// Gets the value of the given MySQL System Variable
// The returned byte slice is only valid until the next read
func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
// Send command
handleOk := mc.clearResult()
if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
return nil, err
}
// Read Result
resLen, err := handleOk.readResultSetHeaderPacket()
if err == nil {
rows := new(textRows)
rows.mc = mc
rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
if resLen > 0 {
// Columns
if err := mc.readUntilEOF(); err != nil {
return nil, err
}
}
dest := make([]driver.Value, resLen)
if err = rows.readRow(dest); err == nil {
return dest[0].([]byte), mc.readUntilEOF()
}
}
return nil, err
}
// finish is called when the query has canceled.
func (mc *mysqlConn) cancel(err error) {
mc.canceled.Set(err)
mc.cleanup()
}
// finish is called when the query has succeeded.
func (mc *mysqlConn) finish() {
if !mc.watching || mc.finished == nil {
return
}
select {
case mc.finished <- struct{}{}:
mc.watching = false
case <-mc.closech:
}
}
// Ping implements driver.Pinger interface
func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
if mc.closed.Load() {
mc.log(ErrInvalidConn)
return driver.ErrBadConn
}
if err = mc.watchCancel(ctx); err != nil {
return
}
defer mc.finish()
handleOk := mc.clearResult()
if err = mc.writeCommandPacket(comPing); err != nil {
return mc.markBadConn(err)
}
return handleOk.readResultOK()
}
// BeginTx implements driver.ConnBeginTx interface
func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
if mc.closed.Load() {
return nil, driver.ErrBadConn
}
if err := mc.watchCancel(ctx); err != nil {
return nil, err
}
defer mc.finish()
if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault {
level, err := mapIsolationLevel(opts.Isolation)
if err != nil {
return nil, err
}
err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level)
if err != nil {
return nil, err
}
}
return mc.begin(opts.ReadOnly)
}
func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
if err := mc.watchCancel(ctx); err != nil {
return nil, err
}
rows, err := mc.query(query, dargs)
if err != nil {
mc.finish()
return nil, err
}
rows.finish = mc.finish
return rows, err
}
func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
if err := mc.watchCancel(ctx); err != nil {
return nil, err
}
defer mc.finish()
return mc.Exec(query, dargs)
}
func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
if err := mc.watchCancel(ctx); err != nil {
return nil, err
}
stmt, err := mc.Prepare(query)
mc.finish()
if err != nil {
return nil, err
}
select {
default:
case <-ctx.Done():
stmt.Close()
return nil, ctx.Err()
}
return stmt, nil
}
func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
if err := stmt.mc.watchCancel(ctx); err != nil {
return nil, err
}
rows, err := stmt.query(dargs)
if err != nil {
stmt.mc.finish()
return nil, err
}
rows.finish = stmt.mc.finish
return rows, err
}
func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
if err := stmt.mc.watchCancel(ctx); err != nil {
return nil, err
}
defer stmt.mc.finish()
return stmt.Exec(dargs)
}
func (mc *mysqlConn) watchCancel(ctx context.Context) error {
if mc.watching {
// Reach here if canceled,
// so the connection is already invalid
mc.cleanup()
return nil
}
// When ctx is already cancelled, don't watch it.
if err := ctx.Err(); err != nil {
return err
}
// When ctx is not cancellable, don't watch it.
if ctx.Done() == nil {
return nil
}
// When watcher is not alive, can't watch it.
if mc.watcher == nil {
return nil
}
mc.watching = true
mc.watcher <- ctx
return nil
}
func (mc *mysqlConn) startWatcher() {
watcher := make(chan context.Context, 1)
mc.watcher = watcher
finished := make(chan struct{})
mc.finished = finished
go func() {
for {
var ctx context.Context
select {
case ctx = <-watcher:
case <-mc.closech:
return
}
select {
case <-ctx.Done():
mc.cancel(ctx.Err())
case <-finished:
case <-mc.closech:
return
}
}
}()
}
func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) {
nv.Value, err = converter{}.ConvertValue(nv.Value)
return
}
// ResetSession implements driver.SessionResetter.
// (From Go 1.10)
func (mc *mysqlConn) ResetSession(ctx context.Context) error {
if mc.closed.Load() {
return driver.ErrBadConn
}
// Perform a stale connection check. We only perform this check for
// the first query on a connection that has been checked out of the
// connection pool: a fresh connection from the pool is more likely
// to be stale, and it has not performed any previous writes that
// could cause data corruption, so it's safe to return ErrBadConn
// if the check fails.
if mc.cfg.CheckConnLiveness {
conn := mc.netConn
if mc.rawConn != nil {
conn = mc.rawConn
}
var err error
if mc.cfg.ReadTimeout != 0 {
err = conn.SetReadDeadline(time.Now().Add(mc.cfg.ReadTimeout))
}
if err == nil {
err = connCheck(conn)
}
if err != nil {
mc.log("closing bad idle connection: ", err)
return driver.ErrBadConn
}
}
return nil
}
// IsValid implements driver.Validator interface
// (From Go 1.15)
func (mc *mysqlConn) IsValid() bool {
return !mc.closed.Load()
}
================================================
FILE: gossh/mysql/connector.go
================================================
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"context"
"database/sql/driver"
"net"
"os"
"strconv"
"strings"
)
type connector struct {
cfg *Config // immutable private copy.
encodedAttributes string // Encoded connection attributes.
}
func encodeConnectionAttributes(cfg *Config) string {
connAttrsBuf := make([]byte, 0)
// default connection attributes
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientName)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientNameValue)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOS)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOSValue)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatform)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatformValue)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPid)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, strconv.Itoa(os.Getpid()))
serverHost, _, _ := net.SplitHostPort(cfg.Addr)
if serverHost != "" {
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrServerHost)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, serverHost)
}
// user-defined connection attributes
for _, connAttr := range strings.Split(cfg.ConnectionAttributes, ",") {
k, v, found := strings.Cut(connAttr, ":")
if !found {
continue
}
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, k)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, v)
}
return string(connAttrsBuf)
}
func newConnector(cfg *Config) *connector {
encodedAttributes := encodeConnectionAttributes(cfg)
return &connector{
cfg: cfg,
encodedAttributes: encodedAttributes,
}
}
// Connect implements driver.Connector interface.
// Connect returns a connection to the database.
func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
var err error
// Invoke beforeConnect if present, with a copy of the configuration
cfg := c.cfg
if c.cfg.beforeConnect != nil {
cfg = c.cfg.Clone()
err = c.cfg.beforeConnect(ctx, cfg)
if err != nil {
return nil, err
}
}
// New mysqlConn
mc := &mysqlConn{
maxAllowedPacket: maxPacketSize,
maxWriteSize: maxPacketSize - 1,
closech: make(chan struct{}),
cfg: cfg,
connector: c,
}
mc.parseTime = mc.cfg.ParseTime
// Connect to Server
dialsLock.RLock()
dial, ok := dials[mc.cfg.Net]
dialsLock.RUnlock()
if ok {
dctx := ctx
if mc.cfg.Timeout > 0 {
var cancel context.CancelFunc
dctx, cancel = context.WithTimeout(ctx, c.cfg.Timeout)
defer cancel()
}
mc.netConn, err = dial(dctx, mc.cfg.Addr)
} else {
nd := net.Dialer{Timeout: mc.cfg.Timeout}
mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr)
}
if err != nil {
return nil, err
}
mc.rawConn = mc.netConn
// Enable TCP Keepalives on TCP connections
if tc, ok := mc.netConn.(*net.TCPConn); ok {
if err := tc.SetKeepAlive(true); err != nil {
c.cfg.Logger.Print(err)
}
}
// Call startWatcher for context support (From Go 1.8)
mc.startWatcher()
if err := mc.watchCancel(ctx); err != nil {
mc.cleanup()
return nil, err
}
defer mc.finish()
mc.buf = newBuffer(mc.netConn)
// Set I/O timeouts
mc.buf.timeout = mc.cfg.ReadTimeout
mc.writeTimeout = mc.cfg.WriteTimeout
// Reading Handshake Initialization Packet
authData, plugin, err := mc.readHandshakePacket()
if err != nil {
mc.cleanup()
return nil, err
}
if plugin == "" {
plugin = defaultAuthPlugin
}
// Send Client Authentication Packet
authResp, err := mc.auth(authData, plugin)
if err != nil {
// try the default auth plugin, if using the requested plugin failed
c.cfg.Logger.Print("could not use requested auth plugin '"+plugin+"': ", err.Error())
plugin = defaultAuthPlugin
authResp, err = mc.auth(authData, plugin)
if err != nil {
mc.cleanup()
return nil, err
}
}
if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil {
mc.cleanup()
return nil, err
}
// Handle response to auth packet, switch methods if possible
if err = mc.handleAuthResult(authData, plugin); err != nil {
// Authentication failed and MySQL has already closed the connection
// (https://dev.mysql.com/doc/internals/en/authentication-fails.html).
// Do not send COM_QUIT, just cleanup and return the error.
mc.cleanup()
return nil, err
}
if mc.cfg.MaxAllowedPacket > 0 {
mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket
} else {
// Get max allowed packet size
maxap, err := mc.getSystemVar("max_allowed_packet")
if err != nil {
mc.Close()
return nil, err
}
mc.maxAllowedPacket = stringToInt(maxap) - 1
}
if mc.maxAllowedPacket < maxPacketSize {
mc.maxWriteSize = mc.maxAllowedPacket
}
// Handle DSN Params
err = mc.handleParams()
if err != nil {
mc.Close()
return nil, err
}
return mc, nil
}
// Driver implements driver.Connector interface.
// Driver returns &MySQLDriver{}.
func (c *connector) Driver() driver.Driver {
return &MySQLDriver{}
}
================================================
FILE: gossh/mysql/const.go
================================================
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import "runtime"
const (
defaultAuthPlugin = "mysql_native_password"
defaultMaxAllowedPacket = 64 << 20 // 64 MiB. See https://github.com/go-sql-driver/mysql/issues/1355
minProtocolVersion = 10
maxPacketSize = 1<<24 - 1
timeFormat = "2006-01-02 15:04:05.999999"
// Connection attributes
// See https://dev.mysql.com/doc/refman/8.0/en/performance-schema-connection-attribute-tables.html#performance-schema-connection-attributes-available
connAttrClientName = "_client_name"
connAttrClientNameValue = "Go-MySQL-Driver"
connAttrOS = "_os"
connAttrOSValue = runtime.GOOS
connAttrPlatform = "_platform"
connAttrPlatformValue = runtime.GOARCH
connAttrPid = "_pid"
connAttrServerHost = "_server_host"
)
// MySQL constants documentation:
// http://dev.mysql.com/doc/internals/en/client-server-protocol.html
const (
iOK byte = 0x00
iAuthMoreData byte = 0x01
iLocalInFile byte = 0xfb
iEOF byte = 0xfe
iERR byte = 0xff
)
// https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags
type clientFlag uint32
const (
clientLongPassword clientFlag = 1 << iota
clientFoundRows
clientLongFlag
clientConnectWithDB
clientNoSchema
clientCompress
clientODBC
clientLocalFiles
clientIgnoreSpace
clientProtocol41
clientInteractive
clientSSL
clientIgnoreSIGPIPE
clientTransactions
clientReserved
clientSecureConn
clientMultiStatements
clientMultiResults
clientPSMultiResults
clientPluginAuth
clientConnectAttrs
clientPluginAuthLenEncClientData
clientCanHandleExpiredPasswords
clientSessionTrack
clientDeprecateEOF
)
const (
comQuit byte = iota + 1
comInitDB
comQuery
comFieldList
comCreateDB
comDropDB
comRefresh
comShutdown
comStatistics
comProcessInfo
comConnect
comProcessKill
comDebug
comPing
comTime
comDelayedInsert
comChangeUser
comBinlogDump
comTableDump
comConnectOut
comRegisterSlave
comStmtPrepare
comStmtExecute
comStmtSendLongData
comStmtClose
comStmtReset
comSetOption
comStmtFetch
)
// https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnType
type fieldType byte
const (
fieldTypeDecimal fieldType = iota
fieldTypeTiny
fieldTypeShort
fieldTypeLong
fieldTypeFloat
fieldTypeDouble
fieldTypeNULL
fieldTypeTimestamp
fieldTypeLongLong
fieldTypeInt24
fieldTypeDate
fieldTypeTime
fieldTypeDateTime
fieldTypeYear
fieldTypeNewDate
fieldTypeVarChar
fieldTypeBit
)
const (
fieldTypeJSON fieldType = iota + 0xf5
fieldTypeNewDecimal
fieldTypeEnum
fieldTypeSet
fieldTypeTinyBLOB
fieldTypeMediumBLOB
fieldTypeLongBLOB
fieldTypeBLOB
fieldTypeVarString
fieldTypeString
fieldTypeGeometry
)
type fieldFlag uint16
const (
flagNotNULL fieldFlag = 1 << iota
flagPriKey
flagUniqueKey
flagMultipleKey
flagBLOB
flagUnsigned
flagZeroFill
flagBinary
flagEnum
flagAutoIncrement
flagTimestamp
flagSet
flagUnknown1
flagUnknown2
flagUnknown3
flagUnknown4
)
// http://dev.mysql.com/doc/internals/en/status-flags.html
type statusFlag uint16
const (
statusInTrans statusFlag = 1 << iota
statusInAutocommit
statusReserved // Not in documentation
statusMoreResultsExists
statusNoGoodIndexUsed
statusNoIndexUsed
statusCursorExists
statusLastRowSent
statusDbDropped
statusNoBackslashEscapes
statusMetadataChanged
statusQueryWasSlow
statusPsOutParams
statusInTransReadonly
statusSessionStateChanged
)
const (
cachingSha2PasswordRequestPublicKey = 2
cachingSha2PasswordFastAuthSuccess = 3
cachingSha2PasswordPerformFullAuthentication = 4
)
================================================
FILE: gossh/mysql/driver.go
================================================
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
// Package mysql provides a MySQL driver for Go's database/sql package.
//
// The driver should be used via the database/sql package:
//
// import "database/sql"
// import _ "github.com/go-sql-driver/mysql"
//
// db, err := sql.Open("mysql", "user:password@/dbname")
//
// See https://github.com/go-sql-driver/mysql#usage for details
package mysql
import (
"context"
"database/sql"
"database/sql/driver"
"net"
"sync"
)
// MySQLDriver is exported to make the driver directly accessible.
// In general the driver is used via the database/sql package.
type MySQLDriver struct{}
// DialFunc is a function which can be used to establish the network connection.
// Custom dial functions must be registered with RegisterDial
//
// Deprecated: users should register a DialContextFunc instead
type DialFunc func(addr string) (net.Conn, error)
// DialContextFunc is a function which can be used to establish the network connection.
// Custom dial functions must be registered with RegisterDialContext
type DialContextFunc func(ctx context.Context, addr string) (net.Conn, error)
var (
dialsLock sync.RWMutex
dials map[string]DialContextFunc
)
// RegisterDialContext registers a custom dial function. It can then be used by the
// network address mynet(addr), where mynet is the registered new network.
// The current context for the connection and its address is passed to the dial function.
func RegisterDialContext(net string, dial DialContextFunc) {
dialsLock.Lock()
defer dialsLock.Unlock()
if dials == nil {
dials = make(map[string]DialContextFunc)
}
dials[net] = dial
}
// DeregisterDialContext removes the custom dial function registered with the given net.
func DeregisterDialContext(net string) {
dialsLock.Lock()
defer dialsLock.Unlock()
if dials != nil {
delete(dials, net)
}
}
// RegisterDial registers a custom dial function. It can then be used by the
// network address mynet(addr), where mynet is the registered new network.
// addr is passed as a parameter to the dial function.
//
// Deprecated: users should call RegisterDialContext instead
func RegisterDial(network string, dial DialFunc) {
RegisterDialContext(network, func(_ context.Context, addr string) (net.Conn, error) {
return dial(addr)
})
}
// Open new Connection.
// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how
// the DSN string is formatted
func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
cfg, err := ParseDSN(dsn)
if err != nil {
return nil, err
}
c := newConnector(cfg)
return c.Connect(context.Background())
}
// This variable can be replaced with -ldflags like below:
// go build "-ldflags=-X github.com/go-sql-driver/mysql.driverName=custom"
var driverName = "mysql"
func init() {
if driverName != "" {
sql.Register(driverName, &MySQLDriver{})
}
}
// NewConnector returns new driver.Connector.
func NewConnector(cfg *Config) (driver.Connector, error) {
cfg = cfg.Clone()
// normalize the contents of cfg so calls to NewConnector have the same
// behavior as MySQLDriver.OpenConnector
if err := cfg.normalize(); err != nil {
return nil, err
}
return newConnector(cfg), nil
}
// OpenConnector implements driver.DriverContext.
func (d MySQLDriver) OpenConnector(dsn string) (driver.Connector, error) {
cfg, err := ParseDSN(dsn)
if err != nil {
return nil, err
}
return newConnector(cfg), nil
}
================================================
FILE: gossh/mysql/dsn.go
================================================
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"bytes"
"context"
"crypto/rsa"
"crypto/tls"
"errors"
"fmt"
"math/big"
"net"
"net/url"
"sort"
"strconv"
"strings"
"time"
)
var (
errInvalidDSNUnescaped = errors.New("invalid DSN: did you forget to escape a param value?")
errInvalidDSNAddr = errors.New("invalid DSN: network address not terminated (missing closing brace)")
errInvalidDSNNoSlash = errors.New("invalid DSN: missing the slash separating the database name")
errInvalidDSNUnsafeCollation = errors.New("invalid DSN: interpolateParams can not be used with unsafe collations")
)
// Config is a configuration parsed from a DSN string.
// If a new Config is created instead of being parsed from a DSN string,
// the NewConfig function should be used, which sets default values.
type Config struct {
// non boolean fields
User string // Username
Passwd string // Password (requires User)
Net string // Network (e.g. "tcp", "tcp6", "unix". default: "tcp")
Addr string // Address (default: "127.0.0.1:3306" for "tcp" and "/tmp/mysql.sock" for "unix")
DBName string // Database name
Params map[string]string // Connection parameters
ConnectionAttributes string // Connection Attributes, comma-delimited string of user-defined "key:value" pairs
Collation string // Connection collation
Loc *time.Location // Location for time.Time values
MaxAllowedPacket int // Max packet size allowed
ServerPubKey string // Server public key name
TLSConfig string // TLS configuration name
TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig
Timeout time.Duration // Dial timeout
ReadTimeout time.Duration // I/O read timeout
WriteTimeout time.Duration // I/O write timeout
Logger Logger // Logger
// boolean fields
AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE
AllowCleartextPasswords bool // Allows the cleartext client side plugin
AllowFallbackToPlaintext bool // Allows fallback to unencrypted connection if server does not support TLS
AllowNativePasswords bool // Allows the native password authentication method
AllowOldPasswords bool // Allows the old insecure password method
CheckConnLiveness bool // Check connections for liveness before using them
ClientFoundRows bool // Return number of matching rows instead of rows changed
ColumnsWithAlias bool // Prepend table alias to column names
InterpolateParams bool // Interpolate placeholders into query string
MultiStatements bool // Allow multiple statements in one query
ParseTime bool // Parse time values to time.Time
RejectReadOnly bool // Reject read-only connections
// unexported fields. new options should be come here
beforeConnect func(context.Context, *Config) error // Invoked before a connection is established
pubKey *rsa.PublicKey // Server public key
timeTruncate time.Duration // Truncate time.Time values to the specified duration
}
// Functional Options Pattern
// https://dave.cheney.net/2014/10/17/functional-options-for-friendly-apis
type Option func(*Config) error
// NewConfig creates a new Config and sets default values.
func NewConfig() *Config {
cfg := &Config{
Loc: time.UTC,
MaxAllowedPacket: defaultMaxAllowedPacket,
Logger: defaultLogger,
AllowNativePasswords: true,
CheckConnLiveness: true,
}
return cfg
}
// Apply applies the given options to the Config object.
func (c *Config) Apply(opts ...Option) error {
for _, opt := range opts {
err := opt(c)
if err != nil {
return err
}
}
return nil
}
// TimeTruncate sets the time duration to truncate time.Time values in
// query parameters.
func TimeTruncate(d time.Duration) Option {
return func(cfg *Config) error {
cfg.timeTruncate = d
return nil
}
}
// BeforeConnect sets the function to be invoked before a connection is established.
func BeforeConnect(fn func(context.Context, *Config) error) Option {
return func(cfg *Config) error {
cfg.beforeConnect = fn
return nil
}
}
func (cfg *Config) Clone() *Config {
cp := *cfg
if cp.TLS != nil {
cp.TLS = cfg.TLS.Clone()
}
if len(cp.Params) > 0 {
cp.Params = make(map[string]string, len(cfg.Params))
for k, v := range cfg.Params {
cp.Params[k] = v
}
}
if cfg.pubKey != nil {
cp.pubKey = &rsa.PublicKey{
N: new(big.Int).Set(cfg.pubKey.N),
E: cfg.pubKey.E,
}
}
return &cp
}
func (cfg *Config) normalize() error {
if cfg.InterpolateParams && cfg.Collation != "" && unsafeCollations[cfg.Collation] {
return errInvalidDSNUnsafeCollation
}
// Set default network if empty
if cfg.Net == "" {
cfg.Net = "tcp"
}
// Set default address if empty
if cfg.Addr == "" {
switch cfg.Net {
case "tcp":
cfg.Addr = "127.0.0.1:3306"
case "unix":
cfg.Addr = "/tmp/mysql.sock"
default:
return errors.New("default addr for network '" + cfg.Net + "' unknown")
}
} else if cfg.Net == "tcp" {
cfg.Addr = ensureHavePort(cfg.Addr)
}
if cfg.TLS == nil {
switch cfg.TLSConfig {
case "false", "":
// don't set anything
case "true":
cfg.TLS = &tls.Config{}
case "skip-verify":
cfg.TLS = &tls.Config{InsecureSkipVerify: true}
case "preferred":
cfg.TLS = &tls.Config{InsecureSkipVerify: true}
cfg.AllowFallbackToPlaintext = true
default:
cfg.TLS = getTLSConfigClone(cfg.TLSConfig)
if cfg.TLS == nil {
return errors.New("invalid value / unknown config name: " + cfg.TLSConfig)
}
}
}
if cfg.TLS != nil && cfg.TLS.ServerName == "" && !cfg.TLS.InsecureSkipVerify {
host, _, err := net.SplitHostPort(cfg.Addr)
if err == nil {
cfg.TLS.ServerName = host
}
}
if cfg.ServerPubKey != "" {
cfg.pubKey = getServerPubKey(cfg.ServerPubKey)
if cfg.pubKey == nil {
return errors.New("invalid value / unknown server pub key name: " + cfg.ServerPubKey)
}
}
if cfg.Logger == nil {
cfg.Logger = defaultLogger
}
return nil
}
func writeDSNParam(buf *bytes.Buffer, hasParam *bool, name, value string) {
buf.Grow(1 + len(name) + 1 + len(value))
if !*hasParam {
*hasParam = true
buf.WriteByte('?')
} else {
buf.WriteByte('&')
}
buf.WriteString(name)
buf.WriteByte('=')
buf.WriteString(value)
}
// FormatDSN formats the given Config into a DSN string which can be passed to
// the driver.
//
// Note: use [NewConnector] and [database/sql.OpenDB] to open a connection from a [*Config].
func (cfg *Config) FormatDSN() string {
var buf bytes.Buffer
// [username[:password]@]
if len(cfg.User) > 0 {
buf.WriteString(cfg.User)
if len(cfg.Passwd) > 0 {
buf.WriteByte(':')
buf.WriteString(cfg.Passwd)
}
buf.WriteByte('@')
}
// [protocol[(address)]]
if len(cfg.Net) > 0 {
buf.WriteString(cfg.Net)
if len(cfg.Addr) > 0 {
buf.WriteByte('(')
buf.WriteString(cfg.Addr)
buf.WriteByte(')')
}
}
// /dbname
buf.WriteByte('/')
buf.WriteString(url.PathEscape(cfg.DBName))
// [?param1=value1&...¶mN=valueN]
hasParam := false
if cfg.AllowAllFiles {
hasParam = true
buf.WriteString("?allowAllFiles=true")
}
if cfg.AllowCleartextPasswords {
writeDSNParam(&buf, &hasParam, "allowCleartextPasswords", "true")
}
if cfg.AllowFallbackToPlaintext {
writeDSNParam(&buf, &hasParam, "allowFallbackToPlaintext", "true")
}
if !cfg.AllowNativePasswords {
writeDSNParam(&buf, &hasParam, "allowNativePasswords", "false")
}
if cfg.AllowOldPasswords {
writeDSNParam(&buf, &hasParam, "allowOldPasswords", "true")
}
if !cfg.CheckConnLiveness {
writeDSNParam(&buf, &hasParam, "checkConnLiveness", "false")
}
if cfg.ClientFoundRows {
writeDSNParam(&buf, &hasParam, "clientFoundRows", "true")
}
if col := cfg.Collation; col != "" {
writeDSNParam(&buf, &hasParam, "collation", col)
}
if cfg.ColumnsWithAlias {
writeDSNParam(&buf, &hasParam, "columnsWithAlias", "true")
}
if cfg.InterpolateParams {
writeDSNParam(&buf, &hasParam, "interpolateParams", "true")
}
if cfg.Loc != time.UTC && cfg.Loc != nil {
writeDSNParam(&buf, &hasParam, "loc", url.QueryEscape(cfg.Loc.String()))
}
if cfg.MultiStatements {
writeDSNParam(&buf, &hasParam, "multiStatements", "true")
}
if cfg.ParseTime {
writeDSNParam(&buf, &hasParam, "parseTime", "true")
}
if cfg.timeTruncate > 0 {
writeDSNParam(&buf, &hasParam, "timeTruncate", cfg.timeTruncate.String())
}
if cfg.ReadTimeout > 0 {
writeDSNParam(&buf, &hasParam, "readTimeout", cfg.ReadTimeout.String())
}
if cfg.RejectReadOnly {
writeDSNParam(&buf, &hasParam, "rejectReadOnly", "true")
}
if len(cfg.ServerPubKey) > 0 {
writeDSNParam(&buf, &hasParam, "serverPubKey", url.QueryEscape(cfg.ServerPubKey))
}
if cfg.Timeout > 0 {
writeDSNParam(&buf, &hasParam, "timeout", cfg.Timeout.String())
}
if len(cfg.TLSConfig) > 0 {
writeDSNParam(&buf, &hasParam, "tls", url.QueryEscape(cfg.TLSConfig))
}
if cfg.WriteTimeout > 0 {
writeDSNParam(&buf, &hasParam, "writeTimeout", cfg.WriteTimeout.String())
}
if cfg.MaxAllowedPacket != defaultMaxAllowedPacket {
writeDSNParam(&buf, &hasParam, "maxAllowedPacket", strconv.Itoa(cfg.MaxAllowedPacket))
}
// other params
if cfg.Params != nil {
var params []string
for param := range cfg.Params {
params = append(params, param)
}
sort.Strings(params)
for _, param := range params {
writeDSNParam(&buf, &hasParam, param, url.QueryEscape(cfg.Params[param]))
}
}
return buf.String()
}
// ParseDSN parses the DSN string to a Config
func ParseDSN(dsn string) (cfg *Config, err error) {
// New config with some default values
cfg = NewConfig()
// [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN]
// Find the last '/' (since the password or the net addr might contain a '/')
foundSlash := false
for i := len(dsn) - 1; i >= 0; i-- {
if dsn[i] == '/' {
foundSlash = true
var j, k int
// left part is empty if i <= 0
if i > 0 {
// [username[:password]@][protocol[(address)]]
// Find the last '@' in dsn[:i]
for j = i; j >= 0; j-- {
if dsn[j] == '@' {
// username[:password]
// Find the first ':' in dsn[:j]
for k = 0; k < j; k++ {
if dsn[k] == ':' {
cfg.Passwd = dsn[k+1 : j]
break
}
}
cfg.User = dsn[:k]
break
}
}
// [protocol[(address)]]
// Find the first '(' in dsn[j+1:i]
for k = j + 1; k < i; k++ {
if dsn[k] == '(' {
// dsn[i-1] must be == ')' if an address is specified
if dsn[i-1] != ')' {
if strings.ContainsRune(dsn[k+1:i], ')') {
return nil, errInvalidDSNUnescaped
}
return nil, errInvalidDSNAddr
}
cfg.Addr = dsn[k+1 : i-1]
break
}
}
cfg.Net = dsn[j+1 : k]
}
// dbname[?param1=value1&...¶mN=valueN]
// Find the first '?' in dsn[i+1:]
for j = i + 1; j < len(dsn); j++ {
if dsn[j] == '?' {
if err = parseDSNParams(cfg, dsn[j+1:]); err != nil {
return
}
break
}
}
dbname := dsn[i+1 : j]
if cfg.DBName, err = url.PathUnescape(dbname); err != nil {
return nil, fmt.Errorf("invalid dbname %q: %w", dbname, err)
}
break
}
}
if !foundSlash && len(dsn) > 0 {
return nil, errInvalidDSNNoSlash
}
if err = cfg.normalize(); err != nil {
return nil, err
}
return
}
// parseDSNParams parses the DSN "query string"
// Values must be url.QueryEscape'ed
func parseDSNParams(cfg *Config, params string) (err error) {
for _, v := range strings.Split(params, "&") {
key, value, found := strings.Cut(v, "=")
if !found {
continue
}
// cfg params
switch key {
// Disable INFILE allowlist / enable all files
case "allowAllFiles":
var isBool bool
cfg.AllowAllFiles, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// Use cleartext authentication mode (MySQL 5.5.10+)
case "allowCleartextPasswords":
var isBool bool
cfg.AllowCleartextPasswords, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// Allow fallback to unencrypted connection if server does not support TLS
case "allowFallbackToPlaintext":
var isBool bool
cfg.AllowFallbackToPlaintext, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// Use native password authentication
case "allowNativePasswords":
var isBool bool
cfg.AllowNativePasswords, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// Use old authentication mode (pre MySQL 4.1)
case "allowOldPasswords":
var isBool bool
cfg.AllowOldPasswords, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// Check connections for Liveness before using them
case "checkConnLiveness":
var isBool bool
cfg.CheckConnLiveness, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// Switch "rowsAffected" mode
case "clientFoundRows":
var isBool bool
cfg.ClientFoundRows, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// Collation
case "collation":
cfg.Collation = value
case "columnsWithAlias":
var isBool bool
cfg.ColumnsWithAlias, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// Compression
case "compress":
return errors.New("compression not implemented yet")
// Enable client side placeholder substitution
case "interpolateParams":
var isBool bool
cfg.InterpolateParams, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// Time Location
case "loc":
if value, err = url.QueryUnescape(value); err != nil {
return
}
cfg.Loc, err = time.LoadLocation(value)
if err != nil {
return
}
// multiple statements in one query
case "multiStatements":
var isBool bool
cfg.MultiStatements, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// time.Time parsing
case "parseTime":
var isBool bool
cfg.ParseTime, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// time.Time truncation
case "timeTruncate":
cfg.timeTruncate, err = time.ParseDuration(value)
if err != nil {
return fmt.Errorf("invalid timeTruncate value: %v, error: %w", value, err)
}
// I/O read Timeout
case "readTimeout":
cfg.ReadTimeout, err = time.ParseDuration(value)
if err != nil {
return
}
// Reject read-only connections
case "rejectReadOnly":
var isBool bool
cfg.RejectReadOnly, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// Server public key
case "serverPubKey":
name, err := url.QueryUnescape(value)
if err != nil {
return fmt.Errorf("invalid value for server pub key name: %v", err)
}
cfg.ServerPubKey = name
// Strict mode
case "strict":
panic("strict mode has been removed. See https://github.com/go-sql-driver/mysql/wiki/strict-mode")
// Dial Timeout
case "timeout":
cfg.Timeout, err = time.ParseDuration(value)
if err != nil {
return
}
// TLS-Encryption
case "tls":
boolValue, isBool := readBool(value)
if isBool {
if boolValue {
cfg.TLSConfig = "true"
} else {
cfg.TLSConfig = "false"
}
} else if vl := strings.ToLower(value); vl == "skip-verify" || vl == "preferred" {
cfg.TLSConfig = vl
} else {
name, err := url.QueryUnescape(value)
if err != nil {
return fmt.Errorf("invalid value for TLS config name: %v", err)
}
cfg.TLSConfig = name
}
// I/O write Timeout
case "writeTimeout":
cfg.WriteTimeout, err = time.ParseDuration(value)
if err != nil {
return
}
case "maxAllowedPacket":
cfg.MaxAllowedPacket, err = strconv.Atoi(value)
if err != nil {
return
}
// Connection attributes
case "connectionAttributes":
connectionAttributes, err := url.QueryUnescape(value)
if err != nil {
return fmt.Errorf("invalid connectionAttributes value: %v", err)
}
cfg.ConnectionAttributes = connectionAttributes
default:
// lazy init
if cfg.Params == nil {
cfg.Params = make(map[string]string)
}
if cfg.Params[key], err = url.QueryUnescape(value); err != nil {
return
}
}
}
return
}
func ensureHavePort(addr string) string {
if _, _, err := net.SplitHostPort(addr); err != nil {
return net.JoinHostPort(addr, "3306")
}
return addr
}
================================================
FILE: gossh/mysql/errors.go
================================================
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"errors"
"fmt"
"log"
"os"
)
// Various errors the driver might return. Can change between driver versions.
var (
ErrInvalidConn = errors.New("invalid connection")
ErrMalformPkt = errors.New("malformed packet")
ErrNoTLS = errors.New("TLS requested but server does not support TLS")
ErrCleartextPassword = errors.New("this user requires clear text authentication. If you still want to use it, please add 'allowCleartextPasswords=1' to your DSN")
ErrNativePassword = errors.New("this user requires mysql native password authentication")
ErrOldPassword = errors.New("this user requires old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords")
ErrUnknownPlugin = errors.New("this authentication plugin is not supported")
ErrOldProtocol = errors.New("MySQL server does not support required protocol 41+")
ErrPktSync = errors.New("commands out of sync. You can't run this command now")
ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?")
ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the `Config.MaxAllowedPacket`")
ErrBusyBuffer = errors.New("busy buffer")
// errBadConnNoWrite is used for connection errors where nothing was sent to the database yet.
// If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn
// to trigger a resend.
// See https://github.com/go-sql-driver/mysql/pull/302
errBadConnNoWrite = errors.New("bad connection")
)
var defaultLogger = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile))
// Logger is used to log critical error messages.
type Logger interface {
Print(v ...any)
}
// NopLogger is a nop implementation of the Logger interface.
type NopLogger struct{}
// Print implements Logger interface.
func (nl *NopLogger) Print(_ ...any) {}
// SetLogger is used to set the default logger for critical errors.
// The initial logger is os.Stderr.
func SetLogger(logger Logger) error {
if logger == nil {
return errors.New("logger is nil")
}
defaultLogger = logger
return nil
}
// MySQLError is an error type which represents a single MySQL error
type MySQLError struct {
Number uint16
SQLState [5]byte
Message string
}
func (me *MySQLError) Error() string {
if me.SQLState != [5]byte{} {
return fmt.Sprintf("Error %d (%s): %s", me.Number, me.SQLState, me.Message)
}
return fmt.Sprintf("Error %d: %s", me.Number, me.Message)
}
func (me *MySQLError) Is(err error) bool {
if merr, ok := err.(*MySQLError); ok {
return merr.Number == me.Number
}
return false
}
================================================
FILE: gossh/mysql/fields.go
================================================
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"database/sql"
"reflect"
)
func (mf *mysqlField) typeDatabaseName() string {
switch mf.fieldType {
case fieldTypeBit:
return "BIT"
case fieldTypeBLOB:
if mf.charSet != binaryCollationID {
return "TEXT"
}
return "BLOB"
case fieldTypeDate:
return "DATE"
case fieldTypeDateTime:
return "DATETIME"
case fieldTypeDecimal:
return "DECIMAL"
case fieldTypeDouble:
return "DOUBLE"
case fieldTypeEnum:
return "ENUM"
case fieldTypeFloat:
return "FLOAT"
case fieldTypeGeometry:
return "GEOMETRY"
case fieldTypeInt24:
if mf.flags&flagUnsigned != 0 {
return "UNSIGNED MEDIUMINT"
}
return "MEDIUMINT"
case fieldTypeJSON:
return "JSON"
case fieldTypeLong:
if mf.flags&flagUnsigned != 0 {
return "UNSIGNED INT"
}
return "INT"
case fieldTypeLongBLOB:
if mf.charSet != binaryCollationID {
return "LONGTEXT"
}
return "LONGBLOB"
case fieldTypeLongLong:
if mf.flags&flagUnsigned != 0 {
return "UNSIGNED BIGINT"
}
return "BIGINT"
case fieldTypeMediumBLOB:
if mf.charSet != binaryCollationID {
return "MEDIUMTEXT"
}
return "MEDIUMBLOB"
case fieldTypeNewDate:
return "DATE"
case fieldTypeNewDecimal:
return "DECIMAL"
case fieldTypeNULL:
return "NULL"
case fieldTypeSet:
return "SET"
case fieldTypeShort:
if mf.flags&flagUnsigned != 0 {
return "UNSIGNED SMALLINT"
}
return "SMALLINT"
case fieldTypeString:
if mf.flags&flagEnum != 0 {
return "ENUM"
} else if mf.flags&flagSet != 0 {
return "SET"
}
if mf.charSet == binaryCollationID {
return "BINARY"
}
return "CHAR"
case fieldTypeTime:
return "TIME"
case fieldTypeTimestamp:
return "TIMESTAMP"
case fieldTypeTiny:
if mf.flags&flagUnsigned != 0 {
return "UNSIGNED TINYINT"
}
return "TINYINT"
case fieldTypeTinyBLOB:
if mf.charSet != binaryCollationID {
return "TINYTEXT"
}
return "TINYBLOB"
case fieldTypeVarChar:
if mf.charSet == binaryCollationID {
return "VARBINARY"
}
return "VARCHAR"
case fieldTypeVarString:
if mf.charSet == binaryCollationID {
return "VARBINARY"
}
return "VARCHAR"
case fieldTypeYear:
return "YEAR"
default:
return ""
}
}
var (
scanTypeFloat32 = reflect.TypeOf(float32(0))
scanTypeFloat64 = reflect.TypeOf(float64(0))
scanTypeInt8 = reflect.TypeOf(int8(0))
scanTypeInt16 = reflect.TypeOf(int16(0))
scanTypeInt32 = reflect.TypeOf(int32(0))
scanTypeInt64 = reflect.TypeOf(int64(0))
scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{})
scanTypeNullInt = reflect.TypeOf(sql.NullInt64{})
scanTypeNullTime = reflect.TypeOf(sql.NullTime{})
scanTypeUint8 = reflect.TypeOf(uint8(0))
scanTypeUint16 = reflect.TypeOf(uint16(0))
scanTypeUint32 = reflect.TypeOf(uint32(0))
scanTypeUint64 = reflect.TypeOf(uint64(0))
scanTypeString = reflect.TypeOf("")
scanTypeNullString = reflect.TypeOf(sql.NullString{})
scanTypeBytes = reflect.TypeOf([]byte{})
scanTypeUnknown = reflect.TypeOf(new(any))
)
type mysqlField struct {
tableName string
name string
length uint32
flags fieldFlag
fieldType fieldType
decimals byte
charSet uint8
}
func (mf *mysqlField) scanType() reflect.Type {
switch mf.fieldType {
case fieldTypeTiny:
if mf.flags&flagNotNULL != 0 {
if mf.flags&flagUnsigned != 0 {
return scanTypeUint8
}
return scanTypeInt8
}
return scanTypeNullInt
case fieldTypeShort, fieldTypeYear:
if mf.flags&flagNotNULL != 0 {
if mf.flags&flagUnsigned != 0 {
return scanTypeUint16
}
return scanTypeInt16
}
return scanTypeNullInt
case fieldTypeInt24, fieldTypeLong:
if mf.flags&flagNotNULL != 0 {
if mf.flags&flagUnsigned != 0 {
return scanTypeUint32
}
return scanTypeInt32
}
return scanTypeNullInt
case fieldTypeLongLong:
if mf.flags&flagNotNULL != 0 {
if mf.flags&flagUnsigned != 0 {
return scanTypeUint64
}
return scanTypeInt64
}
return scanTypeNullInt
case fieldTypeFloat:
if mf.flags&flagNotNULL != 0 {
return scanTypeFloat32
}
return scanTypeNullFloat
case fieldTypeDouble:
if mf.flags&flagNotNULL != 0 {
return scanTypeFloat64
}
return scanTypeNullFloat
case fieldTypeBit, fieldTypeTinyBLOB, fieldTypeMediumBLOB, fieldTypeLongBLOB,
fieldTypeBLOB, fieldTypeVarString, fieldTypeString, fieldTypeGeometry:
if mf.charSet == binaryCollationID {
return scanTypeBytes
}
fallthrough
case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar,
fieldTypeEnum, fieldTypeSet, fieldTypeJSON, fieldTypeTime:
if mf.flags&flagNotNULL != 0 {
return scanTypeString
}
return scanTypeNullString
case fieldTypeDate, fieldTypeNewDate,
fieldTypeTimestamp, fieldTypeDateTime:
// NullTime is always returned for more consistent behavior as it can
// handle both cases of parseTime regardless if the field is nullable.
return scanTypeNullTime
default:
return scanTypeUnknown
}
}
================================================
FILE: gossh/mysql/infile.go
================================================
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"fmt"
"io"
"os"
"strings"
"sync"
)
var (
fileRegister map[string]bool
fileRegisterLock sync.RWMutex
readerRegister map[string]func() io.Reader
readerRegisterLock sync.RWMutex
)
// RegisterLocalFile adds the given file to the file allowlist,
// so that it can be used by "LOAD DATA LOCAL INFILE ".
// Alternatively you can allow the use of all local files with
// the DSN parameter 'allowAllFiles=true'
//
// filePath := "/home/gopher/data.csv"
// mysql.RegisterLocalFile(filePath)
// err := db.Exec("LOAD DATA LOCAL INFILE '" + filePath + "' INTO TABLE foo")
// if err != nil {
// ...
func RegisterLocalFile(filePath string) {
fileRegisterLock.Lock()
// lazy map init
if fileRegister == nil {
fileRegister = make(map[string]bool)
}
fileRegister[strings.Trim(filePath, `"`)] = true
fileRegisterLock.Unlock()
}
// DeregisterLocalFile removes the given filepath from the allowlist.
func DeregisterLocalFile(filePath string) {
fileRegisterLock.Lock()
delete(fileRegister, strings.Trim(filePath, `"`))
fileRegisterLock.Unlock()
}
// RegisterReaderHandler registers a handler function which is used
// to receive a io.Reader.
// The Reader can be used by "LOAD DATA LOCAL INFILE Reader::".
// If the handler returns a io.ReadCloser Close() is called when the
// request is finished.
//
// mysql.RegisterReaderHandler("data", func() io.Reader {
// var csvReader io.Reader // Some Reader that returns CSV data
// ... // Open Reader here
// return csvReader
// })
// err := db.Exec("LOAD DATA LOCAL INFILE 'Reader::data' INTO TABLE foo")
// if err != nil {
// ...
func RegisterReaderHandler(name string, handler func() io.Reader) {
readerRegisterLock.Lock()
// lazy map init
if readerRegister == nil {
readerRegister = make(map[string]func() io.Reader)
}
readerRegister[name] = handler
readerRegisterLock.Unlock()
}
// DeregisterReaderHandler removes the ReaderHandler function with
// the given name from the registry.
func DeregisterReaderHandler(name string) {
readerRegisterLock.Lock()
delete(readerRegister, name)
readerRegisterLock.Unlock()
}
func deferredClose(err *error, closer io.Closer) {
closeErr := closer.Close()
if *err == nil {
*err = closeErr
}
}
const defaultPacketSize = 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP
func (mc *okHandler) handleInFileRequest(name string) (err error) {
var rdr io.Reader
var data []byte
packetSize := defaultPacketSize
if mc.maxWriteSize < packetSize {
packetSize = mc.maxWriteSize
}
if idx := strings.Index(name, "Reader::"); idx == 0 || (idx > 0 && name[idx-1] == '/') { // io.Reader
// The server might return an an absolute path. See issue #355.
name = name[idx+8:]
readerRegisterLock.RLock()
handler, inMap := readerRegister[name]
readerRegisterLock.RUnlock()
if inMap {
rdr = handler()
if rdr != nil {
if cl, ok := rdr.(io.Closer); ok {
defer deferredClose(&err, cl)
}
} else {
err = fmt.Errorf("reader '%s' is ", name)
}
} else {
err = fmt.Errorf("reader '%s' is not registered", name)
}
} else { // File
name = strings.Trim(name, `"`)
fileRegisterLock.RLock()
fr := fileRegister[name]
fileRegisterLock.RUnlock()
if mc.cfg.AllowAllFiles || fr {
var file *os.File
var fi os.FileInfo
if file, err = os.Open(name); err == nil {
defer deferredClose(&err, file)
// get file size
if fi, err = file.Stat(); err == nil {
rdr = file
if fileSize := int(fi.Size()); fileSize < packetSize {
packetSize = fileSize
}
}
}
} else {
err = fmt.Errorf("local file '%s' is not registered", name)
}
}
// send content packets
// if packetSize == 0, the Reader contains no data
if err == nil && packetSize > 0 {
data := make([]byte, 4+packetSize)
var n int
for err == nil {
n, err = rdr.Read(data[4:])
if n > 0 {
if ioErr := mc.conn().writePacket(data[:4+n]); ioErr != nil {
return ioErr
}
}
}
if err == io.EOF {
err = nil
}
}
// send empty packet (termination)
if data == nil {
data = make([]byte, 4)
}
if ioErr := mc.conn().writePacket(data[:4]); ioErr != nil {
return ioErr
}
// read OK packet
if err == nil {
return mc.readResultOK()
}
mc.conn().readPacket()
return err
}
================================================
FILE: gossh/mysql/nulltime.go
================================================
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"database/sql"
"database/sql/driver"
"fmt"
"time"
)
// NullTime represents a time.Time that may be NULL.
// NullTime implements the Scanner interface so
// it can be used as a scan destination:
//
// var nt NullTime
// err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt)
// ...
// if nt.Valid {
// // use nt.Time
// } else {
// // NULL value
// }
//
// # This NullTime implementation is not driver-specific
//
// Deprecated: NullTime doesn't honor the loc DSN parameter.
// NullTime.Scan interprets a time as UTC, not the loc DSN parameter.
// Use sql.NullTime instead.
type NullTime sql.NullTime
// Scan implements the Scanner interface.
// The value type must be time.Time or string / []byte (formatted time-string),
// otherwise Scan fails.
func (nt *NullTime) Scan(value any) (err error) {
if value == nil {
nt.Time, nt.Valid = time.Time{}, false
return
}
switch v := value.(type) {
case time.Time:
nt.Time, nt.Valid = v, true
return
case []byte:
nt.Time, err = parseDateTime(v, time.UTC)
nt.Valid = (err == nil)
return
case string:
nt.Time, err = parseDateTime([]byte(v), time.UTC)
nt.Valid = (err == nil)
return
}
nt.Valid = false
return fmt.Errorf("can't convert %T to time.Time", value)
}
// Value implements the driver Valuer interface.
func (nt NullTime) Value() (driver.Value, error) {
if !nt.Valid {
return nil, nil
}
return nt.Time, nil
}
================================================
FILE: gossh/mysql/packets.go
================================================
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"bytes"
"crypto/tls"
"database/sql/driver"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"math"
"strconv"
"time"
)
// Packets documentation:
// http://dev.mysql.com/doc/internals/en/client-server-protocol.html
// Read packet to buffer 'data'
func (mc *mysqlConn) readPacket() ([]byte, error) {
var prevData []byte
for {
// read packet header
data, err := mc.buf.readNext(4)
if err != nil {
if cerr := mc.canceled.Value(); cerr != nil {
return nil, cerr
}
mc.log(err)
mc.Close()
return nil, ErrInvalidConn
}
// packet length [24 bit]
pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16)
// check packet sync [8 bit]
if data[3] != mc.sequence {
mc.Close()
if data[3] > mc.sequence {
return nil, ErrPktSyncMul
}
return nil, ErrPktSync
}
mc.sequence++
// packets with length 0 terminate a previous packet which is a
// multiple of (2^24)-1 bytes long
if pktLen == 0 {
// there was no previous packet
if prevData == nil {
mc.log(ErrMalformPkt)
mc.Close()
return nil, ErrInvalidConn
}
return prevData, nil
}
// read packet body [pktLen bytes]
data, err = mc.buf.readNext(pktLen)
if err != nil {
if cerr := mc.canceled.Value(); cerr != nil {
return nil, cerr
}
mc.log(err)
mc.Close()
return nil, ErrInvalidConn
}
// return data if this was the last packet
if pktLen < maxPacketSize {
// zero allocations for non-split packets
if prevData == nil {
return data, nil
}
return append(prevData, data...), nil
}
prevData = append(prevData, data...)
}
}
// Write packet buffer 'data'
func (mc *mysqlConn) writePacket(data []byte) error {
pktLen := len(data) - 4
if pktLen > mc.maxAllowedPacket {
return ErrPktTooLarge
}
for {
var size int
if pktLen >= maxPacketSize {
data[0] = 0xff
data[1] = 0xff
data[2] = 0xff
size = maxPacketSize
} else {
data[0] = byte(pktLen)
data[1] = byte(pktLen >> 8)
data[2] = byte(pktLen >> 16)
size = pktLen
}
data[3] = mc.sequence
// Write packet
if mc.writeTimeout > 0 {
if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil {
return err
}
}
n, err := mc.netConn.Write(data[:4+size])
if err == nil && n == 4+size {
mc.sequence++
if size != maxPacketSize {
return nil
}
pktLen -= size
data = data[size:]
continue
}
// Handle error
if err == nil { // n != len(data)
mc.cleanup()
mc.log(ErrMalformPkt)
} else {
if cerr := mc.canceled.Value(); cerr != nil {
return cerr
}
if n == 0 && pktLen == len(data)-4 {
// only for the first loop iteration when nothing was written yet
return errBadConnNoWrite
}
mc.cleanup()
mc.log(err)
}
return ErrInvalidConn
}
}
/******************************************************************************
* Initialization Process *
******************************************************************************/
// Handshake Initialization Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err error) {
data, err = mc.readPacket()
if err != nil {
// for init we can rewrite this to ErrBadConn for sql.Driver to retry, since
// in connection initialization we don't risk retrying non-idempotent actions.
if err == ErrInvalidConn {
return nil, "", driver.ErrBadConn
}
return
}
if data[0] == iERR {
return nil, "", mc.handleErrorPacket(data)
}
// protocol version [1 byte]
if data[0] < minProtocolVersion {
return nil, "", fmt.Errorf(
"unsupported protocol version %d. Version %d or higher is required",
data[0],
minProtocolVersion,
)
}
// server version [null terminated string]
// connection id [4 bytes]
pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4
// first part of the password cipher [8 bytes]
authData := data[pos : pos+8]
// (filler) always 0x00 [1 byte]
pos += 8 + 1
// capability flags (lower 2 bytes) [2 bytes]
mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
if mc.flags&clientProtocol41 == 0 {
return nil, "", ErrOldProtocol
}
if mc.flags&clientSSL == 0 && mc.cfg.TLS != nil {
if mc.cfg.AllowFallbackToPlaintext {
mc.cfg.TLS = nil
} else {
return nil, "", ErrNoTLS
}
}
pos += 2
if len(data) > pos {
// character set [1 byte]
// status flags [2 bytes]
// capability flags (upper 2 bytes) [2 bytes]
// length of auth-plugin-data [1 byte]
// reserved (all [00]) [10 bytes]
pos += 1 + 2 + 2 + 1 + 10
// second part of the password cipher [minimum 13 bytes],
// where len=MAX(13, length of auth-plugin-data - 8)
//
// The web documentation is ambiguous about the length. However,
// according to mysql-5.7/sql/auth/sql_authentication.cc line 538,
// the 13th byte is "\0 byte, terminating the second part of
// a scramble". So the second part of the password cipher is
// a NULL terminated string that's at least 13 bytes with the
// last byte being NULL.
//
// The official Python library uses the fixed length 12
// which seems to work but technically could have a hidden bug.
authData = append(authData, data[pos:pos+12]...)
pos += 13
// EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2)
// \NUL otherwise
if end := bytes.IndexByte(data[pos:], 0x00); end != -1 {
plugin = string(data[pos : pos+end])
} else {
plugin = string(data[pos:])
}
// make a memory safe copy of the cipher slice
var b [20]byte
copy(b[:], authData)
return b[:], plugin, nil
}
// make a memory safe copy of the cipher slice
var b [8]byte
copy(b[:], authData)
return b[:], plugin, nil
}
// Client Authentication Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error {
// Adjust client flags based on server support
clientFlags := clientProtocol41 |
clientSecureConn |
clientLongPassword |
clientTransactions |
clientLocalFiles |
clientPluginAuth |
clientMultiResults |
clientConnectAttrs |
mc.flags&clientLongFlag
if mc.cfg.ClientFoundRows {
clientFlags |= clientFoundRows
}
// To enable TLS / SSL
if mc.cfg.TLS != nil {
clientFlags |= clientSSL
}
if mc.cfg.MultiStatements {
clientFlags |= clientMultiStatements
}
// encode length of the auth plugin data
var authRespLEIBuf [9]byte
authRespLen := len(authResp)
authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(authRespLen))
if len(authRespLEI) > 1 {
// if the length can not be written in 1 byte, it must be written as a
// length encoded integer
clientFlags |= clientPluginAuthLenEncClientData
}
pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1
// To specify a db name
if n := len(mc.cfg.DBName); n > 0 {
clientFlags |= clientConnectWithDB
pktLen += n + 1
}
// encode length of the connection attributes
var connAttrsLEIBuf [9]byte
connAttrsLen := len(mc.connector.encodedAttributes)
connAttrsLEI := appendLengthEncodedInteger(connAttrsLEIBuf[:0], uint64(connAttrsLen))
pktLen += len(connAttrsLEI) + len(mc.connector.encodedAttributes)
// Calculate packet length and get buffer with that size
data, err := mc.buf.takeBuffer(pktLen + 4)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
mc.log(err)
return errBadConnNoWrite
}
// ClientFlags [32 bit]
data[4] = byte(clientFlags)
data[5] = byte(clientFlags >> 8)
data[6] = byte(clientFlags >> 16)
data[7] = byte(clientFlags >> 24)
// MaxPacketSize [32 bit] (none)
data[8] = 0x00
data[9] = 0x00
data[10] = 0x00
data[11] = 0x00
// Collation ID [1 byte]
cname := mc.cfg.Collation
if cname == "" {
cname = defaultCollation
}
var found bool
data[12], found = collations[cname]
if !found {
// Note possibility for false negatives:
// could be triggered although the collation is valid if the
// collations map does not contain entries the server supports.
return fmt.Errorf("unknown collation: %q", cname)
}
// Filler [23 bytes] (all 0x00)
pos := 13
for ; pos < 13+23; pos++ {
data[pos] = 0
}
// SSL Connection Request Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
if mc.cfg.TLS != nil {
// Send TLS / SSL request packet
if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil {
return err
}
// Switch to TLS
tlsConn := tls.Client(mc.netConn, mc.cfg.TLS)
if err := tlsConn.Handshake(); err != nil {
return err
}
mc.netConn = tlsConn
mc.buf.nc = tlsConn
}
// User [null terminated string]
if len(mc.cfg.User) > 0 {
pos += copy(data[pos:], mc.cfg.User)
}
data[pos] = 0x00
pos++
// Auth Data [length encoded integer]
pos += copy(data[pos:], authRespLEI)
pos += copy(data[pos:], authResp)
// Databasename [null terminated string]
if len(mc.cfg.DBName) > 0 {
pos += copy(data[pos:], mc.cfg.DBName)
data[pos] = 0x00
pos++
}
pos += copy(data[pos:], plugin)
data[pos] = 0x00
pos++
// Connection Attributes
pos += copy(data[pos:], connAttrsLEI)
pos += copy(data[pos:], []byte(mc.connector.encodedAttributes))
// Send Auth packet
return mc.writePacket(data[:pos])
}
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error {
pktLen := 4 + len(authData)
data, err := mc.buf.takeSmallBuffer(pktLen)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
mc.log(err)
return errBadConnNoWrite
}
// Add the auth data [EOF]
copy(data[4:], authData)
return mc.writePacket(data)
}
/******************************************************************************
* Command Packets *
******************************************************************************/
func (mc *mysqlConn) writeCommandPacket(command byte) error {
// Reset Packet Sequence
mc.sequence = 0
data, err := mc.buf.takeSmallBuffer(4 + 1)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
mc.log(err)
return errBadConnNoWrite
}
// Add command byte
data[4] = command
// Send CMD packet
return mc.writePacket(data)
}
func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
// Reset Packet Sequence
mc.sequence = 0
pktLen := 1 + len(arg)
data, err := mc.buf.takeBuffer(pktLen + 4)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
mc.log(err)
return errBadConnNoWrite
}
// Add command byte
data[4] = command
// Add arg
copy(data[5:], arg)
// Send CMD packet
return mc.writePacket(data)
}
func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
// Reset Packet Sequence
mc.sequence = 0
data, err := mc.buf.takeSmallBuffer(4 + 1 + 4)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
mc.log(err)
return errBadConnNoWrite
}
// Add command byte
data[4] = command
// Add arg [32 bit]
data[5] = byte(arg)
data[6] = byte(arg >> 8)
data[7] = byte(arg >> 16)
data[8] = byte(arg >> 24)
// Send CMD packet
return mc.writePacket(data)
}
/******************************************************************************
* Result Packets *
******************************************************************************/
func (mc *mysqlConn) readAuthResult() ([]byte, string, error) {
data, err := mc.readPacket()
if err != nil {
return nil, "", err
}
// packet indicator
switch data[0] {
case iOK:
// resultUnchanged, since auth happens before any queries or
// commands have been executed.
return nil, "", mc.resultUnchanged().handleOkPacket(data)
case iAuthMoreData:
return data[1:], "", err
case iEOF:
if len(data) == 1 {
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest
return nil, "mysql_old_password", nil
}
pluginEndIndex := bytes.IndexByte(data, 0x00)
if pluginEndIndex < 0 {
return nil, "", ErrMalformPkt
}
plugin := string(data[1:pluginEndIndex])
authData := data[pluginEndIndex+1:]
return authData, plugin, nil
default: // Error otherwise
return nil, "", mc.handleErrorPacket(data)
}
}
// Returns error if Packet is not a 'Result OK'-Packet
func (mc *okHandler) readResultOK() error {
data, err := mc.conn().readPacket()
if err != nil {
return err
}
if data[0] == iOK {
return mc.handleOkPacket(data)
}
return mc.conn().handleErrorPacket(data)
}
// Result Set Header Packet
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset
func (mc *okHandler) readResultSetHeaderPacket() (int, error) {
// handleOkPacket replaces both values; other cases leave the values unchanged.
mc.result.affectedRows = append(mc.result.affectedRows, 0)
mc.result.insertIds = append(mc.result.insertIds, 0)
data, err := mc.conn().readPacket()
if err == nil {
switch data[0] {
case iOK:
return 0, mc.handleOkPacket(data)
case iERR:
return 0, mc.conn().handleErrorPacket(data)
case iLocalInFile:
return 0, mc.handleInFileRequest(string(data[1:]))
}
// column count
num, _, _ := readLengthEncodedInteger(data)
// ignore remaining data in the packet. see #1478.
return int(num), nil
}
return 0, err
}
// Error Packet
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-ERR_Packet
func (mc *mysqlConn) handleErrorPacket(data []byte) error {
if data[0] != iERR {
return ErrMalformPkt
}
// 0xff [1 byte]
// Error Number [16 bit uint]
errno := binary.LittleEndian.Uint16(data[1:3])
// 1792: ER_CANT_EXECUTE_IN_READ_ONLY_TRANSACTION
// 1290: ER_OPTION_PREVENTS_STATEMENT (returned by Aurora during failover)
if (errno == 1792 || errno == 1290) && mc.cfg.RejectReadOnly {
// Oops; we are connected to a read-only connection, and won't be able
// to issue any write statements. Since RejectReadOnly is configured,
// we throw away this connection hoping this one would have write
// permission. This is specifically for a possible race condition
// during failover (e.g. on AWS Aurora). See README.md for more.
//
// We explicitly close the connection before returning
// driver.ErrBadConn to ensure that `database/sql` purges this
// connection and initiates a new one for next statement next time.
mc.Close()
return driver.ErrBadConn
}
me := &MySQLError{Number: errno}
pos := 3
// SQL State [optional: # + 5bytes string]
if data[3] == 0x23 {
copy(me.SQLState[:], data[4:4+5])
pos = 9
}
// Error Message [string]
me.Message = string(data[pos:])
return me
}
func readStatus(b []byte) statusFlag {
return statusFlag(b[0]) | statusFlag(b[1])<<8
}
// Returns an instance of okHandler for codepaths where mysqlConn.result doesn't
// need to be cleared first (e.g. during authentication, or while additional
// resultsets are being fetched.)
func (mc *mysqlConn) resultUnchanged() *okHandler {
return (*okHandler)(mc)
}
// okHandler represents the state of the connection when mysqlConn.result has
// been prepared for processing of OK packets.
//
// To correctly populate mysqlConn.result (updated by handleOkPacket()), all
// callpaths must either:
//
// 1. first clear it using clearResult(), or
// 2. confirm that they don't need to (by calling resultUnchanged()).
//
// Both return an instance of type *okHandler.
type okHandler mysqlConn
// Exposes the underlying type's methods.
func (mc *okHandler) conn() *mysqlConn {
return (*mysqlConn)(mc)
}
// clearResult clears the connection's stored affectedRows and insertIds
// fields.
//
// It returns a handler that can process OK responses.
func (mc *mysqlConn) clearResult() *okHandler {
mc.result = mysqlResult{}
return (*okHandler)(mc)
}
// Ok Packet
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
func (mc *okHandler) handleOkPacket(data []byte) error {
var n, m int
var affectedRows, insertId uint64
// 0x00 [1 byte]
// Affected rows [Length Coded Binary]
affectedRows, _, n = readLengthEncodedInteger(data[1:])
// Insert id [Length Coded Binary]
insertId, _, m = readLengthEncodedInteger(data[1+n:])
// Update for the current statement result (only used by
// readResultSetHeaderPacket).
if len(mc.result.affectedRows) > 0 {
mc.result.affectedRows[len(mc.result.affectedRows)-1] = int64(affectedRows)
}
if len(mc.result.insertIds) > 0 {
mc.result.insertIds[len(mc.result.insertIds)-1] = int64(insertId)
}
// server_status [2 bytes]
mc.status = readStatus(data[1+n+m : 1+n+m+2])
if mc.status&statusMoreResultsExists != 0 {
return nil
}
// warning count [2 bytes]
return nil
}
// Read Packets as Field Packets until EOF-Packet or an Error appears
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41
func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
columns := make([]mysqlField, count)
for i := 0; ; i++ {
data, err := mc.readPacket()
if err != nil {
return nil, err
}
// EOF Packet
if data[0] == iEOF && (len(data) == 5 || len(data) == 1) {
if i == count {
return columns, nil
}
return nil, fmt.Errorf("column count mismatch n:%d len:%d", count, len(columns))
}
// Catalog
pos, err := skipLengthEncodedString(data)
if err != nil {
return nil, err
}
// Database [len coded string]
n, err := skipLengthEncodedString(data[pos:])
if err != nil {
return nil, err
}
pos += n
// Table [len coded string]
if mc.cfg.ColumnsWithAlias {
tableName, _, n, err := readLengthEncodedString(data[pos:])
if err != nil {
return nil, err
}
pos += n
columns[i].tableName = string(tableName)
} else {
n, err = skipLengthEncodedString(data[pos:])
if err != nil {
return nil, err
}
pos += n
}
// Original table [len coded string]
n, err = skipLengthEncodedString(data[pos:])
if err != nil {
return nil, err
}
pos += n
// Name [len coded string]
name, _, n, err := readLengthEncodedString(data[pos:])
if err != nil {
return nil, err
}
columns[i].name = string(name)
pos += n
// Original name [len coded string]
n, err = skipLengthEncodedString(data[pos:])
if err != nil {
return nil, err
}
pos += n
// Filler [uint8]
pos++
// Charset [charset, collation uint8]
columns[i].charSet = data[pos]
pos += 2
// Length [uint32]
columns[i].length = binary.LittleEndian.Uint32(data[pos : pos+4])
pos += 4
// Field type [uint8]
columns[i].fieldType = fieldType(data[pos])
pos++
// Flags [uint16]
columns[i].flags = fieldFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
pos += 2
// Decimals [uint8]
columns[i].decimals = data[pos]
//pos++
// Default value [len coded binary]
//if pos < len(data) {
// defaultVal, _, err = bytesToLengthCodedBinary(data[pos:])
//}
}
}
// Read Packets as Field Packets until EOF-Packet or an Error appears
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow
func (rows *textRows) readRow(dest []driver.Value) error {
mc := rows.mc
if rows.rs.done {
return io.EOF
}
data, err := mc.readPacket()
if err != nil {
return err
}
// EOF Packet
if data[0] == iEOF && len(data) == 5 {
// server_status [2 bytes]
rows.mc.status = readStatus(data[3:])
rows.rs.done = true
if !rows.HasNextResultSet() {
rows.mc = nil
}
return io.EOF
}
if data[0] == iERR {
rows.mc = nil
return mc.handleErrorPacket(data)
}
// RowSet Packet
var (
n int
isNull bool
pos int = 0
)
for i := range dest {
// Read bytes and convert to string
var buf []byte
buf, isNull, n, err = readLengthEncodedString(data[pos:])
pos += n
if err != nil {
return err
}
if isNull {
dest[i] = nil
continue
}
switch rows.rs.columns[i].fieldType {
case fieldTypeTimestamp,
fieldTypeDateTime,
fieldTypeDate,
fieldTypeNewDate:
if mc.parseTime {
dest[i], err = parseDateTime(buf, mc.cfg.Loc)
} else {
dest[i] = buf
}
case fieldTypeTiny, fieldTypeShort, fieldTypeInt24, fieldTypeYear, fieldTypeLong:
dest[i], err = strconv.ParseInt(string(buf), 10, 64)
case fieldTypeLongLong:
if rows.rs.columns[i].flags&flagUnsigned != 0 {
dest[i], err = strconv.ParseUint(string(buf), 10, 64)
} else {
dest[i], err = strconv.ParseInt(string(buf), 10, 64)
}
case fieldTypeFloat:
var d float64
d, err = strconv.ParseFloat(string(buf), 32)
dest[i] = float32(d)
case fieldTypeDouble:
dest[i], err = strconv.ParseFloat(string(buf), 64)
default:
dest[i] = buf
}
if err != nil {
return err
}
}
return nil
}
// Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read
func (mc *mysqlConn) readUntilEOF() error {
for {
data, err := mc.readPacket()
if err != nil {
return err
}
switch data[0] {
case iERR:
return mc.handleErrorPacket(data)
case iEOF:
if len(data) == 5 {
mc.status = readStatus(data[3:])
}
return nil
}
}
}
/******************************************************************************
* Prepared Statements *
******************************************************************************/
// Prepare Result Packets
// http://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html
func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) {
data, err := stmt.mc.readPacket()
if err == nil {
// packet indicator [1 byte]
if data[0] != iOK {
return 0, stmt.mc.handleErrorPacket(data)
}
// statement id [4 bytes]
stmt.id = binary.LittleEndian.Uint32(data[1:5])
// Column count [16 bit uint]
columnCount := binary.LittleEndian.Uint16(data[5:7])
// Param count [16 bit uint]
stmt.paramCount = int(binary.LittleEndian.Uint16(data[7:9]))
// Reserved [8 bit]
// Warning count [16 bit uint]
return columnCount, nil
}
return 0, err
}
// http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html
func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
maxLen := stmt.mc.maxAllowedPacket - 1
pktLen := maxLen
// After the header (bytes 0-3) follows before the data:
// 1 byte command
// 4 bytes stmtID
// 2 bytes paramID
const dataOffset = 1 + 4 + 2
// Cannot use the write buffer since
// a) the buffer is too small
// b) it is in use
data := make([]byte, 4+1+4+2+len(arg))
copy(data[4+dataOffset:], arg)
for argLen := len(arg); argLen > 0; argLen -= pktLen - dataOffset {
if dataOffset+argLen < maxLen {
pktLen = dataOffset + argLen
}
stmt.mc.sequence = 0
// Add command byte [1 byte]
data[4] = comStmtSendLongData
// Add stmtID [32 bit]
data[5] = byte(stmt.id)
data[6] = byte(stmt.id >> 8)
data[7] = byte(stmt.id >> 16)
data[8] = byte(stmt.id >> 24)
// Add paramID [16 bit]
data[9] = byte(paramID)
data[10] = byte(paramID >> 8)
// Send CMD packet
err := stmt.mc.writePacket(data[:4+pktLen])
if err == nil {
data = data[pktLen-dataOffset:]
continue
}
return err
}
// Reset Packet Sequence
stmt.mc.sequence = 0
return nil
}
// Execute Prepared Statement
// http://dev.mysql.com/doc/internals/en/com-stmt-execute.html
func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
if len(args) != stmt.paramCount {
return fmt.Errorf(
"argument count mismatch (got: %d; has: %d)",
len(args),
stmt.paramCount,
)
}
const minPktLen = 4 + 1 + 4 + 1 + 4
mc := stmt.mc
// Determine threshold dynamically to avoid packet size shortage.
longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1)
if longDataSize < 64 {
longDataSize = 64
}
// Reset packet-sequence
mc.sequence = 0
var data []byte
var err error
if len(args) == 0 {
data, err = mc.buf.takeBuffer(minPktLen)
} else {
data, err = mc.buf.takeCompleteBuffer()
// In this case the len(data) == cap(data) which is used to optimise the flow below.
}
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
mc.log(err)
return errBadConnNoWrite
}
// command [1 byte]
data[4] = comStmtExecute
// statement_id [4 bytes]
data[5] = byte(stmt.id)
data[6] = byte(stmt.id >> 8)
data[7] = byte(stmt.id >> 16)
data[8] = byte(stmt.id >> 24)
// flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte]
data[9] = 0x00
// iteration_count (uint32(1)) [4 bytes]
data[10] = 0x01
data[11] = 0x00
data[12] = 0x00
data[13] = 0x00
if len(args) > 0 {
pos := minPktLen
var nullMask []byte
if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= cap(data) {
// buffer has to be extended but we don't know by how much so
// we depend on append after all data with known sizes fit.
// We stop at that because we deal with a lot of columns here
// which makes the required allocation size hard to guess.
tmp := make([]byte, pos+maskLen+typesLen)
copy(tmp[:pos], data[:pos])
data = tmp
nullMask = data[pos : pos+maskLen]
// No need to clean nullMask as make ensures that.
pos += maskLen
} else {
nullMask = data[pos : pos+maskLen]
for i := range nullMask {
nullMask[i] = 0
}
pos += maskLen
}
// newParameterBoundFlag 1 [1 byte]
data[pos] = 0x01
pos++
// type of each parameter [len(args)*2 bytes]
paramTypes := data[pos:]
pos += len(args) * 2
// value of each parameter [n bytes]
paramValues := data[pos:pos]
valuesCap := cap(paramValues)
for i, arg := range args {
// build NULL-bitmap
if arg == nil {
nullMask[i/8] |= 1 << (uint(i) & 7)
paramTypes[i+i] = byte(fieldTypeNULL)
paramTypes[i+i+1] = 0x00
continue
}
if v, ok := arg.(json.RawMessage); ok {
arg = []byte(v)
}
// cache types and values
switch v := arg.(type) {
case int64:
paramTypes[i+i] = byte(fieldTypeLongLong)
paramTypes[i+i+1] = 0x00
if cap(paramValues)-len(paramValues)-8 >= 0 {
paramValues = paramValues[:len(paramValues)+8]
binary.LittleEndian.PutUint64(
paramValues[len(paramValues)-8:],
uint64(v),
)
} else {
paramValues = append(paramValues,
uint64ToBytes(uint64(v))...,
)
}
case uint64:
paramTypes[i+i] = byte(fieldTypeLongLong)
paramTypes[i+i+1] = 0x80 // type is unsigned
if cap(paramValues)-len(paramValues)-8 >= 0 {
paramValues = paramValues[:len(paramValues)+8]
binary.LittleEndian.PutUint64(
paramValues[len(paramValues)-8:],
uint64(v),
)
} else {
paramValues = append(paramValues,
uint64ToBytes(uint64(v))...,
)
}
case float64:
paramTypes[i+i] = byte(fieldTypeDouble)
paramTypes[i+i+1] = 0x00
if cap(paramValues)-len(paramValues)-8 >= 0 {
paramValues = paramValues[:len(paramValues)+8]
binary.LittleEndian.PutUint64(
paramValues[len(paramValues)-8:],
math.Float64bits(v),
)
} else {
paramValues = append(paramValues,
uint64ToBytes(math.Float64bits(v))...,
)
}
case bool:
paramTypes[i+i] = byte(fieldTypeTiny)
paramTypes[i+i+1] = 0x00
if v {
paramValues = append(paramValues, 0x01)
} else {
paramValues = append(paramValues, 0x00)
}
case []byte:
// Common case (non-nil value) first
if v != nil {
paramTypes[i+i] = byte(fieldTypeString)
paramTypes[i+i+1] = 0x00
if len(v) < longDataSize {
paramValues = appendLengthEncodedInteger(paramValues,
uint64(len(v)),
)
paramValues = append(paramValues, v...)
} else {
if err := stmt.writeCommandLongData(i, v); err != nil {
return err
}
}
continue
}
// Handle []byte(nil) as a NULL value
nullMask[i/8] |= 1 << (uint(i) & 7)
paramTypes[i+i] = byte(fieldTypeNULL)
paramTypes[i+i+1] = 0x00
case string:
paramTypes[i+i] = byte(fieldTypeString)
paramTypes[i+i+1] = 0x00
if len(v) < longDataSize {
paramValues = appendLengthEncodedInteger(paramValues,
uint64(len(v)),
)
paramValues = append(paramValues, v...)
} else {
if err := stmt.writeCommandLongData(i, []byte(v)); err != nil {
return err
}
}
case time.Time:
paramTypes[i+i] = byte(fieldTypeString)
paramTypes[i+i+1] = 0x00
var a [64]byte
var b = a[:0]
if v.IsZero() {
b = append(b, "0000-00-00"...)
} else {
b, err = appendDateTime(b, v.In(mc.cfg.Loc), mc.cfg.timeTruncate)
if err != nil {
return err
}
}
paramValues = appendLengthEncodedInteger(paramValues,
uint64(len(b)),
)
paramValues = append(paramValues, b...)
default:
return fmt.Errorf("cannot convert type: %T", arg)
}
}
// Check if param values exceeded the available buffer
// In that case we must build the data packet with the new values buffer
if valuesCap != cap(paramValues) {
data = append(data[:pos], paramValues...)
if err = mc.buf.store(data); err != nil {
mc.log(err)
return errBadConnNoWrite
}
}
pos += len(paramValues)
data = data[:pos]
}
return mc.writePacket(data)
}
// For each remaining resultset in the stream, discards its rows and updates
// mc.affectedRows and mc.insertIds.
func (mc *okHandler) discardResults() error {
for mc.status&statusMoreResultsExists != 0 {
resLen, err := mc.readResultSetHeaderPacket()
if err != nil {
return err
}
if resLen > 0 {
// columns
if err := mc.conn().readUntilEOF(); err != nil {
return err
}
// rows
if err := mc.conn().readUntilEOF(); err != nil {
return err
}
}
}
return nil
}
// http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html
func (rows *binaryRows) readRow(dest []driver.Value) error {
data, err := rows.mc.readPacket()
if err != nil {
return err
}
// packet indicator [1 byte]
if data[0] != iOK {
// EOF Packet
if data[0] == iEOF && len(data) == 5 {
rows.mc.status = readStatus(data[3:])
rows.rs.done = true
if !rows.HasNextResultSet() {
rows.mc = nil
}
return io.EOF
}
mc := rows.mc
rows.mc = nil
// Error otherwise
return mc.handleErrorPacket(data)
}
// NULL-bitmap, [(column-count + 7 + 2) / 8 bytes]
pos := 1 + (len(dest)+7+2)>>3
nullMask := data[1:pos]
for i := range dest {
// Field is NULL
// (byte >> bit-pos) % 2 == 1
if ((nullMask[(i+2)>>3] >> uint((i+2)&7)) & 1) == 1 {
dest[i] = nil
continue
}
// Convert to byte-coded string
switch rows.rs.columns[i].fieldType {
case fieldTypeNULL:
dest[i] = nil
continue
// Numeric Types
case fieldTypeTiny:
if rows.rs.columns[i].flags&flagUnsigned != 0 {
dest[i] = int64(data[pos])
} else {
dest[i] = int64(int8(data[pos]))
}
pos++
continue
case fieldTypeShort, fieldTypeYear:
if rows.rs.columns[i].flags&flagUnsigned != 0 {
dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2]))
} else {
dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2])))
}
pos += 2
continue
case fieldTypeInt24, fieldTypeLong:
if rows.rs.columns[i].flags&flagUnsigned != 0 {
dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4]))
} else {
dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4])))
}
pos += 4
continue
case fieldTypeLongLong:
if rows.rs.columns[i].flags&flagUnsigned != 0 {
val := binary.LittleEndian.Uint64(data[pos : pos+8])
if val > math.MaxInt64 {
dest[i] = uint64ToString(val)
} else {
dest[i] = int64(val)
}
} else {
dest[i] = int64(binary.LittleEndian.Uint64(data[pos : pos+8]))
}
pos += 8
continue
case fieldTypeFloat:
dest[i] = math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4]))
pos += 4
continue
case fieldTypeDouble:
dest[i] = math.Float64frombits(binary.LittleEndian.Uint64(data[pos : pos+8]))
pos += 8
continue
// Length coded Binary Strings
case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar,
fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB,
fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB,
fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON:
var isNull bool
var n int
dest[i], isNull, n, err = readLengthEncodedString(data[pos:])
pos += n
if err == nil {
if !isNull {
continue
} else {
dest[i] = nil
continue
}
}
return err
case
fieldTypeDate, fieldTypeNewDate, // Date YYYY-MM-DD
fieldTypeTime, // Time [-][H]HH:MM:SS[.fractal]
fieldTypeTimestamp, fieldTypeDateTime: // Timestamp YYYY-MM-DD HH:MM:SS[.fractal]
num, isNull, n := readLengthEncodedInteger(data[pos:])
pos += n
switch {
case isNull:
dest[i] = nil
continue
case rows.rs.columns[i].fieldType == fieldTypeTime:
// database/sql does not support an equivalent to TIME, return a string
var dstlen uint8
switch decimals := rows.rs.columns[i].decimals; decimals {
case 0x00, 0x1f:
dstlen = 8
case 1, 2, 3, 4, 5, 6:
dstlen = 8 + 1 + decimals
default:
return fmt.Errorf(
"protocol error, illegal decimals value %d",
rows.rs.columns[i].decimals,
)
}
dest[i], err = formatBinaryTime(data[pos:pos+int(num)], dstlen)
case rows.mc.parseTime:
dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc)
default:
var dstlen uint8
if rows.rs.columns[i].fieldType == fieldTypeDate {
dstlen = 10
} else {
switch decimals := rows.rs.columns[i].decimals; decimals {
case 0x00, 0x1f:
dstlen = 19
case 1, 2, 3, 4, 5, 6:
dstlen = 19 + 1 + decimals
default:
return fmt.Errorf(
"protocol error, illegal decimals value %d",
rows.rs.columns[i].decimals,
)
}
}
dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen)
}
if err == nil {
pos += int(num)
continue
} else {
return err
}
// Please report if this happens!
default:
return fmt.Errorf("unknown field type %d", rows.rs.columns[i].fieldType)
}
}
return nil
}
================================================
FILE: gossh/mysql/result.go
================================================
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import "database/sql/driver"
// Result exposes data not available through *connection.Result.
//
// This is accessible by executing statements using sql.Conn.Raw() and
// downcasting the returned result:
//
// res, err := rawConn.Exec(...)
// res.(mysql.Result).AllRowsAffected()
type Result interface {
driver.Result
// AllRowsAffected returns a slice containing the affected rows for each
// executed statement.
AllRowsAffected() []int64
// AllLastInsertIds returns a slice containing the last inserted ID for each
// executed statement.
AllLastInsertIds() []int64
}
type mysqlResult struct {
// One entry in both slices is created for every executed statement result.
affectedRows []int64
insertIds []int64
}
func (res *mysqlResult) LastInsertId() (int64, error) {
return res.insertIds[len(res.insertIds)-1], nil
}
func (res *mysqlResult) RowsAffected() (int64, error) {
return res.affectedRows[len(res.affectedRows)-1], nil
}
func (res *mysqlResult) AllLastInsertIds() []int64 {
return append([]int64{}, res.insertIds...) // defensive copy
}
func (res *mysqlResult) AllRowsAffected() []int64 {
return append([]int64{}, res.affectedRows...) // defensive copy
}
================================================
FILE: gossh/mysql/rows.go
================================================
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"database/sql/driver"
"io"
"math"
"reflect"
)
type resultSet struct {
columns []mysqlField
columnNames []string
done bool
}
type mysqlRows struct {
mc *mysqlConn
rs resultSet
finish func()
}
type binaryRows struct {
mysqlRows
}
type textRows struct {
mysqlRows
}
func (rows *mysqlRows) Columns() []string {
if rows.rs.columnNames != nil {
return rows.rs.columnNames
}
columns := make([]string, len(rows.rs.columns))
if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias {
for i := range columns {
if tableName := rows.rs.columns[i].tableName; len(tableName) > 0 {
columns[i] = tableName + "." + rows.rs.columns[i].name
} else {
columns[i] = rows.rs.columns[i].name
}
}
} else {
for i := range columns {
columns[i] = rows.rs.columns[i].name
}
}
rows.rs.columnNames = columns
return columns
}
func (rows *mysqlRows) ColumnTypeDatabaseTypeName(i int) string {
return rows.rs.columns[i].typeDatabaseName()
}
// func (rows *mysqlRows) ColumnTypeLength(i int) (length int64, ok bool) {
// return int64(rows.rs.columns[i].length), true
// }
func (rows *mysqlRows) ColumnTypeNullable(i int) (nullable, ok bool) {
return rows.rs.columns[i].flags&flagNotNULL == 0, true
}
func (rows *mysqlRows) ColumnTypePrecisionScale(i int) (int64, int64, bool) {
column := rows.rs.columns[i]
decimals := int64(column.decimals)
switch column.fieldType {
case fieldTypeDecimal, fieldTypeNewDecimal:
if decimals > 0 {
return int64(column.length) - 2, decimals, true
}
return int64(column.length) - 1, decimals, true
case fieldTypeTimestamp, fieldTypeDateTime, fieldTypeTime:
return decimals, decimals, true
case fieldTypeFloat, fieldTypeDouble:
if decimals == 0x1f {
return math.MaxInt64, math.MaxInt64, true
}
return math.MaxInt64, decimals, true
}
return 0, 0, false
}
func (rows *mysqlRows) ColumnTypeScanType(i int) reflect.Type {
return rows.rs.columns[i].scanType()
}
func (rows *mysqlRows) Close() (err error) {
if f := rows.finish; f != nil {
f()
rows.finish = nil
}
mc := rows.mc
if mc == nil {
return nil
}
if err := mc.error(); err != nil {
return err
}
// flip the buffer for this connection if we need to drain it.
// note that for a successful query (i.e. one where rows.next()
// has been called until it returns false), `rows.mc` will be nil
// by the time the user calls `(*Rows).Close`, so we won't reach this
// see: https://github.com/golang/go/commit/651ddbdb5056ded455f47f9c494c67b389622a47
mc.buf.flip()
// Remove unread packets from stream
if !rows.rs.done {
err = mc.readUntilEOF()
}
if err == nil {
handleOk := mc.clearResult()
if err = handleOk.discardResults(); err != nil {
return err
}
}
rows.mc = nil
return err
}
func (rows *mysqlRows) HasNextResultSet() (b bool) {
if rows.mc == nil {
return false
}
return rows.mc.status&statusMoreResultsExists != 0
}
func (rows *mysqlRows) nextResultSet() (int, error) {
if rows.mc == nil {
return 0, io.EOF
}
if err := rows.mc.error(); err != nil {
return 0, err
}
// Remove unread packets from stream
if !rows.rs.done {
if err := rows.mc.readUntilEOF(); err != nil {
return 0, err
}
rows.rs.done = true
}
if !rows.HasNextResultSet() {
rows.mc = nil
return 0, io.EOF
}
rows.rs = resultSet{}
// rows.mc.affectedRows and rows.mc.insertIds accumulate on each call to
// nextResultSet.
resLen, err := rows.mc.resultUnchanged().readResultSetHeaderPacket()
if err != nil {
// Clean up about multi-results flag
rows.rs.done = true
rows.mc.status = rows.mc.status & (^statusMoreResultsExists)
}
return resLen, err
}
func (rows *mysqlRows) nextNotEmptyResultSet() (int, error) {
for {
resLen, err := rows.nextResultSet()
if err != nil {
return 0, err
}
if resLen > 0 {
return resLen, nil
}
rows.rs.done = true
}
}
func (rows *binaryRows) NextResultSet() error {
resLen, err := rows.nextNotEmptyResultSet()
if err != nil {
return err
}
rows.rs.columns, err = rows.mc.readColumns(resLen)
return err
}
func (rows *binaryRows) Next(dest []driver.Value) error {
if mc := rows.mc; mc != nil {
if err := mc.error(); err != nil {
return err
}
// Fetch next row from stream
return rows.readRow(dest)
}
return io.EOF
}
func (rows *textRows) NextResultSet() (err error) {
resLen, err := rows.nextNotEmptyResultSet()
if err != nil {
return err
}
rows.rs.columns, err = rows.mc.readColumns(resLen)
return err
}
func (rows *textRows) Next(dest []driver.Value) error {
if mc := rows.mc; mc != nil {
if err := mc.error(); err != nil {
return err
}
// Fetch next row from stream
return rows.readRow(dest)
}
return io.EOF
}
================================================
FILE: gossh/mysql/statement.go
================================================
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"database/sql/driver"
"encoding/json"
"fmt"
"io"
"reflect"
)
type mysqlStmt struct {
mc *mysqlConn
id uint32
paramCount int
}
func (stmt *mysqlStmt) Close() error {
if stmt.mc == nil || stmt.mc.closed.Load() {
// driver.Stmt.Close can be called more than once, thus this function
// has to be idempotent.
// See also Issue #450 and golang/go#16019.
//errLog.Print(ErrInvalidConn)
return driver.ErrBadConn
}
err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id)
stmt.mc = nil
return err
}
func (stmt *mysqlStmt) NumInput() int {
return stmt.paramCount
}
func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter {
return converter{}
}
func (stmt *mysqlStmt) CheckNamedValue(nv *driver.NamedValue) (err error) {
nv.Value, err = converter{}.ConvertValue(nv.Value)
return
}
func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
if stmt.mc.closed.Load() {
stmt.mc.log(ErrInvalidConn)
return nil, driver.ErrBadConn
}
// Send command
err := stmt.writeExecutePacket(args)
if err != nil {
return nil, stmt.mc.markBadConn(err)
}
mc := stmt.mc
handleOk := stmt.mc.clearResult()
// Read Result
resLen, err := handleOk.readResultSetHeaderPacket()
if err != nil {
return nil, err
}
if resLen > 0 {
// Columns
if err = mc.readUntilEOF(); err != nil {
return nil, err
}
// Rows
if err := mc.readUntilEOF(); err != nil {
return nil, err
}
}
if err := handleOk.discardResults(); err != nil {
return nil, err
}
copied := mc.result
return &copied, nil
}
func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
return stmt.query(args)
}
func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
if stmt.mc.closed.Load() {
stmt.mc.log(ErrInvalidConn)
return nil, driver.ErrBadConn
}
// Send command
err := stmt.writeExecutePacket(args)
if err != nil {
return nil, stmt.mc.markBadConn(err)
}
mc := stmt.mc
// Read Result
handleOk := stmt.mc.clearResult()
resLen, err := handleOk.readResultSetHeaderPacket()
if err != nil {
return nil, err
}
rows := new(binaryRows)
if resLen > 0 {
rows.mc = mc
rows.rs.columns, err = mc.readColumns(resLen)
} else {
rows.rs.done = true
switch err := rows.NextResultSet(); err {
case nil, io.EOF:
return rows, nil
default:
return nil, err
}
}
return rows, err
}
var jsonType = reflect.TypeOf(json.RawMessage{})
type converter struct{}
// ConvertValue mirrors the reference/default converter in database/sql/driver
// with _one_ exception. We support uint64 with their high bit and the default
// implementation does not. This function should be kept in sync with
// database/sql/driver defaultConverter.ConvertValue() except for that
// deliberate difference.
func (c converter) ConvertValue(v any) (driver.Value, error) {
if driver.IsValue(v) {
return v, nil
}
if vr, ok := v.(driver.Valuer); ok {
sv, err := callValuerValue(vr)
if err != nil {
return nil, err
}
if driver.IsValue(sv) {
return sv, nil
}
// A value returned from the Valuer interface can be "a type handled by
// a database driver's NamedValueChecker interface" so we should accept
// uint64 here as well.
if u, ok := sv.(uint64); ok {
return u, nil
}
return nil, fmt.Errorf("non-Value type %T returned from Value", sv)
}
rv := reflect.ValueOf(v)
switch rv.Kind() {
case reflect.Ptr:
// indirect pointers
if rv.IsNil() {
return nil, nil
} else {
return c.ConvertValue(rv.Elem().Interface())
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return rv.Int(), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return rv.Uint(), nil
case reflect.Float32, reflect.Float64:
return rv.Float(), nil
case reflect.Bool:
return rv.Bool(), nil
case reflect.Slice:
switch t := rv.Type(); {
case t == jsonType:
return v, nil
case t.Elem().Kind() == reflect.Uint8:
return rv.Bytes(), nil
default:
return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, t.Elem().Kind())
}
case reflect.String:
return rv.String(), nil
}
return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind())
}
var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
// callValuerValue returns vr.Value(), with one exception:
// If vr.Value is an auto-generated method on a pointer type and the
// pointer is nil, it would panic at runtime in the panicwrap
// method. Treat it like nil instead.
//
// This is so people can implement driver.Value on value types and
// still use nil pointers to those types to mean nil/NULL, just like
// string/*string.
//
// This is an exact copy of the same-named unexported function from the
// database/sql package.
func callValuerValue(vr driver.Valuer) (v driver.Value, err error) {
if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr &&
rv.IsNil() &&
rv.Type().Elem().Implements(valuerReflectType) {
return nil, nil
}
return vr.Value()
}
================================================
FILE: gossh/mysql/transaction.go
================================================
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
type mysqlTx struct {
mc *mysqlConn
}
func (tx *mysqlTx) Commit() (err error) {
if tx.mc == nil || tx.mc.closed.Load() {
return ErrInvalidConn
}
err = tx.mc.exec("COMMIT")
tx.mc = nil
return
}
func (tx *mysqlTx) Rollback() (err error) {
if tx.mc == nil || tx.mc.closed.Load() {
return ErrInvalidConn
}
err = tx.mc.exec("ROLLBACK")
tx.mc = nil
return
}
================================================
FILE: gossh/mysql/utils.go
================================================
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"crypto/tls"
"database/sql"
"database/sql/driver"
"encoding/binary"
"errors"
"fmt"
"io"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
)
// Registry for custom tls.Configs
var (
tlsConfigLock sync.RWMutex
tlsConfigRegistry map[string]*tls.Config
)
// RegisterTLSConfig registers a custom tls.Config to be used with sql.Open.
// Use the key as a value in the DSN where tls=value.
//
// Note: The provided tls.Config is exclusively owned by the driver after
// registering it.
//
// rootCertPool := x509.NewCertPool()
// pem, err := os.ReadFile("/path/ca-cert.pem")
// if err != nil {
// log.Fatal(err)
// }
// if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
// log.Fatal("Failed to append PEM.")
// }
// clientCert := make([]tls.Certificate, 0, 1)
// certs, err := tls.LoadX509KeyPair("/path/client-cert.pem", "/path/client-key.pem")
// if err != nil {
// log.Fatal(err)
// }
// clientCert = append(clientCert, certs)
// mysql.RegisterTLSConfig("custom", &tls.Config{
// RootCAs: rootCertPool,
// Certificates: clientCert,
// })
// db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom")
func RegisterTLSConfig(key string, config *tls.Config) error {
if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" || strings.ToLower(key) == "preferred" {
return fmt.Errorf("key '%s' is reserved", key)
}
tlsConfigLock.Lock()
if tlsConfigRegistry == nil {
tlsConfigRegistry = make(map[string]*tls.Config)
}
tlsConfigRegistry[key] = config
tlsConfigLock.Unlock()
return nil
}
// DeregisterTLSConfig removes the tls.Config associated with key.
func DeregisterTLSConfig(key string) {
tlsConfigLock.Lock()
if tlsConfigRegistry != nil {
delete(tlsConfigRegistry, key)
}
tlsConfigLock.Unlock()
}
func getTLSConfigClone(key string) (config *tls.Config) {
tlsConfigLock.RLock()
if v, ok := tlsConfigRegistry[key]; ok {
config = v.Clone()
}
tlsConfigLock.RUnlock()
return
}
// Returns the bool value of the input.
// The 2nd return value indicates if the input was a valid bool value
func readBool(input string) (value bool, valid bool) {
switch input {
case "1", "true", "TRUE", "True":
return true, true
case "0", "false", "FALSE", "False":
return false, true
}
// Not a valid bool value
return
}
/******************************************************************************
* Time related utils *
******************************************************************************/
func parseDateTime(b []byte, loc *time.Location) (time.Time, error) {
const base = "0000-00-00 00:00:00.000000"
switch len(b) {
case 10, 19, 21, 22, 23, 24, 25, 26: // up to "YYYY-MM-DD HH:MM:SS.MMMMMM"
if string(b) == base[:len(b)] {
return time.Time{}, nil
}
year, err := parseByteYear(b)
if err != nil {
return time.Time{}, err
}
if b[4] != '-' {
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[4])
}
m, err := parseByte2Digits(b[5], b[6])
if err != nil {
return time.Time{}, err
}
month := time.Month(m)
if b[7] != '-' {
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[7])
}
day, err := parseByte2Digits(b[8], b[9])
if err != nil {
return time.Time{}, err
}
if len(b) == 10 {
return time.Date(year, month, day, 0, 0, 0, 0, loc), nil
}
if b[10] != ' ' {
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[10])
}
hour, err := parseByte2Digits(b[11], b[12])
if err != nil {
return time.Time{}, err
}
if b[13] != ':' {
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[13])
}
min, err := parseByte2Digits(b[14], b[15])
if err != nil {
return time.Time{}, err
}
if b[16] != ':' {
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[16])
}
sec, err := parseByte2Digits(b[17], b[18])
if err != nil {
return time.Time{}, err
}
if len(b) == 19 {
return time.Date(year, month, day, hour, min, sec, 0, loc), nil
}
if b[19] != '.' {
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[19])
}
nsec, err := parseByteNanoSec(b[20:])
if err != nil {
return time.Time{}, err
}
return time.Date(year, month, day, hour, min, sec, nsec, loc), nil
default:
return time.Time{}, fmt.Errorf("invalid time bytes: %s", b)
}
}
func parseByteYear(b []byte) (int, error) {
year, n := 0, 1000
for i := 0; i < 4; i++ {
v, err := bToi(b[i])
if err != nil {
return 0, err
}
year += v * n
n /= 10
}
return year, nil
}
func parseByte2Digits(b1, b2 byte) (int, error) {
d1, err := bToi(b1)
if err != nil {
return 0, err
}
d2, err := bToi(b2)
if err != nil {
return 0, err
}
return d1*10 + d2, nil
}
func parseByteNanoSec(b []byte) (int, error) {
ns, digit := 0, 100000 // max is 6-digits
for i := 0; i < len(b); i++ {
v, err := bToi(b[i])
if err != nil {
return 0, err
}
ns += v * digit
digit /= 10
}
// nanoseconds has 10-digits. (needs to scale digits)
// 10 - 6 = 4, so we have to multiple 1000.
return ns * 1000, nil
}
func bToi(b byte) (int, error) {
if b < '0' || b > '9' {
return 0, errors.New("not [0-9]")
}
return int(b - '0'), nil
}
func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Value, error) {
switch num {
case 0:
return time.Time{}, nil
case 4:
return time.Date(
int(binary.LittleEndian.Uint16(data[:2])), // year
time.Month(data[2]), // month
int(data[3]), // day
0, 0, 0, 0,
loc,
), nil
case 7:
return time.Date(
int(binary.LittleEndian.Uint16(data[:2])), // year
time.Month(data[2]), // month
int(data[3]), // day
int(data[4]), // hour
int(data[5]), // minutes
int(data[6]), // seconds
0,
loc,
), nil
case 11:
return time.Date(
int(binary.LittleEndian.Uint16(data[:2])), // year
time.Month(data[2]), // month
int(data[3]), // day
int(data[4]), // hour
int(data[5]), // minutes
int(data[6]), // seconds
int(binary.LittleEndian.Uint32(data[7:11]))*1000, // nanoseconds
loc,
), nil
}
return nil, fmt.Errorf("invalid DATETIME packet length %d", num)
}
func appendDateTime(buf []byte, t time.Time, timeTruncate time.Duration) ([]byte, error) {
if timeTruncate > 0 {
t = t.Truncate(timeTruncate)
}
year, month, day := t.Date()
hour, min, sec := t.Clock()
nsec := t.Nanosecond()
if year < 1 || year > 9999 {
return buf, errors.New("year is not in the range [1, 9999]: " + strconv.Itoa(year)) // use errors.New instead of fmt.Errorf to avoid year escape to heap
}
year100 := year / 100
year1 := year % 100
var localBuf [len("2006-01-02T15:04:05.999999999")]byte // does not escape
localBuf[0], localBuf[1], localBuf[2], localBuf[3] = digits10[year100], digits01[year100], digits10[year1], digits01[year1]
localBuf[4] = '-'
localBuf[5], localBuf[6] = digits10[month], digits01[month]
localBuf[7] = '-'
localBuf[8], localBuf[9] = digits10[day], digits01[day]
if hour == 0 && min == 0 && sec == 0 && nsec == 0 {
return append(buf, localBuf[:10]...), nil
}
localBuf[10] = ' '
localBuf[11], localBuf[12] = digits10[hour], digits01[hour]
localBuf[13] = ':'
localBuf[14], localBuf[15] = digits10[min], digits01[min]
localBuf[16] = ':'
localBuf[17], localBuf[18] = digits10[sec], digits01[sec]
if nsec == 0 {
return append(buf, localBuf[:19]...), nil
}
nsec100000000 := nsec / 100000000
nsec1000000 := (nsec / 1000000) % 100
nsec10000 := (nsec / 10000) % 100
nsec100 := (nsec / 100) % 100
nsec1 := nsec % 100
localBuf[19] = '.'
// milli second
localBuf[20], localBuf[21], localBuf[22] =
digits01[nsec100000000], digits10[nsec1000000], digits01[nsec1000000]
// micro second
localBuf[23], localBuf[24], localBuf[25] =
digits10[nsec10000], digits01[nsec10000], digits10[nsec100]
// nano second
localBuf[26], localBuf[27], localBuf[28] =
digits01[nsec100], digits10[nsec1], digits01[nsec1]
// trim trailing zeros
n := len(localBuf)
for n > 0 && localBuf[n-1] == '0' {
n--
}
return append(buf, localBuf[:n]...), nil
}
// zeroDateTime is used in formatBinaryDateTime to avoid an allocation
// if the DATE or DATETIME has the zero value.
// It must never be changed.
// The current behavior depends on database/sql copying the result.
var zeroDateTime = []byte("0000-00-00 00:00:00.000000")
const digits01 = "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789"
const digits10 = "0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999"
func appendMicrosecs(dst, src []byte, decimals int) []byte {
if decimals <= 0 {
return dst
}
if len(src) == 0 {
return append(dst, ".000000"[:decimals+1]...)
}
microsecs := binary.LittleEndian.Uint32(src[:4])
p1 := byte(microsecs / 10000)
microsecs -= 10000 * uint32(p1)
p2 := byte(microsecs / 100)
microsecs -= 100 * uint32(p2)
p3 := byte(microsecs)
switch decimals {
default:
return append(dst, '.',
digits10[p1], digits01[p1],
digits10[p2], digits01[p2],
digits10[p3], digits01[p3],
)
case 1:
return append(dst, '.',
digits10[p1],
)
case 2:
return append(dst, '.',
digits10[p1], digits01[p1],
)
case 3:
return append(dst, '.',
digits10[p1], digits01[p1],
digits10[p2],
)
case 4:
return append(dst, '.',
digits10[p1], digits01[p1],
digits10[p2], digits01[p2],
)
case 5:
return append(dst, '.',
digits10[p1], digits01[p1],
digits10[p2], digits01[p2],
digits10[p3],
)
}
}
func formatBinaryDateTime(src []byte, length uint8) (driver.Value, error) {
// length expects the deterministic length of the zero value,
// negative time and 100+ hours are automatically added if needed
if len(src) == 0 {
return zeroDateTime[:length], nil
}
var dst []byte // return value
var p1, p2, p3 byte // current digit pair
switch length {
case 10, 19, 21, 22, 23, 24, 25, 26:
default:
t := "DATE"
if length > 10 {
t += "TIME"
}
return nil, fmt.Errorf("illegal %s length %d", t, length)
}
switch len(src) {
case 4, 7, 11:
default:
t := "DATE"
if length > 10 {
t += "TIME"
}
return nil, fmt.Errorf("illegal %s packet length %d", t, len(src))
}
dst = make([]byte, 0, length)
// start with the date
year := binary.LittleEndian.Uint16(src[:2])
pt := year / 100
p1 = byte(year - 100*uint16(pt))
p2, p3 = src[2], src[3]
dst = append(dst,
digits10[pt], digits01[pt],
digits10[p1], digits01[p1], '-',
digits10[p2], digits01[p2], '-',
digits10[p3], digits01[p3],
)
if length == 10 {
return dst, nil
}
if len(src) == 4 {
return append(dst, zeroDateTime[10:length]...), nil
}
dst = append(dst, ' ')
p1 = src[4] // hour
src = src[5:]
// p1 is 2-digit hour, src is after hour
p2, p3 = src[0], src[1]
dst = append(dst,
digits10[p1], digits01[p1], ':',
digits10[p2], digits01[p2], ':',
digits10[p3], digits01[p3],
)
return appendMicrosecs(dst, src[2:], int(length)-20), nil
}
func formatBinaryTime(src []byte, length uint8) (driver.Value, error) {
// length expects the deterministic length of the zero value,
// negative time and 100+ hours are automatically added if needed
if len(src) == 0 {
return zeroDateTime[11 : 11+length], nil
}
var dst []byte // return value
switch length {
case
8, // time (can be up to 10 when negative and 100+ hours)
10, 11, 12, 13, 14, 15: // time with fractional seconds
default:
return nil, fmt.Errorf("illegal TIME length %d", length)
}
switch len(src) {
case 8, 12:
default:
return nil, fmt.Errorf("invalid TIME packet length %d", len(src))
}
// +2 to enable negative time and 100+ hours
dst = make([]byte, 0, length+2)
if src[0] == 1 {
dst = append(dst, '-')
}
days := binary.LittleEndian.Uint32(src[1:5])
hours := int64(days)*24 + int64(src[5])
if hours >= 100 {
dst = strconv.AppendInt(dst, hours, 10)
} else {
dst = append(dst, digits10[hours], digits01[hours])
}
min, sec := src[6], src[7]
dst = append(dst, ':',
digits10[min], digits01[min], ':',
digits10[sec], digits01[sec],
)
return appendMicrosecs(dst, src[8:], int(length)-9), nil
}
/******************************************************************************
* Convert from and to bytes *
******************************************************************************/
func uint64ToBytes(n uint64) []byte {
return []byte{
byte(n),
byte(n >> 8),
byte(n >> 16),
byte(n >> 24),
byte(n >> 32),
byte(n >> 40),
byte(n >> 48),
byte(n >> 56),
}
}
func uint64ToString(n uint64) []byte {
var a [20]byte
i := 20
// U+0030 = 0
// ...
// U+0039 = 9
var q uint64
for n >= 10 {
i--
q = n / 10
a[i] = uint8(n-q*10) + 0x30
n = q
}
i--
a[i] = uint8(n) + 0x30
return a[i:]
}
// treats string value as unsigned integer representation
func stringToInt(b []byte) int {
val := 0
for i := range b {
val *= 10
val += int(b[i] - 0x30)
}
return val
}
// returns the string read as a bytes slice, whether the value is NULL,
// the number of bytes read and an error, in case the string is longer than
// the input slice
func readLengthEncodedString(b []byte) ([]byte, bool, int, error) {
// Get length
num, isNull, n := readLengthEncodedInteger(b)
if num < 1 {
return b[n:n], isNull, n, nil
}
n += int(num)
// Check data length
if len(b) >= n {
return b[n-int(num) : n : n], false, n, nil
}
return nil, false, n, io.EOF
}
// returns the number of bytes skipped and an error, in case the string is
// longer than the input slice
func skipLengthEncodedString(b []byte) (int, error) {
// Get length
num, _, n := readLengthEncodedInteger(b)
if num < 1 {
return n, nil
}
n += int(num)
// Check data length
if len(b) >= n {
return n, nil
}
return n, io.EOF
}
// returns the number read, whether the value is NULL and the number of bytes read
func readLengthEncodedInteger(b []byte) (uint64, bool, int) {
// See issue #349
if len(b) == 0 {
return 0, true, 1
}
switch b[0] {
// 251: NULL
case 0xfb:
return 0, true, 1
// 252: value of following 2
case 0xfc:
return uint64(b[1]) | uint64(b[2])<<8, false, 3
// 253: value of following 3
case 0xfd:
return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16, false, 4
// 254: value of following 8
case 0xfe:
return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 |
uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 |
uint64(b[7])<<48 | uint64(b[8])<<56,
false, 9
}
// 0-250: value of first byte
return uint64(b[0]), false, 1
}
// encodes a uint64 value and appends it to the given bytes slice
func appendLengthEncodedInteger(b []byte, n uint64) []byte {
switch {
case n <= 250:
return append(b, byte(n))
case n <= 0xffff:
return append(b, 0xfc, byte(n), byte(n>>8))
case n <= 0xffffff:
return append(b, 0xfd, byte(n), byte(n>>8), byte(n>>16))
}
return append(b, 0xfe, byte(n), byte(n>>8), byte(n>>16), byte(n>>24),
byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56))
}
func appendLengthEncodedString(b []byte, s string) []byte {
b = appendLengthEncodedInteger(b, uint64(len(s)))
return append(b, s...)
}
// reserveBuffer checks cap(buf) and expand buffer to len(buf) + appendSize.
// If cap(buf) is not enough, reallocate new buffer.
func reserveBuffer(buf []byte, appendSize int) []byte {
newSize := len(buf) + appendSize
if cap(buf) < newSize {
// Grow buffer exponentially
newBuf := make([]byte, len(buf)*2+appendSize)
copy(newBuf, buf)
buf = newBuf
}
return buf[:newSize]
}
// escapeBytesBackslash escapes []byte with backslashes (\)
// This escapes the contents of a string (provided as []byte) by adding backslashes before special
// characters, and turning others into specific escape sequences, such as
// turning newlines into \n and null bytes into \0.
// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L823-L932
func escapeBytesBackslash(buf, v []byte) []byte {
pos := len(buf)
buf = reserveBuffer(buf, len(v)*2)
for _, c := range v {
switch c {
case '\x00':
buf[pos+1] = '0'
buf[pos] = '\\'
pos += 2
case '\n':
buf[pos+1] = 'n'
buf[pos] = '\\'
pos += 2
case '\r':
buf[pos+1] = 'r'
buf[pos] = '\\'
pos += 2
case '\x1a':
buf[pos+1] = 'Z'
buf[pos] = '\\'
pos += 2
case '\'':
buf[pos+1] = '\''
buf[pos] = '\\'
pos += 2
case '"':
buf[pos+1] = '"'
buf[pos] = '\\'
pos += 2
case '\\':
buf[pos+1] = '\\'
buf[pos] = '\\'
pos += 2
default:
buf[pos] = c
pos++
}
}
return buf[:pos]
}
// escapeStringBackslash is similar to escapeBytesBackslash but for string.
func escapeStringBackslash(buf []byte, v string) []byte {
pos := len(buf)
buf = reserveBuffer(buf, len(v)*2)
for i := 0; i < len(v); i++ {
c := v[i]
switch c {
case '\x00':
buf[pos+1] = '0'
buf[pos] = '\\'
pos += 2
case '\n':
buf[pos+1] = 'n'
buf[pos] = '\\'
pos += 2
case '\r':
buf[pos+1] = 'r'
buf[pos] = '\\'
pos += 2
case '\x1a':
buf[pos+1] = 'Z'
buf[pos] = '\\'
pos += 2
case '\'':
buf[pos+1] = '\''
buf[pos] = '\\'
pos += 2
case '"':
buf[pos+1] = '"'
buf[pos] = '\\'
pos += 2
case '\\':
buf[pos+1] = '\\'
buf[pos] = '\\'
pos += 2
default:
buf[pos] = c
pos++
}
}
return buf[:pos]
}
// escapeBytesQuotes escapes apostrophes in []byte by doubling them up.
// This escapes the contents of a string by doubling up any apostrophes that
// it contains. This is used when the NO_BACKSLASH_ESCAPES SQL_MODE is in
// effect on the server.
// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L963-L1038
func escapeBytesQuotes(buf, v []byte) []byte {
pos := len(buf)
buf = reserveBuffer(buf, len(v)*2)
for _, c := range v {
if c == '\'' {
buf[pos+1] = '\''
buf[pos] = '\''
pos += 2
} else {
buf[pos] = c
pos++
}
}
return buf[:pos]
}
// escapeStringQuotes is similar to escapeBytesQuotes but for string.
func escapeStringQuotes(buf []byte, v string) []byte {
pos := len(buf)
buf = reserveBuffer(buf, len(v)*2)
for i := 0; i < len(v); i++ {
c := v[i]
if c == '\'' {
buf[pos+1] = '\''
buf[pos] = '\''
pos += 2
} else {
buf[pos] = c
pos++
}
}
return buf[:pos]
}
/******************************************************************************
* Sync utils *
******************************************************************************/
// noCopy may be embedded into structs which must not be copied
// after the first use.
//
// See https://github.com/golang/go/issues/8005#issuecomment-190753527
// for details.
type noCopy struct{}
// Lock is a no-op used by -copylocks checker from `go vet`.
func (*noCopy) Lock() {}
// Unlock is a no-op used by -copylocks checker from `go vet`.
// noCopy should implement sync.Locker from Go 1.11
// https://github.com/golang/go/commit/c2eba53e7f80df21d51285879d51ab81bcfbf6bc
// https://github.com/golang/go/issues/26165
func (*noCopy) Unlock() {}
// atomicError is a wrapper for atomically accessed error values
type atomicError struct {
_ noCopy
value atomic.Value
}
// Set sets the error value regardless of the previous value.
// The value must not be nil
func (ae *atomicError) Set(value error) {
ae.value.Store(value)
}
// Value returns the current error value
func (ae *atomicError) Value() error {
if v := ae.value.Load(); v != nil {
// this will panic if the value doesn't implement the error interface
return v.(error)
}
return nil
}
func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
dargs := make([]driver.Value, len(named))
for n, param := range named {
if len(param.Name) > 0 {
// TODO: support the use of Named Parameters #561
return nil, errors.New("mysql: driver does not support the use of Named Parameters")
}
dargs[n] = param.Value
}
return dargs, nil
}
func mapIsolationLevel(level driver.IsolationLevel) (string, error) {
switch sql.IsolationLevel(level) {
case sql.LevelRepeatableRead:
return "REPEATABLE READ", nil
case sql.LevelReadCommitted:
return "READ COMMITTED", nil
case sql.LevelReadUncommitted:
return "READ UNCOMMITTED", nil
case sql.LevelSerializable:
return "SERIALIZABLE", nil
default:
return "", fmt.Errorf("mysql: unsupported isolation level: %v", level)
}
}
================================================
FILE: gossh/pgsql/array.go
================================================
package pgsql
import (
"bytes"
"database/sql"
"database/sql/driver"
"encoding/hex"
"fmt"
"reflect"
"strconv"
"strings"
)
var typeByteSlice = reflect.TypeOf([]byte{})
var typeDriverValuer = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
var typeSQLScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
// Array returns the optimal driver.Valuer and sql.Scanner for an array or
// slice of any dimension.
//
// For example:
//
// db.Query(`SELECT * FROM t WHERE id = ANY($1)`, pq.Array([]int{235, 401}))
//
// var x []sql.NullInt64
// db.QueryRow(`SELECT ARRAY[235, 401]`).Scan(pq.Array(&x))
//
// Scanning multi-dimensional arrays is not supported. Arrays where the lower
// bound is not one (such as `[0:0]={1}') are not supported.
func Array(a any) interface {
driver.Valuer
sql.Scanner
} {
switch a := a.(type) {
case []bool:
return (*BoolArray)(&a)
case []float64:
return (*Float64Array)(&a)
case []float32:
return (*Float32Array)(&a)
case []int64:
return (*Int64Array)(&a)
case []int32:
return (*Int32Array)(&a)
case []string:
return (*StringArray)(&a)
case [][]byte:
return (*ByteaArray)(&a)
case *[]bool:
return (*BoolArray)(a)
case *[]float64:
return (*Float64Array)(a)
case *[]float32:
return (*Float32Array)(a)
case *[]int64:
return (*Int64Array)(a)
case *[]int32:
return (*Int32Array)(a)
case *[]string:
return (*StringArray)(a)
case *[][]byte:
return (*ByteaArray)(a)
}
return GenericArray{a}
}
// ArrayDelimiter may be optionally implemented by driver.Valuer or sql.Scanner
// to override the array delimiter used by GenericArray.
type ArrayDelimiter interface {
// ArrayDelimiter returns the delimiter character(s) for this element's type.
ArrayDelimiter() string
}
// BoolArray represents a one-dimensional array of the PostgreSQL boolean type.
type BoolArray []bool
// Scan implements the sql.Scanner interface.
func (a *BoolArray) Scan(src any) error {
switch src := src.(type) {
case []byte:
return a.scanBytes(src)
case string:
return a.scanBytes([]byte(src))
case nil:
*a = nil
return nil
}
return fmt.Errorf("pq: cannot convert %T to BoolArray", src)
}
func (a *BoolArray) scanBytes(src []byte) error {
elems, err := scanLinearArray(src, []byte{','}, "BoolArray")
if err != nil {
return err
}
if *a != nil && len(elems) == 0 {
*a = (*a)[:0]
} else {
b := make(BoolArray, len(elems))
for i, v := range elems {
if len(v) != 1 {
return fmt.Errorf("pq: could not parse boolean array index %d: invalid boolean %q", i, v)
}
switch v[0] {
case 't':
b[i] = true
case 'f':
b[i] = false
default:
return fmt.Errorf("pq: could not parse boolean array index %d: invalid boolean %q", i, v)
}
}
*a = b
}
return nil
}
// Value implements the driver.Valuer interface.
func (a BoolArray) Value() (driver.Value, error) {
if a == nil {
return nil, nil
}
if n := len(a); n > 0 {
// There will be exactly two curly brackets, N bytes of values,
// and N-1 bytes of delimiters.
b := make([]byte, 1+2*n)
for i := 0; i < n; i++ {
b[2*i] = ','
if a[i] {
b[1+2*i] = 't'
} else {
b[1+2*i] = 'f'
}
}
b[0] = '{'
b[2*n] = '}'
return string(b), nil
}
return "{}", nil
}
// ByteaArray represents a one-dimensional array of the PostgreSQL bytea type.
type ByteaArray [][]byte
// Scan implements the sql.Scanner interface.
func (a *ByteaArray) Scan(src any) error {
switch src := src.(type) {
case []byte:
return a.scanBytes(src)
case string:
return a.scanBytes([]byte(src))
case nil:
*a = nil
return nil
}
return fmt.Errorf("pq: cannot convert %T to ByteaArray", src)
}
func (a *ByteaArray) scanBytes(src []byte) error {
elems, err := scanLinearArray(src, []byte{','}, "ByteaArray")
if err != nil {
return err
}
if *a != nil && len(elems) == 0 {
*a = (*a)[:0]
} else {
b := make(ByteaArray, len(elems))
for i, v := range elems {
b[i], err = parseBytea(v)
if err != nil {
return fmt.Errorf("could not parse bytea array index %d: %s", i, err.Error())
}
}
*a = b
}
return nil
}
// Value implements the driver.Valuer interface. It uses the "hex" format which
// is only supported on PostgreSQL 9.0 or newer.
func (a ByteaArray) Value() (driver.Value, error) {
if a == nil {
return nil, nil
}
if n := len(a); n > 0 {
// There will be at least two curly brackets, 2*N bytes of quotes,
// 3*N bytes of hex formatting, and N-1 bytes of delimiters.
size := 1 + 6*n
for _, x := range a {
size += hex.EncodedLen(len(x))
}
b := make([]byte, size)
for i, s := 0, b; i < n; i++ {
o := copy(s, `,"\\x`)
o += hex.Encode(s[o:], a[i])
s[o] = '"'
s = s[o+1:]
}
b[0] = '{'
b[size-1] = '}'
return string(b), nil
}
return "{}", nil
}
// Float64Array represents a one-dimensional array of the PostgreSQL double
// precision type.
type Float64Array []float64
// Scan implements the sql.Scanner interface.
func (a *Float64Array) Scan(src any) error {
switch src := src.(type) {
case []byte:
return a.scanBytes(src)
case string:
return a.scanBytes([]byte(src))
case nil:
*a = nil
return nil
}
return fmt.Errorf("pq: cannot convert %T to Float64Array", src)
}
func (a *Float64Array) scanBytes(src []byte) error {
elems, err := scanLinearArray(src, []byte{','}, "Float64Array")
if err != nil {
return err
}
if *a != nil && len(elems) == 0 {
*a = (*a)[:0]
} else {
b := make(Float64Array, len(elems))
for i, v := range elems {
if b[i], err = strconv.ParseFloat(string(v), 64); err != nil {
return fmt.Errorf("pq: parsing array element index %d: %v", i, err)
}
}
*a = b
}
return nil
}
// Value implements the driver.Valuer interface.
func (a Float64Array) Value() (driver.Value, error) {
if a == nil {
return nil, nil
}
if n := len(a); n > 0 {
// There will be at least two curly brackets, N bytes of values,
// and N-1 bytes of delimiters.
b := make([]byte, 1, 1+2*n)
b[0] = '{'
b = strconv.AppendFloat(b, a[0], 'f', -1, 64)
for i := 1; i < n; i++ {
b = append(b, ',')
b = strconv.AppendFloat(b, a[i], 'f', -1, 64)
}
return string(append(b, '}')), nil
}
return "{}", nil
}
// Float32Array represents a one-dimensional array of the PostgreSQL double
// precision type.
type Float32Array []float32
// Scan implements the sql.Scanner interface.
func (a *Float32Array) Scan(src any) error {
switch src := src.(type) {
case []byte:
return a.scanBytes(src)
case string:
return a.scanBytes([]byte(src))
case nil:
*a = nil
return nil
}
return fmt.Errorf("pq: cannot convert %T to Float32Array", src)
}
func (a *Float32Array) scanBytes(src []byte) error {
elems, err := scanLinearArray(src, []byte{','}, "Float32Array")
if err != nil {
return err
}
if *a != nil && len(elems) == 0 {
*a = (*a)[:0]
} else {
b := make(Float32Array, len(elems))
for i, v := range elems {
var x float64
if x, err = strconv.ParseFloat(string(v), 32); err != nil {
return fmt.Errorf("pq: parsing array element index %d: %v", i, err)
}
b[i] = float32(x)
}
*a = b
}
return nil
}
// Value implements the driver.Valuer interface.
func (a Float32Array) Value() (driver.Value, error) {
if a == nil {
return nil, nil
}
if n := len(a); n > 0 {
// There will be at least two curly brackets, N bytes of values,
// and N-1 bytes of delimiters.
b := make([]byte, 1, 1+2*n)
b[0] = '{'
b = strconv.AppendFloat(b, float64(a[0]), 'f', -1, 32)
for i := 1; i < n; i++ {
b = append(b, ',')
b = strconv.AppendFloat(b, float64(a[i]), 'f', -1, 32)
}
return string(append(b, '}')), nil
}
return "{}", nil
}
// GenericArray implements the driver.Valuer and sql.Scanner interfaces for
// an array or slice of any dimension.
type GenericArray struct{ A any }
func (GenericArray) evaluateDestination(rt reflect.Type) (reflect.Type, func([]byte, reflect.Value) error, string) {
var assign func([]byte, reflect.Value) error
var del = ","
// TODO calculate the assign function for other types
// TODO repeat this section on the element type of arrays or slices (multidimensional)
{
if reflect.PtrTo(rt).Implements(typeSQLScanner) {
// dest is always addressable because it is an element of a slice.
assign = func(src []byte, dest reflect.Value) (err error) {
ss := dest.Addr().Interface().(sql.Scanner)
if src == nil {
err = ss.Scan(nil)
} else {
err = ss.Scan(src)
}
return
}
goto FoundType
}
assign = func([]byte, reflect.Value) error {
return fmt.Errorf("pq: scanning to %s is not implemented; only sql.Scanner", rt)
}
}
FoundType:
if ad, ok := reflect.Zero(rt).Interface().(ArrayDelimiter); ok {
del = ad.ArrayDelimiter()
}
return rt, assign, del
}
// Scan implements the sql.Scanner interface.
func (a GenericArray) Scan(src any) error {
dpv := reflect.ValueOf(a.A)
switch {
case dpv.Kind() != reflect.Ptr:
return fmt.Errorf("pq: destination %T is not a pointer to array or slice", a.A)
case dpv.IsNil():
return fmt.Errorf("pq: destination %T is nil", a.A)
}
dv := dpv.Elem()
switch dv.Kind() {
case reflect.Slice:
case reflect.Array:
default:
return fmt.Errorf("pq: destination %T is not a pointer to array or slice", a.A)
}
switch src := src.(type) {
case []byte:
return a.scanBytes(src, dv)
case string:
return a.scanBytes([]byte(src), dv)
case nil:
if dv.Kind() == reflect.Slice {
dv.Set(reflect.Zero(dv.Type()))
return nil
}
}
return fmt.Errorf("pq: cannot convert %T to %s", src, dv.Type())
}
func (a GenericArray) scanBytes(src []byte, dv reflect.Value) error {
dtype, assign, del := a.evaluateDestination(dv.Type().Elem())
dims, elems, err := parseArray(src, []byte(del))
if err != nil {
return err
}
// TODO allow multidimensional
if len(dims) > 1 {
return fmt.Errorf("pq: scanning from multidimensional ARRAY%s is not implemented",
strings.Replace(fmt.Sprint(dims), " ", "][", -1))
}
// Treat a zero-dimensional array like an array with a single dimension of zero.
if len(dims) == 0 {
dims = append(dims, 0)
}
for i, rt := 0, dv.Type(); i < len(dims); i, rt = i+1, rt.Elem() {
switch rt.Kind() {
case reflect.Slice:
case reflect.Array:
if rt.Len() != dims[i] {
return fmt.Errorf("pq: cannot convert ARRAY%s to %s",
strings.Replace(fmt.Sprint(dims), " ", "][", -1), dv.Type())
}
default:
// TODO handle multidimensional
}
}
values := reflect.MakeSlice(reflect.SliceOf(dtype), len(elems), len(elems))
for i, e := range elems {
if err := assign(e, values.Index(i)); err != nil {
return fmt.Errorf("pq: parsing array element index %d: %v", i, err)
}
}
// TODO handle multidimensional
switch dv.Kind() {
case reflect.Slice:
dv.Set(values.Slice(0, dims[0]))
case reflect.Array:
for i := 0; i < dims[0]; i++ {
dv.Index(i).Set(values.Index(i))
}
}
return nil
}
// Value implements the driver.Valuer interface.
func (a GenericArray) Value() (driver.Value, error) {
if a.A == nil {
return nil, nil
}
rv := reflect.ValueOf(a.A)
switch rv.Kind() {
case reflect.Slice:
if rv.IsNil() {
return nil, nil
}
case reflect.Array:
default:
return nil, fmt.Errorf("pq: Unable to convert %T to array", a.A)
}
if n := rv.Len(); n > 0 {
// There will be at least two curly brackets, N bytes of values,
// and N-1 bytes of delimiters.
b := make([]byte, 0, 1+2*n)
b, _, err := appendArray(b, rv, n)
return string(b), err
}
return "{}", nil
}
// Int64Array represents a one-dimensional array of the PostgreSQL integer types.
type Int64Array []int64
// Scan implements the sql.Scanner interface.
func (a *Int64Array) Scan(src any) error {
switch src := src.(type) {
case []byte:
return a.scanBytes(src)
case string:
return a.scanBytes([]byte(src))
case nil:
*a = nil
return nil
}
return fmt.Errorf("pq: cannot convert %T to Int64Array", src)
}
func (a *Int64Array) scanBytes(src []byte) error {
elems, err := scanLinearArray(src, []byte{','}, "Int64Array")
if err != nil {
return err
}
if *a != nil && len(elems) == 0 {
*a = (*a)[:0]
} else {
b := make(Int64Array, len(elems))
for i, v := range elems {
if b[i], err = strconv.ParseInt(string(v), 10, 64); err != nil {
return fmt.Errorf("pq: parsing array element index %d: %v", i, err)
}
}
*a = b
}
return nil
}
// Value implements the driver.Valuer interface.
func (a Int64Array) Value() (driver.Value, error) {
if a == nil {
return nil, nil
}
if n := len(a); n > 0 {
// There will be at least two curly brackets, N bytes of values,
// and N-1 bytes of delimiters.
b := make([]byte, 1, 1+2*n)
b[0] = '{'
b = strconv.AppendInt(b, a[0], 10)
for i := 1; i < n; i++ {
b = append(b, ',')
b = strconv.AppendInt(b, a[i], 10)
}
return string(append(b, '}')), nil
}
return "{}", nil
}
// Int32Array represents a one-dimensional array of the PostgreSQL integer types.
type Int32Array []int32
// Scan implements the sql.Scanner interface.
func (a *Int32Array) Scan(src any) error {
switch src := src.(type) {
case []byte:
return a.scanBytes(src)
case string:
return a.scanBytes([]byte(src))
case nil:
*a = nil
return nil
}
return fmt.Errorf("pq: cannot convert %T to Int32Array", src)
}
func (a *Int32Array) scanBytes(src []byte) error {
elems, err := scanLinearArray(src, []byte{','}, "Int32Array")
if err != nil {
return err
}
if *a != nil && len(elems) == 0 {
*a = (*a)[:0]
} else {
b := make(Int32Array, len(elems))
for i, v := range elems {
x, err := strconv.ParseInt(string(v), 10, 32)
if err != nil {
return fmt.Errorf("pq: parsing array element index %d: %v", i, err)
}
b[i] = int32(x)
}
*a = b
}
return nil
}
// Value implements the driver.Valuer interface.
func (a Int32Array) Value() (driver.Value, error) {
if a == nil {
return nil, nil
}
if n := len(a); n > 0 {
// There will be at least two curly brackets, N bytes of values,
// and N-1 bytes of delimiters.
b := make([]byte, 1, 1+2*n)
b[0] = '{'
b = strconv.AppendInt(b, int64(a[0]), 10)
for i := 1; i < n; i++ {
b = append(b, ',')
b = strconv.AppendInt(b, int64(a[i]), 10)
}
return string(append(b, '}')), nil
}
return "{}", nil
}
// StringArray represents a one-dimensional array of the PostgreSQL character types.
type StringArray []string
// Scan implements the sql.Scanner interface.
func (a *StringArray) Scan(src any) error {
switch src := src.(type) {
case []byte:
return a.scanBytes(src)
case string:
return a.scanBytes([]byte(src))
case nil:
*a = nil
return nil
}
return fmt.Errorf("pq: cannot convert %T to StringArray", src)
}
func (a *StringArray) scanBytes(src []byte) error {
elems, err := scanLinearArray(src, []byte{','}, "StringArray")
if err != nil {
return err
}
if *a != nil && len(elems) == 0 {
*a = (*a)[:0]
} else {
b := make(StringArray, len(elems))
for i, v := range elems {
if b[i] = string(v); v == nil {
return fmt.Errorf("pq: parsing array element index %d: cannot convert nil to string", i)
}
}
*a = b
}
return nil
}
// Value implements the driver.Valuer interface.
func (a StringArray) Value() (driver.Value, error) {
if a == nil {
return nil, nil
}
if n := len(a); n > 0 {
// There will be at least two curly brackets, 2*N bytes of quotes,
// and N-1 bytes of delimiters.
b := make([]byte, 1, 1+3*n)
b[0] = '{'
b = appendArrayQuotedBytes(b, []byte(a[0]))
for i := 1; i < n; i++ {
b = append(b, ',')
b = appendArrayQuotedBytes(b, []byte(a[i]))
}
return string(append(b, '}')), nil
}
return "{}", nil
}
// appendArray appends rv to the buffer, returning the extended buffer and
// the delimiter used between elements.
//
// It panics when n <= 0 or rv's Kind is not reflect.Array nor reflect.Slice.
func appendArray(b []byte, rv reflect.Value, n int) ([]byte, string, error) {
var del string
var err error
b = append(b, '{')
if b, del, err = appendArrayElement(b, rv.Index(0)); err != nil {
return b, del, err
}
for i := 1; i < n; i++ {
b = append(b, del...)
if b, del, err = appendArrayElement(b, rv.Index(i)); err != nil {
return b, del, err
}
}
return append(b, '}'), del, nil
}
// appendArrayElement appends rv to the buffer, returning the extended buffer
// and the delimiter to use before the next element.
//
// When rv's Kind is neither reflect.Array nor reflect.Slice, it is converted
// using driver.DefaultParameterConverter and the resulting []byte or string
// is double-quoted.
//
// See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO
func appendArrayElement(b []byte, rv reflect.Value) ([]byte, string, error) {
if k := rv.Kind(); k == reflect.Array || k == reflect.Slice {
if t := rv.Type(); t != typeByteSlice && !t.Implements(typeDriverValuer) {
if n := rv.Len(); n > 0 {
return appendArray(b, rv, n)
}
return b, "", nil
}
}
var del = ","
var err error
var iv any = rv.Interface()
if ad, ok := iv.(ArrayDelimiter); ok {
del = ad.ArrayDelimiter()
}
if iv, err = driver.DefaultParameterConverter.ConvertValue(iv); err != nil {
return b, del, err
}
switch v := iv.(type) {
case nil:
return append(b, "NULL"...), del, nil
case []byte:
return appendArrayQuotedBytes(b, v), del, nil
case string:
return appendArrayQuotedBytes(b, []byte(v)), del, nil
}
b, err = appendValue(b, iv)
return b, del, err
}
func appendArrayQuotedBytes(b, v []byte) []byte {
b = append(b, '"')
for {
i := bytes.IndexAny(v, `"\`)
if i < 0 {
b = append(b, v...)
break
}
if i > 0 {
b = append(b, v[:i]...)
}
b = append(b, '\\', v[i])
v = v[i+1:]
}
return append(b, '"')
}
func appendValue(b []byte, v driver.Value) ([]byte, error) {
return append(b, encode(nil, v, 0)...), nil
}
// parseArray extracts the dimensions and elements of an array represented in
// text format. Only representations emitted by the backend are supported.
// Notably, whitespace around brackets and delimiters is significant, and NULL
// is case-sensitive.
//
// See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO
func parseArray(src, del []byte) (dims []int, elems [][]byte, err error) {
var depth, i int
if len(src) < 1 || src[0] != '{' {
return nil, nil, fmt.Errorf("pq: unable to parse array; expected %q at offset %d", '{', 0)
}
Open:
for i < len(src) {
switch src[i] {
case '{':
depth++
i++
case '}':
elems = make([][]byte, 0)
goto Close
default:
break Open
}
}
dims = make([]int, i)
Element:
for i < len(src) {
switch src[i] {
case '{':
if depth == len(dims) {
break Element
}
depth++
dims[depth-1] = 0
i++
case '"':
var elem = []byte{}
var escape bool
for i++; i < len(src); i++ {
if escape {
elem = append(elem, src[i])
escape = false
} else {
switch src[i] {
default:
elem = append(elem, src[i])
case '\\':
escape = true
case '"':
elems = append(elems, elem)
i++
break Element
}
}
}
default:
for start := i; i < len(src); i++ {
if bytes.HasPrefix(src[i:], del) || src[i] == '}' {
elem := src[start:i]
if len(elem) == 0 {
return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i)
}
if bytes.Equal(elem, []byte("NULL")) {
elem = nil
}
elems = append(elems, elem)
break Element
}
}
}
}
for i < len(src) {
if bytes.HasPrefix(src[i:], del) && depth > 0 {
dims[depth-1]++
i += len(del)
goto Element
} else if src[i] == '}' && depth > 0 {
dims[depth-1]++
depth--
i++
} else {
return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i)
}
}
Close:
for i < len(src) {
if src[i] == '}' && depth > 0 {
depth--
i++
} else {
return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i)
}
}
if depth > 0 {
err = fmt.Errorf("pq: unable to parse array; expected %q at offset %d", '}', i)
}
if err == nil {
for _, d := range dims {
if (len(elems) % d) != 0 {
err = fmt.Errorf("pq: multidimensional arrays must have elements with matching dimensions")
}
}
}
return
}
func scanLinearArray(src, del []byte, typ string) (elems [][]byte, err error) {
dims, elems, err := parseArray(src, del)
if err != nil {
return nil, err
}
if len(dims) > 1 {
return nil, fmt.Errorf("pq: cannot convert ARRAY%s to %s", strings.Replace(fmt.Sprint(dims), " ", "][", -1), typ)
}
return elems, err
}
================================================
FILE: gossh/pgsql/buf.go
================================================
package pgsql
import (
"bytes"
"encoding/binary"
"gossh/pgsql/oid"
)
type readBuf []byte
func (b *readBuf) int32() (n int) {
n = int(int32(binary.BigEndian.Uint32(*b)))
*b = (*b)[4:]
return
}
func (b *readBuf) oid() (n oid.Oid) {
n = oid.Oid(binary.BigEndian.Uint32(*b))
*b = (*b)[4:]
return
}
// N.B: this is actually an unsigned 16-bit integer, unlike int32
func (b *readBuf) int16() (n int) {
n = int(binary.BigEndian.Uint16(*b))
*b = (*b)[2:]
return
}
func (b *readBuf) string() string {
i := bytes.IndexByte(*b, 0)
if i < 0 {
errorf("invalid message format; expected string terminator")
}
s := (*b)[:i]
*b = (*b)[i+1:]
return string(s)
}
func (b *readBuf) next(n int) (v []byte) {
v = (*b)[:n]
*b = (*b)[n:]
return
}
func (b *readBuf) byte() byte {
return b.next(1)[0]
}
type writeBuf struct {
buf []byte
pos int
}
func (b *writeBuf) int32(n int) {
x := make([]byte, 4)
binary.BigEndian.PutUint32(x, uint32(n))
b.buf = append(b.buf, x...)
}
func (b *writeBuf) int16(n int) {
x := make([]byte, 2)
binary.BigEndian.PutUint16(x, uint16(n))
b.buf = append(b.buf, x...)
}
func (b *writeBuf) string(s string) {
b.buf = append(append(b.buf, s...), '\000')
}
func (b *writeBuf) byte(c byte) {
b.buf = append(b.buf, c)
}
func (b *writeBuf) bytes(v []byte) {
b.buf = append(b.buf, v...)
}
func (b *writeBuf) wrap() []byte {
p := b.buf[b.pos:]
binary.BigEndian.PutUint32(p, uint32(len(p)))
return b.buf
}
func (b *writeBuf) next(c byte) {
p := b.buf[b.pos:]
binary.BigEndian.PutUint32(p, uint32(len(p)))
b.pos = len(b.buf) + 1
b.buf = append(b.buf, c, 0, 0, 0, 0)
}
================================================
FILE: gossh/pgsql/conn.go
================================================
package pgsql
import (
"bufio"
"bytes"
"context"
"crypto/md5"
"crypto/sha256"
"database/sql"
"database/sql/driver"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"os"
"os/user"
"path"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
"unicode"
"gossh/pgsql/oid"
"gossh/pgsql/scram"
)
// Common error types
var (
ErrNotSupported = errors.New("pq: Unsupported command")
ErrInFailedTransaction = errors.New("pq: Could not complete operation in a failed transaction")
ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server")
ErrSSLKeyUnknownOwnership = errors.New("pq: Could not get owner information for private key, may not be properly protected")
ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key has world access. Permissions should be u=rw,g=r (0640) if owned by root, or u=rw (0600), or less")
ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly")
errUnexpectedReady = errors.New("unexpected ReadyForQuery")
errNoRowsAffected = errors.New("no RowsAffected available after the empty statement")
errNoLastInsertID = errors.New("no LastInsertId available after the empty statement")
)
// Compile time validation that our types implement the expected interfaces
var (
_ driver.Driver = Driver{}
)
// Driver is the Postgres database driver.
type Driver struct{}
// Open opens a new connection to the database. name is a connection string.
// Most users should only use it through database/sql package from the standard
// library.
func (d Driver) Open(name string) (driver.Conn, error) {
return Open(name)
}
func init() {
sql.Register("postgres", &Driver{})
}
type parameterStatus struct {
// server version in the same format as server_version_num, or 0 if
// unavailable
serverVersion int
// the current location based on the TimeZone value of the session, if
// available
currentLocation *time.Location
}
type transactionStatus byte
const (
txnStatusIdle transactionStatus = 'I'
txnStatusIdleInTransaction transactionStatus = 'T'
txnStatusInFailedTransaction transactionStatus = 'E'
)
func (s transactionStatus) String() string {
switch s {
case txnStatusIdle:
return "idle"
case txnStatusIdleInTransaction:
return "idle in transaction"
case txnStatusInFailedTransaction:
return "in a failed transaction"
default:
errorf("unknown transactionStatus %d", s)
}
panic("not reached")
}
// Dialer is the dialer interface. It can be used to obtain more control over
// how pq creates network connections.
type Dialer interface {
Dial(network, address string) (net.Conn, error)
DialTimeout(network, address string, timeout time.Duration) (net.Conn, error)
}
// DialerContext is the context-aware dialer interface.
type DialerContext interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
type defaultDialer struct {
d net.Dialer
}
func (d defaultDialer) Dial(network, address string) (net.Conn, error) {
return d.d.Dial(network, address)
}
func (d defaultDialer) DialTimeout(
network, address string, timeout time.Duration,
) (net.Conn, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return d.DialContext(ctx, network, address)
}
func (d defaultDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return d.d.DialContext(ctx, network, address)
}
type conn struct {
c net.Conn
buf *bufio.Reader
namei int
scratch [512]byte
txnStatus transactionStatus
txnFinish func()
// Save connection arguments to use during CancelRequest.
dialer Dialer
opts values
// Cancellation key data for use with CancelRequest messages.
processID int
secretKey int
parameterStatus parameterStatus
saveMessageType byte
saveMessageBuffer []byte
// If an error is set, this connection is bad and all public-facing
// functions should return the appropriate error by calling get()
// (ErrBadConn) or getForNext().
err syncErr
// If set, this connection should never use the binary format when
// receiving query results from prepared statements. Only provided for
// debugging.
disablePreparedBinaryResult bool
// Whether to always send []byte parameters over as binary. Enables single
// round-trip mode for non-prepared Query calls.
binaryParameters bool
// If true this connection is in the middle of a COPY
inCopy bool
// If not nil, notices will be synchronously sent here
noticeHandler func(*Error)
// If not nil, notifications will be synchronously sent here
notificationHandler func(*Notification)
// GSSAPI context
gss GSS
}
type syncErr struct {
err error
sync.Mutex
}
// Return ErrBadConn if connection is bad.
func (e *syncErr) get() error {
e.Lock()
defer e.Unlock()
if e.err != nil {
return driver.ErrBadConn
}
return nil
}
// Return the error set on the connection. Currently only used by rows.Next.
func (e *syncErr) getForNext() error {
e.Lock()
defer e.Unlock()
return e.err
}
// Set error, only if it isn't set yet.
func (e *syncErr) set(err error) {
if err == nil {
panic("attempt to set nil err")
}
e.Lock()
defer e.Unlock()
if e.err == nil {
e.err = err
}
}
// Handle driver-side settings in parsed connection string.
func (cn *conn) handleDriverSettings(o values) (err error) {
boolSetting := func(key string, val *bool) error {
if value, ok := o[key]; ok {
if value == "yes" {
*val = true
} else if value == "no" {
*val = false
} else {
return fmt.Errorf("unrecognized value %q for %s", value, key)
}
}
return nil
}
err = boolSetting("disable_prepared_binary_result", &cn.disablePreparedBinaryResult)
if err != nil {
return err
}
return boolSetting("binary_parameters", &cn.binaryParameters)
}
func (cn *conn) handlePgpass(o values) {
// if a password was supplied, do not process .pgpass
if _, ok := o["password"]; ok {
return
}
filename := os.Getenv("PGPASSFILE")
if filename == "" {
// XXX this code doesn't work on Windows where the default filename is
// XXX %APPDATA%\postgresql\pgpass.conf
// Prefer $HOME over user.Current due to glibc bug: golang.org/issue/13470
userHome := os.Getenv("HOME")
if userHome == "" {
user, err := user.Current()
if err != nil {
return
}
userHome = user.HomeDir
}
filename = filepath.Join(userHome, ".pgpass")
}
fileinfo, err := os.Stat(filename)
if err != nil {
return
}
mode := fileinfo.Mode()
if mode&(0x77) != 0 {
// XXX should warn about incorrect .pgpass permissions as psql does
return
}
file, err := os.Open(filename)
if err != nil {
return
}
defer file.Close()
scanner := bufio.NewScanner(io.Reader(file))
// From: https://github.com/tg/pgpass/blob/master/reader.go
for scanner.Scan() {
if scanText(scanner.Text(), o) {
break
}
}
}
// GetFields is a helper function for scanText.
func getFields(s string) []string {
fs := make([]string, 0, 5)
f := make([]rune, 0, len(s))
var esc bool
for _, c := range s {
switch {
case esc:
f = append(f, c)
esc = false
case c == '\\':
esc = true
case c == ':':
fs = append(fs, string(f))
f = f[:0]
default:
f = append(f, c)
}
}
return append(fs, string(f))
}
// ScanText assists HandlePgpass in it's objective.
func scanText(line string, o values) bool {
hostname := o["host"]
ntw, _ := network(o)
port := o["port"]
db := o["dbname"]
username := o["user"]
if len(line) == 0 || line[0] == '#' {
return false
}
split := getFields(line)
if len(split) != 5 {
return false
}
if (split[0] == "*" || split[0] == hostname || (split[0] == "localhost" && (hostname == "" || ntw == "unix"))) && (split[1] == "*" || split[1] == port) && (split[2] == "*" || split[2] == db) && (split[3] == "*" || split[3] == username) {
o["password"] = split[4]
return true
}
return false
}
func (cn *conn) writeBuf(b byte) *writeBuf {
cn.scratch[0] = b
return &writeBuf{
buf: cn.scratch[:5],
pos: 1,
}
}
// Open opens a new connection to the database. dsn is a connection string.
// Most users should only use it through database/sql package from the standard
// library.
func Open(dsn string) (_ driver.Conn, err error) {
return DialOpen(defaultDialer{}, dsn)
}
// DialOpen opens a new connection to the database using a dialer.
func DialOpen(d Dialer, dsn string) (_ driver.Conn, err error) {
c, err := NewConnector(dsn)
if err != nil {
return nil, err
}
c.Dialer(d)
return c.open(context.Background())
}
func (c *Connector) open(ctx context.Context) (cn *conn, err error) {
// Handle any panics during connection initialization. Note that we
// specifically do *not* want to use errRecover(), as that would turn any
// connection errors into ErrBadConns, hiding the real error message from
// the user.
defer errRecoverNoErrBadConn(&err)
// Create a new values map (copy). This makes it so maps in different
// connections do not reference the same underlying data structure, so it
// is safe for multiple connections to concurrently write to their opts.
o := make(values)
for k, v := range c.opts {
o[k] = v
}
cn = &conn{
opts: o,
dialer: c.dialer,
}
err = cn.handleDriverSettings(o)
if err != nil {
return nil, err
}
cn.handlePgpass(o)
cn.c, err = dial(ctx, c.dialer, o)
if err != nil {
return nil, err
}
err = cn.ssl(o)
if err != nil {
if cn.c != nil {
cn.c.Close()
}
return nil, err
}
// cn.startup panics on error. Make sure we don't leak cn.c.
panicking := true
defer func() {
if panicking {
cn.c.Close()
}
}()
cn.buf = bufio.NewReader(cn.c)
cn.startup(o)
// reset the deadline, in case one was set (see dial)
if timeout, ok := o["connect_timeout"]; ok && timeout != "0" {
err = cn.c.SetDeadline(time.Time{})
}
panicking = false
return cn, err
}
func dial(ctx context.Context, d Dialer, o values) (net.Conn, error) {
network, address := network(o)
// Zero or not specified means wait indefinitely.
if timeout, ok := o["connect_timeout"]; ok && timeout != "0" {
seconds, err := strconv.ParseInt(timeout, 10, 0)
if err != nil {
return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err)
}
duration := time.Duration(seconds) * time.Second
// connect_timeout should apply to the entire connection establishment
// procedure, so we both use a timeout for the TCP connection
// establishment and set a deadline for doing the initial handshake.
// The deadline is then reset after startup() is done.
deadline := time.Now().Add(duration)
var conn net.Conn
if dctx, ok := d.(DialerContext); ok {
ctx, cancel := context.WithTimeout(ctx, duration)
defer cancel()
conn, err = dctx.DialContext(ctx, network, address)
} else {
conn, err = d.DialTimeout(network, address, duration)
}
if err != nil {
return nil, err
}
err = conn.SetDeadline(deadline)
return conn, err
}
if dctx, ok := d.(DialerContext); ok {
return dctx.DialContext(ctx, network, address)
}
return d.Dial(network, address)
}
func network(o values) (string, string) {
host := o["host"]
if strings.HasPrefix(host, "/") {
sockPath := path.Join(host, ".s.PGSQL."+o["port"])
return "unix", sockPath
}
return "tcp", net.JoinHostPort(host, o["port"])
}
type values map[string]string
// scanner implements a tokenizer for libpq-style option strings.
type scanner struct {
s []rune
i int
}
// newScanner returns a new scanner initialized with the option string s.
func newScanner(s string) *scanner {
return &scanner{[]rune(s), 0}
}
// Next returns the next rune.
// It returns 0, false if the end of the text has been reached.
func (s *scanner) Next() (rune, bool) {
if s.i >= len(s.s) {
return 0, false
}
r := s.s[s.i]
s.i++
return r, true
}
// SkipSpaces returns the next non-whitespace rune.
// It returns 0, false if the end of the text has been reached.
func (s *scanner) SkipSpaces() (rune, bool) {
r, ok := s.Next()
for unicode.IsSpace(r) && ok {
r, ok = s.Next()
}
return r, ok
}
// parseOpts parses the options from name and adds them to the values.
//
// The parsing code is based on conninfo_parse from libpq's fe-connect.c
func parseOpts(name string, o values) error {
s := newScanner(name)
for {
var (
keyRunes, valRunes []rune
r rune
ok bool
)
if r, ok = s.SkipSpaces(); !ok {
break
}
// Scan the key
for !unicode.IsSpace(r) && r != '=' {
keyRunes = append(keyRunes, r)
if r, ok = s.Next(); !ok {
break
}
}
// Skip any whitespace if we're not at the = yet
if r != '=' {
r, ok = s.SkipSpaces()
}
// The current character should be =
if r != '=' || !ok {
return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes))
}
// Skip any whitespace after the =
if r, ok = s.SkipSpaces(); !ok {
// If we reach the end here, the last value is just an empty string as per libpq.
o[string(keyRunes)] = ""
break
}
if r != '\'' {
for !unicode.IsSpace(r) {
if r == '\\' {
if r, ok = s.Next(); !ok {
return fmt.Errorf(`missing character after backslash`)
}
}
valRunes = append(valRunes, r)
if r, ok = s.Next(); !ok {
break
}
}
} else {
quote:
for {
if r, ok = s.Next(); !ok {
return fmt.Errorf(`unterminated quoted string literal in connection string`)
}
switch r {
case '\'':
break quote
case '\\':
r, _ = s.Next()
fallthrough
default:
valRunes = append(valRunes, r)
}
}
}
o[string(keyRunes)] = string(valRunes)
}
return nil
}
func (cn *conn) isInTransaction() bool {
return cn.txnStatus == txnStatusIdleInTransaction ||
cn.txnStatus == txnStatusInFailedTransaction
}
func (cn *conn) checkIsInTransaction(intxn bool) {
if cn.isInTransaction() != intxn {
cn.err.set(driver.ErrBadConn)
errorf("unexpected transaction status %v", cn.txnStatus)
}
}
func (cn *conn) Begin() (_ driver.Tx, err error) {
return cn.begin("")
}
func (cn *conn) begin(mode string) (_ driver.Tx, err error) {
if err := cn.err.get(); err != nil {
return nil, err
}
defer cn.errRecover(&err)
cn.checkIsInTransaction(false)
_, commandTag, err := cn.simpleExec("BEGIN" + mode)
if err != nil {
return nil, err
}
if commandTag != "BEGIN" {
cn.err.set(driver.ErrBadConn)
return nil, fmt.Errorf("unexpected command tag %s", commandTag)
}
if cn.txnStatus != txnStatusIdleInTransaction {
cn.err.set(driver.ErrBadConn)
return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus)
}
return cn, nil
}
func (cn *conn) closeTxn() {
if finish := cn.txnFinish; finish != nil {
finish()
}
}
func (cn *conn) Commit() (err error) {
defer cn.closeTxn()
if err := cn.err.get(); err != nil {
return err
}
defer cn.errRecover(&err)
cn.checkIsInTransaction(true)
// We don't want the client to think that everything is okay if it tries
// to commit a failed transaction. However, no matter what we return,
// database/sql will release this connection back into the free connection
// pool so we have to abort the current transaction here. Note that you
// would get the same behaviour if you issued a COMMIT in a failed
// transaction, so it's also the least surprising thing to do here.
if cn.txnStatus == txnStatusInFailedTransaction {
if err := cn.rollback(); err != nil {
return err
}
return ErrInFailedTransaction
}
_, commandTag, err := cn.simpleExec("COMMIT")
if err != nil {
if cn.isInTransaction() {
cn.err.set(driver.ErrBadConn)
}
return err
}
if commandTag != "COMMIT" {
cn.err.set(driver.ErrBadConn)
return fmt.Errorf("unexpected command tag %s", commandTag)
}
cn.checkIsInTransaction(false)
return nil
}
func (cn *conn) Rollback() (err error) {
defer cn.closeTxn()
if err := cn.err.get(); err != nil {
return err
}
defer cn.errRecover(&err)
return cn.rollback()
}
func (cn *conn) rollback() (err error) {
cn.checkIsInTransaction(true)
_, commandTag, err := cn.simpleExec("ROLLBACK")
if err != nil {
if cn.isInTransaction() {
cn.err.set(driver.ErrBadConn)
}
return err
}
if commandTag != "ROLLBACK" {
return fmt.Errorf("unexpected command tag %s", commandTag)
}
cn.checkIsInTransaction(false)
return nil
}
func (cn *conn) gname() string {
cn.namei++
return strconv.FormatInt(int64(cn.namei), 10)
}
func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err error) {
b := cn.writeBuf('Q')
b.string(q)
cn.send(b)
for {
t, r := cn.recv1()
switch t {
case 'C':
res, commandTag = cn.parseComplete(r.string())
case 'Z':
cn.processReadyForQuery(r)
if res == nil && err == nil {
err = errUnexpectedReady
}
// done
return
case 'E':
err = parseError(r)
case 'I':
res = emptyRows
case 'T', 'D':
// ignore any results
default:
cn.err.set(driver.ErrBadConn)
errorf("unknown response for simple query: %q", t)
}
}
}
func (cn *conn) simpleQuery(q string) (res *rows, err error) {
defer cn.errRecover(&err)
b := cn.writeBuf('Q')
b.string(q)
cn.send(b)
for {
t, r := cn.recv1()
switch t {
case 'C', 'I':
// We allow queries which don't return any results through Query as
// well as Exec. We still have to give database/sql a rows object
// the user can close, though, to avoid connections from being
// leaked. A "rows" with done=true works fine for that purpose.
if err != nil {
cn.err.set(driver.ErrBadConn)
errorf("unexpected message %q in simple query execution", t)
}
if res == nil {
res = &rows{
cn: cn,
}
}
// Set the result and tag to the last command complete if there wasn't a
// query already run. Although queries usually return from here and cede
// control to Next, a query with zero results does not.
if t == 'C' {
res.result, res.tag = cn.parseComplete(r.string())
if res.colNames != nil {
return
}
}
res.done = true
case 'Z':
cn.processReadyForQuery(r)
// done
return
case 'E':
res = nil
err = parseError(r)
case 'D':
if res == nil {
cn.err.set(driver.ErrBadConn)
errorf("unexpected DataRow in simple query execution")
}
// the query didn't fail; kick off to Next
cn.saveMessage(t, r)
return
case 'T':
// res might be non-nil here if we received a previous
// CommandComplete, but that's fine; just overwrite it
res = &rows{cn: cn}
res.rowsHeader = parsePortalRowDescribe(r)
// To work around a bug in QueryRow in Go 1.2 and earlier, wait
// until the first DataRow has been received.
default:
cn.err.set(driver.ErrBadConn)
errorf("unknown response for simple query: %q", t)
}
}
}
type noRows struct{}
var emptyRows noRows
var _ driver.Result = noRows{}
func (noRows) LastInsertId() (int64, error) {
return 0, errNoLastInsertID
}
func (noRows) RowsAffected() (int64, error) {
return 0, errNoRowsAffected
}
// Decides which column formats to use for a prepared statement. The input is
// an array of type oids, one element per result column.
func decideColumnFormats(
colTyps []fieldDesc, forceText bool,
) (colFmts []format, colFmtData []byte) {
if len(colTyps) == 0 {
return nil, colFmtDataAllText
}
colFmts = make([]format, len(colTyps))
if forceText {
return colFmts, colFmtDataAllText
}
allBinary := true
allText := true
for i, t := range colTyps {
switch t.OID {
// This is the list of types to use binary mode for when receiving them
// through a prepared statement. If a type appears in this list, it
// must also be implemented in binaryDecode in encode.go.
case oid.T_bytea:
fallthrough
case oid.T_int8:
fallthrough
case oid.T_int4:
fallthrough
case oid.T_int2:
fallthrough
case oid.T_uuid:
colFmts[i] = formatBinary
allText = false
default:
allBinary = false
}
}
if allBinary {
return colFmts, colFmtDataAllBinary
} else if allText {
return colFmts, colFmtDataAllText
} else {
colFmtData = make([]byte, 2+len(colFmts)*2)
binary.BigEndian.PutUint16(colFmtData, uint16(len(colFmts)))
for i, v := range colFmts {
binary.BigEndian.PutUint16(colFmtData[2+i*2:], uint16(v))
}
return colFmts, colFmtData
}
}
func (cn *conn) prepareTo(q, stmtName string) *stmt {
st := &stmt{cn: cn, name: stmtName}
b := cn.writeBuf('P')
b.string(st.name)
b.string(q)
b.int16(0)
b.next('D')
b.byte('S')
b.string(st.name)
b.next('S')
cn.send(b)
cn.readParseResponse()
st.paramTyps, st.colNames, st.colTyps = cn.readStatementDescribeResponse()
st.colFmts, st.colFmtData = decideColumnFormats(st.colTyps, cn.disablePreparedBinaryResult)
cn.readReadyForQuery()
return st
}
func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) {
if err := cn.err.get(); err != nil {
return nil, err
}
defer cn.errRecover(&err)
if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") {
s, err := cn.prepareCopyIn(q)
if err == nil {
cn.inCopy = true
}
return s, err
}
return cn.prepareTo(q, cn.gname()), nil
}
func (cn *conn) Close() (err error) {
// Skip cn.bad return here because we always want to close a connection.
defer cn.errRecover(&err)
// Ensure that cn.c.Close is always run. Since error handling is done with
// panics and cn.errRecover, the Close must be in a defer.
defer func() {
cerr := cn.c.Close()
if err == nil {
err = cerr
}
}()
// Don't go through send(); ListenerConn relies on us not scribbling on the
// scratch buffer of this connection.
return cn.sendSimpleMessage('X')
}
// Implement the "Queryer" interface
func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
return cn.query(query, args)
}
func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) {
if err := cn.err.get(); err != nil {
return nil, err
}
if cn.inCopy {
return nil, errCopyInProgress
}
defer cn.errRecover(&err)
// Check to see if we can use the "simpleQuery" interface, which is
// *much* faster than going through prepare/exec
if len(args) == 0 {
return cn.simpleQuery(query)
}
if cn.binaryParameters {
cn.sendBinaryModeQuery(query, args)
cn.readParseResponse()
cn.readBindResponse()
rows := &rows{cn: cn}
rows.rowsHeader = cn.readPortalDescribeResponse()
cn.postExecuteWorkaround()
return rows, nil
}
st := cn.prepareTo(query, "")
st.exec(args)
return &rows{
cn: cn,
rowsHeader: st.rowsHeader,
}, nil
}
// Implement the optional "Execer" interface for one-shot queries
func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) {
if err := cn.err.get(); err != nil {
return nil, err
}
defer cn.errRecover(&err)
// Check to see if we can use the "simpleExec" interface, which is
// *much* faster than going through prepare/exec
if len(args) == 0 {
// ignore commandTag, our caller doesn't care
r, _, err := cn.simpleExec(query)
return r, err
}
if cn.binaryParameters {
cn.sendBinaryModeQuery(query, args)
cn.readParseResponse()
cn.readBindResponse()
cn.readPortalDescribeResponse()
cn.postExecuteWorkaround()
res, _, err = cn.readExecuteResponse("Execute")
return res, err
}
// Use the unnamed statement to defer planning until bind
// time, or else value-based selectivity estimates cannot be
// used.
st := cn.prepareTo(query, "")
r, err := st.Exec(args)
if err != nil {
panic(err)
}
return r, err
}
type safeRetryError struct {
Err error
}
func (se *safeRetryError) Error() string {
return se.Err.Error()
}
func (cn *conn) send(m *writeBuf) {
n, err := cn.c.Write(m.wrap())
if err != nil {
if n == 0 {
err = &safeRetryError{Err: err}
}
panic(err)
}
}
func (cn *conn) sendStartupPacket(m *writeBuf) error {
_, err := cn.c.Write((m.wrap())[1:])
return err
}
// Send a message of type typ to the server on the other end of cn. The
// message should have no payload. This method does not use the scratch
// buffer.
func (cn *conn) sendSimpleMessage(typ byte) (err error) {
_, err = cn.c.Write([]byte{typ, '\x00', '\x00', '\x00', '\x04'})
return err
}
// saveMessage memorizes a message and its buffer in the conn struct.
// recvMessage will then return these values on the next call to it. This
// method is useful in cases where you have to see what the next message is
// going to be (e.g. to see whether it's an error or not) but you can't handle
// the message yourself.
func (cn *conn) saveMessage(typ byte, buf *readBuf) {
if cn.saveMessageType != 0 {
cn.err.set(driver.ErrBadConn)
errorf("unexpected saveMessageType %d", cn.saveMessageType)
}
cn.saveMessageType = typ
cn.saveMessageBuffer = *buf
}
// recvMessage receives any message from the backend, or returns an error if
// a problem occurred while reading the message.
func (cn *conn) recvMessage(r *readBuf) (byte, error) {
// workaround for a QueryRow bug, see exec
if cn.saveMessageType != 0 {
t := cn.saveMessageType
*r = cn.saveMessageBuffer
cn.saveMessageType = 0
cn.saveMessageBuffer = nil
return t, nil
}
x := cn.scratch[:5]
_, err := io.ReadFull(cn.buf, x)
if err != nil {
return 0, err
}
// read the type and length of the message that follows
t := x[0]
n := int(binary.BigEndian.Uint32(x[1:])) - 4
var y []byte
if n <= len(cn.scratch) {
y = cn.scratch[:n]
} else {
y = make([]byte, n)
}
_, err = io.ReadFull(cn.buf, y)
if err != nil {
return 0, err
}
*r = y
return t, nil
}
// recv receives a message from the backend, but if an error happened while
// reading the message or the received message was an ErrorResponse, it panics.
// NoticeResponses are ignored. This function should generally be used only
// during the startup sequence.
func (cn *conn) recv() (t byte, r *readBuf) {
for {
var err error
r = &readBuf{}
t, err = cn.recvMessage(r)
if err != nil {
panic(err)
}
switch t {
case 'E':
panic(parseError(r))
case 'N':
if n := cn.noticeHandler; n != nil {
n(parseError(r))
}
case 'A':
if n := cn.notificationHandler; n != nil {
n(recvNotification(r))
}
default:
return
}
}
}
// recv1Buf is exactly equivalent to recv1, except it uses a buffer supplied by
// the caller to avoid an allocation.
func (cn *conn) recv1Buf(r *readBuf) byte {
for {
t, err := cn.recvMessage(r)
if err != nil {
panic(err)
}
switch t {
case 'A':
if n := cn.notificationHandler; n != nil {
n(recvNotification(r))
}
case 'N':
if n := cn.noticeHandler; n != nil {
n(parseError(r))
}
case 'S':
cn.processParameterStatus(r)
default:
return t
}
}
}
// recv1 receives a message from the backend, panicking if an error occurs
// while attempting to read it. All asynchronous messages are ignored, with
// the exception of ErrorResponse.
func (cn *conn) recv1() (t byte, r *readBuf) {
r = &readBuf{}
t = cn.recv1Buf(r)
return t, r
}
func (cn *conn) ssl(o values) error {
upgrade, err := ssl(o)
if err != nil {
return err
}
if upgrade == nil {
// Nothing to do
return nil
}
w := cn.writeBuf(0)
w.int32(80877103)
if err = cn.sendStartupPacket(w); err != nil {
return err
}
b := cn.scratch[:1]
_, err = io.ReadFull(cn.c, b)
if err != nil {
return err
}
if b[0] != 'S' {
return ErrSSLNotSupported
}
cn.c, err = upgrade(cn.c)
return err
}
// isDriverSetting returns true iff a setting is purely for configuring the
// driver's options and should not be sent to the server in the connection
// startup packet.
func isDriverSetting(key string) bool {
switch key {
case "host", "port":
return true
case "password":
return true
case "sslmode", "sslcert", "sslkey", "sslrootcert", "sslinline", "sslsni":
return true
case "fallback_application_name":
return true
case "connect_timeout":
return true
case "disable_prepared_binary_result":
return true
case "binary_parameters":
return true
case "krbsrvname":
return true
case "krbspn":
return true
default:
return false
}
}
func (cn *conn) startup(o values) {
w := cn.writeBuf(0)
w.int32(196608)
// Send the backend the name of the database we want to connect to, and the
// user we want to connect as. Additionally, we send over any run-time
// parameters potentially included in the connection string. If the server
// doesn't recognize any of them, it will reply with an error.
for k, v := range o {
if isDriverSetting(k) {
// skip options which can't be run-time parameters
continue
}
// The protocol requires us to supply the database name as "database"
// instead of "dbname".
if k == "dbname" {
k = "database"
}
w.string(k)
w.string(v)
}
w.string("")
if err := cn.sendStartupPacket(w); err != nil {
panic(err)
}
for {
t, r := cn.recv()
switch t {
case 'K':
cn.processBackendKeyData(r)
case 'S':
cn.processParameterStatus(r)
case 'R':
cn.auth(r, o)
case 'Z':
cn.processReadyForQuery(r)
return
default:
errorf("unknown response for startup: %q", t)
}
}
}
func (cn *conn) auth(r *readBuf, o values) {
switch code := r.int32(); code {
case 0:
// OK
case 3:
w := cn.writeBuf('p')
w.string(o["password"])
cn.send(w)
t, r := cn.recv()
if t != 'R' {
errorf("unexpected password response: %q", t)
}
if r.int32() != 0 {
errorf("unexpected authentication response: %q", t)
}
case 5:
s := string(r.next(4))
w := cn.writeBuf('p')
w.string("md5" + md5s(md5s(o["password"]+o["user"])+s))
cn.send(w)
t, r := cn.recv()
if t != 'R' {
errorf("unexpected password response: %q", t)
}
if r.int32() != 0 {
errorf("unexpected authentication response: %q", t)
}
case 7: // GSSAPI, startup
if newGss == nil {
errorf("kerberos error: no GSSAPI provider registered (import github.com/lib/pq/auth/kerberos if you need Kerberos support)")
}
cli, err := newGss()
if err != nil {
errorf("kerberos error: %s", err.Error())
}
var token []byte
if spn, ok := o["krbspn"]; ok {
// Use the supplied SPN if provided..
token, err = cli.GetInitTokenFromSpn(spn)
} else {
// Allow the kerberos service name to be overridden
service := "postgres"
if val, ok := o["krbsrvname"]; ok {
service = val
}
token, err = cli.GetInitToken(o["host"], service)
}
if err != nil {
errorf("failed to get Kerberos ticket: %q", err)
}
w := cn.writeBuf('p')
w.bytes(token)
cn.send(w)
// Store for GSSAPI continue message
cn.gss = cli
case 8: // GSSAPI continue
if cn.gss == nil {
errorf("GSSAPI protocol error")
}
b := []byte(*r)
done, tokOut, err := cn.gss.Continue(b)
if err == nil && !done {
w := cn.writeBuf('p')
w.bytes(tokOut)
cn.send(w)
}
// Errors fall through and read the more detailed message
// from the server..
case 10:
sc := scram.NewClient(sha256.New, o["user"], o["password"])
sc.Step(nil)
if sc.Err() != nil {
errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
}
scOut := sc.Out()
w := cn.writeBuf('p')
w.string("SCRAM-SHA-256")
w.int32(len(scOut))
w.bytes(scOut)
cn.send(w)
t, r := cn.recv()
if t != 'R' {
errorf("unexpected password response: %q", t)
}
if r.int32() != 11 {
errorf("unexpected authentication response: %q", t)
}
nextStep := r.next(len(*r))
sc.Step(nextStep)
if sc.Err() != nil {
errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
}
scOut = sc.Out()
w = cn.writeBuf('p')
w.bytes(scOut)
cn.send(w)
t, r = cn.recv()
if t != 'R' {
errorf("unexpected password response: %q", t)
}
if r.int32() != 12 {
errorf("unexpected authentication response: %q", t)
}
nextStep = r.next(len(*r))
sc.Step(nextStep)
if sc.Err() != nil {
errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
}
default:
errorf("unknown authentication response: %d", code)
}
}
type format int
const formatText format = 0
const formatBinary format = 1
// One result-column format code with the value 1 (i.e. all binary).
var colFmtDataAllBinary = []byte{0, 1, 0, 1}
// No result-column format codes (i.e. all text).
var colFmtDataAllText = []byte{0, 0}
type stmt struct {
cn *conn
name string
rowsHeader
colFmtData []byte
paramTyps []oid.Oid
closed bool
}
func (st *stmt) Close() (err error) {
if st.closed {
return nil
}
if err := st.cn.err.get(); err != nil {
return err
}
defer st.cn.errRecover(&err)
w := st.cn.writeBuf('C')
w.byte('S')
w.string(st.name)
st.cn.send(w)
st.cn.send(st.cn.writeBuf('S'))
t, _ := st.cn.recv1()
if t != '3' {
st.cn.err.set(driver.ErrBadConn)
errorf("unexpected close response: %q", t)
}
st.closed = true
t, r := st.cn.recv1()
if t != 'Z' {
st.cn.err.set(driver.ErrBadConn)
errorf("expected ready for query, but got: %q", t)
}
st.cn.processReadyForQuery(r)
return nil
}
func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) {
return st.query(v)
}
func (st *stmt) query(v []driver.Value) (r *rows, err error) {
if err := st.cn.err.get(); err != nil {
return nil, err
}
defer st.cn.errRecover(&err)
st.exec(v)
return &rows{
cn: st.cn,
rowsHeader: st.rowsHeader,
}, nil
}
func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) {
if err := st.cn.err.get(); err != nil {
return nil, err
}
defer st.cn.errRecover(&err)
st.exec(v)
res, _, err = st.cn.readExecuteResponse("simple query")
return res, err
}
func (st *stmt) exec(v []driver.Value) {
if len(v) >= 65536 {
errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(v))
}
if len(v) != len(st.paramTyps) {
errorf("got %d parameters but the statement requires %d", len(v), len(st.paramTyps))
}
cn := st.cn
w := cn.writeBuf('B')
w.byte(0) // unnamed portal
w.string(st.name)
if cn.binaryParameters {
cn.sendBinaryParameters(w, v)
} else {
w.int16(0)
w.int16(len(v))
for i, x := range v {
if x == nil {
w.int32(-1)
} else {
b := encode(&cn.parameterStatus, x, st.paramTyps[i])
w.int32(len(b))
w.bytes(b)
}
}
}
w.bytes(st.colFmtData)
w.next('E')
w.byte(0)
w.int32(0)
w.next('S')
cn.send(w)
cn.readBindResponse()
cn.postExecuteWorkaround()
}
func (st *stmt) NumInput() int {
return len(st.paramTyps)
}
// parseComplete parses the "command tag" from a CommandComplete message, and
// returns the number of rows affected (if applicable) and a string
// identifying only the command that was executed, e.g. "ALTER TABLE". If the
// command tag could not be parsed, parseComplete panics.
func (cn *conn) parseComplete(commandTag string) (driver.Result, string) {
commandsWithAffectedRows := []string{
"SELECT ",
// INSERT is handled below
"UPDATE ",
"DELETE ",
"FETCH ",
"MOVE ",
"COPY ",
}
var affectedRows *string
for _, tag := range commandsWithAffectedRows {
if strings.HasPrefix(commandTag, tag) {
t := commandTag[len(tag):]
affectedRows = &t
commandTag = tag[:len(tag)-1]
break
}
}
// INSERT also includes the oid of the inserted row in its command tag.
// Oids in user tables are deprecated, and the oid is only returned when
// exactly one row is inserted, so it's unlikely to be of value to any
// real-world application and we can ignore it.
if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") {
parts := strings.Split(commandTag, " ")
if len(parts) != 3 {
cn.err.set(driver.ErrBadConn)
errorf("unexpected INSERT command tag %s", commandTag)
}
affectedRows = &parts[len(parts)-1]
commandTag = "INSERT"
}
// There should be no affected rows attached to the tag, just return it
if affectedRows == nil {
return driver.RowsAffected(0), commandTag
}
n, err := strconv.ParseInt(*affectedRows, 10, 64)
if err != nil {
cn.err.set(driver.ErrBadConn)
errorf("could not parse commandTag: %s", err)
}
return driver.RowsAffected(n), commandTag
}
type rowsHeader struct {
colNames []string
colTyps []fieldDesc
colFmts []format
}
type rows struct {
cn *conn
finish func()
rowsHeader
done bool
rb readBuf
result driver.Result
tag string
next *rowsHeader
}
func (rs *rows) Close() error {
if finish := rs.finish; finish != nil {
defer finish()
}
// no need to look at cn.bad as Next() will
for {
err := rs.Next(nil)
switch err {
case nil:
case io.EOF:
// rs.Next can return io.EOF on both 'Z' (ready for query) and 'T' (row
// description, used with HasNextResultSet). We need to fetch messages until
// we hit a 'Z', which is done by waiting for done to be set.
if rs.done {
return nil
}
default:
return err
}
}
}
func (rs *rows) Columns() []string {
return rs.colNames
}
func (rs *rows) Result() driver.Result {
if rs.result == nil {
return emptyRows
}
return rs.result
}
func (rs *rows) Tag() string {
return rs.tag
}
func (rs *rows) Next(dest []driver.Value) (err error) {
if rs.done {
return io.EOF
}
conn := rs.cn
if err := conn.err.getForNext(); err != nil {
return err
}
defer conn.errRecover(&err)
for {
t := conn.recv1Buf(&rs.rb)
switch t {
case 'E':
err = parseError(&rs.rb)
case 'C', 'I':
if t == 'C' {
rs.result, rs.tag = conn.parseComplete(rs.rb.string())
}
continue
case 'Z':
conn.processReadyForQuery(&rs.rb)
rs.done = true
if err != nil {
return err
}
return io.EOF
case 'D':
n := rs.rb.int16()
if err != nil {
conn.err.set(driver.ErrBadConn)
errorf("unexpected DataRow after error %s", err)
}
if n < len(dest) {
dest = dest[:n]
}
for i := range dest {
l := rs.rb.int32()
if l == -1 {
dest[i] = nil
continue
}
dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i].OID, rs.colFmts[i])
}
return
case 'T':
next := parsePortalRowDescribe(&rs.rb)
rs.next = &next
return io.EOF
default:
errorf("unexpected message after execute: %q", t)
}
}
}
func (rs *rows) HasNextResultSet() bool {
hasNext := rs.next != nil && !rs.done
return hasNext
}
func (rs *rows) NextResultSet() error {
if rs.next == nil {
return io.EOF
}
rs.rowsHeader = *rs.next
rs.next = nil
return nil
}
// QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be
// used as part of an SQL statement. For example:
//
// tblname := "my_table"
// data := "my_data"
// quoted := pq.QuoteIdentifier(tblname)
// err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data)
//
// Any double quotes in name will be escaped. The quoted identifier will be
// case sensitive when used in a query. If the input string contains a zero
// byte, the result will be truncated immediately before it.
func QuoteIdentifier(name string) string {
end := strings.IndexRune(name, 0)
if end > -1 {
name = name[:end]
}
return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
}
// BufferQuoteIdentifier satisfies the same purpose as QuoteIdentifier, but backed by a
// byte buffer.
func BufferQuoteIdentifier(name string, buffer *bytes.Buffer) {
end := strings.IndexRune(name, 0)
if end > -1 {
name = name[:end]
}
buffer.WriteRune('"')
buffer.WriteString(strings.Replace(name, `"`, `""`, -1))
buffer.WriteRune('"')
}
// QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal
// to DDL and other statements that do not accept parameters) to be used as part
// of an SQL statement. For example:
//
// exp_date := pq.QuoteLiteral("2023-01-05 15:00:00Z")
// err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date))
//
// Any single quotes in name will be escaped. Any backslashes (i.e. "\") will be
// replaced by two backslashes (i.e. "\\") and the C-style escape identifier
// that PostgreSQL provides ('E') will be prepended to the string.
func QuoteLiteral(literal string) string {
// This follows the PostgreSQL internal algorithm for handling quoted literals
// from libpq, which can be found in the "PQEscapeStringInternal" function,
// which is found in the libpq/fe-exec.c source file:
// https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/interfaces/libpq/fe-exec.c
//
// substitute any single-quotes (') with two single-quotes ('')
literal = strings.Replace(literal, `'`, `''`, -1)
// determine if the string has any backslashes (\) in it.
// if it does, replace any backslashes (\) with two backslashes (\\)
// then, we need to wrap the entire string with a PostgreSQL
// C-style escape. Per how "PQEscapeStringInternal" handles this case, we
// also add a space before the "E"
if strings.Contains(literal, `\`) {
literal = strings.Replace(literal, `\`, `\\`, -1)
literal = ` E'` + literal + `'`
} else {
// otherwise, we can just wrap the literal with a pair of single quotes
literal = `'` + literal + `'`
}
return literal
}
func md5s(s string) string {
h := md5.New()
h.Write([]byte(s))
return fmt.Sprintf("%x", h.Sum(nil))
}
func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.Value) {
// Do one pass over the parameters to see if we're going to send any of
// them over in binary. If we are, create a paramFormats array at the
// same time.
var paramFormats []int
for i, x := range args {
_, ok := x.([]byte)
if ok {
if paramFormats == nil {
paramFormats = make([]int, len(args))
}
paramFormats[i] = 1
}
}
if paramFormats == nil {
b.int16(0)
} else {
b.int16(len(paramFormats))
for _, x := range paramFormats {
b.int16(x)
}
}
b.int16(len(args))
for _, x := range args {
if x == nil {
b.int32(-1)
} else {
datum := binaryEncode(&cn.parameterStatus, x)
b.int32(len(datum))
b.bytes(datum)
}
}
}
func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) {
if len(args) >= 65536 {
errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(args))
}
b := cn.writeBuf('P')
b.byte(0) // unnamed statement
b.string(query)
b.int16(0)
b.next('B')
b.int16(0) // unnamed portal and statement
cn.sendBinaryParameters(b, args)
b.bytes(colFmtDataAllText)
b.next('D')
b.byte('P')
b.byte(0) // unnamed portal
b.next('E')
b.byte(0)
b.int32(0)
b.next('S')
cn.send(b)
}
func (cn *conn) processParameterStatus(r *readBuf) {
var err error
param := r.string()
switch param {
case "server_version":
var major1 int
var major2 int
_, err = fmt.Sscanf(r.string(), "%d.%d", &major1, &major2)
if err == nil {
cn.parameterStatus.serverVersion = major1*10000 + major2*100
}
case "TimeZone":
cn.parameterStatus.currentLocation, err = time.LoadLocation(r.string())
if err != nil {
cn.parameterStatus.currentLocation = nil
}
default:
// ignore
}
}
func (cn *conn) processReadyForQuery(r *readBuf) {
cn.txnStatus = transactionStatus(r.byte())
}
func (cn *conn) readReadyForQuery() {
t, r := cn.recv1()
switch t {
case 'Z':
cn.processReadyForQuery(r)
return
default:
cn.err.set(driver.ErrBadConn)
errorf("unexpected message %q; expected ReadyForQuery", t)
}
}
func (cn *conn) processBackendKeyData(r *readBuf) {
cn.processID = r.int32()
cn.secretKey = r.int32()
}
func (cn *conn) readParseResponse() {
t, r := cn.recv1()
switch t {
case '1':
return
case 'E':
err := parseError(r)
cn.readReadyForQuery()
panic(err)
default:
cn.err.set(driver.ErrBadConn)
errorf("unexpected Parse response %q", t)
}
}
func (cn *conn) readStatementDescribeResponse() (
paramTyps []oid.Oid,
colNames []string,
colTyps []fieldDesc,
) {
for {
t, r := cn.recv1()
switch t {
case 't':
nparams := r.int16()
paramTyps = make([]oid.Oid, nparams)
for i := range paramTyps {
paramTyps[i] = r.oid()
}
case 'n':
return paramTyps, nil, nil
case 'T':
colNames, colTyps = parseStatementRowDescribe(r)
return paramTyps, colNames, colTyps
case 'E':
err := parseError(r)
cn.readReadyForQuery()
panic(err)
default:
cn.err.set(driver.ErrBadConn)
errorf("unexpected Describe statement response %q", t)
}
}
}
func (cn *conn) readPortalDescribeResponse() rowsHeader {
t, r := cn.recv1()
switch t {
case 'T':
return parsePortalRowDescribe(r)
case 'n':
return rowsHeader{}
case 'E':
err := parseError(r)
cn.readReadyForQuery()
panic(err)
default:
cn.err.set(driver.ErrBadConn)
errorf("unexpected Describe response %q", t)
}
panic("not reached")
}
func (cn *conn) readBindResponse() {
t, r := cn.recv1()
switch t {
case '2':
return
case 'E':
err := parseError(r)
cn.readReadyForQuery()
panic(err)
default:
cn.err.set(driver.ErrBadConn)
errorf("unexpected Bind response %q", t)
}
}
func (cn *conn) postExecuteWorkaround() {
// Work around a bug in sql.DB.QueryRow: in Go 1.2 and earlier it ignores
// any errors from rows.Next, which masks errors that happened during the
// execution of the query. To avoid the problem in common cases, we wait
// here for one more message from the database. If it's not an error the
// query will likely succeed (or perhaps has already, if it's a
// CommandComplete), so we push the message into the conn struct; recv1
// will return it as the next message for rows.Next or rows.Close.
// However, if it's an error, we wait until ReadyForQuery and then return
// the error to our caller.
for {
t, r := cn.recv1()
switch t {
case 'E':
err := parseError(r)
cn.readReadyForQuery()
panic(err)
case 'C', 'D', 'I':
// the query didn't fail, but we can't process this message
cn.saveMessage(t, r)
return
default:
cn.err.set(driver.ErrBadConn)
errorf("unexpected message during extended query execution: %q", t)
}
}
}
// Only for Exec(), since we ignore the returned data
func (cn *conn) readExecuteResponse(
protocolState string,
) (res driver.Result, commandTag string, err error) {
for {
t, r := cn.recv1()
switch t {
case 'C':
if err != nil {
cn.err.set(driver.ErrBadConn)
errorf("unexpected CommandComplete after error %s", err)
}
res, commandTag = cn.parseComplete(r.string())
case 'Z':
cn.processReadyForQuery(r)
if res == nil && err == nil {
err = errUnexpectedReady
}
return res, commandTag, err
case 'E':
err = parseError(r)
case 'T', 'D', 'I':
if err != nil {
cn.err.set(driver.ErrBadConn)
errorf("unexpected %q after error %s", t, err)
}
if t == 'I' {
res = emptyRows
}
// ignore any results
default:
cn.err.set(driver.ErrBadConn)
errorf("unknown %s response: %q", protocolState, t)
}
}
}
func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []fieldDesc) {
n := r.int16()
colNames = make([]string, n)
colTyps = make([]fieldDesc, n)
for i := range colNames {
colNames[i] = r.string()
r.next(6)
colTyps[i].OID = r.oid()
colTyps[i].Len = r.int16()
colTyps[i].Mod = r.int32()
// format code not known when describing a statement; always 0
r.next(2)
}
return
}
func parsePortalRowDescribe(r *readBuf) rowsHeader {
n := r.int16()
colNames := make([]string, n)
colFmts := make([]format, n)
colTyps := make([]fieldDesc, n)
for i := range colNames {
colNames[i] = r.string()
r.next(6)
colTyps[i].OID = r.oid()
colTyps[i].Len = r.int16()
colTyps[i].Mod = r.int32()
colFmts[i] = format(r.int16())
}
return rowsHeader{
colNames: colNames,
colFmts: colFmts,
colTyps: colTyps,
}
}
// parseEnviron tries to mimic some of libpq's environment handling
//
// To ease testing, it does not directly reference os.Environ, but is
// designed to accept its output.
//
// Environment-set connection information is intended to have a higher
// precedence than a library default but lower than any explicitly
// passed information (such as in the URL or connection string).
func parseEnviron(env []string) (out map[string]string) {
out = make(map[string]string)
for _, v := range env {
parts := strings.SplitN(v, "=", 2)
accrue := func(keyname string) {
out[keyname] = parts[1]
}
unsupported := func() {
panic(fmt.Sprintf("setting %v not supported", parts[0]))
}
// The order of these is the same as is seen in the
// PostgreSQL 9.1 manual. Unsupported but well-defined
// keys cause a panic; these should be unset prior to
// execution. Options which pq expects to be set to a
// certain value are allowed, but must be set to that
// value if present (they can, of course, be absent).
switch parts[0] {
case "PGHOST":
accrue("host")
case "PGHOSTADDR":
unsupported()
case "PGPORT":
accrue("port")
case "PGDATABASE":
accrue("dbname")
case "PGUSER":
accrue("user")
case "PGPASSWORD":
accrue("password")
case "PGSERVICE", "PGSERVICEFILE", "PGREALM":
unsupported()
case "PGOPTIONS":
accrue("options")
case "PGAPPNAME":
accrue("application_name")
case "PGSSLMODE":
accrue("sslmode")
case "PGSSLCERT":
accrue("sslcert")
case "PGSSLKEY":
accrue("sslkey")
case "PGSSLROOTCERT":
accrue("sslrootcert")
case "PGSSLSNI":
accrue("sslsni")
case "PGREQUIRESSL", "PGSSLCRL":
unsupported()
case "PGREQUIREPEER":
unsupported()
case "PGKRBSRVNAME", "PGGSSLIB":
unsupported()
case "PGCONNECT_TIMEOUT":
accrue("connect_timeout")
case "PGCLIENTENCODING":
accrue("client_encoding")
case "PGDATESTYLE":
accrue("datestyle")
case "PGTZ":
accrue("timezone")
case "PGGEQO":
accrue("geqo")
case "PGSYSCONFDIR", "PGLOCALEDIR":
unsupported()
}
}
return out
}
// isUTF8 returns whether name is a fuzzy variation of the string "UTF-8".
func isUTF8(name string) bool {
// Recognize all sorts of silly things as "UTF-8", like Postgres does
s := strings.Map(alnumLowerASCII, name)
return s == "utf8" || s == "unicode"
}
func alnumLowerASCII(ch rune) rune {
if 'A' <= ch && ch <= 'Z' {
return ch + ('a' - 'A')
}
if 'a' <= ch && ch <= 'z' || '0' <= ch && ch <= '9' {
return ch
}
return -1 // discard
}
// The database/sql/driver package says:
// All Conn implementations should implement the following interfaces: Pinger, SessionResetter, and Validator.
var _ driver.Pinger = &conn{}
var _ driver.SessionResetter = &conn{}
func (cn *conn) ResetSession(ctx context.Context) error {
// Ensure bad connections are reported: From database/sql/driver:
// If a connection is never returned to the connection pool but immediately reused, then
// ResetSession is called prior to reuse but IsValid is not called.
return cn.err.get()
}
func (cn *conn) IsValid() bool {
return cn.err.get() == nil
}
================================================
FILE: gossh/pgsql/conn_go115.go
================================================
//go:build go1.15
// +build go1.15
package pgsql
import "database/sql/driver"
var _ driver.Validator = &conn{}
================================================
FILE: gossh/pgsql/conn_go18.go
================================================
package pgsql
import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"io"
"io/ioutil"
"time"
)
const (
watchCancelDialContextTimeout = time.Second * 10
)
// Implement the "QueryerContext" interface
func (cn *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
list := make([]driver.Value, len(args))
for i, nv := range args {
list[i] = nv.Value
}
finish := cn.watchCancel(ctx)
r, err := cn.query(query, list)
if err != nil {
if finish != nil {
finish()
}
return nil, err
}
r.finish = finish
return r, nil
}
// Implement the "ExecerContext" interface
func (cn *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
list := make([]driver.Value, len(args))
for i, nv := range args {
list[i] = nv.Value
}
if finish := cn.watchCancel(ctx); finish != nil {
defer finish()
}
return cn.Exec(query, list)
}
// Implement the "ConnPrepareContext" interface
func (cn *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
if finish := cn.watchCancel(ctx); finish != nil {
defer finish()
}
return cn.Prepare(query)
}
// Implement the "ConnBeginTx" interface
func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
var mode string
switch sql.IsolationLevel(opts.Isolation) {
case sql.LevelDefault:
// Don't touch mode: use the server's default
case sql.LevelReadUncommitted:
mode = " ISOLATION LEVEL READ UNCOMMITTED"
case sql.LevelReadCommitted:
mode = " ISOLATION LEVEL READ COMMITTED"
case sql.LevelRepeatableRead:
mode = " ISOLATION LEVEL REPEATABLE READ"
case sql.LevelSerializable:
mode = " ISOLATION LEVEL SERIALIZABLE"
default:
return nil, fmt.Errorf("pq: isolation level not supported: %d", opts.Isolation)
}
if opts.ReadOnly {
mode += " READ ONLY"
} else {
mode += " READ WRITE"
}
tx, err := cn.begin(mode)
if err != nil {
return nil, err
}
cn.txnFinish = cn.watchCancel(ctx)
return tx, nil
}
func (cn *conn) Ping(ctx context.Context) error {
if finish := cn.watchCancel(ctx); finish != nil {
defer finish()
}
rows, err := cn.simpleQuery(";")
if err != nil {
return driver.ErrBadConn // https://golang.org/pkg/database/sql/driver/#Pinger
}
rows.Close()
return nil
}
func (cn *conn) watchCancel(ctx context.Context) func() {
if done := ctx.Done(); done != nil {
finished := make(chan struct{}, 1)
go func() {
select {
case <-done:
select {
case finished <- struct{}{}:
default:
// We raced with the finish func, let the next query handle this with the
// context.
return
}
// Set the connection state to bad so it does not get reused.
cn.err.set(ctx.Err())
// At this point the function level context is canceled,
// so it must not be used for the additional network
// request to cancel the query.
// Create a new context to pass into the dial.
ctxCancel, cancel := context.WithTimeout(context.Background(), watchCancelDialContextTimeout)
defer cancel()
_ = cn.cancel(ctxCancel)
case <-finished:
}
}()
return func() {
select {
case <-finished:
cn.err.set(ctx.Err())
cn.Close()
case finished <- struct{}{}:
}
}
}
return nil
}
func (cn *conn) cancel(ctx context.Context) error {
// Create a new values map (copy). This makes sure the connection created
// in this method cannot write to the same underlying data, which could
// cause a concurrent map write panic. This is necessary because cancel
// is called from a goroutine in watchCancel.
o := make(values)
for k, v := range cn.opts {
o[k] = v
}
c, err := dial(ctx, cn.dialer, o)
if err != nil {
return err
}
defer c.Close()
{
can := conn{
c: c,
}
err = can.ssl(o)
if err != nil {
return err
}
w := can.writeBuf(0)
w.int32(80877102) // cancel request code
w.int32(cn.processID)
w.int32(cn.secretKey)
if err := can.sendStartupPacket(w); err != nil {
return err
}
}
// Read until EOF to ensure that the server received the cancel.
{
_, err := io.Copy(ioutil.Discard, c)
return err
}
}
// Implement the "StmtQueryContext" interface
func (st *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
list := make([]driver.Value, len(args))
for i, nv := range args {
list[i] = nv.Value
}
finish := st.watchCancel(ctx)
r, err := st.query(list)
if err != nil {
if finish != nil {
finish()
}
return nil, err
}
r.finish = finish
return r, nil
}
// Implement the "StmtExecContext" interface
func (st *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
list := make([]driver.Value, len(args))
for i, nv := range args {
list[i] = nv.Value
}
if finish := st.watchCancel(ctx); finish != nil {
defer finish()
}
return st.Exec(list)
}
// watchCancel is implemented on stmt in order to not mark the parent conn as bad
func (st *stmt) watchCancel(ctx context.Context) func() {
if done := ctx.Done(); done != nil {
finished := make(chan struct{})
go func() {
select {
case <-done:
// At this point the function level context is canceled,
// so it must not be used for the additional network
// request to cancel the query.
// Create a new context to pass into the dial.
ctxCancel, cancel := context.WithTimeout(context.Background(), watchCancelDialContextTimeout)
defer cancel()
_ = st.cancel(ctxCancel)
finished <- struct{}{}
case <-finished:
}
}()
return func() {
select {
case <-finished:
case finished <- struct{}{}:
}
}
}
return nil
}
func (st *stmt) cancel(ctx context.Context) error {
return st.cn.cancel(ctx)
}
================================================
FILE: gossh/pgsql/connector.go
================================================
package pgsql
import (
"context"
"database/sql/driver"
"errors"
"fmt"
"os"
"strings"
)
// Connector represents a fixed configuration for the pq driver with a given
// name. Connector satisfies the database/sql/driver Connector interface and
// can be used to create any number of DB Conn's via the database/sql OpenDB
// function.
//
// See https://golang.org/pkg/database/sql/driver/#Connector.
// See https://golang.org/pkg/database/sql/#OpenDB.
type Connector struct {
opts values
dialer Dialer
}
// Connect returns a connection to the database using the fixed configuration
// of this Connector. Context is not used.
func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) {
return c.open(ctx)
}
// Dialer allows change the dialer used to open connections.
func (c *Connector) Dialer(dialer Dialer) {
c.dialer = dialer
}
// Driver returns the underlying driver of this Connector.
func (c *Connector) Driver() driver.Driver {
return &Driver{}
}
// NewConnector returns a connector for the pq driver in a fixed configuration
// with the given dsn. The returned connector can be used to create any number
// of equivalent Conn's. The returned connector is intended to be used with
// database/sql.OpenDB.
//
// See https://golang.org/pkg/database/sql/driver/#Connector.
// See https://golang.org/pkg/database/sql/#OpenDB.
func NewConnector(dsn string) (*Connector, error) {
var err error
o := make(values)
// A number of defaults are applied here, in this order:
//
// * Very low precedence defaults applied in every situation
// * Environment variables
// * Explicitly passed connection information
o["host"] = "localhost"
o["port"] = "5432"
// N.B.: Extra float digits should be set to 3, but that breaks
// Postgres 8.4 and older, where the max is 2.
o["extra_float_digits"] = "2"
for k, v := range parseEnviron(os.Environ()) {
o[k] = v
}
if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
dsn, err = ParseURL(dsn)
if err != nil {
return nil, err
}
}
if err := parseOpts(dsn, o); err != nil {
return nil, err
}
// Use the "fallback" application name if necessary
if fallback, ok := o["fallback_application_name"]; ok {
if _, ok := o["application_name"]; !ok {
o["application_name"] = fallback
}
}
// We can't work with any client_encoding other than UTF-8 currently.
// However, we have historically allowed the user to set it to UTF-8
// explicitly, and there's no reason to break such programs, so allow that.
// Note that the "options" setting could also set client_encoding, but
// parsing its value is not worth it. Instead, we always explicitly send
// client_encoding as a separate run-time parameter, which should override
// anything set in options.
if enc, ok := o["client_encoding"]; ok && !isUTF8(enc) {
return nil, errors.New("client_encoding must be absent or 'UTF8'")
}
o["client_encoding"] = "UTF8"
// DateStyle needs a similar treatment.
if datestyle, ok := o["datestyle"]; ok {
if datestyle != "ISO, MDY" {
return nil, fmt.Errorf("setting datestyle must be absent or %v; got %v", "ISO, MDY", datestyle)
}
} else {
o["datestyle"] = "ISO, MDY"
}
// If a user is not provided by any other means, the last
// resort is to use the current operating system provided user
// name.
if _, ok := o["user"]; !ok {
u, err := userCurrent()
if err != nil {
return nil, err
}
o["user"] = u
}
// SSL is not necessary or supported over UNIX domain sockets
if network, _ := network(o); network == "unix" {
o["sslmode"] = "disable"
}
return &Connector{opts: o, dialer: defaultDialer{}}, nil
}
================================================
FILE: gossh/pgsql/copy.go
================================================
package pgsql
import (
"bytes"
"context"
"database/sql/driver"
"encoding/binary"
"errors"
"fmt"
"sync"
)
var (
errCopyInClosed = errors.New("pq: copyin statement has already been closed")
errBinaryCopyNotSupported = errors.New("pq: only text format supported for COPY")
errCopyToNotSupported = errors.New("pq: COPY TO is not supported")
errCopyNotSupportedOutsideTxn = errors.New("pq: COPY is only allowed inside a transaction")
errCopyInProgress = errors.New("pq: COPY in progress")
)
// CopyIn creates a COPY FROM statement which can be prepared with
// Tx.Prepare(). The target table should be visible in search_path.
func CopyIn(table string, columns ...string) string {
buffer := bytes.NewBufferString("COPY ")
BufferQuoteIdentifier(table, buffer)
buffer.WriteString(" (")
makeStmt(buffer, columns...)
return buffer.String()
}
// MakeStmt makes the stmt string for CopyIn and CopyInSchema.
func makeStmt(buffer *bytes.Buffer, columns ...string) {
//s := bytes.NewBufferString()
for i, col := range columns {
if i != 0 {
buffer.WriteString(", ")
}
BufferQuoteIdentifier(col, buffer)
}
buffer.WriteString(") FROM STDIN")
}
// CopyInSchema creates a COPY FROM statement which can be prepared with
// Tx.Prepare().
func CopyInSchema(schema, table string, columns ...string) string {
buffer := bytes.NewBufferString("COPY ")
BufferQuoteIdentifier(schema, buffer)
buffer.WriteRune('.')
BufferQuoteIdentifier(table, buffer)
buffer.WriteString(" (")
makeStmt(buffer, columns...)
return buffer.String()
}
type copyin struct {
cn *conn
buffer []byte
rowData chan []byte
done chan bool
closed bool
mu struct {
sync.Mutex
err error
driver.Result
}
}
const ciBufferSize = 64 * 1024
// flush buffer before the buffer is filled up and needs reallocation
const ciBufferFlushSize = 63 * 1024
func (cn *conn) prepareCopyIn(q string) (_ driver.Stmt, err error) {
if !cn.isInTransaction() {
return nil, errCopyNotSupportedOutsideTxn
}
ci := ©in{
cn: cn,
buffer: make([]byte, 0, ciBufferSize),
rowData: make(chan []byte),
done: make(chan bool, 1),
}
// add CopyData identifier + 4 bytes for message length
ci.buffer = append(ci.buffer, 'd', 0, 0, 0, 0)
b := cn.writeBuf('Q')
b.string(q)
cn.send(b)
awaitCopyInResponse:
for {
t, r := cn.recv1()
switch t {
case 'G':
if r.byte() != 0 {
err = errBinaryCopyNotSupported
break awaitCopyInResponse
}
go ci.resploop()
return ci, nil
case 'H':
err = errCopyToNotSupported
break awaitCopyInResponse
case 'E':
err = parseError(r)
case 'Z':
if err == nil {
ci.setBad(driver.ErrBadConn)
errorf("unexpected ReadyForQuery in response to COPY")
}
cn.processReadyForQuery(r)
return nil, err
default:
ci.setBad(driver.ErrBadConn)
errorf("unknown response for copy query: %q", t)
}
}
// something went wrong, abort COPY before we return
b = cn.writeBuf('f')
b.string(err.Error())
cn.send(b)
for {
t, r := cn.recv1()
switch t {
case 'c', 'C', 'E':
case 'Z':
// correctly aborted, we're done
cn.processReadyForQuery(r)
return nil, err
default:
ci.setBad(driver.ErrBadConn)
errorf("unknown response for CopyFail: %q", t)
}
}
}
func (ci *copyin) flush(buf []byte) {
// set message length (without message identifier)
binary.BigEndian.PutUint32(buf[1:], uint32(len(buf)-1))
_, err := ci.cn.c.Write(buf)
if err != nil {
panic(err)
}
}
func (ci *copyin) resploop() {
for {
var r readBuf
t, err := ci.cn.recvMessage(&r)
if err != nil {
ci.setBad(driver.ErrBadConn)
ci.setError(err)
ci.done <- true
return
}
switch t {
case 'C':
// complete
res, _ := ci.cn.parseComplete(r.string())
ci.setResult(res)
case 'N':
if n := ci.cn.noticeHandler; n != nil {
n(parseError(&r))
}
case 'Z':
ci.cn.processReadyForQuery(&r)
ci.done <- true
return
case 'E':
err := parseError(&r)
ci.setError(err)
default:
ci.setBad(driver.ErrBadConn)
ci.setError(fmt.Errorf("unknown response during CopyIn: %q", t))
ci.done <- true
return
}
}
}
func (ci *copyin) setBad(err error) {
ci.cn.err.set(err)
}
func (ci *copyin) getBad() error {
return ci.cn.err.get()
}
func (ci *copyin) err() error {
ci.mu.Lock()
err := ci.mu.err
ci.mu.Unlock()
return err
}
// setError() sets ci.err if one has not been set already. Caller must not be
// holding ci.Mutex.
func (ci *copyin) setError(err error) {
ci.mu.Lock()
if ci.mu.err == nil {
ci.mu.err = err
}
ci.mu.Unlock()
}
func (ci *copyin) setResult(result driver.Result) {
ci.mu.Lock()
ci.mu.Result = result
ci.mu.Unlock()
}
func (ci *copyin) getResult() driver.Result {
ci.mu.Lock()
result := ci.mu.Result
ci.mu.Unlock()
if result == nil {
return driver.RowsAffected(0)
}
return result
}
func (ci *copyin) NumInput() int {
return -1
}
func (ci *copyin) Query(v []driver.Value) (r driver.Rows, err error) {
return nil, ErrNotSupported
}
// Exec inserts values into the COPY stream. The insert is asynchronous
// and Exec can return errors from previous Exec calls to the same
// COPY stmt.
//
// You need to call Exec(nil) to sync the COPY stream and to get any
// errors from pending data, since Stmt.Close() doesn't return errors
// to the user.
func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) {
if ci.closed {
return nil, errCopyInClosed
}
if err := ci.getBad(); err != nil {
return nil, err
}
defer ci.cn.errRecover(&err)
if err := ci.err(); err != nil {
return nil, err
}
if len(v) == 0 {
if err := ci.Close(); err != nil {
return driver.RowsAffected(0), err
}
return ci.getResult(), nil
}
numValues := len(v)
for i, value := range v {
ci.buffer = appendEncodedText(&ci.cn.parameterStatus, ci.buffer, value)
if i < numValues-1 {
ci.buffer = append(ci.buffer, '\t')
}
}
ci.buffer = append(ci.buffer, '\n')
if len(ci.buffer) > ciBufferFlushSize {
ci.flush(ci.buffer)
// reset buffer, keep bytes for message identifier and length
ci.buffer = ci.buffer[:5]
}
return driver.RowsAffected(0), nil
}
// CopyData inserts a raw string into the COPY stream. The insert is
// asynchronous and CopyData can return errors from previous CopyData calls to
// the same COPY stmt.
//
// You need to call Exec(nil) to sync the COPY stream and to get any
// errors from pending data, since Stmt.Close() doesn't return errors
// to the user.
func (ci *copyin) CopyData(ctx context.Context, line string) (r driver.Result, err error) {
if ci.closed {
return nil, errCopyInClosed
}
if finish := ci.cn.watchCancel(ctx); finish != nil {
defer finish()
}
if err := ci.getBad(); err != nil {
return nil, err
}
defer ci.cn.errRecover(&err)
if err := ci.err(); err != nil {
return nil, err
}
ci.buffer = append(ci.buffer, []byte(line)...)
ci.buffer = append(ci.buffer, '\n')
if len(ci.buffer) > ciBufferFlushSize {
ci.flush(ci.buffer)
// reset buffer, keep bytes for message identifier and length
ci.buffer = ci.buffer[:5]
}
return driver.RowsAffected(0), nil
}
func (ci *copyin) Close() (err error) {
if ci.closed { // Don't do anything, we're already closed
return nil
}
ci.closed = true
if err := ci.getBad(); err != nil {
return err
}
defer ci.cn.errRecover(&err)
if len(ci.buffer) > 0 {
ci.flush(ci.buffer)
}
// Avoid touching the scratch buffer as resploop could be using it.
err = ci.cn.sendSimpleMessage('c')
if err != nil {
return err
}
<-ci.done
ci.cn.inCopy = false
if err := ci.err(); err != nil {
return err
}
return nil
}
================================================
FILE: gossh/pgsql/encode.go
================================================
package pgsql
import (
"bytes"
"database/sql/driver"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
"math"
"regexp"
"strconv"
"strings"
"sync"
"time"
"gossh/pgsql/oid"
)
var time2400Regex = regexp.MustCompile(`^(24:00(?::00(?:\.0+)?)?)(?:[Z+-].*)?$`)
func binaryEncode(parameterStatus *parameterStatus, x any) []byte {
switch v := x.(type) {
case []byte:
return v
default:
return encode(parameterStatus, x, oid.T_unknown)
}
}
func encode(parameterStatus *parameterStatus, x any, pgtypOid oid.Oid) []byte {
switch v := x.(type) {
case int64:
return strconv.AppendInt(nil, v, 10)
case float64:
return strconv.AppendFloat(nil, v, 'f', -1, 64)
case []byte:
if pgtypOid == oid.T_bytea {
return encodeBytea(parameterStatus.serverVersion, v)
}
return v
case string:
if pgtypOid == oid.T_bytea {
return encodeBytea(parameterStatus.serverVersion, []byte(v))
}
return []byte(v)
case bool:
return strconv.AppendBool(nil, v)
case time.Time:
return formatTs(v)
default:
errorf("encode: unknown type for %T", v)
}
panic("not reached")
}
func decode(parameterStatus *parameterStatus, s []byte, typ oid.Oid, f format) any {
switch f {
case formatBinary:
return binaryDecode(parameterStatus, s, typ)
case formatText:
return textDecode(parameterStatus, s, typ)
default:
panic("not reached")
}
}
func binaryDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) any {
switch typ {
case oid.T_bytea:
return s
case oid.T_int8:
return int64(binary.BigEndian.Uint64(s))
case oid.T_int4:
return int64(int32(binary.BigEndian.Uint32(s)))
case oid.T_int2:
return int64(int16(binary.BigEndian.Uint16(s)))
case oid.T_uuid:
b, err := decodeUUIDBinary(s)
if err != nil {
panic(err)
}
return b
default:
errorf("don't know how to decode binary parameter of type %d", uint32(typ))
}
panic("not reached")
}
func textDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) any {
switch typ {
case oid.T_char, oid.T_varchar, oid.T_text:
return string(s)
case oid.T_bytea:
b, err := parseBytea(s)
if err != nil {
errorf("%s", err)
}
return b
case oid.T_timestamptz:
return parseTs(parameterStatus.currentLocation, string(s))
case oid.T_timestamp, oid.T_date:
return parseTs(nil, string(s))
case oid.T_time:
return mustParse("15:04:05", typ, s)
case oid.T_timetz:
return mustParse("15:04:05-07", typ, s)
case oid.T_bool:
return s[0] == 't'
case oid.T_int8, oid.T_int4, oid.T_int2:
i, err := strconv.ParseInt(string(s), 10, 64)
if err != nil {
errorf("%s", err)
}
return i
case oid.T_float4, oid.T_float8:
// We always use 64 bit parsing, regardless of whether the input text is for
// a float4 or float8, because clients expect float64s for all float datatypes
// and returning a 32-bit parsed float64 produces lossy results.
f, err := strconv.ParseFloat(string(s), 64)
if err != nil {
errorf("%s", err)
}
return f
}
return s
}
// appendEncodedText encodes item in text format as required by COPY
// and appends to buf
func appendEncodedText(parameterStatus *parameterStatus, buf []byte, x any) []byte {
switch v := x.(type) {
case int64:
return strconv.AppendInt(buf, v, 10)
case float64:
return strconv.AppendFloat(buf, v, 'f', -1, 64)
case []byte:
encodedBytea := encodeBytea(parameterStatus.serverVersion, v)
return appendEscapedText(buf, string(encodedBytea))
case string:
return appendEscapedText(buf, v)
case bool:
return strconv.AppendBool(buf, v)
case time.Time:
return append(buf, formatTs(v)...)
case nil:
return append(buf, "\\N"...)
default:
errorf("encode: unknown type for %T", v)
}
panic("not reached")
}
func appendEscapedText(buf []byte, text string) []byte {
escapeNeeded := false
startPos := 0
var c byte
// check if we need to escape
for i := 0; i < len(text); i++ {
c = text[i]
if c == '\\' || c == '\n' || c == '\r' || c == '\t' {
escapeNeeded = true
startPos = i
break
}
}
if !escapeNeeded {
return append(buf, text...)
}
// copy till first char to escape, iterate the rest
result := append(buf, text[:startPos]...)
for i := startPos; i < len(text); i++ {
c = text[i]
switch c {
case '\\':
result = append(result, '\\', '\\')
case '\n':
result = append(result, '\\', 'n')
case '\r':
result = append(result, '\\', 'r')
case '\t':
result = append(result, '\\', 't')
default:
result = append(result, c)
}
}
return result
}
func mustParse(f string, typ oid.Oid, s []byte) time.Time {
str := string(s)
// Check for a minute and second offset in the timezone.
if typ == oid.T_timestamptz || typ == oid.T_timetz {
for i := 3; i <= 6; i += 3 {
if str[len(str)-i] == ':' {
f += ":00"
continue
}
break
}
}
// Special case for 24:00 time.
// Unfortunately, golang does not parse 24:00 as a proper time.
// In this case, we want to try "round to the next day", to differentiate.
// As such, we find if the 24:00 time matches at the beginning; if so,
// we default it back to 00:00 but add a day later.
var is2400Time bool
switch typ {
case oid.T_timetz, oid.T_time:
if matches := time2400Regex.FindStringSubmatch(str); matches != nil {
// Concatenate timezone information at the back.
str = "00:00:00" + str[len(matches[1]):]
is2400Time = true
}
}
t, err := time.Parse(f, str)
if err != nil {
errorf("decode: %s", err)
}
if is2400Time {
t = t.Add(24 * time.Hour)
}
return t
}
var errInvalidTimestamp = errors.New("invalid timestamp")
type timestampParser struct {
err error
}
func (p *timestampParser) expect(str string, char byte, pos int) {
if p.err != nil {
return
}
if pos+1 > len(str) {
p.err = errInvalidTimestamp
return
}
if c := str[pos]; c != char && p.err == nil {
p.err = fmt.Errorf("expected '%v' at position %v; got '%v'", char, pos, c)
}
}
func (p *timestampParser) mustAtoi(str string, begin int, end int) int {
if p.err != nil {
return 0
}
if begin < 0 || end < 0 || begin > end || end > len(str) {
p.err = errInvalidTimestamp
return 0
}
result, err := strconv.Atoi(str[begin:end])
if err != nil {
if p.err == nil {
p.err = fmt.Errorf("expected number; got '%v'", str)
}
return 0
}
return result
}
// The location cache caches the time zones typically used by the client.
type locationCache struct {
cache map[int]*time.Location
lock sync.Mutex
}
// All connections share the same list of timezones. Benchmarking shows that
// about 5% speed could be gained by putting the cache in the connection and
// losing the mutex, at the cost of a small amount of memory and a somewhat
// significant increase in code complexity.
var globalLocationCache = newLocationCache()
func newLocationCache() *locationCache {
return &locationCache{cache: make(map[int]*time.Location)}
}
// Returns the cached timezone for the specified offset, creating and caching
// it if necessary.
func (c *locationCache) getLocation(offset int) *time.Location {
c.lock.Lock()
defer c.lock.Unlock()
location, ok := c.cache[offset]
if !ok {
location = time.FixedZone("", offset)
c.cache[offset] = location
}
return location
}
var infinityTsEnabled = false
var infinityTsNegative time.Time
var infinityTsPositive time.Time
const (
infinityTsEnabledAlready = "pq: infinity timestamp enabled already"
infinityTsNegativeMustBeSmaller = "pq: infinity timestamp: negative value must be smaller (before) than positive"
)
// EnableInfinityTs controls the handling of Postgres' "-infinity" and
// "infinity" "timestamp"s.
//
// If EnableInfinityTs is not called, "-infinity" and "infinity" will return
// []byte("-infinity") and []byte("infinity") respectively, and potentially
// cause error "sql: Scan error on column index 0: unsupported driver -> Scan
// pair: []uint8 -> *time.Time", when scanning into a time.Time value.
//
// Once EnableInfinityTs has been called, all connections created using this
// driver will decode Postgres' "-infinity" and "infinity" for "timestamp",
// "timestamp with time zone" and "date" types to the predefined minimum and
// maximum times, respectively. When encoding time.Time values, any time which
// equals or precedes the predefined minimum time will be encoded to
// "-infinity". Any values at or past the maximum time will similarly be
// encoded to "infinity".
//
// If EnableInfinityTs is called with negative >= positive, it will panic.
// Calling EnableInfinityTs after a connection has been established results in
// undefined behavior. If EnableInfinityTs is called more than once, it will
// panic.
func EnableInfinityTs(negative time.Time, positive time.Time) {
if infinityTsEnabled {
panic(infinityTsEnabledAlready)
}
if !negative.Before(positive) {
panic(infinityTsNegativeMustBeSmaller)
}
infinityTsEnabled = true
infinityTsNegative = negative
infinityTsPositive = positive
}
/*
* Testing might want to toggle infinityTsEnabled
*/
func disableInfinityTs() {
infinityTsEnabled = false
}
// This is a time function specific to the Postgres default DateStyle
// setting ("ISO, MDY"), the only one we currently support. This
// accounts for the discrepancies between the parsing available with
// time.Parse and the Postgres date formatting quirks.
func parseTs(currentLocation *time.Location, str string) any {
switch str {
case "-infinity":
if infinityTsEnabled {
return infinityTsNegative
}
return []byte(str)
case "infinity":
if infinityTsEnabled {
return infinityTsPositive
}
return []byte(str)
}
t, err := ParseTimestamp(currentLocation, str)
if err != nil {
panic(err)
}
return t
}
// ParseTimestamp parses Postgres' text format. It returns a time.Time in
// currentLocation iff that time's offset agrees with the offset sent from the
// Postgres server. Otherwise, ParseTimestamp returns a time.Time with the
// fixed offset offset provided by the Postgres server.
func ParseTimestamp(currentLocation *time.Location, str string) (time.Time, error) {
p := timestampParser{}
monSep := strings.IndexRune(str, '-')
// this is Gregorian year, not ISO Year
// In Gregorian system, the year 1 BC is followed by AD 1
year := p.mustAtoi(str, 0, monSep)
daySep := monSep + 3
month := p.mustAtoi(str, monSep+1, daySep)
p.expect(str, '-', daySep)
timeSep := daySep + 3
day := p.mustAtoi(str, daySep+1, timeSep)
minLen := monSep + len("01-01") + 1
isBC := strings.HasSuffix(str, " BC")
if isBC {
minLen += 3
}
var hour, minute, second int
if len(str) > minLen {
p.expect(str, ' ', timeSep)
minSep := timeSep + 3
p.expect(str, ':', minSep)
hour = p.mustAtoi(str, timeSep+1, minSep)
secSep := minSep + 3
p.expect(str, ':', secSep)
minute = p.mustAtoi(str, minSep+1, secSep)
secEnd := secSep + 3
second = p.mustAtoi(str, secSep+1, secEnd)
}
remainderIdx := monSep + len("01-01 00:00:00") + 1
// Three optional (but ordered) sections follow: the
// fractional seconds, the time zone offset, and the BC
// designation. We set them up here and adjust the other
// offsets if the preceding sections exist.
nanoSec := 0
tzOff := 0
if remainderIdx < len(str) && str[remainderIdx] == '.' {
fracStart := remainderIdx + 1
fracOff := strings.IndexAny(str[fracStart:], "-+Z ")
if fracOff < 0 {
fracOff = len(str) - fracStart
}
fracSec := p.mustAtoi(str, fracStart, fracStart+fracOff)
nanoSec = fracSec * (1000000000 / int(math.Pow(10, float64(fracOff))))
remainderIdx += fracOff + 1
}
if tzStart := remainderIdx; tzStart < len(str) && (str[tzStart] == '-' || str[tzStart] == '+') {
// time zone separator is always '-' or '+' or 'Z' (UTC is +00)
var tzSign int
switch c := str[tzStart]; c {
case '-':
tzSign = -1
case '+':
tzSign = +1
default:
return time.Time{}, fmt.Errorf("expected '-' or '+' at position %v; got %v", tzStart, c)
}
tzHours := p.mustAtoi(str, tzStart+1, tzStart+3)
remainderIdx += 3
var tzMin, tzSec int
if remainderIdx < len(str) && str[remainderIdx] == ':' {
tzMin = p.mustAtoi(str, remainderIdx+1, remainderIdx+3)
remainderIdx += 3
}
if remainderIdx < len(str) && str[remainderIdx] == ':' {
tzSec = p.mustAtoi(str, remainderIdx+1, remainderIdx+3)
remainderIdx += 3
}
tzOff = tzSign * ((tzHours * 60 * 60) + (tzMin * 60) + tzSec)
} else if tzStart < len(str) && str[tzStart] == 'Z' {
// time zone Z separator indicates UTC is +00
remainderIdx += 1
}
var isoYear int
if isBC {
isoYear = 1 - year
remainderIdx += 3
} else {
isoYear = year
}
if remainderIdx < len(str) {
return time.Time{}, fmt.Errorf("expected end of input, got %v", str[remainderIdx:])
}
t := time.Date(isoYear, time.Month(month), day,
hour, minute, second, nanoSec,
globalLocationCache.getLocation(tzOff))
if currentLocation != nil {
// Set the location of the returned Time based on the session's
// TimeZone value, but only if the local time zone database agrees with
// the remote database on the offset.
lt := t.In(currentLocation)
_, newOff := lt.Zone()
if newOff == tzOff {
t = lt
}
}
return t, p.err
}
// formatTs formats t into a format postgres understands.
func formatTs(t time.Time) []byte {
if infinityTsEnabled {
// t <= -infinity : ! (t > -infinity)
if !t.After(infinityTsNegative) {
return []byte("-infinity")
}
// t >= infinity : ! (!t < infinity)
if !t.Before(infinityTsPositive) {
return []byte("infinity")
}
}
return FormatTimestamp(t)
}
// FormatTimestamp formats t into Postgres' text format for timestamps.
func FormatTimestamp(t time.Time) []byte {
// Need to send dates before 0001 A.D. with " BC" suffix, instead of the
// minus sign preferred by Go.
// Beware, "0000" in ISO is "1 BC", "-0001" is "2 BC" and so on
bc := false
if t.Year() <= 0 {
// flip year sign, and add 1, e.g: "0" will be "1", and "-10" will be "11"
t = t.AddDate((-t.Year())*2+1, 0, 0)
bc = true
}
b := []byte(t.Format("2006-01-02 15:04:05.999999999Z07:00"))
_, offset := t.Zone()
offset %= 60
if offset != 0 {
// RFC3339Nano already printed the minus sign
if offset < 0 {
offset = -offset
}
b = append(b, ':')
if offset < 10 {
b = append(b, '0')
}
b = strconv.AppendInt(b, int64(offset), 10)
}
if bc {
b = append(b, " BC"...)
}
return b
}
// Parse a bytea value received from the server. Both "hex" and the legacy
// "escape" format are supported.
func parseBytea(s []byte) (result []byte, err error) {
if len(s) >= 2 && bytes.Equal(s[:2], []byte("\\x")) {
// bytea_output = hex
s = s[2:] // trim off leading "\\x"
result = make([]byte, hex.DecodedLen(len(s)))
_, err := hex.Decode(result, s)
if err != nil {
return nil, err
}
} else {
// bytea_output = escape
for len(s) > 0 {
if s[0] == '\\' {
// escaped '\\'
if len(s) >= 2 && s[1] == '\\' {
result = append(result, '\\')
s = s[2:]
continue
}
// '\\' followed by an octal number
if len(s) < 4 {
return nil, fmt.Errorf("invalid bytea sequence %v", s)
}
r, err := strconv.ParseUint(string(s[1:4]), 8, 8)
if err != nil {
return nil, fmt.Errorf("could not parse bytea value: %s", err.Error())
}
result = append(result, byte(r))
s = s[4:]
} else {
// We hit an unescaped, raw byte. Try to read in as many as
// possible in one go.
i := bytes.IndexByte(s, '\\')
if i == -1 {
result = append(result, s...)
break
}
result = append(result, s[:i]...)
s = s[i:]
}
}
}
return result, nil
}
func encodeBytea(serverVersion int, v []byte) (result []byte) {
if serverVersion >= 90000 {
// Use the hex format if we know that the server supports it
result = make([]byte, 2+hex.EncodedLen(len(v)))
result[0] = '\\'
result[1] = 'x'
hex.Encode(result[2:], v)
} else {
// .. or resort to "escape"
for _, b := range v {
if b == '\\' {
result = append(result, '\\', '\\')
} else if b < 0x20 || b > 0x7e {
result = append(result, []byte(fmt.Sprintf("\\%03o", b))...)
} else {
result = append(result, b)
}
}
}
return result
}
// NullTime represents a time.Time that may be null. NullTime implements the
// sql.Scanner interface so it can be used as a scan destination, similar to
// sql.NullString.
type NullTime struct {
Time time.Time
Valid bool // Valid is true if Time is not NULL
}
// Scan implements the Scanner interface.
func (nt *NullTime) Scan(value any) error {
nt.Time, nt.Valid = value.(time.Time)
return nil
}
// Value implements the driver Valuer interface.
func (nt NullTime) Value() (driver.Value, error) {
if !nt.Valid {
return nil, nil
}
return nt.Time, nil
}
================================================
FILE: gossh/pgsql/error.go
================================================
package pgsql
import (
"database/sql/driver"
"fmt"
"io"
"net"
"runtime"
)
// Error severities
const (
Efatal = "FATAL"
Epanic = "PANIC"
Ewarning = "WARNING"
Enotice = "NOTICE"
Edebug = "DEBUG"
Einfo = "INFO"
Elog = "LOG"
)
// Error represents an error communicating with the server.
//
// See http://www.postgresql.org/docs/current/static/protocol-error-fields.html for details of the fields
type Error struct {
Severity string
Code ErrorCode
Message string
Detail string
Hint string
Position string
InternalPosition string
InternalQuery string
Where string
Schema string
Table string
Column string
DataTypeName string
Constraint string
File string
Line string
Routine string
}
// ErrorCode is a five-character error code.
type ErrorCode string
// Name returns a more human friendly rendering of the error code, namely the
// "condition name".
//
// See http://www.postgresql.org/docs/9.3/static/errcodes-appendix.html for
// details.
func (ec ErrorCode) Name() string {
return errorCodeNames[ec]
}
// ErrorClass is only the class part of an error code.
type ErrorClass string
// Name returns the condition name of an error class. It is equivalent to the
// condition name of the "standard" error code (i.e. the one having the last
// three characters "000").
func (ec ErrorClass) Name() string {
return errorCodeNames[ErrorCode(ec+"000")]
}
// Class returns the error class, e.g. "28".
//
// See http://www.postgresql.org/docs/9.3/static/errcodes-appendix.html for
// details.
func (ec ErrorCode) Class() ErrorClass {
return ErrorClass(ec[0:2])
}
// errorCodeNames is a mapping between the five-character error codes and the
// human readable "condition names". It is derived from the list at
// http://www.postgresql.org/docs/9.3/static/errcodes-appendix.html
var errorCodeNames = map[ErrorCode]string{
// Class 00 - Successful Completion
"00000": "successful_completion",
// Class 01 - Warning
"01000": "warning",
"0100C": "dynamic_result_sets_returned",
"01008": "implicit_zero_bit_padding",
"01003": "null_value_eliminated_in_set_function",
"01007": "privilege_not_granted",
"01006": "privilege_not_revoked",
"01004": "string_data_right_truncation",
"01P01": "deprecated_feature",
// Class 02 - No Data (this is also a warning class per the SQL standard)
"02000": "no_data",
"02001": "no_additional_dynamic_result_sets_returned",
// Class 03 - SQL Statement Not Yet Complete
"03000": "sql_statement_not_yet_complete",
// Class 08 - Connection Exception
"08000": "connection_exception",
"08003": "connection_does_not_exist",
"08006": "connection_failure",
"08001": "sqlclient_unable_to_establish_sqlconnection",
"08004": "sqlserver_rejected_establishment_of_sqlconnection",
"08007": "transaction_resolution_unknown",
"08P01": "protocol_violation",
// Class 09 - Triggered Action Exception
"09000": "triggered_action_exception",
// Class 0A - Feature Not Supported
"0A000": "feature_not_supported",
// Class 0B - Invalid Transaction Initiation
"0B000": "invalid_transaction_initiation",
// Class 0F - Locator Exception
"0F000": "locator_exception",
"0F001": "invalid_locator_specification",
// Class 0L - Invalid Grantor
"0L000": "invalid_grantor",
"0LP01": "invalid_grant_operation",
// Class 0P - Invalid Role Specification
"0P000": "invalid_role_specification",
// Class 0Z - Diagnostics Exception
"0Z000": "diagnostics_exception",
"0Z002": "stacked_diagnostics_accessed_without_active_handler",
// Class 20 - Case Not Found
"20000": "case_not_found",
// Class 21 - Cardinality Violation
"21000": "cardinality_violation",
// Class 22 - Data Exception
"22000": "data_exception",
"2202E": "array_subscript_error",
"22021": "character_not_in_repertoire",
"22008": "datetime_field_overflow",
"22012": "division_by_zero",
"22005": "error_in_assignment",
"2200B": "escape_character_conflict",
"22022": "indicator_overflow",
"22015": "interval_field_overflow",
"2201E": "invalid_argument_for_logarithm",
"22014": "invalid_argument_for_ntile_function",
"22016": "invalid_argument_for_nth_value_function",
"2201F": "invalid_argument_for_power_function",
"2201G": "invalid_argument_for_width_bucket_function",
"22018": "invalid_character_value_for_cast",
"22007": "invalid_datetime_format",
"22019": "invalid_escape_character",
"2200D": "invalid_escape_octet",
"22025": "invalid_escape_sequence",
"22P06": "nonstandard_use_of_escape_character",
"22010": "invalid_indicator_parameter_value",
"22023": "invalid_parameter_value",
"2201B": "invalid_regular_expression",
"2201W": "invalid_row_count_in_limit_clause",
"2201X": "invalid_row_count_in_result_offset_clause",
"22009": "invalid_time_zone_displacement_value",
"2200C": "invalid_use_of_escape_character",
"2200G": "most_specific_type_mismatch",
"22004": "null_value_not_allowed",
"22002": "null_value_no_indicator_parameter",
"22003": "numeric_value_out_of_range",
"2200H": "sequence_generator_limit_exceeded",
"22026": "string_data_length_mismatch",
"22001": "string_data_right_truncation",
"22011": "substring_error",
"22027": "trim_error",
"22024": "unterminated_c_string",
"2200F": "zero_length_character_string",
"22P01": "floating_point_exception",
"22P02": "invalid_text_representation",
"22P03": "invalid_binary_representation",
"22P04": "bad_copy_file_format",
"22P05": "untranslatable_character",
"2200L": "not_an_xml_document",
"2200M": "invalid_xml_document",
"2200N": "invalid_xml_content",
"2200S": "invalid_xml_comment",
"2200T": "invalid_xml_processing_instruction",
// Class 23 - Integrity Constraint Violation
"23000": "integrity_constraint_violation",
"23001": "restrict_violation",
"23502": "not_null_violation",
"23503": "foreign_key_violation",
"23505": "unique_violation",
"23514": "check_violation",
"23P01": "exclusion_violation",
// Class 24 - Invalid Cursor State
"24000": "invalid_cursor_state",
// Class 25 - Invalid Transaction State
"25000": "invalid_transaction_state",
"25001": "active_sql_transaction",
"25002": "branch_transaction_already_active",
"25008": "held_cursor_requires_same_isolation_level",
"25003": "inappropriate_access_mode_for_branch_transaction",
"25004": "inappropriate_isolation_level_for_branch_transaction",
"25005": "no_active_sql_transaction_for_branch_transaction",
"25006": "read_only_sql_transaction",
"25007": "schema_and_data_statement_mixing_not_supported",
"25P01": "no_active_sql_transaction",
"25P02": "in_failed_sql_transaction",
// Class 26 - Invalid SQL Statement Name
"26000": "invalid_sql_statement_name",
// Class 27 - Triggered Data Change Violation
"27000": "triggered_data_change_violation",
// Class 28 - Invalid Authorization Specification
"28000": "invalid_authorization_specification",
"28P01": "invalid_password",
// Class 2B - Dependent Privilege Descriptors Still Exist
"2B000": "dependent_privilege_descriptors_still_exist",
"2BP01": "dependent_objects_still_exist",
// Class 2D - Invalid Transaction Termination
"2D000": "invalid_transaction_termination",
// Class 2F - SQL Routine Exception
"2F000": "sql_routine_exception",
"2F005": "function_executed_no_return_statement",
"2F002": "modifying_sql_data_not_permitted",
"2F003": "prohibited_sql_statement_attempted",
"2F004": "reading_sql_data_not_permitted",
// Class 34 - Invalid Cursor Name
"34000": "invalid_cursor_name",
// Class 38 - External Routine Exception
"38000": "external_routine_exception",
"38001": "containing_sql_not_permitted",
"38002": "modifying_sql_data_not_permitted",
"38003": "prohibited_sql_statement_attempted",
"38004": "reading_sql_data_not_permitted",
// Class 39 - External Routine Invocation Exception
"39000": "external_routine_invocation_exception",
"39001": "invalid_sqlstate_returned",
"39004": "null_value_not_allowed",
"39P01": "trigger_protocol_violated",
"39P02": "srf_protocol_violated",
// Class 3B - Savepoint Exception
"3B000": "savepoint_exception",
"3B001": "invalid_savepoint_specification",
// Class 3D - Invalid Catalog Name
"3D000": "invalid_catalog_name",
// Class 3F - Invalid Schema Name
"3F000": "invalid_schema_name",
// Class 40 - Transaction Rollback
"40000": "transaction_rollback",
"40002": "transaction_integrity_constraint_violation",
"40001": "serialization_failure",
"40003": "statement_completion_unknown",
"40P01": "deadlock_detected",
// Class 42 - Syntax Error or Access Rule Violation
"42000": "syntax_error_or_access_rule_violation",
"42601": "syntax_error",
"42501": "insufficient_privilege",
"42846": "cannot_coerce",
"42803": "grouping_error",
"42P20": "windowing_error",
"42P19": "invalid_recursion",
"42830": "invalid_foreign_key",
"42602": "invalid_name",
"42622": "name_too_long",
"42939": "reserved_name",
"42804": "datatype_mismatch",
"42P18": "indeterminate_datatype",
"42P21": "collation_mismatch",
"42P22": "indeterminate_collation",
"42809": "wrong_object_type",
"42703": "undefined_column",
"42883": "undefined_function",
"42P01": "undefined_table",
"42P02": "undefined_parameter",
"42704": "undefined_object",
"42701": "duplicate_column",
"42P03": "duplicate_cursor",
"42P04": "duplicate_database",
"42723": "duplicate_function",
"42P05": "duplicate_prepared_statement",
"42P06": "duplicate_schema",
"42P07": "duplicate_table",
"42712": "duplicate_alias",
"42710": "duplicate_object",
"42702": "ambiguous_column",
"42725": "ambiguous_function",
"42P08": "ambiguous_parameter",
"42P09": "ambiguous_alias",
"42P10": "invalid_column_reference",
"42611": "invalid_column_definition",
"42P11": "invalid_cursor_definition",
"42P12": "invalid_database_definition",
"42P13": "invalid_function_definition",
"42P14": "invalid_prepared_statement_definition",
"42P15": "invalid_schema_definition",
"42P16": "invalid_table_definition",
"42P17": "invalid_object_definition",
// Class 44 - WITH CHECK OPTION Violation
"44000": "with_check_option_violation",
// Class 53 - Insufficient Resources
"53000": "insufficient_resources",
"53100": "disk_full",
"53200": "out_of_memory",
"53300": "too_many_connections",
"53400": "configuration_limit_exceeded",
// Class 54 - Program Limit Exceeded
"54000": "program_limit_exceeded",
"54001": "statement_too_complex",
"54011": "too_many_columns",
"54023": "too_many_arguments",
// Class 55 - Object Not In Prerequisite State
"55000": "object_not_in_prerequisite_state",
"55006": "object_in_use",
"55P02": "cant_change_runtime_param",
"55P03": "lock_not_available",
// Class 57 - Operator Intervention
"57000": "operator_intervention",
"57014": "query_canceled",
"57P01": "admin_shutdown",
"57P02": "crash_shutdown",
"57P03": "cannot_connect_now",
"57P04": "database_dropped",
// Class 58 - System Error (errors external to PostgreSQL itself)
"58000": "system_error",
"58030": "io_error",
"58P01": "undefined_file",
"58P02": "duplicate_file",
// Class F0 - Configuration File Error
"F0000": "config_file_error",
"F0001": "lock_file_exists",
// Class HV - Foreign Data Wrapper Error (SQL/MED)
"HV000": "fdw_error",
"HV005": "fdw_column_name_not_found",
"HV002": "fdw_dynamic_parameter_value_needed",
"HV010": "fdw_function_sequence_error",
"HV021": "fdw_inconsistent_descriptor_information",
"HV024": "fdw_invalid_attribute_value",
"HV007": "fdw_invalid_column_name",
"HV008": "fdw_invalid_column_number",
"HV004": "fdw_invalid_data_type",
"HV006": "fdw_invalid_data_type_descriptors",
"HV091": "fdw_invalid_descriptor_field_identifier",
"HV00B": "fdw_invalid_handle",
"HV00C": "fdw_invalid_option_index",
"HV00D": "fdw_invalid_option_name",
"HV090": "fdw_invalid_string_length_or_buffer_length",
"HV00A": "fdw_invalid_string_format",
"HV009": "fdw_invalid_use_of_null_pointer",
"HV014": "fdw_too_many_handles",
"HV001": "fdw_out_of_memory",
"HV00P": "fdw_no_schemas",
"HV00J": "fdw_option_name_not_found",
"HV00K": "fdw_reply_handle",
"HV00Q": "fdw_schema_not_found",
"HV00R": "fdw_table_not_found",
"HV00L": "fdw_unable_to_create_execution",
"HV00M": "fdw_unable_to_create_reply",
"HV00N": "fdw_unable_to_establish_connection",
// Class P0 - PL/pgSQL Error
"P0000": "plpgsql_error",
"P0001": "raise_exception",
"P0002": "no_data_found",
"P0003": "too_many_rows",
// Class XX - Internal Error
"XX000": "internal_error",
"XX001": "data_corrupted",
"XX002": "index_corrupted",
}
func parseError(r *readBuf) *Error {
err := new(Error)
for t := r.byte(); t != 0; t = r.byte() {
msg := r.string()
switch t {
case 'S':
err.Severity = msg
case 'C':
err.Code = ErrorCode(msg)
case 'M':
err.Message = msg
case 'D':
err.Detail = msg
case 'H':
err.Hint = msg
case 'P':
err.Position = msg
case 'p':
err.InternalPosition = msg
case 'q':
err.InternalQuery = msg
case 'W':
err.Where = msg
case 's':
err.Schema = msg
case 't':
err.Table = msg
case 'c':
err.Column = msg
case 'd':
err.DataTypeName = msg
case 'n':
err.Constraint = msg
case 'F':
err.File = msg
case 'L':
err.Line = msg
case 'R':
err.Routine = msg
}
}
return err
}
// Fatal returns true if the Error Severity is fatal.
func (err *Error) Fatal() bool {
return err.Severity == Efatal
}
// SQLState returns the SQLState of the error.
func (err *Error) SQLState() string {
return string(err.Code)
}
// Get implements the legacy PGError interface. New code should use the fields
// of the Error struct directly.
func (err *Error) Get(k byte) (v string) {
switch k {
case 'S':
return err.Severity
case 'C':
return string(err.Code)
case 'M':
return err.Message
case 'D':
return err.Detail
case 'H':
return err.Hint
case 'P':
return err.Position
case 'p':
return err.InternalPosition
case 'q':
return err.InternalQuery
case 'W':
return err.Where
case 's':
return err.Schema
case 't':
return err.Table
case 'c':
return err.Column
case 'd':
return err.DataTypeName
case 'n':
return err.Constraint
case 'F':
return err.File
case 'L':
return err.Line
case 'R':
return err.Routine
}
return ""
}
func (err *Error) Error() string {
return "pq: " + err.Message
}
// PGError is an interface used by previous versions of pq. It is provided
// only to support legacy code. New code should use the Error type.
type PGError interface {
Error() string
Fatal() bool
Get(k byte) (v string)
}
func errorf(s string, args ...any) {
panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...)))
}
// TODO(ainar-g) Rename to errorf after removing panics.
func fmterrorf(s string, args ...any) error {
return fmt.Errorf("pq: %s", fmt.Sprintf(s, args...))
}
func errRecoverNoErrBadConn(err *error) {
e := recover()
if e == nil {
// Do nothing
return
}
var ok bool
*err, ok = e.(error)
if !ok {
*err = fmt.Errorf("pq: unexpected error: %#v", e)
}
}
func (cn *conn) errRecover(err *error) {
e := recover()
switch v := e.(type) {
case nil:
// Do nothing
case runtime.Error:
cn.err.set(driver.ErrBadConn)
panic(v)
case *Error:
if v.Fatal() {
*err = driver.ErrBadConn
} else {
*err = v
}
case *net.OpError:
cn.err.set(driver.ErrBadConn)
*err = v
case *safeRetryError:
cn.err.set(driver.ErrBadConn)
*err = driver.ErrBadConn
case error:
if v == io.EOF || v.Error() == "remote error: handshake failure" {
*err = driver.ErrBadConn
} else {
*err = v
}
default:
cn.err.set(driver.ErrBadConn)
panic(fmt.Sprintf("unknown error: %#v", e))
}
// Any time we return ErrBadConn, we need to remember it since *Tx doesn't
// mark the connection bad in database/sql.
if *err == driver.ErrBadConn {
cn.err.set(driver.ErrBadConn)
}
}
================================================
FILE: gossh/pgsql/hstore/hstore.go
================================================
package hstore
import (
"database/sql"
"database/sql/driver"
"strings"
)
// Hstore is a wrapper for transferring Hstore values back and forth easily.
type Hstore struct {
Map map[string]sql.NullString
}
// escapes and quotes hstore keys/values
// s should be a sql.NullString or string
func hQuote(s any) string {
var str string
switch v := s.(type) {
case sql.NullString:
if !v.Valid {
return "NULL"
}
str = v.String
case string:
str = v
default:
panic("not a string or sql.NullString")
}
str = strings.Replace(str, "\\", "\\\\", -1)
return `"` + strings.Replace(str, "\"", "\\\"", -1) + `"`
}
// Scan implements the Scanner interface.
//
// Note h.Map is reallocated before the scan to clear existing values. If the
// hstore column's database value is NULL, then h.Map is set to nil instead.
func (h *Hstore) Scan(value any) error {
if value == nil {
h.Map = nil
return nil
}
h.Map = make(map[string]sql.NullString)
var b byte
pair := [][]byte{{}, {}}
pi := 0
inQuote := false
didQuote := false
sawSlash := false
bindex := 0
for bindex, b = range value.([]byte) {
if sawSlash {
pair[pi] = append(pair[pi], b)
sawSlash = false
continue
}
switch b {
case '\\':
sawSlash = true
continue
case '"':
inQuote = !inQuote
if !didQuote {
didQuote = true
}
continue
default:
if !inQuote {
switch b {
case ' ', '\t', '\n', '\r':
continue
case '=':
continue
case '>':
pi = 1
didQuote = false
continue
case ',':
s := string(pair[1])
if !didQuote && len(s) == 4 && strings.ToLower(s) == "null" {
h.Map[string(pair[0])] = sql.NullString{String: "", Valid: false}
} else {
h.Map[string(pair[0])] = sql.NullString{String: string(pair[1]), Valid: true}
}
pair[0] = []byte{}
pair[1] = []byte{}
pi = 0
continue
}
}
}
pair[pi] = append(pair[pi], b)
}
if bindex > 0 {
s := string(pair[1])
if !didQuote && len(s) == 4 && strings.ToLower(s) == "null" {
h.Map[string(pair[0])] = sql.NullString{String: "", Valid: false}
} else {
h.Map[string(pair[0])] = sql.NullString{String: string(pair[1]), Valid: true}
}
}
return nil
}
// Value implements the driver Valuer interface. Note if h.Map is nil, the
// database column value will be set to NULL.
func (h Hstore) Value() (driver.Value, error) {
if h.Map == nil {
return nil, nil
}
parts := []string{}
for key, val := range h.Map {
thispart := hQuote(key) + "=>" + hQuote(val)
parts = append(parts, thispart)
}
return []byte(strings.Join(parts, ",")), nil
}
================================================
FILE: gossh/pgsql/krb.go
================================================
package pgsql
// NewGSSFunc creates a GSS authentication provider, for use with
// RegisterGSSProvider.
type NewGSSFunc func() (GSS, error)
var newGss NewGSSFunc
// RegisterGSSProvider registers a GSS authentication provider. For example, if
// you need to use Kerberos to authenticate with your server, add this to your
// main package:
//
// import "github.com/lib/pq/auth/kerberos"
//
// func init() {
// pq.RegisterGSSProvider(func() (pq.GSS, error) { return kerberos.NewGSS() })
// }
func RegisterGSSProvider(newGssArg NewGSSFunc) {
newGss = newGssArg
}
// GSS provides GSSAPI authentication (e.g., Kerberos).
type GSS interface {
GetInitToken(host string, service string) ([]byte, error)
GetInitTokenFromSpn(spn string) ([]byte, error)
Continue(inToken []byte) (done bool, outToken []byte, err error)
}
================================================
FILE: gossh/pgsql/notice.go
================================================
//go:build go1.10
// +build go1.10
package pgsql
import (
"context"
"database/sql/driver"
)
// NoticeHandler returns the notice handler on the given connection, if any. A
// runtime panic occurs if c is not a pq connection. This is rarely used
// directly, use ConnectorNoticeHandler and ConnectorWithNoticeHandler instead.
func NoticeHandler(c driver.Conn) func(*Error) {
return c.(*conn).noticeHandler
}
// SetNoticeHandler sets the given notice handler on the given connection. A
// runtime panic occurs if c is not a pq connection. A nil handler may be used
// to unset it. This is rarely used directly, use ConnectorNoticeHandler and
// ConnectorWithNoticeHandler instead.
//
// Note: Notice handlers are executed synchronously by pq meaning commands
// won't continue to be processed until the handler returns.
func SetNoticeHandler(c driver.Conn, handler func(*Error)) {
c.(*conn).noticeHandler = handler
}
// NoticeHandlerConnector wraps a regular connector and sets a notice handler
// on it.
type NoticeHandlerConnector struct {
driver.Connector
noticeHandler func(*Error)
}
// Connect calls the underlying connector's connect method and then sets the
// notice handler.
func (n *NoticeHandlerConnector) Connect(ctx context.Context) (driver.Conn, error) {
c, err := n.Connector.Connect(ctx)
if err == nil {
SetNoticeHandler(c, n.noticeHandler)
}
return c, err
}
// ConnectorNoticeHandler returns the currently set notice handler, if any. If
// the given connector is not a result of ConnectorWithNoticeHandler, nil is
// returned.
func ConnectorNoticeHandler(c driver.Connector) func(*Error) {
if c, ok := c.(*NoticeHandlerConnector); ok {
return c.noticeHandler
}
return nil
}
// ConnectorWithNoticeHandler creates or sets the given handler for the given
// connector. If the given connector is a result of calling this function
// previously, it is simply set on the given connector and returned. Otherwise,
// this returns a new connector wrapping the given one and setting the notice
// handler. A nil notice handler may be used to unset it.
//
// The returned connector is intended to be used with database/sql.OpenDB.
//
// Note: Notice handlers are executed synchronously by pq meaning commands
// won't continue to be processed until the handler returns.
func ConnectorWithNoticeHandler(c driver.Connector, handler func(*Error)) *NoticeHandlerConnector {
if c, ok := c.(*NoticeHandlerConnector); ok {
c.noticeHandler = handler
return c
}
return &NoticeHandlerConnector{Connector: c, noticeHandler: handler}
}
================================================
FILE: gossh/pgsql/notify.go
================================================
package pgsql
// Package pq is a pure Go Postgres driver for the database/sql package.
// This module contains support for Postgres LISTEN/NOTIFY.
import (
"context"
"database/sql/driver"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
)
// Notification represents a single notification from the database.
type Notification struct {
// Process ID (PID) of the notifying postgres backend.
BePid int
// Name of the channel the notification was sent on.
Channel string
// Payload, or the empty string if unspecified.
Extra string
}
func recvNotification(r *readBuf) *Notification {
bePid := r.int32()
channel := r.string()
extra := r.string()
return &Notification{bePid, channel, extra}
}
// SetNotificationHandler sets the given notification handler on the given
// connection. A runtime panic occurs if c is not a pq connection. A nil handler
// may be used to unset it.
//
// Note: Notification handlers are executed synchronously by pq meaning commands
// won't continue to be processed until the handler returns.
func SetNotificationHandler(c driver.Conn, handler func(*Notification)) {
c.(*conn).notificationHandler = handler
}
// NotificationHandlerConnector wraps a regular connector and sets a notification handler
// on it.
type NotificationHandlerConnector struct {
driver.Connector
notificationHandler func(*Notification)
}
// Connect calls the underlying connector's connect method and then sets the
// notification handler.
func (n *NotificationHandlerConnector) Connect(ctx context.Context) (driver.Conn, error) {
c, err := n.Connector.Connect(ctx)
if err == nil {
SetNotificationHandler(c, n.notificationHandler)
}
return c, err
}
// ConnectorNotificationHandler returns the currently set notification handler, if any. If
// the given connector is not a result of ConnectorWithNotificationHandler, nil is
// returned.
func ConnectorNotificationHandler(c driver.Connector) func(*Notification) {
if c, ok := c.(*NotificationHandlerConnector); ok {
return c.notificationHandler
}
return nil
}
// ConnectorWithNotificationHandler creates or sets the given handler for the given
// connector. If the given connector is a result of calling this function
// previously, it is simply set on the given connector and returned. Otherwise,
// this returns a new connector wrapping the given one and setting the notification
// handler. A nil notification handler may be used to unset it.
//
// The returned connector is intended to be used with database/sql.OpenDB.
//
// Note: Notification handlers are executed synchronously by pq meaning commands
// won't continue to be processed until the handler returns.
func ConnectorWithNotificationHandler(c driver.Connector, handler func(*Notification)) *NotificationHandlerConnector {
if c, ok := c.(*NotificationHandlerConnector); ok {
c.notificationHandler = handler
return c
}
return &NotificationHandlerConnector{Connector: c, notificationHandler: handler}
}
const (
connStateIdle int32 = iota
connStateExpectResponse
connStateExpectReadyForQuery
)
type message struct {
typ byte
err error
}
var errListenerConnClosed = errors.New("pq: ListenerConn has been closed")
// ListenerConn is a low-level interface for waiting for notifications. You
// should use Listener instead.
type ListenerConn struct {
// guards cn and err
connectionLock sync.Mutex
cn *conn
err error
connState int32
// the sending goroutine will be holding this lock
senderLock sync.Mutex
notificationChan chan<- *Notification
replyChan chan message
}
// NewListenerConn creates a new ListenerConn. Use NewListener instead.
func NewListenerConn(name string, notificationChan chan<- *Notification) (*ListenerConn, error) {
return newDialListenerConn(defaultDialer{}, name, notificationChan)
}
func newDialListenerConn(d Dialer, name string, c chan<- *Notification) (*ListenerConn, error) {
cn, err := DialOpen(d, name)
if err != nil {
return nil, err
}
l := &ListenerConn{
cn: cn.(*conn),
notificationChan: c,
connState: connStateIdle,
replyChan: make(chan message, 2),
}
go l.listenerConnMain()
return l, nil
}
// We can only allow one goroutine at a time to be running a query on the
// connection for various reasons, so the goroutine sending on the connection
// must be holding senderLock.
//
// Returns an error if an unrecoverable error has occurred and the ListenerConn
// should be abandoned.
func (l *ListenerConn) acquireSenderLock() error {
// we must acquire senderLock first to avoid deadlocks; see ExecSimpleQuery
l.senderLock.Lock()
l.connectionLock.Lock()
err := l.err
l.connectionLock.Unlock()
if err != nil {
l.senderLock.Unlock()
return err
}
return nil
}
func (l *ListenerConn) releaseSenderLock() {
l.senderLock.Unlock()
}
// setState advances the protocol state to newState. Returns false if moving
// to that state from the current state is not allowed.
func (l *ListenerConn) setState(newState int32) bool {
var expectedState int32
switch newState {
case connStateIdle:
expectedState = connStateExpectReadyForQuery
case connStateExpectResponse:
expectedState = connStateIdle
case connStateExpectReadyForQuery:
expectedState = connStateExpectResponse
default:
panic(fmt.Sprintf("unexpected listenerConnState %d", newState))
}
return atomic.CompareAndSwapInt32(&l.connState, expectedState, newState)
}
// Main logic is here: receive messages from the postgres backend, forward
// notifications and query replies and keep the internal state in sync with the
// protocol state. Returns when the connection has been lost, is about to go
// away or should be discarded because we couldn't agree on the state with the
// server backend.
func (l *ListenerConn) listenerConnLoop() (err error) {
defer errRecoverNoErrBadConn(&err)
r := &readBuf{}
for {
t, err := l.cn.recvMessage(r)
if err != nil {
return err
}
switch t {
case 'A':
// recvNotification copies all the data so we don't need to worry
// about the scratch buffer being overwritten.
l.notificationChan <- recvNotification(r)
case 'T', 'D':
// only used by tests; ignore
case 'E':
// We might receive an ErrorResponse even when not in a query; it
// is expected that the server will close the connection after
// that, but we should make sure that the error we display is the
// one from the stray ErrorResponse, not io.ErrUnexpectedEOF.
if !l.setState(connStateExpectReadyForQuery) {
return parseError(r)
}
l.replyChan <- message{t, parseError(r)}
case 'C', 'I':
if !l.setState(connStateExpectReadyForQuery) {
// protocol out of sync
return fmt.Errorf("unexpected CommandComplete")
}
// ExecSimpleQuery doesn't need to know about this message
case 'Z':
if !l.setState(connStateIdle) {
// protocol out of sync
return fmt.Errorf("unexpected ReadyForQuery")
}
l.replyChan <- message{t, nil}
case 'S':
// ignore
case 'N':
if n := l.cn.noticeHandler; n != nil {
n(parseError(r))
}
default:
return fmt.Errorf("unexpected message %q from server in listenerConnLoop", t)
}
}
}
// This is the main routine for the goroutine receiving on the database
// connection. Most of the main logic is in listenerConnLoop.
func (l *ListenerConn) listenerConnMain() {
err := l.listenerConnLoop()
// listenerConnLoop terminated; we're done, but we still have to clean up.
// Make sure nobody tries to start any new queries by making sure the err
// pointer is set. It is important that we do not overwrite its value; a
// connection could be closed by either this goroutine or one sending on
// the connection -- whoever closes the connection is assumed to have the
// more meaningful error message (as the other one will probably get
// net.errClosed), so that goroutine sets the error we expose while the
// other error is discarded. If the connection is lost while two
// goroutines are operating on the socket, it probably doesn't matter which
// error we expose so we don't try to do anything more complex.
l.connectionLock.Lock()
if l.err == nil {
l.err = err
}
l.cn.Close()
l.connectionLock.Unlock()
// There might be a query in-flight; make sure nobody's waiting for a
// response to it, since there's not going to be one.
close(l.replyChan)
// let the listener know we're done
close(l.notificationChan)
// this ListenerConn is done
}
// Listen sends a LISTEN query to the server. See ExecSimpleQuery.
func (l *ListenerConn) Listen(channel string) (bool, error) {
return l.ExecSimpleQuery("LISTEN " + QuoteIdentifier(channel))
}
// Unlisten sends an UNLISTEN query to the server. See ExecSimpleQuery.
func (l *ListenerConn) Unlisten(channel string) (bool, error) {
return l.ExecSimpleQuery("UNLISTEN " + QuoteIdentifier(channel))
}
// UnlistenAll sends an `UNLISTEN *` query to the server. See ExecSimpleQuery.
func (l *ListenerConn) UnlistenAll() (bool, error) {
return l.ExecSimpleQuery("UNLISTEN *")
}
// Ping the remote server to make sure it's alive. Non-nil error means the
// connection has failed and should be abandoned.
func (l *ListenerConn) Ping() error {
sent, err := l.ExecSimpleQuery("")
if !sent {
return err
}
if err != nil {
// shouldn't happen
panic(err)
}
return nil
}
// Attempt to send a query on the connection. Returns an error if sending the
// query failed, and the caller should initiate closure of this connection.
// The caller must be holding senderLock (see acquireSenderLock and
// releaseSenderLock).
func (l *ListenerConn) sendSimpleQuery(q string) (err error) {
defer errRecoverNoErrBadConn(&err)
// must set connection state before sending the query
if !l.setState(connStateExpectResponse) {
panic("two queries running at the same time")
}
// Can't use l.cn.writeBuf here because it uses the scratch buffer which
// might get overwritten by listenerConnLoop.
b := &writeBuf{
buf: []byte("Q\x00\x00\x00\x00"),
pos: 1,
}
b.string(q)
l.cn.send(b)
return nil
}
// ExecSimpleQuery executes a "simple query" (i.e. one with no bindable
// parameters) on the connection. The possible return values are:
// 1. "executed" is true; the query was executed to completion on the
// database server. If the query failed, err will be set to the error
// returned by the database, otherwise err will be nil.
// 2. If "executed" is false, the query could not be executed on the remote
// server. err will be non-nil.
//
// After a call to ExecSimpleQuery has returned an executed=false value, the
// connection has either been closed or will be closed shortly thereafter, and
// all subsequently executed queries will return an error.
func (l *ListenerConn) ExecSimpleQuery(q string) (executed bool, err error) {
if err = l.acquireSenderLock(); err != nil {
return false, err
}
defer l.releaseSenderLock()
err = l.sendSimpleQuery(q)
if err != nil {
// We can't know what state the protocol is in, so we need to abandon
// this connection.
l.connectionLock.Lock()
// Set the error pointer if it hasn't been set already; see
// listenerConnMain.
if l.err == nil {
l.err = err
}
l.connectionLock.Unlock()
l.cn.c.Close()
return false, err
}
// now we just wait for a reply..
for {
m, ok := <-l.replyChan
if !ok {
// We lost the connection to server, don't bother waiting for a
// a response. err should have been set already.
l.connectionLock.Lock()
err := l.err
l.connectionLock.Unlock()
return false, err
}
switch m.typ {
case 'Z':
// sanity check
if m.err != nil {
panic("m.err != nil")
}
// done; err might or might not be set
return true, err
case 'E':
// sanity check
if m.err == nil {
panic("m.err == nil")
}
// server responded with an error; ReadyForQuery to follow
err = m.err
default:
return false, fmt.Errorf("unknown response for simple query: %q", m.typ)
}
}
}
// Close closes the connection.
func (l *ListenerConn) Close() error {
l.connectionLock.Lock()
if l.err != nil {
l.connectionLock.Unlock()
return errListenerConnClosed
}
l.err = errListenerConnClosed
l.connectionLock.Unlock()
// We can't send anything on the connection without holding senderLock.
// Simply close the net.Conn to wake up everyone operating on it.
return l.cn.c.Close()
}
// Err returns the reason the connection was closed. It is not safe to call
// this function until l.Notify has been closed.
func (l *ListenerConn) Err() error {
return l.err
}
var errListenerClosed = errors.New("pq: Listener has been closed")
// ErrChannelAlreadyOpen is returned from Listen when a channel is already
// open.
var ErrChannelAlreadyOpen = errors.New("pq: channel is already open")
// ErrChannelNotOpen is returned from Unlisten when a channel is not open.
var ErrChannelNotOpen = errors.New("pq: channel is not open")
// ListenerEventType is an enumeration of listener event types.
type ListenerEventType int
const (
// ListenerEventConnected is emitted only when the database connection
// has been initially initialized. The err argument of the callback
// will always be nil.
ListenerEventConnected ListenerEventType = iota
// ListenerEventDisconnected is emitted after a database connection has
// been lost, either because of an error or because Close has been
// called. The err argument will be set to the reason the database
// connection was lost.
ListenerEventDisconnected
// ListenerEventReconnected is emitted after a database connection has
// been re-established after connection loss. The err argument of the
// callback will always be nil. After this event has been emitted, a
// nil pq.Notification is sent on the Listener.Notify channel.
ListenerEventReconnected
// ListenerEventConnectionAttemptFailed is emitted after a connection
// to the database was attempted, but failed. The err argument will be
// set to an error describing why the connection attempt did not
// succeed.
ListenerEventConnectionAttemptFailed
)
// EventCallbackType is the event callback type. See also ListenerEventType
// constants' documentation.
type EventCallbackType func(event ListenerEventType, err error)
// Listener provides an interface for listening to notifications from a
// PostgreSQL database. For general usage information, see section
// "Notifications".
//
// Listener can safely be used from concurrently running goroutines.
type Listener struct {
// Channel for receiving notifications from the database. In some cases a
// nil value will be sent. See section "Notifications" above.
Notify chan *Notification
name string
minReconnectInterval time.Duration
maxReconnectInterval time.Duration
dialer Dialer
eventCallback EventCallbackType
lock sync.Mutex
isClosed bool
reconnectCond *sync.Cond
cn *ListenerConn
connNotificationChan <-chan *Notification
channels map[string]struct{}
}
// NewListener creates a new database connection dedicated to LISTEN / NOTIFY.
//
// name should be set to a connection string to be used to establish the
// database connection (see section "Connection String Parameters" above).
//
// minReconnectInterval controls the duration to wait before trying to
// re-establish the database connection after connection loss. After each
// consecutive failure this interval is doubled, until maxReconnectInterval is
// reached. Successfully completing the connection establishment procedure
// resets the interval back to minReconnectInterval.
//
// The last parameter eventCallback can be set to a function which will be
// called by the Listener when the state of the underlying database connection
// changes. This callback will be called by the goroutine which dispatches the
// notifications over the Notify channel, so you should try to avoid doing
// potentially time-consuming operations from the callback.
func NewListener(name string,
minReconnectInterval time.Duration,
maxReconnectInterval time.Duration,
eventCallback EventCallbackType) *Listener {
return NewDialListener(defaultDialer{}, name, minReconnectInterval, maxReconnectInterval, eventCallback)
}
// NewDialListener is like NewListener but it takes a Dialer.
func NewDialListener(d Dialer,
name string,
minReconnectInterval time.Duration,
maxReconnectInterval time.Duration,
eventCallback EventCallbackType) *Listener {
l := &Listener{
name: name,
minReconnectInterval: minReconnectInterval,
maxReconnectInterval: maxReconnectInterval,
dialer: d,
eventCallback: eventCallback,
channels: make(map[string]struct{}),
Notify: make(chan *Notification, 32),
}
l.reconnectCond = sync.NewCond(&l.lock)
go l.listenerMain()
return l
}
// NotificationChannel returns the notification channel for this listener.
// This is the same channel as Notify, and will not be recreated during the
// life time of the Listener.
func (l *Listener) NotificationChannel() <-chan *Notification {
return l.Notify
}
// Listen starts listening for notifications on a channel. Calls to this
// function will block until an acknowledgement has been received from the
// server. Note that Listener automatically re-establishes the connection
// after connection loss, so this function may block indefinitely if the
// connection can not be re-established.
//
// Listen will only fail in three conditions:
// 1. The channel is already open. The returned error will be
// ErrChannelAlreadyOpen.
// 2. The query was executed on the remote server, but PostgreSQL returned an
// error message in response to the query. The returned error will be a
// pq.Error containing the information the server supplied.
// 3. Close is called on the Listener before the request could be completed.
//
// The channel name is case-sensitive.
func (l *Listener) Listen(channel string) error {
l.lock.Lock()
defer l.lock.Unlock()
if l.isClosed {
return errListenerClosed
}
// The server allows you to issue a LISTEN on a channel which is already
// open, but it seems useful to be able to detect this case to spot for
// mistakes in application logic. If the application genuinely does't
// care, it can check the exported error and ignore it.
_, exists := l.channels[channel]
if exists {
return ErrChannelAlreadyOpen
}
if l.cn != nil {
// If gotResponse is true but error is set, the query was executed on
// the remote server, but resulted in an error. This should be
// relatively rare, so it's fine if we just pass the error to our
// caller. However, if gotResponse is false, we could not complete the
// query on the remote server and our underlying connection is about
// to go away, so we only add relname to l.channels, and wait for
// resync() to take care of the rest.
gotResponse, err := l.cn.Listen(channel)
if gotResponse && err != nil {
return err
}
}
l.channels[channel] = struct{}{}
for l.cn == nil {
l.reconnectCond.Wait()
// we let go of the mutex for a while
if l.isClosed {
return errListenerClosed
}
}
return nil
}
// Unlisten removes a channel from the Listener's channel list. Returns
// ErrChannelNotOpen if the Listener is not listening on the specified channel.
// Returns immediately with no error if there is no connection. Note that you
// might still get notifications for this channel even after Unlisten has
// returned.
//
// The channel name is case-sensitive.
func (l *Listener) Unlisten(channel string) error {
l.lock.Lock()
defer l.lock.Unlock()
if l.isClosed {
return errListenerClosed
}
// Similarly to LISTEN, this is not an error in Postgres, but it seems
// useful to distinguish from the normal conditions.
_, exists := l.channels[channel]
if !exists {
return ErrChannelNotOpen
}
if l.cn != nil {
// Similarly to Listen (see comment in that function), the caller
// should only be bothered with an error if it came from the backend as
// a response to our query.
gotResponse, err := l.cn.Unlisten(channel)
if gotResponse && err != nil {
return err
}
}
// Don't bother waiting for resync if there's no connection.
delete(l.channels, channel)
return nil
}
// UnlistenAll removes all channels from the Listener's channel list. Returns
// immediately with no error if there is no connection. Note that you might
// still get notifications for any of the deleted channels even after
// UnlistenAll has returned.
func (l *Listener) UnlistenAll() error {
l.lock.Lock()
defer l.lock.Unlock()
if l.isClosed {
return errListenerClosed
}
if l.cn != nil {
// Similarly to Listen (see comment in that function), the caller
// should only be bothered with an error if it came from the backend as
// a response to our query.
gotResponse, err := l.cn.UnlistenAll()
if gotResponse && err != nil {
return err
}
}
// Don't bother waiting for resync if there's no connection.
l.channels = make(map[string]struct{})
return nil
}
// Ping the remote server to make sure it's alive. Non-nil return value means
// that there is no active connection.
func (l *Listener) Ping() error {
l.lock.Lock()
defer l.lock.Unlock()
if l.isClosed {
return errListenerClosed
}
if l.cn == nil {
return errors.New("no connection")
}
return l.cn.Ping()
}
// Clean up after losing the server connection. Returns l.cn.Err(), which
// should have the reason the connection was lost.
func (l *Listener) disconnectCleanup() error {
l.lock.Lock()
defer l.lock.Unlock()
// sanity check; can't look at Err() until the channel has been closed
select {
case _, ok := <-l.connNotificationChan:
if ok {
panic("connNotificationChan not closed")
}
default:
panic("connNotificationChan not closed")
}
err := l.cn.Err()
l.cn.Close()
l.cn = nil
return err
}
// Synchronize the list of channels we want to be listening on with the server
// after the connection has been established.
func (l *Listener) resync(cn *ListenerConn, notificationChan <-chan *Notification) error {
doneChan := make(chan error)
go func(notificationChan <-chan *Notification) {
for channel := range l.channels {
// If we got a response, return that error to our caller as it's
// going to be more descriptive than cn.Err().
gotResponse, err := cn.Listen(channel)
if gotResponse && err != nil {
doneChan <- err
return
}
// If we couldn't reach the server, wait for notificationChan to
// close and then return the error message from the connection, as
// per ListenerConn's interface.
if err != nil {
for range notificationChan {
}
doneChan <- cn.Err()
return
}
}
doneChan <- nil
}(notificationChan)
// Ignore notifications while synchronization is going on to avoid
// deadlocks. We have to send a nil notification over Notify anyway as
// we can't possibly know which notifications (if any) were lost while
// the connection was down, so there's no reason to try and process
// these messages at all.
for {
select {
case _, ok := <-notificationChan:
if !ok {
notificationChan = nil
}
case err := <-doneChan:
return err
}
}
}
// caller should NOT be holding l.lock
func (l *Listener) closed() bool {
l.lock.Lock()
defer l.lock.Unlock()
return l.isClosed
}
func (l *Listener) connect() error {
notificationChan := make(chan *Notification, 32)
cn, err := newDialListenerConn(l.dialer, l.name, notificationChan)
if err != nil {
return err
}
l.lock.Lock()
defer l.lock.Unlock()
err = l.resync(cn, notificationChan)
if err != nil {
cn.Close()
return err
}
l.cn = cn
l.connNotificationChan = notificationChan
l.reconnectCond.Broadcast()
return nil
}
// Close disconnects the Listener from the database and shuts it down.
// Subsequent calls to its methods will return an error. Close returns an
// error if the connection has already been closed.
func (l *Listener) Close() error {
l.lock.Lock()
defer l.lock.Unlock()
if l.isClosed {
return errListenerClosed
}
if l.cn != nil {
l.cn.Close()
}
l.isClosed = true
// Unblock calls to Listen()
l.reconnectCond.Broadcast()
return nil
}
func (l *Listener) emitEvent(event ListenerEventType, err error) {
if l.eventCallback != nil {
l.eventCallback(event, err)
}
}
// Main logic here: maintain a connection to the server when possible, wait
// for notifications and emit events.
func (l *Listener) listenerConnLoop() {
var nextReconnect time.Time
reconnectInterval := l.minReconnectInterval
for {
for {
err := l.connect()
if err == nil {
break
}
if l.closed() {
return
}
l.emitEvent(ListenerEventConnectionAttemptFailed, err)
time.Sleep(reconnectInterval)
reconnectInterval *= 2
if reconnectInterval > l.maxReconnectInterval {
reconnectInterval = l.maxReconnectInterval
}
}
if nextReconnect.IsZero() {
l.emitEvent(ListenerEventConnected, nil)
} else {
l.emitEvent(ListenerEventReconnected, nil)
l.Notify <- nil
}
reconnectInterval = l.minReconnectInterval
nextReconnect = time.Now().Add(reconnectInterval)
for {
notification, ok := <-l.connNotificationChan
if !ok {
// lost connection, loop again
break
}
l.Notify <- notification
}
err := l.disconnectCleanup()
if l.closed() {
return
}
l.emitEvent(ListenerEventDisconnected, err)
time.Sleep(time.Until(nextReconnect))
}
}
func (l *Listener) listenerMain() {
l.listenerConnLoop()
close(l.Notify)
}
================================================
FILE: gossh/pgsql/oid/doc.go
================================================
// Package oid contains OID constants
// as defined by the Postgres server.
package oid
// Oid is a Postgres Object ID.
type Oid uint32
================================================
FILE: gossh/pgsql/oid/gen.go
================================================
//go:build ignore
// +build ignore
// Generate the table of OID values
// Run with 'go run gen.go'.
package main
import (
"database/sql"
"fmt"
"log"
"os"
"os/exec"
"strings"
_ "gossh/pgsql"
)
// OID represent a postgres Object Identifier Type.
type OID struct {
ID int
Type string
}
// Name returns an upper case version of the oid type.
func (o OID) Name() string {
return strings.ToUpper(o.Type)
}
func main() {
datname := os.Getenv("PGDATABASE")
sslmode := os.Getenv("PGSSLMODE")
if datname == "" {
os.Setenv("PGDATABASE", "pqgotest")
}
if sslmode == "" {
os.Setenv("PGSSLMODE", "disable")
}
db, err := sql.Open("postgres", "")
if err != nil {
log.Fatal(err)
}
rows, err := db.Query(`
SELECT typname, oid
FROM pg_type WHERE oid < 10000
ORDER BY oid;
`)
if err != nil {
log.Fatal(err)
}
oids := make([]*OID, 0)
for rows.Next() {
var oid OID
if err = rows.Scan(&oid.Type, &oid.ID); err != nil {
log.Fatal(err)
}
oids = append(oids, &oid)
}
if err = rows.Err(); err != nil {
log.Fatal(err)
}
cmd := exec.Command("gofmt")
cmd.Stderr = os.Stderr
w, err := cmd.StdinPipe()
if err != nil {
log.Fatal(err)
}
f, err := os.Create("types.go")
if err != nil {
log.Fatal(err)
}
cmd.Stdout = f
err = cmd.Start()
if err != nil {
log.Fatal(err)
}
fmt.Fprintln(w, "// Code generated by gen.go. DO NOT EDIT.")
fmt.Fprintln(w, "\npackage oid")
fmt.Fprintln(w, "const (")
for _, oid := range oids {
fmt.Fprintf(w, "T_%s Oid = %d\n", oid.Type, oid.ID)
}
fmt.Fprintln(w, ")")
fmt.Fprintln(w, "var TypeName = map[Oid]string{")
for _, oid := range oids {
fmt.Fprintf(w, "T_%s: \"%s\",\n", oid.Type, oid.Name())
}
fmt.Fprintln(w, "}")
w.Close()
cmd.Wait()
}
================================================
FILE: gossh/pgsql/oid/types.go
================================================
// Code generated by gen.go. DO NOT EDIT.
package oid
const (
T_bool Oid = 16
T_bytea Oid = 17
T_char Oid = 18
T_name Oid = 19
T_int8 Oid = 20
T_int2 Oid = 21
T_int2vector Oid = 22
T_int4 Oid = 23
T_regproc Oid = 24
T_text Oid = 25
T_oid Oid = 26
T_tid Oid = 27
T_xid Oid = 28
T_cid Oid = 29
T_oidvector Oid = 30
T_pg_ddl_command Oid = 32
T_pg_type Oid = 71
T_pg_attribute Oid = 75
T_pg_proc Oid = 81
T_pg_class Oid = 83
T_json Oid = 114
T_xml Oid = 142
T__xml Oid = 143
T_pg_node_tree Oid = 194
T__json Oid = 199
T_smgr Oid = 210
T_index_am_handler Oid = 325
T_point Oid = 600
T_lseg Oid = 601
T_path Oid = 602
T_box Oid = 603
T_polygon Oid = 604
T_line Oid = 628
T__line Oid = 629
T_cidr Oid = 650
T__cidr Oid = 651
T_float4 Oid = 700
T_float8 Oid = 701
T_abstime Oid = 702
T_reltime Oid = 703
T_tinterval Oid = 704
T_unknown Oid = 705
T_circle Oid = 718
T__circle Oid = 719
T_money Oid = 790
T__money Oid = 791
T_macaddr Oid = 829
T_inet Oid = 869
T__bool Oid = 1000
T__bytea Oid = 1001
T__char Oid = 1002
T__name Oid = 1003
T__int2 Oid = 1005
T__int2vector Oid = 1006
T__int4 Oid = 1007
T__regproc Oid = 1008
T__text Oid = 1009
T__tid Oid = 1010
T__xid Oid = 1011
T__cid Oid = 1012
T__oidvector Oid = 1013
T__bpchar Oid = 1014
T__varchar Oid = 1015
T__int8 Oid = 1016
T__point Oid = 1017
T__lseg Oid = 1018
T__path Oid = 1019
T__box Oid = 1020
T__float4 Oid = 1021
T__float8 Oid = 1022
T__abstime Oid = 1023
T__reltime Oid = 1024
T__tinterval Oid = 1025
T__polygon Oid = 1027
T__oid Oid = 1028
T_aclitem Oid = 1033
T__aclitem Oid = 1034
T__macaddr Oid = 1040
T__inet Oid = 1041
T_bpchar Oid = 1042
T_varchar Oid = 1043
T_date Oid = 1082
T_time Oid = 1083
T_timestamp Oid = 1114
T__timestamp Oid = 1115
T__date Oid = 1182
T__time Oid = 1183
T_timestamptz Oid = 1184
T__timestamptz Oid = 1185
T_interval Oid = 1186
T__interval Oid = 1187
T__numeric Oid = 1231
T_pg_database Oid = 1248
T__cstring Oid = 1263
T_timetz Oid = 1266
T__timetz Oid = 1270
T_bit Oid = 1560
T__bit Oid = 1561
T_varbit Oid = 1562
T__varbit Oid = 1563
T_numeric Oid = 1700
T_refcursor Oid = 1790
T__refcursor Oid = 2201
T_regprocedure Oid = 2202
T_regoper Oid = 2203
T_regoperator Oid = 2204
T_regclass Oid = 2205
T_regtype Oid = 2206
T__regprocedure Oid = 2207
T__regoper Oid = 2208
T__regoperator Oid = 2209
T__regclass Oid = 2210
T__regtype Oid = 2211
T_record Oid = 2249
T_cstring Oid = 2275
T_any Oid = 2276
T_anyarray Oid = 2277
T_void Oid = 2278
T_trigger Oid = 2279
T_language_handler Oid = 2280
T_internal Oid = 2281
T_opaque Oid = 2282
T_anyelement Oid = 2283
T__record Oid = 2287
T_anynonarray Oid = 2776
T_pg_authid Oid = 2842
T_pg_auth_members Oid = 2843
T__txid_snapshot Oid = 2949
T_uuid Oid = 2950
T__uuid Oid = 2951
T_txid_snapshot Oid = 2970
T_fdw_handler Oid = 3115
T_pg_lsn Oid = 3220
T__pg_lsn Oid = 3221
T_tsm_handler Oid = 3310
T_anyenum Oid = 3500
T_tsvector Oid = 3614
T_tsquery Oid = 3615
T_gtsvector Oid = 3642
T__tsvector Oid = 3643
T__gtsvector Oid = 3644
T__tsquery Oid = 3645
T_regconfig Oid = 3734
T__regconfig Oid = 3735
T_regdictionary Oid = 3769
T__regdictionary Oid = 3770
T_jsonb Oid = 3802
T__jsonb Oid = 3807
T_anyrange Oid = 3831
T_event_trigger Oid = 3838
T_int4range Oid = 3904
T__int4range Oid = 3905
T_numrange Oid = 3906
T__numrange Oid = 3907
T_tsrange Oid = 3908
T__tsrange Oid = 3909
T_tstzrange Oid = 3910
T__tstzrange Oid = 3911
T_daterange Oid = 3912
T__daterange Oid = 3913
T_int8range Oid = 3926
T__int8range Oid = 3927
T_pg_shseclabel Oid = 4066
T_regnamespace Oid = 4089
T__regnamespace Oid = 4090
T_regrole Oid = 4096
T__regrole Oid = 4097
)
var TypeName = map[Oid]string{
T_bool: "BOOL",
T_bytea: "BYTEA",
T_char: "CHAR",
T_name: "NAME",
T_int8: "INT8",
T_int2: "INT2",
T_int2vector: "INT2VECTOR",
T_int4: "INT4",
T_regproc: "REGPROC",
T_text: "TEXT",
T_oid: "OID",
T_tid: "TID",
T_xid: "XID",
T_cid: "CID",
T_oidvector: "OIDVECTOR",
T_pg_ddl_command: "PG_DDL_COMMAND",
T_pg_type: "PG_TYPE",
T_pg_attribute: "PG_ATTRIBUTE",
T_pg_proc: "PG_PROC",
T_pg_class: "PG_CLASS",
T_json: "JSON",
T_xml: "XML",
T__xml: "_XML",
T_pg_node_tree: "PG_NODE_TREE",
T__json: "_JSON",
T_smgr: "SMGR",
T_index_am_handler: "INDEX_AM_HANDLER",
T_point: "POINT",
T_lseg: "LSEG",
T_path: "PATH",
T_box: "BOX",
T_polygon: "POLYGON",
T_line: "LINE",
T__line: "_LINE",
T_cidr: "CIDR",
T__cidr: "_CIDR",
T_float4: "FLOAT4",
T_float8: "FLOAT8",
T_abstime: "ABSTIME",
T_reltime: "RELTIME",
T_tinterval: "TINTERVAL",
T_unknown: "UNKNOWN",
T_circle: "CIRCLE",
T__circle: "_CIRCLE",
T_money: "MONEY",
T__money: "_MONEY",
T_macaddr: "MACADDR",
T_inet: "INET",
T__bool: "_BOOL",
T__bytea: "_BYTEA",
T__char: "_CHAR",
T__name: "_NAME",
T__int2: "_INT2",
T__int2vector: "_INT2VECTOR",
T__int4: "_INT4",
T__regproc: "_REGPROC",
T__text: "_TEXT",
T__tid: "_TID",
T__xid: "_XID",
T__cid: "_CID",
T__oidvector: "_OIDVECTOR",
T__bpchar: "_BPCHAR",
T__varchar: "_VARCHAR",
T__int8: "_INT8",
T__point: "_POINT",
T__lseg: "_LSEG",
T__path: "_PATH",
T__box: "_BOX",
T__float4: "_FLOAT4",
T__float8: "_FLOAT8",
T__abstime: "_ABSTIME",
T__reltime: "_RELTIME",
T__tinterval: "_TINTERVAL",
T__polygon: "_POLYGON",
T__oid: "_OID",
T_aclitem: "ACLITEM",
T__aclitem: "_ACLITEM",
T__macaddr: "_MACADDR",
T__inet: "_INET",
T_bpchar: "BPCHAR",
T_varchar: "VARCHAR",
T_date: "DATE",
T_time: "TIME",
T_timestamp: "TIMESTAMP",
T__timestamp: "_TIMESTAMP",
T__date: "_DATE",
T__time: "_TIME",
T_timestamptz: "TIMESTAMPTZ",
T__timestamptz: "_TIMESTAMPTZ",
T_interval: "INTERVAL",
T__interval: "_INTERVAL",
T__numeric: "_NUMERIC",
T_pg_database: "PG_DATABASE",
T__cstring: "_CSTRING",
T_timetz: "TIMETZ",
T__timetz: "_TIMETZ",
T_bit: "BIT",
T__bit: "_BIT",
T_varbit: "VARBIT",
T__varbit: "_VARBIT",
T_numeric: "NUMERIC",
T_refcursor: "REFCURSOR",
T__refcursor: "_REFCURSOR",
T_regprocedure: "REGPROCEDURE",
T_regoper: "REGOPER",
T_regoperator: "REGOPERATOR",
T_regclass: "REGCLASS",
T_regtype: "REGTYPE",
T__regprocedure: "_REGPROCEDURE",
T__regoper: "_REGOPER",
T__regoperator: "_REGOPERATOR",
T__regclass: "_REGCLASS",
T__regtype: "_REGTYPE",
T_record: "RECORD",
T_cstring: "CSTRING",
T_any: "ANY",
T_anyarray: "ANYARRAY",
T_void: "VOID",
T_trigger: "TRIGGER",
T_language_handler: "LANGUAGE_HANDLER",
T_internal: "INTERNAL",
T_opaque: "OPAQUE",
T_anyelement: "ANYELEMENT",
T__record: "_RECORD",
T_anynonarray: "ANYNONARRAY",
T_pg_authid: "PG_AUTHID",
T_pg_auth_members: "PG_AUTH_MEMBERS",
T__txid_snapshot: "_TXID_SNAPSHOT",
T_uuid: "UUID",
T__uuid: "_UUID",
T_txid_snapshot: "TXID_SNAPSHOT",
T_fdw_handler: "FDW_HANDLER",
T_pg_lsn: "PG_LSN",
T__pg_lsn: "_PG_LSN",
T_tsm_handler: "TSM_HANDLER",
T_anyenum: "ANYENUM",
T_tsvector: "TSVECTOR",
T_tsquery: "TSQUERY",
T_gtsvector: "GTSVECTOR",
T__tsvector: "_TSVECTOR",
T__gtsvector: "_GTSVECTOR",
T__tsquery: "_TSQUERY",
T_regconfig: "REGCONFIG",
T__regconfig: "_REGCONFIG",
T_regdictionary: "REGDICTIONARY",
T__regdictionary: "_REGDICTIONARY",
T_jsonb: "JSONB",
T__jsonb: "_JSONB",
T_anyrange: "ANYRANGE",
T_event_trigger: "EVENT_TRIGGER",
T_int4range: "INT4RANGE",
T__int4range: "_INT4RANGE",
T_numrange: "NUMRANGE",
T__numrange: "_NUMRANGE",
T_tsrange: "TSRANGE",
T__tsrange: "_TSRANGE",
T_tstzrange: "TSTZRANGE",
T__tstzrange: "_TSTZRANGE",
T_daterange: "DATERANGE",
T__daterange: "_DATERANGE",
T_int8range: "INT8RANGE",
T__int8range: "_INT8RANGE",
T_pg_shseclabel: "PG_SHSECLABEL",
T_regnamespace: "REGNAMESPACE",
T__regnamespace: "_REGNAMESPACE",
T_regrole: "REGROLE",
T__regrole: "_REGROLE",
}
================================================
FILE: gossh/pgsql/rows.go
================================================
package pgsql
import (
"math"
"reflect"
"time"
"gossh/pgsql/oid"
)
const headerSize = 4
type fieldDesc struct {
// The object ID of the data type.
OID oid.Oid
// The data type size (see pg_type.typlen).
// Note that negative values denote variable-width types.
Len int
// The type modifier (see pg_attribute.atttypmod).
// The meaning of the modifier is type-specific.
Mod int
}
func (fd fieldDesc) Type() reflect.Type {
switch fd.OID {
case oid.T_int8:
return reflect.TypeOf(int64(0))
case oid.T_int4:
return reflect.TypeOf(int32(0))
case oid.T_int2:
return reflect.TypeOf(int16(0))
case oid.T_varchar, oid.T_text:
return reflect.TypeOf("")
case oid.T_bool:
return reflect.TypeOf(false)
case oid.T_date, oid.T_time, oid.T_timetz, oid.T_timestamp, oid.T_timestamptz:
return reflect.TypeOf(time.Time{})
case oid.T_bytea:
return reflect.TypeOf([]byte(nil))
default:
return reflect.TypeOf(new(any)).Elem()
}
}
func (fd fieldDesc) Name() string {
return oid.TypeName[fd.OID]
}
func (fd fieldDesc) Length() (length int64, ok bool) {
switch fd.OID {
case oid.T_text, oid.T_bytea:
return math.MaxInt64, true
case oid.T_varchar, oid.T_bpchar:
return int64(fd.Mod - headerSize), true
default:
return 0, false
}
}
func (fd fieldDesc) PrecisionScale() (precision, scale int64, ok bool) {
switch fd.OID {
case oid.T_numeric, oid.T__numeric:
mod := fd.Mod - headerSize
precision = int64((mod >> 16) & 0xffff)
scale = int64(mod & 0xffff)
return precision, scale, true
default:
return 0, 0, false
}
}
// ColumnTypeScanType returns the value type that can be used to scan types into.
func (rs *rows) ColumnTypeScanType(index int) reflect.Type {
return rs.colTyps[index].Type()
}
// ColumnTypeDatabaseTypeName return the database system type name.
func (rs *rows) ColumnTypeDatabaseTypeName(index int) string {
return rs.colTyps[index].Name()
}
// ColumnTypeLength returns the length of the column type if the column is a
// variable length type. If the column is not a variable length type ok
// should return false.
func (rs *rows) ColumnTypeLength(index int) (length int64, ok bool) {
return rs.colTyps[index].Length()
}
// ColumnTypePrecisionScale should return the precision and scale for decimal
// types. If not applicable, ok should be false.
func (rs *rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) {
return rs.colTyps[index].PrecisionScale()
}
================================================
FILE: gossh/pgsql/scram/scram.go
================================================
// Copyright (c) 2014 - Gustavo Niemeyer
//
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Package scram implements a SCRAM-{SHA-1,etc} client per RFC5802.
//
// http://tools.ietf.org/html/rfc5802
package scram
import (
"bytes"
"crypto/hmac"
"crypto/rand"
"encoding/base64"
"fmt"
"hash"
"strconv"
"strings"
)
// Client implements a SCRAM-* client (SCRAM-SHA-1, SCRAM-SHA-256, etc).
//
// A Client may be used within a SASL conversation with logic resembling:
//
// var in []byte
// var client = scram.NewClient(sha1.New, user, pass)
// for client.Step(in) {
// out := client.Out()
// // send out to server
// in := serverOut
// }
// if client.Err() != nil {
// // auth failed
// }
type Client struct {
newHash func() hash.Hash
user string
pass string
step int
out bytes.Buffer
err error
clientNonce []byte
serverNonce []byte
saltedPass []byte
authMsg bytes.Buffer
}
// NewClient returns a new SCRAM-* client with the provided hash algorithm.
//
// For SCRAM-SHA-256, for example, use:
//
// client := scram.NewClient(sha256.New, user, pass)
func NewClient(newHash func() hash.Hash, user, pass string) *Client {
c := &Client{
newHash: newHash,
user: user,
pass: pass,
}
c.out.Grow(256)
c.authMsg.Grow(256)
return c
}
// Out returns the data to be sent to the server in the current step.
func (c *Client) Out() []byte {
if c.out.Len() == 0 {
return nil
}
return c.out.Bytes()
}
// Err returns the error that occurred, or nil if there were no errors.
func (c *Client) Err() error {
return c.err
}
// SetNonce sets the client nonce to the provided value.
// If not set, the nonce is generated automatically out of crypto/rand on the first step.
func (c *Client) SetNonce(nonce []byte) {
c.clientNonce = nonce
}
var escaper = strings.NewReplacer("=", "=3D", ",", "=2C")
// Step processes the incoming data from the server and makes the
// next round of data for the server available via Client.Out.
// Step returns false if there are no errors and more data is
// still expected.
func (c *Client) Step(in []byte) bool {
c.out.Reset()
if c.step > 2 || c.err != nil {
return false
}
c.step++
switch c.step {
case 1:
c.err = c.step1(in)
case 2:
c.err = c.step2(in)
case 3:
c.err = c.step3(in)
}
return c.step > 2 || c.err != nil
}
func (c *Client) step1(in []byte) error {
if len(c.clientNonce) == 0 {
const nonceLen = 16
buf := make([]byte, nonceLen+b64.EncodedLen(nonceLen))
if _, err := rand.Read(buf[:nonceLen]); err != nil {
return fmt.Errorf("cannot read random SCRAM-SHA-256 nonce from operating system: %v", err)
}
c.clientNonce = buf[nonceLen:]
b64.Encode(c.clientNonce, buf[:nonceLen])
}
c.authMsg.WriteString("n=")
escaper.WriteString(&c.authMsg, c.user)
c.authMsg.WriteString(",r=")
c.authMsg.Write(c.clientNonce)
c.out.WriteString("n,,")
c.out.Write(c.authMsg.Bytes())
return nil
}
var b64 = base64.StdEncoding
func (c *Client) step2(in []byte) error {
c.authMsg.WriteByte(',')
c.authMsg.Write(in)
fields := bytes.Split(in, []byte(","))
if len(fields) != 3 {
return fmt.Errorf("expected 3 fields in first SCRAM-SHA-256 server message, got %d: %q", len(fields), in)
}
if !bytes.HasPrefix(fields[0], []byte("r=")) || len(fields[0]) < 2 {
return fmt.Errorf("server sent an invalid SCRAM-SHA-256 nonce: %q", fields[0])
}
if !bytes.HasPrefix(fields[1], []byte("s=")) || len(fields[1]) < 6 {
return fmt.Errorf("server sent an invalid SCRAM-SHA-256 salt: %q", fields[1])
}
if !bytes.HasPrefix(fields[2], []byte("i=")) || len(fields[2]) < 6 {
return fmt.Errorf("server sent an invalid SCRAM-SHA-256 iteration count: %q", fields[2])
}
c.serverNonce = fields[0][2:]
if !bytes.HasPrefix(c.serverNonce, c.clientNonce) {
return fmt.Errorf("server SCRAM-SHA-256 nonce is not prefixed by client nonce: got %q, want %q+\"...\"", c.serverNonce, c.clientNonce)
}
salt := make([]byte, b64.DecodedLen(len(fields[1][2:])))
n, err := b64.Decode(salt, fields[1][2:])
if err != nil {
return fmt.Errorf("cannot decode SCRAM-SHA-256 salt sent by server: %q", fields[1])
}
salt = salt[:n]
iterCount, err := strconv.Atoi(string(fields[2][2:]))
if err != nil {
return fmt.Errorf("server sent an invalid SCRAM-SHA-256 iteration count: %q", fields[2])
}
c.saltPassword(salt, iterCount)
c.authMsg.WriteString(",c=biws,r=")
c.authMsg.Write(c.serverNonce)
c.out.WriteString("c=biws,r=")
c.out.Write(c.serverNonce)
c.out.WriteString(",p=")
c.out.Write(c.clientProof())
return nil
}
func (c *Client) step3(in []byte) error {
var isv, ise bool
var fields = bytes.Split(in, []byte(","))
if len(fields) == 1 {
isv = bytes.HasPrefix(fields[0], []byte("v="))
ise = bytes.HasPrefix(fields[0], []byte("e="))
}
if ise {
return fmt.Errorf("SCRAM-SHA-256 authentication error: %s", fields[0][2:])
} else if !isv {
return fmt.Errorf("unsupported SCRAM-SHA-256 final message from server: %q", in)
}
if !bytes.Equal(c.serverSignature(), fields[0][2:]) {
return fmt.Errorf("cannot authenticate SCRAM-SHA-256 server signature: %q", fields[0][2:])
}
return nil
}
func (c *Client) saltPassword(salt []byte, iterCount int) {
mac := hmac.New(c.newHash, []byte(c.pass))
mac.Write(salt)
mac.Write([]byte{0, 0, 0, 1})
ui := mac.Sum(nil)
hi := make([]byte, len(ui))
copy(hi, ui)
for i := 1; i < iterCount; i++ {
mac.Reset()
mac.Write(ui)
mac.Sum(ui[:0])
for j, b := range ui {
hi[j] ^= b
}
}
c.saltedPass = hi
}
func (c *Client) clientProof() []byte {
mac := hmac.New(c.newHash, c.saltedPass)
mac.Write([]byte("Client Key"))
clientKey := mac.Sum(nil)
hash := c.newHash()
hash.Write(clientKey)
storedKey := hash.Sum(nil)
mac = hmac.New(c.newHash, storedKey)
mac.Write(c.authMsg.Bytes())
clientProof := mac.Sum(nil)
for i, b := range clientKey {
clientProof[i] ^= b
}
clientProof64 := make([]byte, b64.EncodedLen(len(clientProof)))
b64.Encode(clientProof64, clientProof)
return clientProof64
}
func (c *Client) serverSignature() []byte {
mac := hmac.New(c.newHash, c.saltedPass)
mac.Write([]byte("Server Key"))
serverKey := mac.Sum(nil)
mac = hmac.New(c.newHash, serverKey)
mac.Write(c.authMsg.Bytes())
serverSignature := mac.Sum(nil)
encoded := make([]byte, b64.EncodedLen(len(serverSignature)))
b64.Encode(encoded, serverSignature)
return encoded
}
================================================
FILE: gossh/pgsql/ssl.go
================================================
package pgsql
import (
"crypto/tls"
"crypto/x509"
"io/ioutil"
"net"
"os"
"os/user"
"path/filepath"
"strings"
)
// ssl generates a function to upgrade a net.Conn based on the "sslmode" and
// related settings. The function is nil when no upgrade should take place.
func ssl(o values) (func(net.Conn) (net.Conn, error), error) {
verifyCaOnly := false
tlsConf := tls.Config{}
switch mode := o["sslmode"]; mode {
// "require" is the default.
case "", "require":
// We must skip TLS's own verification since it requires full
// verification since Go 1.3.
tlsConf.InsecureSkipVerify = true
// From http://www.postgresql.org/docs/current/static/libpq-ssl.html:
//
// Note: For backwards compatibility with earlier versions of
// PostgreSQL, if a root CA file exists, the behavior of
// sslmode=require will be the same as that of verify-ca, meaning the
// server certificate is validated against the CA. Relying on this
// behavior is discouraged, and applications that need certificate
// validation should always use verify-ca or verify-full.
if sslrootcert, ok := o["sslrootcert"]; ok {
if _, err := os.Stat(sslrootcert); err == nil {
verifyCaOnly = true
} else {
delete(o, "sslrootcert")
}
}
case "verify-ca":
// We must skip TLS's own verification since it requires full
// verification since Go 1.3.
tlsConf.InsecureSkipVerify = true
verifyCaOnly = true
case "verify-full":
tlsConf.ServerName = o["host"]
case "disable":
return nil, nil
default:
return nil, fmterrorf(`unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported`, mode)
}
// Set Server Name Indication (SNI), if enabled by connection parameters.
// By default SNI is on, any value which is not starting with "1" disables
// SNI -- that is the same check vanilla libpq uses.
if sslsni := o["sslsni"]; sslsni == "" || strings.HasPrefix(sslsni, "1") {
// RFC 6066 asks to not set SNI if the host is a literal IP address (IPv4
// or IPv6). This check is coded already crypto.tls.hostnameInSNI, so
// just always set ServerName here and let crypto/tls do the filtering.
tlsConf.ServerName = o["host"]
}
err := sslClientCertificates(&tlsConf, o)
if err != nil {
return nil, err
}
err = sslCertificateAuthority(&tlsConf, o)
if err != nil {
return nil, err
}
// Accept renegotiation requests initiated by the backend.
//
// Renegotiation was deprecated then removed from PostgreSQL 9.5, but
// the default configuration of older versions has it enabled. Redshift
// also initiates renegotiations and cannot be reconfigured.
tlsConf.Renegotiation = tls.RenegotiateFreelyAsClient
return func(conn net.Conn) (net.Conn, error) {
client := tls.Client(conn, &tlsConf)
if verifyCaOnly {
err := sslVerifyCertificateAuthority(client, &tlsConf)
if err != nil {
return nil, err
}
}
return client, nil
}, nil
}
// sslClientCertificates adds the certificate specified in the "sslcert" and
// "sslkey" settings, or if they aren't set, from the .postgresql directory
// in the user's home directory. The configured files must exist and have
// the correct permissions.
func sslClientCertificates(tlsConf *tls.Config, o values) error {
sslinline := o["sslinline"]
if sslinline == "true" {
cert, err := tls.X509KeyPair([]byte(o["sslcert"]), []byte(o["sslkey"]))
if err != nil {
return err
}
tlsConf.Certificates = []tls.Certificate{cert}
return nil
}
// user.Current() might fail when cross-compiling. We have to ignore the
// error and continue without home directory defaults, since we wouldn't
// know from where to load them.
user, _ := user.Current()
// In libpq, the client certificate is only loaded if the setting is not blank.
//
// https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1036-L1037
sslcert := o["sslcert"]
if len(sslcert) == 0 && user != nil {
sslcert = filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt")
}
// https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1045
if len(sslcert) == 0 {
return nil
}
// https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1050:L1054
if _, err := os.Stat(sslcert); os.IsNotExist(err) {
return nil
} else if err != nil {
return err
}
// In libpq, the ssl key is only loaded if the setting is not blank.
//
// https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1123-L1222
sslkey := o["sslkey"]
if len(sslkey) == 0 && user != nil {
sslkey = filepath.Join(user.HomeDir, ".postgresql", "postgresql.key")
}
if len(sslkey) > 0 {
if err := sslKeyPermissions(sslkey); err != nil {
return err
}
}
cert, err := tls.LoadX509KeyPair(sslcert, sslkey)
if err != nil {
return err
}
tlsConf.Certificates = []tls.Certificate{cert}
return nil
}
// sslCertificateAuthority adds the RootCA specified in the "sslrootcert" setting.
func sslCertificateAuthority(tlsConf *tls.Config, o values) error {
// In libpq, the root certificate is only loaded if the setting is not blank.
//
// https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L950-L951
if sslrootcert := o["sslrootcert"]; len(sslrootcert) > 0 {
tlsConf.RootCAs = x509.NewCertPool()
sslinline := o["sslinline"]
var cert []byte
if sslinline == "true" {
cert = []byte(sslrootcert)
} else {
var err error
cert, err = ioutil.ReadFile(sslrootcert)
if err != nil {
return err
}
}
if !tlsConf.RootCAs.AppendCertsFromPEM(cert) {
return fmterrorf("couldn't parse pem in sslrootcert")
}
}
return nil
}
// sslVerifyCertificateAuthority carries out a TLS handshake to the server and
// verifies the presented certificate against the CA, i.e. the one specified in
// sslrootcert or the system CA if sslrootcert was not specified.
func sslVerifyCertificateAuthority(client *tls.Conn, tlsConf *tls.Config) error {
err := client.Handshake()
if err != nil {
return err
}
certs := client.ConnectionState().PeerCertificates
opts := x509.VerifyOptions{
DNSName: client.ConnectionState().ServerName,
Intermediates: x509.NewCertPool(),
Roots: tlsConf.RootCAs,
}
for i, cert := range certs {
if i == 0 {
continue
}
opts.Intermediates.AddCert(cert)
}
_, err = certs[0].Verify(opts)
return err
}
================================================
FILE: gossh/pgsql/ssl_permissions.go
================================================
//go:build !windows
// +build !windows
package pgsql
import (
"errors"
"os"
"syscall"
)
const (
rootUserID = uint32(0)
// The maximum permissions that a private key file owned by a regular user
// is allowed to have. This translates to u=rw.
maxUserOwnedKeyPermissions os.FileMode = 0600
// The maximum permissions that a private key file owned by root is allowed
// to have. This translates to u=rw,g=r.
maxRootOwnedKeyPermissions os.FileMode = 0640
)
var (
errSSLKeyHasUnacceptableUserPermissions = errors.New("permissions for files not owned by root should be u=rw (0600) or less")
errSSLKeyHasUnacceptableRootPermissions = errors.New("permissions for root owned files should be u=rw,g=r (0640) or less")
)
// sslKeyPermissions checks the permissions on user-supplied ssl key files.
// The key file should have very little access.
//
// libpq does not check key file permissions on Windows.
func sslKeyPermissions(sslkey string) error {
info, err := os.Stat(sslkey)
if err != nil {
return err
}
err = hasCorrectPermissions(info)
// return ErrSSLKeyHasWorldPermissions for backwards compatability with
// existing code.
if err == errSSLKeyHasUnacceptableUserPermissions || err == errSSLKeyHasUnacceptableRootPermissions {
err = ErrSSLKeyHasWorldPermissions
}
return err
}
// hasCorrectPermissions checks the file info (and the unix-specific stat_t
// output) to verify that the permissions on the file are correct.
//
// If the file is owned by the same user the process is running as,
// the file should only have 0600 (u=rw). If the file is owned by root,
// and the group matches the group that the process is running in, the
// permissions cannot be more than 0640 (u=rw,g=r). The file should
// never have world permissions.
//
// Returns an error when the permission check fails.
func hasCorrectPermissions(info os.FileInfo) error {
// if file's permission matches 0600, allow access.
userPermissionMask := (os.FileMode(0777) ^ maxUserOwnedKeyPermissions)
// regardless of if we're running as root or not, 0600 is acceptable,
// so we return if we match the regular user permission mask.
if info.Mode().Perm()&userPermissionMask == 0 {
return nil
}
// We need to pull the Unix file information to get the file's owner.
// If we can't access it, there's some sort of operating system level error
// and we should fail rather than attempting to use faulty information.
sysInfo := info.Sys()
if sysInfo == nil {
return ErrSSLKeyUnknownOwnership
}
unixStat, ok := sysInfo.(*syscall.Stat_t)
if !ok {
return ErrSSLKeyUnknownOwnership
}
// if the file is owned by root, we allow 0640 (u=rw,g=r) to match what
// Postgres does.
if unixStat.Uid == rootUserID {
rootPermissionMask := (os.FileMode(0777) ^ maxRootOwnedKeyPermissions)
if info.Mode().Perm()&rootPermissionMask != 0 {
return errSSLKeyHasUnacceptableRootPermissions
}
return nil
}
return errSSLKeyHasUnacceptableUserPermissions
}
================================================
FILE: gossh/pgsql/ssl_windows.go
================================================
//go:build windows
// +build windows
package pgsql
// sslKeyPermissions checks the permissions on user-supplied ssl key files.
// The key file should have very little access.
//
// libpq does not check key file permissions on Windows.
func sslKeyPermissions(string) error { return nil }
================================================
FILE: gossh/pgsql/url.go
================================================
package pgsql
import (
"fmt"
"net"
nurl "net/url"
"sort"
"strings"
)
// ParseURL no longer needs to be used by clients of this library since supplying a URL as a
// connection string to sql.Open() is now supported:
//
// sql.Open("postgres", "postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full")
//
// It remains exported here for backwards-compatibility.
//
// ParseURL converts a url to a connection string for driver.Open.
// Example:
//
// "postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full"
//
// converts to:
//
// "user=bob password=secret host=1.2.3.4 port=5432 dbname=mydb sslmode=verify-full"
//
// A minimal example:
//
// "postgres://"
//
// This will be blank, causing driver.Open to use all of the defaults
func ParseURL(url string) (string, error) {
u, err := nurl.Parse(url)
if err != nil {
return "", err
}
if u.Scheme != "postgres" && u.Scheme != "postgresql" {
return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme)
}
var kvs []string
escaper := strings.NewReplacer(`'`, `\'`, `\`, `\\`)
accrue := func(k, v string) {
if v != "" {
kvs = append(kvs, k+"='"+escaper.Replace(v)+"'")
}
}
if u.User != nil {
v := u.User.Username()
accrue("user", v)
v, _ = u.User.Password()
accrue("password", v)
}
if host, port, err := net.SplitHostPort(u.Host); err != nil {
accrue("host", u.Host)
} else {
accrue("host", host)
accrue("port", port)
}
if u.Path != "" {
accrue("dbname", u.Path[1:])
}
q := u.Query()
for k := range q {
accrue(k, q.Get(k))
}
sort.Strings(kvs) // Makes testing easier (not a performance concern)
return strings.Join(kvs, " "), nil
}
================================================
FILE: gossh/pgsql/user_other.go
================================================
// Package pq is a pure Go Postgres driver for the database/sql package.
//go:build js || android || hurd || zos
// +build js android hurd zos
package pgsql
func userCurrent() (string, error) {
return "", ErrCouldNotDetectUsername
}
================================================
FILE: gossh/pgsql/user_posix.go
================================================
// Package pq is a pure Go Postgres driver for the database/sql package.
//go:build aix || darwin || dragonfly || freebsd || (linux && !android) || nacl || netbsd || openbsd || plan9 || solaris || rumprun || illumos
// +build aix darwin dragonfly freebsd linux,!android nacl netbsd openbsd plan9 solaris rumprun illumos
package pgsql
import (
"os"
"os/user"
)
func userCurrent() (string, error) {
u, err := user.Current()
if err == nil {
return u.Username, nil
}
name := os.Getenv("USER")
if name != "" {
return name, nil
}
return "", ErrCouldNotDetectUsername
}
================================================
FILE: gossh/pgsql/user_windows.go
================================================
// Package pq is a pure Go Postgres driver for the database/sql package.
package pgsql
import (
"path/filepath"
"syscall"
)
// Perform Windows user name lookup identically to libpq.
//
// The PostgreSQL code makes use of the legacy Win32 function
// GetUserName, and that function has not been imported into stock Go.
// GetUserNameEx is available though, the difference being that a
// wider range of names are available. To get the output to be the
// same as GetUserName, only the base (or last) component of the
// result is returned.
func userCurrent() (string, error) {
pw_name := make([]uint16, 128)
pwname_size := uint32(len(pw_name)) - 1
err := syscall.GetUserNameEx(syscall.NameSamCompatible, &pw_name[0], &pwname_size)
if err != nil {
return "", ErrCouldNotDetectUsername
}
s := syscall.UTF16ToString(pw_name)
u := filepath.Base(s)
return u, nil
}
================================================
FILE: gossh/pgsql/uuid.go
================================================
package pgsql
import (
"encoding/hex"
"fmt"
)
// decodeUUIDBinary interprets the binary format of a uuid, returning it in text format.
func decodeUUIDBinary(src []byte) ([]byte, error) {
if len(src) != 16 {
return nil, fmt.Errorf("pq: unable to decode uuid; bad length: %d", len(src))
}
dst := make([]byte, 36)
dst[8], dst[13], dst[18], dst[23] = '-', '-', '-', '-'
hex.Encode(dst[0:], src[0:4])
hex.Encode(dst[9:], src[4:6])
hex.Encode(dst[14:], src[6:8])
hex.Encode(dst[19:], src[8:10])
hex.Encode(dst[24:], src[10:16])
return dst, nil
}
================================================
FILE: gossh/pty/asm_solaris_amd64.s
================================================
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gc
//+build gc
#include "textflag.h"
//
// System calls for amd64, Solaris are implemented in runtime/syscall_solaris.go
//
TEXT ·sysvicall6(SB),NOSPLIT,$0-88
JMP syscall·sysvicall6(SB)
TEXT ·rawSysvicall6(SB),NOSPLIT,$0-88
JMP syscall·rawSysvicall6(SB)
================================================
FILE: gossh/pty/cmd_windows.go
================================================
//go:build windows
// +build windows
package pty
import (
"os"
"os/exec"
"path/filepath"
"strings"
"syscall"
"unicode/utf16"
"unsafe"
"gossh/sys/windows"
)
// copied from os/exec.Cmd for platform compatibility
// we need to use startupInfoEx for pty support, but os/exec.Cmd only have
// support for startupInfo on windows, so we have to rewrite some internal
// logic for windows while keep its behavior compatible with other platforms.
// windowExecCmd represents an external command being prepared or run.
//
// A cmd cannot be reused after calling its Run, Output or CombinedOutput
// methods.
type windowExecCmd struct {
cmd *exec.Cmd
waitCalled bool
conPty *WindowsPty
conTty *WindowsTty
attrList *windows.ProcThreadAttributeListContainer
}
func (c *windowExecCmd) close() error {
c.attrList.Delete()
_ = c.conPty.Close()
_ = c.conTty.Close()
return nil
}
func (c *windowExecCmd) argv() []string {
if len(c.cmd.Args) > 0 {
return c.cmd.Args
}
return []string{c.cmd.Path}
}
//
// Helpers for working with Windows. These are exact copies of the same utilities found in the go stdlib.
//
func lookExtensions(path, dir string) (string, error) {
if filepath.Base(path) == path {
path = filepath.Join(".", path)
}
if dir == "" {
return exec.LookPath(path)
}
if filepath.VolumeName(path) != "" {
return exec.LookPath(path)
}
if len(path) > 1 && os.IsPathSeparator(path[0]) {
return exec.LookPath(path)
}
dirandpath := filepath.Join(dir, path)
// We assume that LookPath will only add file extension.
lp, err := exec.LookPath(dirandpath)
if err != nil {
return "", err
}
ext := strings.TrimPrefix(lp, dirandpath)
return path + ext, nil
}
func dedupEnvCase(caseInsensitive bool, env []string) []string {
// Construct the output in reverse order, to preserve the
// last occurrence of each key.
out := make([]string, 0, len(env))
saw := make(map[string]bool, len(env))
for n := len(env); n > 0; n-- {
kv := env[n-1]
i := strings.Index(kv, "=")
if i == 0 {
// We observe in practice keys with a single leading "=" on Windows.
// TODO(#49886): Should we consume only the first leading "=" as part
// of the key, or parse through arbitrarily many of them until a non-"="?
i = strings.Index(kv[1:], "=") + 1
}
if i < 0 {
if kv != "" {
// The entry is not of the form "key=value" (as it is required to be).
// Leave it as-is for now.
// TODO(#52436): should we strip or reject these bogus entries?
out = append(out, kv)
}
continue
}
k := kv[:i]
if caseInsensitive {
k = strings.ToLower(k)
}
if saw[k] {
continue
}
saw[k] = true
out = append(out, kv)
}
// Now reverse the slice to restore the original order.
for i := 0; i < len(out)/2; i++ {
j := len(out) - i - 1
out[i], out[j] = out[j], out[i]
}
return out
}
func addCriticalEnv(env []string) []string {
for _, kv := range env {
eq := strings.Index(kv, "=")
if eq < 0 {
continue
}
k := kv[:eq]
if strings.EqualFold(k, "SYSTEMROOT") {
// We already have it.
return env
}
}
return append(env, "SYSTEMROOT="+os.Getenv("SYSTEMROOT"))
}
func execEnvDefault(sys *syscall.SysProcAttr) (env []string, err error) {
if sys == nil || sys.Token == 0 {
return syscall.Environ(), nil
}
var block *uint16
err = windows.CreateEnvironmentBlock(&block, windows.Token(sys.Token), false)
if err != nil {
return nil, err
}
defer windows.DestroyEnvironmentBlock(block)
blockp := uintptr(unsafe.Pointer(block))
for {
// find NUL terminator
end := unsafe.Pointer(blockp)
for *(*uint16)(end) != 0 {
end = unsafe.Pointer(uintptr(end) + 2)
}
n := (uintptr(end) - uintptr(unsafe.Pointer(blockp))) / 2
if n == 0 {
// environment block ends with empty string
break
}
entry := (*[(1 << 30) - 1]uint16)(unsafe.Pointer(blockp))[:n:n]
env = append(env, string(utf16.Decode(entry)))
blockp += 2 * (uintptr(len(entry)) + 1)
}
return
}
func createEnvBlock(envv []string) *uint16 {
if len(envv) == 0 {
return &utf16.Encode([]rune("\x00\x00"))[0]
}
length := 0
for _, s := range envv {
length += len(s) + 1
}
length += 1
b := make([]byte, length)
i := 0
for _, s := range envv {
l := len(s)
copy(b[i:i+l], []byte(s))
copy(b[i+l:i+l+1], []byte{0})
i = i + l + 1
}
copy(b[i:i+1], []byte{0})
return &utf16.Encode([]rune(string(b)))[0]
}
func makeCmdLine(args []string) string {
var s string
for _, v := range args {
if s != "" {
s += " "
}
s += windows.EscapeArg(v)
}
return s
}
func isSlash(c uint8) bool {
return c == '\\' || c == '/'
}
func normalizeDir(dir string) (name string, err error) {
ndir, err := syscall.FullPath(dir)
if err != nil {
return "", err
}
if len(ndir) > 2 && isSlash(ndir[0]) && isSlash(ndir[1]) {
// dir cannot have \\server\share\path form
return "", syscall.EINVAL
}
return ndir, nil
}
func volToUpper(ch int) int {
if 'a' <= ch && ch <= 'z' {
ch += 'A' - 'a'
}
return ch
}
func joinExeDirAndFName(dir, p string) (name string, err error) {
if len(p) == 0 {
return "", syscall.EINVAL
}
if len(p) > 2 && isSlash(p[0]) && isSlash(p[1]) {
// \\server\share\path form
return p, nil
}
if len(p) > 1 && p[1] == ':' {
// has drive letter
if len(p) == 2 {
return "", syscall.EINVAL
}
if isSlash(p[2]) {
return p, nil
} else {
d, err := normalizeDir(dir)
if err != nil {
return "", err
}
if volToUpper(int(p[0])) == volToUpper(int(d[0])) {
return syscall.FullPath(d + "\\" + p[2:])
} else {
return syscall.FullPath(p)
}
}
} else {
// no drive letter
d, err := normalizeDir(dir)
if err != nil {
return "", err
}
if isSlash(p[0]) {
return windows.FullPath(d[:2] + p)
} else {
return windows.FullPath(d + "\\" + p)
}
}
}
================================================
FILE: gossh/pty/doc.go
================================================
// Package pty provides functions for working with Unix terminals.
package pty
import (
"errors"
"io"
"time"
)
// ErrUnsupported is returned if a function is not
// available on the current platform.
var ErrUnsupported = errors.New("unsupported")
// Open a pty and its corresponding tty.
func Open() (Pty, Tty, error) {
return open()
}
// FdHolder surfaces the Fd() method of the underlying handle.
type FdHolder interface {
Fd() uintptr
}
// DeadlineHolder surfaces the SetDeadline() method to sets the read and write deadlines.
type DeadlineHolder interface {
SetDeadline(t time.Time) error
}
// Pty for terminal control in current process.
//
// - For Unix systems, the real type is *os.File.
// - For Windows, the real type is a *WindowsPty for ConPTY handle.
type Pty interface {
// FdHolder is intended to resize / control ioctls of the TTY of the child process in current process.
FdHolder
Name() string
// WriteString is only used to identify Pty and Tty.
WriteString(s string) (n int, err error)
io.ReadWriteCloser
}
// Tty for data I/O in child process.
//
// - For Unix systems, the real type is *os.File.
// - For Windows, the real type is a *WindowsTty, which is a combination of two pipe file.
type Tty interface {
// FdHolder Fd only intended for manual InheritSize from Pty.
FdHolder
Name() string
io.ReadWriteCloser
}
================================================
FILE: gossh/pty/ioctl.go
================================================
//go:build !windows && go1.12
// +build !windows,go1.12
package pty
import "os"
func ioctl(f *os.File, cmd, ptr uintptr) error {
sc, e := f.SyscallConn()
if e != nil {
return ioctlInner(f.Fd(), cmd, ptr) // Fall back to blocking io (old behavior).
}
ch := make(chan error, 1)
defer close(ch)
e = sc.Control(func(fd uintptr) { ch <- ioctlInner(fd, cmd, ptr) })
if e != nil {
return e
}
e = <-ch
return e
}
================================================
FILE: gossh/pty/ioctl_bsd.go
================================================
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
// +build darwin dragonfly freebsd netbsd openbsd
package pty
// from
const (
_IOC_VOID uintptr = 0x20000000
_IOC_OUT uintptr = 0x40000000
_IOC_IN uintptr = 0x80000000
_IOC_IN_OUT uintptr = _IOC_OUT | _IOC_IN
_IOC_DIRMASK = _IOC_VOID | _IOC_OUT | _IOC_IN
_IOC_PARAM_SHIFT = 13
_IOC_PARAM_MASK = (1 << _IOC_PARAM_SHIFT) - 1
)
func _IOC_PARM_LEN(ioctl uintptr) uintptr {
return (ioctl >> 16) & _IOC_PARAM_MASK
}
func _IOC(inout uintptr, group byte, ioctl_num uintptr, param_len uintptr) uintptr {
return inout | (param_len&_IOC_PARAM_MASK)<<16 | uintptr(group)<<8 | ioctl_num
}
func _IO(group byte, ioctl_num uintptr) uintptr {
return _IOC(_IOC_VOID, group, ioctl_num, 0)
}
func _IOR(group byte, ioctl_num uintptr, param_len uintptr) uintptr {
return _IOC(_IOC_OUT, group, ioctl_num, param_len)
}
func _IOW(group byte, ioctl_num uintptr, param_len uintptr) uintptr {
return _IOC(_IOC_IN, group, ioctl_num, param_len)
}
func _IOWR(group byte, ioctl_num uintptr, param_len uintptr) uintptr {
return _IOC(_IOC_IN_OUT, group, ioctl_num, param_len)
}
================================================
FILE: gossh/pty/ioctl_inner.go
================================================
//go:build !windows && !solaris && !aix
// +build !windows,!solaris,!aix
package pty
import "syscall"
// Local syscall const values.
const (
TIOCGWINSZ = syscall.TIOCGWINSZ
TIOCSWINSZ = syscall.TIOCSWINSZ
)
func ioctlInner(fd, cmd, ptr uintptr) error {
_, _, e := syscall.Syscall(syscall.SYS_IOCTL, fd, cmd, ptr)
if e != 0 {
return e
}
return nil
}
================================================
FILE: gossh/pty/ioctl_legacy.go
================================================
//go:build !windows && !go1.12
// +build !windows,!go1.12
package pty
import "os"
func ioctl(f *os.File, cmd, ptr uintptr) error {
return ioctlInner(f.Fd(), cmd, ptr) // fall back to blocking io (old behavior)
}
================================================
FILE: gossh/pty/ioctl_solaris.go
================================================
//go:build solaris
// +build solaris
package pty
import (
"syscall"
"unsafe"
)
//go:cgo_import_dynamic libc_ioctl ioctl "libc.so"
//go:linkname procioctl libc_ioctl
var procioctl uintptr
const (
// see /usr/include/sys/stropts.h
I_PUSH = uintptr((int32('S')<<8 | 002))
I_STR = uintptr((int32('S')<<8 | 010))
I_FIND = uintptr((int32('S')<<8 | 013))
// see /usr/include/sys/ptms.h
ISPTM = (int32('P') << 8) | 1
UNLKPT = (int32('P') << 8) | 2
PTSSTTY = (int32('P') << 8) | 3
ZONEPT = (int32('P') << 8) | 4
OWNERPT = (int32('P') << 8) | 5
// see /usr/include/sys/termios.h
TIOCSWINSZ = (uint32('T') << 8) | 103
TIOCGWINSZ = (uint32('T') << 8) | 104
)
type strioctl struct {
icCmd int32
icTimeout int32
icLen int32
icDP unsafe.Pointer
}
// Defined in asm_solaris_amd64.s.
func sysvicall6(trap, nargs, a1, a2, a3, a4, a5, a6 uintptr) (r1, r2 uintptr, err syscall.Errno)
func ioctlInner(fd, cmd, ptr uintptr) error {
if _, _, errno := sysvicall6(uintptr(unsafe.Pointer(&procioctl)), 3, fd, cmd, ptr, 0, 0, 0); errno != 0 {
return errno
}
return nil
}
================================================
FILE: gossh/pty/ioctl_unsupported.go
================================================
//go:build aix
// +build aix
package pty
const (
TIOCGWINSZ = 0
TIOCSWINSZ = 0
)
func ioctlInner(fd, cmd, ptr uintptr) error {
return ErrUnsupported
}
================================================
FILE: gossh/pty/pty_darwin.go
================================================
//go:build darwin
// +build darwin
package pty
import (
"errors"
"os"
"syscall"
"unsafe"
)
func open() (pty, tty *os.File, err error) {
pFD, err := syscall.Open("/dev/ptmx", syscall.O_RDWR|syscall.O_CLOEXEC, 0)
if err != nil {
return nil, nil, err
}
p := os.NewFile(uintptr(pFD), "/dev/ptmx")
// In case of error after this point, make sure we close the ptmx fd.
defer func() {
if err != nil {
_ = p.Close() // Best effort.
}
}()
sname, err := ptsname(p)
if err != nil {
return nil, nil, err
}
if err := grantpt(p); err != nil {
return nil, nil, err
}
if err := unlockpt(p); err != nil {
return nil, nil, err
}
t, err := os.OpenFile(sname, os.O_RDWR|syscall.O_NOCTTY, 0)
if err != nil {
return nil, nil, err
}
return p, t, nil
}
func ptsname(f *os.File) (string, error) {
n := make([]byte, _IOC_PARM_LEN(syscall.TIOCPTYGNAME))
err := ioctl(f, syscall.TIOCPTYGNAME, uintptr(unsafe.Pointer(&n[0])))
if err != nil {
return "", err
}
for i, c := range n {
if c == 0 {
return string(n[:i]), nil
}
}
return "", errors.New("TIOCPTYGNAME string not NUL-terminated")
}
func grantpt(f *os.File) error {
return ioctl(f, syscall.TIOCPTYGRANT, 0)
}
func unlockpt(f *os.File) error {
return ioctl(f, syscall.TIOCPTYUNLK, 0)
}
================================================
FILE: gossh/pty/pty_dragonfly.go
================================================
//go:build dragonfly
// +build dragonfly
package pty
import (
"errors"
"os"
"strings"
"syscall"
"unsafe"
)
// same code as pty_darwin.go
func open() (pty, tty *os.File, err error) {
p, err := os.OpenFile("/dev/ptmx", os.O_RDWR, 0)
if err != nil {
return nil, nil, err
}
// In case of error after this point, make sure we close the ptmx fd.
defer func() {
if err != nil {
_ = p.Close() // Best effort.
}
}()
sname, err := ptsname(p)
if err != nil {
return nil, nil, err
}
if err := grantpt(p); err != nil {
return nil, nil, err
}
if err := unlockpt(p); err != nil {
return nil, nil, err
}
t, err := os.OpenFile(sname, os.O_RDWR, 0)
if err != nil {
return nil, nil, err
}
return p, t, nil
}
func grantpt(f *os.File) error {
_, err := isptmaster(f)
return err
}
func unlockpt(f *os.File) error {
_, err := isptmaster(f)
return err
}
func isptmaster(f *os.File) (bool, error) {
err := ioctl(f, syscall.TIOCISPTMASTER, 0)
return err == nil, err
}
var (
emptyFiodgnameArg fiodgnameArg
ioctl_FIODNAME = _IOW('f', 120, unsafe.Sizeof(emptyFiodgnameArg))
)
func ptsname(f *os.File) (string, error) {
name := make([]byte, _C_SPECNAMELEN)
fa := fiodgnameArg{Name: (*byte)(unsafe.Pointer(&name[0])), Len: _C_SPECNAMELEN, Pad_cgo_0: [4]byte{0, 0, 0, 0}}
err := ioctl(f, ioctl_FIODNAME, uintptr(unsafe.Pointer(&fa)))
if err != nil {
return "", err
}
for i, c := range name {
if c == 0 {
s := "/dev/" + string(name[:i])
return strings.Replace(s, "ptm", "pts", -1), nil
}
}
return "", errors.New("TIOCPTYGNAME string not NUL-terminated")
}
================================================
FILE: gossh/pty/pty_freebsd.go
================================================
//go:build freebsd
// +build freebsd
package pty
import (
"errors"
"os"
"syscall"
"unsafe"
)
func posixOpenpt(oflag int) (fd int, err error) {
r0, _, e1 := syscall.Syscall(syscall.SYS_POSIX_OPENPT, uintptr(oflag), 0, 0)
fd = int(r0)
if e1 != 0 {
err = e1
}
return fd, err
}
func open() (pty, tty *os.File, err error) {
fd, err := posixOpenpt(syscall.O_RDWR | syscall.O_CLOEXEC)
if err != nil {
return nil, nil, err
}
p := os.NewFile(uintptr(fd), "/dev/pts")
// In case of error after this point, make sure we close the pts fd.
defer func() {
if err != nil {
_ = p.Close() // Best effort.
}
}()
sname, err := ptsname(p)
if err != nil {
return nil, nil, err
}
t, err := os.OpenFile("/dev/"+sname, os.O_RDWR, 0)
if err != nil {
return nil, nil, err
}
return p, t, nil
}
func isptmaster(f *os.File) (bool, error) {
err := ioctl(f, syscall.TIOCPTMASTER, 0)
return err == nil, err
}
var (
emptyFiodgnameArg fiodgnameArg
ioctlFIODGNAME = _IOW('f', 120, unsafe.Sizeof(emptyFiodgnameArg))
)
func ptsname(f *os.File) (string, error) {
master, err := isptmaster(f)
if err != nil {
return "", err
}
if !master {
return "", syscall.EINVAL
}
const n = _C_SPECNAMELEN + 1
var (
buf = make([]byte, n)
arg = fiodgnameArg{Len: n, Buf: (*byte)(unsafe.Pointer(&buf[0]))}
)
if err := ioctl(f, ioctlFIODGNAME, uintptr(unsafe.Pointer(&arg))); err != nil {
return "", err
}
for i, c := range buf {
if c == 0 {
return string(buf[:i]), nil
}
}
return "", errors.New("FIODGNAME string not NUL-terminated")
}
================================================
FILE: gossh/pty/pty_linux.go
================================================
//go:build linux
// +build linux
package pty
import (
"os"
"strconv"
"syscall"
"unsafe"
)
func open() (pty, tty *os.File, err error) {
p, err := os.OpenFile("/dev/ptmx", os.O_RDWR, 0)
if err != nil {
return nil, nil, err
}
// In case of error after this point, make sure we close the ptmx fd.
defer func() {
if err != nil {
_ = p.Close() // Best effort.
}
}()
sname, err := ptsname(p)
if err != nil {
return nil, nil, err
}
if err := unlockpt(p); err != nil {
return nil, nil, err
}
t, err := os.OpenFile(sname, os.O_RDWR|syscall.O_NOCTTY, 0) //nolint:gosec // Expected Open from a variable.
if err != nil {
return nil, nil, err
}
return p, t, nil
}
func ptsname(f *os.File) (string, error) {
var n _C_uint
err := ioctl(f, syscall.TIOCGPTN, uintptr(unsafe.Pointer(&n))) //nolint:gosec // Expected unsafe pointer for Syscall call.
if err != nil {
return "", err
}
return "/dev/pts/" + strconv.Itoa(int(n)), nil
}
func unlockpt(f *os.File) error {
var u _C_int
// use TIOCSPTLCK with a pointer to zero to clear the lock.
return ioctl(f, syscall.TIOCSPTLCK, uintptr(unsafe.Pointer(&u))) //nolint:gosec // Expected unsafe pointer for Syscall call.
}
================================================
FILE: gossh/pty/pty_netbsd.go
================================================
//go:build netbsd
// +build netbsd
package pty
import (
"errors"
"os"
"syscall"
"unsafe"
)
func open() (pty, tty *os.File, err error) {
p, err := os.OpenFile("/dev/ptmx", os.O_RDWR, 0)
if err != nil {
return nil, nil, err
}
// In case of error after this point, make sure we close the ptmx fd.
defer func() {
if err != nil {
_ = p.Close() // Best effort.
}
}()
sname, err := ptsname(p)
if err != nil {
return nil, nil, err
}
if err := grantpt(p); err != nil {
return nil, nil, err
}
// In NetBSD unlockpt() does nothing, so it isn't called here.
t, err := os.OpenFile(sname, os.O_RDWR|syscall.O_NOCTTY, 0)
if err != nil {
return nil, nil, err
}
return p, t, nil
}
func ptsname(f *os.File) (string, error) {
/*
* from ptsname(3): The ptsname() function is equivalent to:
* struct ptmget pm;
* ioctl(fd, TIOCPTSNAME, &pm) == -1 ? NULL : pm.sn;
*/
var ptm ptmget
if err := ioctl(f, uintptr(ioctl_TIOCPTSNAME), uintptr(unsafe.Pointer(&ptm))); err != nil {
return "", err
}
name := make([]byte, len(ptm.Sn))
for i, c := range ptm.Sn {
name[i] = byte(c)
if c == 0 {
return string(name[:i]), nil
}
}
return "", errors.New("TIOCPTSNAME string not NUL-terminated")
}
func grantpt(f *os.File) error {
/*
* from grantpt(3): Calling grantpt() is equivalent to:
* ioctl(fd, TIOCGRANTPT, 0);
*/
return ioctl(f, uintptr(ioctl_TIOCGRANTPT), 0)
}
================================================
FILE: gossh/pty/pty_openbsd.go
================================================
//go:build openbsd
// +build openbsd
package pty
import (
"os"
"syscall"
"unsafe"
)
func cInt8ToString(in []int8) string {
var s []byte
for _, v := range in {
if v == 0 {
break
}
s = append(s, byte(v))
}
return string(s)
}
func open() (pty, tty *os.File, err error) {
/*
* from ptm(4):
* The PTMGET command allocates a free pseudo terminal, changes its
* ownership to the caller, revokes the access privileges for all previous
* users, opens the file descriptors for the pty and tty devices and
* returns them to the caller in struct ptmget.
*/
p, err := os.OpenFile("/dev/ptm", os.O_RDWR|syscall.O_CLOEXEC, 0)
if err != nil {
return nil, nil, err
}
defer p.Close()
var ptm ptmget
if err := ioctl(p, uintptr(ioctl_PTMGET), uintptr(unsafe.Pointer(&ptm))); err != nil {
return nil, nil, err
}
pty = os.NewFile(uintptr(ptm.Cfd), cInt8ToString(ptm.Cn[:]))
tty = os.NewFile(uintptr(ptm.Sfd), cInt8ToString(ptm.Sn[:]))
return pty, tty, nil
}
================================================
FILE: gossh/pty/pty_solaris.go
================================================
//go:build solaris
// +build solaris
package pty
/* based on:
http://src.illumos.org/source/xref/illumos-gate/usr/src/lib/libc/port/gen/pt.c
*/
import (
"errors"
"os"
"strconv"
"syscall"
"unsafe"
)
func open() (pty, tty *os.File, err error) {
ptmxfd, err := syscall.Open("/dev/ptmx", syscall.O_RDWR|syscall.O_NOCTTY, 0)
if err != nil {
return nil, nil, err
}
p := os.NewFile(uintptr(ptmxfd), "/dev/ptmx")
// In case of error after this point, make sure we close the ptmx fd.
defer func() {
if err != nil {
_ = p.Close() // Best effort.
}
}()
sname, err := ptsname(p)
if err != nil {
return nil, nil, err
}
if err := grantpt(p); err != nil {
return nil, nil, err
}
if err := unlockpt(p); err != nil {
return nil, nil, err
}
ptsfd, err := syscall.Open(sname, os.O_RDWR|syscall.O_NOCTTY, 0)
if err != nil {
return nil, nil, err
}
t := os.NewFile(uintptr(ptsfd), sname)
// In case of error after this point, make sure we close the pts fd.
defer func() {
if err != nil {
_ = t.Close() // Best effort.
}
}()
// pushing terminal driver STREAMS modules as per pts(7)
for _, mod := range []string{"ptem", "ldterm", "ttcompat"} {
if err := streamsPush(t, mod); err != nil {
return nil, nil, err
}
}
return p, t, nil
}
func ptsname(f *os.File) (string, error) {
dev, err := ptsdev(f)
if err != nil {
return "", err
}
fn := "/dev/pts/" + strconv.FormatInt(int64(dev), 10)
if err := syscall.Access(fn, 0); err != nil {
return "", err
}
return fn, nil
}
func unlockpt(f *os.File) error {
istr := strioctl{
icCmd: UNLKPT,
icTimeout: 0,
icLen: 0,
icDP: nil,
}
return ioctl(f, I_STR, uintptr(unsafe.Pointer(&istr)))
}
func minor(x uint64) uint64 { return x & 0377 }
func ptsdev(f *os.File) (uint64, error) {
istr := strioctl{
icCmd: ISPTM,
icTimeout: 0,
icLen: 0,
icDP: nil,
}
if err := ioctl(f, I_STR, uintptr(unsafe.Pointer(&istr))); err != nil {
return 0, err
}
var errors = make(chan error, 1)
var results = make(chan uint64, 1)
defer close(errors)
defer close(results)
var err error
var sc syscall.RawConn
sc, err = f.SyscallConn()
if err != nil {
return 0, err
}
err = sc.Control(func(fd uintptr) {
var status syscall.Stat_t
if err := syscall.Fstat(int(fd), &status); err != nil {
results <- 0
errors <- err
}
results <- uint64(minor(status.Rdev))
errors <- nil
})
if err != nil {
return 0, err
}
return <-results, <-errors
}
type ptOwn struct {
rUID int32
rGID int32
}
func grantpt(f *os.File) error {
if _, err := ptsdev(f); err != nil {
return err
}
pto := ptOwn{
rUID: int32(os.Getuid()),
// XXX should first attempt to get gid of DEFAULT_TTY_GROUP="tty"
rGID: int32(os.Getgid()),
}
istr := strioctl{
icCmd: OWNERPT,
icTimeout: 0,
icLen: int32(unsafe.Sizeof(strioctl{})),
icDP: unsafe.Pointer(&pto),
}
if err := ioctl(f, I_STR, uintptr(unsafe.Pointer(&istr))); err != nil {
return errors.New("access denied")
}
return nil
}
// streamsPush pushes STREAMS modules if not already done so.
func streamsPush(f *os.File, mod string) error {
buf := []byte(mod)
// XXX I_FIND is not returning an error when the module
// is already pushed even though truss reports a return
// value of 1. A bug in the Go Solaris syscall interface?
// XXX without this we are at risk of the issue
// https://www.illumos.org/issues/9042
// but since we are not using libc or XPG4.2, we should not be
// double-pushing modules
if err := ioctl(f, I_FIND, uintptr(unsafe.Pointer(&buf[0]))); err != nil {
return nil
}
return ioctl(f, I_PUSH, uintptr(unsafe.Pointer(&buf[0])))
}
================================================
FILE: gossh/pty/pty_unsupported.go
================================================
//go:build !linux && !darwin && !freebsd && !dragonfly && !netbsd && !openbsd && !solaris && !windows
// +build !linux,!darwin,!freebsd,!dragonfly,!netbsd,!openbsd,!solaris,!windows
package pty
import (
"os"
)
func open() (pty, tty *os.File, err error) {
return nil, nil, ErrUnsupported
}
================================================
FILE: gossh/pty/pty_windows.go
================================================
//go:build windows
// +build windows
package pty
import (
"fmt"
"os"
"syscall"
"time"
"unsafe"
"gossh/sys/windows"
)
const (
// Ref: https://pkg.go.dev/golang.org/x/sys/windows#pkg-constants
PROC_THREAD_ATTRIBUTE_PSEUDOCONSOLE = 0x20016
)
type WindowsPty struct {
handle windows.Handle
r, w *os.File
}
type WindowsTty struct {
handle windows.Handle
r, w *os.File
}
var (
// NOTE(security): As noted by the comment of syscall.NewLazyDLL and syscall.LoadDLL
// user need to call internal/syscall/windows/sysdll.Add("kernel32.dll") to make sure
// the kernel32.dll is loaded from windows system path.
//
// Ref: https://pkg.go.dev/syscall@go1.13?GOOS=windows#LoadDLL
kernel32DLL = windows.NewLazySystemDLL("kernel32.dll")
// https://docs.microsoft.com/en-us/windows/console/createpseudoconsole
createPseudoConsole = kernel32DLL.NewProc("CreatePseudoConsole")
closePseudoConsole = kernel32DLL.NewProc("ClosePseudoConsole")
resizePseudoConsole = kernel32DLL.NewProc("ResizePseudoConsole")
getConsoleScreenBufferInfo = kernel32DLL.NewProc("GetConsoleScreenBufferInfo")
)
func open() (_ Pty, _ Tty, err error) {
pr, consoleW, err := os.Pipe()
if err != nil {
return nil, nil, err
}
consoleR, pw, err := os.Pipe()
if err != nil {
// Closing everything. Best effort.
_ = consoleW.Close()
_ = pr.Close()
return nil, nil, err
}
var consoleHandle windows.Handle
// TODO: As we removed the use of `.Fd()` on Unix (https://github.com/creack/pty/pull/168), we need to check if we should do the same here.
if err := procCreatePseudoConsole(
windows.Handle(consoleR.Fd()),
windows.Handle(consoleW.Fd()),
0,
&consoleHandle); err != nil {
// Closing everything. Best effort.
_ = consoleW.Close()
_ = pr.Close()
_ = pw.Close()
_ = consoleR.Close()
return nil, nil, err
}
return &WindowsPty{
handle: consoleHandle,
r: pr,
w: pw,
}, &WindowsTty{
handle: consoleHandle,
r: consoleR,
w: consoleW,
}, nil
}
func (p *WindowsPty) Name() string {
return p.r.Name()
}
func (p *WindowsPty) Fd() uintptr {
return uintptr(p.handle)
}
func (p *WindowsPty) Read(data []byte) (int, error) {
return p.r.Read(data)
}
func (p *WindowsPty) Write(data []byte) (int, error) {
return p.w.Write(data)
}
func (p *WindowsPty) WriteString(s string) (int, error) {
return p.w.WriteString(s)
}
func (p *WindowsPty) UpdateProcThreadAttribute(attrList *windows.ProcThreadAttributeListContainer) error {
if err := attrList.Update(
PROC_THREAD_ATTRIBUTE_PSEUDOCONSOLE,
unsafe.Pointer(p.handle),
unsafe.Sizeof(p.handle),
); err != nil {
return fmt.Errorf("failed to update proc thread attributes for pseudo console: %w", err)
}
return nil
}
func (p *WindowsPty) Close() error {
// Best effort.
_ = p.r.Close()
_ = p.w.Close()
if err := closePseudoConsole.Find(); err != nil {
return err
}
_, _, err := closePseudoConsole.Call(uintptr(p.handle))
return err
}
func (p *WindowsPty) SetDeadline(value time.Time) error {
return os.ErrNoDeadline
}
func (t *WindowsTty) Name() string {
return t.r.Name()
}
func (t *WindowsTty) Fd() uintptr {
return uintptr(t.handle)
}
func (t *WindowsTty) Read(p []byte) (int, error) {
return t.r.Read(p)
}
func (t *WindowsTty) Write(p []byte) (int, error) {
return t.w.Write(p)
}
func (t *WindowsTty) Close() error {
_ = t.r.Close() // Best effort.
return t.w.Close()
}
func (t *WindowsTty) SetDeadline(value time.Time) error {
return os.ErrNoDeadline
}
func procCreatePseudoConsole(hInput windows.Handle, hOutput windows.Handle, dwFlags uint32, consoleHandle *windows.Handle) error {
if err := createPseudoConsole.Find(); err != nil {
return err
}
// TODO: Check if it is expected to ignore `err` here.
r0, _, _ := createPseudoConsole.Call(
(windowsCoord{X: 80, Y: 30}).Pack(), // Size: default 80x30 window.
uintptr(hInput), // Console input.
uintptr(hOutput), // Console output.
uintptr(dwFlags), // Console flags, currently only PSEUDOCONSOLE_INHERIT_CURSOR supported.
uintptr(unsafe.Pointer(consoleHandle)), // Console handler value return.
)
if int32(r0) < 0 {
if r0&0x1fff0000 == 0x00070000 {
r0 &= 0xffff
}
// S_OK: 0
return syscall.Errno(r0)
}
return nil
}
================================================
FILE: gossh/pty/run.go
================================================
package pty
import (
"os/exec"
)
// Start assigns a pseudo-terminal tty os.File to c.Stdin, c.Stdout,
// and c.Stderr, calls c.Start, and returns the File of the tty's
// corresponding pty.
//
// Starts the process in a new session and sets the controlling terminal.
func Start(cmd *exec.Cmd) (Pty, error) {
return StartWithSize(cmd, nil)
}
================================================
FILE: gossh/pty/run_unix.go
================================================
//go:build !windows
// +build !windows
package pty
import (
"os/exec"
"syscall"
)
// StartWithSize assigns a pseudo-terminal Tty to c.Stdin, c.Stdout,
// and c.Stderr, calls c.Start, and returns the File of the tty's
// corresponding Pty.
//
// This will resize the Pty to the specified size before starting the command.
// Starts the process in a new session and sets the controlling terminal.
func StartWithSize(c *exec.Cmd, sz *Winsize) (Pty, error) {
if c.SysProcAttr == nil {
c.SysProcAttr = &syscall.SysProcAttr{}
}
c.SysProcAttr.Setsid = true
c.SysProcAttr.Setctty = true
return StartWithAttrs(c, sz, c.SysProcAttr)
}
// StartWithAttrs assigns a pseudo-terminal Tty to c.Stdin, c.Stdout,
// and c.Stderr, calls c.Start, and returns the File of the tty's
// corresponding Pty.
//
// This will resize the Pty to the specified size before starting the command if a size is provided.
// The `attrs` parameter overrides the one set in c.SysProcAttr.
//
// This should generally not be needed. Used in some edge cases where it is needed to create a pty
// without a controlling terminal.
func StartWithAttrs(c *exec.Cmd, sz *Winsize, attrs *syscall.SysProcAttr) (Pty, error) {
pty, tty, err := open()
if err != nil {
return nil, err
}
defer func() {
// always close tty fds since it's being used in another process
// but pty is kept to resize tty
_ = tty.Close()
}()
if sz != nil {
if err := Setsize(pty, sz); err != nil {
_ = pty.Close()
return nil, err
}
}
if c.Stdout == nil {
c.Stdout = tty
}
if c.Stderr == nil {
c.Stderr = tty
}
if c.Stdin == nil {
c.Stdin = tty
}
c.SysProcAttr = attrs
if err := c.Start(); err != nil {
_ = pty.Close()
return nil, err
}
return pty, err
}
================================================
FILE: gossh/pty/run_windows.go
================================================
//go:build windows
// +build windows
package pty
import (
"errors"
"fmt"
"os"
"os/exec"
"syscall"
"unsafe"
"gossh/sys/windows"
)
// StartWithSize assigns a pseudo-terminal Tty to c.Stdin, c.Stdout,
// and c.Stderr, calls c.Start, and returns the File of the tty's
// corresponding Pty.
//
// This will resize the Pty to the specified size before starting the command.
// Starts the process in a new session and sets the controlling terminal.
func StartWithSize(c *exec.Cmd, sz *Winsize) (Pty, error) {
return StartWithAttrs(c, sz, c.SysProcAttr)
}
// StartWithAttrs assigns a pseudo-terminal Tty to c.Stdin, c.Stdout,
// and c.Stderr, calls c.Start, and returns the File of the tty's
// corresponding Pty.
//
// This will resize the Pty to the specified size before starting the command if a size is provided.
// The `attrs` parameter overrides the one set in c.SysProcAttr.
//
// This should generally not be needed. Used in some edge cases where it is needed to create a pty
// without a controlling terminal.
func StartWithAttrs(c *exec.Cmd, sz *Winsize, attrs *syscall.SysProcAttr) (Pty, error) {
pty, tty, err := open()
if err != nil {
return nil, err
}
defer func() {
// Unlike unix command exec, do not close tty unless error happened.
if err != nil {
_ = pty.Close() // Best effort.
}
}()
if sz != nil {
if err := Setsize(pty, sz); err != nil {
return nil, err
}
}
// Unlike unix command exec, do not set stdin/stdout/stderr.
c.SysProcAttr = attrs
// Do not use os/exec.Start since we need to append console handler to startup info.
w := windowExecCmd{
cmd: c,
waitCalled: false,
conPty: pty.(*WindowsPty),
conTty: tty.(*WindowsTty),
}
if err := w.Start(); err != nil {
return nil, err
}
return pty, err
}
// Start the specified command but does not wait for it to complete.
//
// If Start returns successfully, the c.Process field will be set.
//
// The Wait method will return the exit code and release associated resources
// once the command exits.
func (c *windowExecCmd) Start() error {
if c.cmd.Process != nil {
return errors.New("exec: already started")
}
argv0 := c.cmd.Path
var argv0p *uint16
var argvp *uint16
var dirp *uint16
var err error
sys := c.cmd.SysProcAttr
if sys == nil {
sys = &syscall.SysProcAttr{}
}
if c.cmd.Env == nil {
c.cmd.Env, err = execEnvDefault(sys)
if err != nil {
return err
}
}
var lp string
lp, err = lookExtensions(c.cmd.Path, c.cmd.Dir)
if err != nil {
return err
}
c.cmd.Path = lp
if len(c.cmd.Dir) != 0 {
// Windows CreateProcess looks for argv0 relative to the current
// directory, and, only once the new process is started, it does
// Chdir(attr.Dir). We are adjusting for that difference here by
// making argv0 absolute.
argv0, err = joinExeDirAndFName(c.cmd.Dir, c.cmd.Path)
if err != nil {
return err
}
}
argv0p, err = syscall.UTF16PtrFromString(argv0)
if err != nil {
return err
}
var cmdline string
// Windows CreateProcess takes the command line as a single string:
// use attr.CmdLine if set, else build the command line by escaping
// and joining each argument with spaces.
if sys.CmdLine != "" {
cmdline = sys.CmdLine
} else {
cmdline = makeCmdLine(c.argv())
}
if len(cmdline) != 0 {
argvp, err = windows.UTF16PtrFromString(cmdline)
if err != nil {
return err
}
}
if len(c.cmd.Dir) != 0 {
dirp, err = windows.UTF16PtrFromString(c.cmd.Dir)
if err != nil {
return err
}
}
// Acquire the fork lock so that no other threads
// create new fds that are not yet close-on-exec
// before we fork.
syscall.ForkLock.Lock()
defer syscall.ForkLock.Unlock()
siEx := new(windows.StartupInfoEx)
siEx.Flags = windows.STARTF_USESTDHANDLES
if sys.HideWindow {
siEx.Flags |= syscall.STARTF_USESHOWWINDOW
siEx.ShowWindow = syscall.SW_HIDE
}
pi := new(windows.ProcessInformation)
// Need EXTENDED_STARTUPINFO_PRESENT as we're making use of the attribute list field.
flags := sys.CreationFlags | uint32(windows.CREATE_UNICODE_ENVIRONMENT) | windows.EXTENDED_STARTUPINFO_PRESENT
c.attrList, err = windows.NewProcThreadAttributeList(1)
if err != nil {
return fmt.Errorf("failed to initialize process thread attribute list: %w", err)
}
if c.conPty != nil {
if err = c.conPty.UpdateProcThreadAttribute(c.attrList); err != nil {
return err
}
}
siEx.ProcThreadAttributeList = c.attrList.List()
siEx.Cb = uint32(unsafe.Sizeof(*siEx))
if sys.Token != 0 {
err = windows.CreateProcessAsUser(
windows.Token(sys.Token),
argv0p,
argvp,
nil,
nil,
false,
flags,
createEnvBlock(addCriticalEnv(dedupEnvCase(true, c.cmd.Env))),
dirp,
&siEx.StartupInfo,
pi,
)
} else {
err = windows.CreateProcess(
argv0p,
argvp,
nil,
nil,
false,
flags,
createEnvBlock(addCriticalEnv(dedupEnvCase(true, c.cmd.Env))),
dirp,
&siEx.StartupInfo,
pi,
)
}
if err != nil {
return err
}
defer func() {
_ = windows.CloseHandle(pi.Thread)
}()
c.cmd.Process, err = os.FindProcess(int(pi.ProcessId))
if err != nil {
return err
}
go c.waitProcess(c.cmd.Process)
return nil
}
func (c *windowExecCmd) waitProcess(process *os.Process) {
_, _ = process.Wait()
_ = c.close()
}
================================================
FILE: gossh/pty/types.go
================================================
//go:build ignore
// +build ignore
package pty
import "C"
type (
_C_int C.int
_C_uint C.uint
)
================================================
FILE: gossh/pty/types_dragonfly.go
================================================
//go:build ignore
// +build ignore
package pty
/*
#define _KERNEL
#include
#include
#include
*/
import "C"
const (
_C_SPECNAMELEN = C.SPECNAMELEN /* max length of devicename */
)
type fiodgnameArg C.struct_fiodname_args
================================================
FILE: gossh/pty/types_freebsd.go
================================================
//go:build ignore
// +build ignore
package pty
/*
#include
#include
*/
import "C"
const (
_C_SPECNAMELEN = C.SPECNAMELEN /* max length of devicename */
)
type fiodgnameArg C.struct_fiodgname_arg
================================================
FILE: gossh/pty/types_netbsd.go
================================================
//go:build ignore
// +build ignore
package pty
/*
#include
#include
#include
*/
import "C"
type ptmget C.struct_ptmget
var (
ioctl_TIOCPTSNAME = C.TIOCPTSNAME
ioctl_TIOCGRANTPT = C.TIOCGRANTPT
)
================================================
FILE: gossh/pty/types_openbsd.go
================================================
//go:build ignore
// +build ignore
package pty
/*
#include
#include
#include
*/
import "C"
type ptmget C.struct_ptmget
var ioctl_PTMGET = C.PTMGET
================================================
FILE: gossh/pty/winsize.go
================================================
package pty
// Winsize describes the terminal size.
type Winsize struct {
Rows uint16 // ws_row: Number of rows (in cells).
Cols uint16 // ws_col: Number of columns (in cells).
X uint16 // ws_xpixel: Width in pixels.
Y uint16 // ws_ypixel: Height in pixels.
}
// InheritSize applies the terminal size of pty to tty. This should be run
// in a signal handler for syscall.SIGWINCH to automatically resize the tty when
// the pty receives a window size change notification.
func InheritSize(pty Pty, tty Tty) error {
size, err := GetsizeFull(pty)
if err != nil {
return err
}
return Setsize(tty, size)
}
// Getsize returns the number of rows (lines) and cols (positions
// in each line) in terminal t.
func Getsize(t FdHolder) (rows, cols int, err error) {
ws, err := GetsizeFull(t)
return int(ws.Rows), int(ws.Cols), err
}
================================================
FILE: gossh/pty/winsize_unix.go
================================================
//go:build !windows
// +build !windows
package pty
import (
"os"
"syscall"
"unsafe"
)
// Setsize resizes t to s.
func Setsize(t FdHolder, ws *Winsize) error {
//nolint:gosec // Expected unsafe pointer for Syscall call.
return ioctl(t.(*os.File), syscall.TIOCSWINSZ, uintptr(unsafe.Pointer(ws)))
}
// GetsizeFull returns the full terminal size description.
func GetsizeFull(t FdHolder) (size *Winsize, err error) {
var ws Winsize
//nolint:gosec // Expected unsafe pointer for Syscall call.
if err := ioctl(t.(*os.File), syscall.TIOCGWINSZ, uintptr(unsafe.Pointer(&ws))); err != nil {
return nil, err
}
return &ws, nil
}
================================================
FILE: gossh/pty/winsize_windows.go
================================================
//go:build windows
// +build windows
package pty
import (
"syscall"
"unsafe"
)
// Types from golang.org/x/sys/windows.
// Ref: https://pkg.go.dev/golang.org/x/sys/windows#Coord
type windowsCoord struct {
X int16
Y int16
}
// Ref: https://pkg.go.dev/golang.org/x/sys/windows#SmallRect
type windowsSmallRect struct {
Left int16
Top int16
Right int16
Bottom int16
}
// Ref: https://pkg.go.dev/golang.org/x/sys/windows#ConsoleScreenBufferInfo
type windowsConsoleScreenBufferInfo struct {
Size windowsCoord
CursorPosition windowsCoord
Attributes uint16
Window windowsSmallRect
MaximumWindowSize windowsCoord
}
func (c windowsCoord) Pack() uintptr {
return uintptr((int32(c.Y) << 16) | int32(c.X))
}
// Setsize resizes t to ws.
func Setsize(t FdHolder, ws *Winsize) error {
if err := resizePseudoConsole.Find(); err != nil {
return err
}
// TODO: As we removed the use of `.Fd()` on Unix (https://github.com/creack/pty/pull/168), we need to check if we should do the same here.
// TODO: Check if it is expected to ignore `err` here.
r0, _, _ := resizePseudoConsole.Call(
t.Fd(),
(windowsCoord{X: int16(ws.Cols), Y: int16(ws.Rows)}).Pack(),
)
if int32(r0) < 0 {
if r0&0x1fff0000 == 0x00070000 {
r0 &= 0xffff
}
// S_OK: 0
return syscall.Errno(r0)
}
return nil
}
// GetsizeFull returns the full terminal size description.
func GetsizeFull(t FdHolder) (size *Winsize, err error) {
if err := getConsoleScreenBufferInfo.Find(); err != nil {
return nil, err
}
var info windowsConsoleScreenBufferInfo
// TODO: Check if it is expected to ignore `err` here.
r0, _, _ := getConsoleScreenBufferInfo.Call(t.Fd(), uintptr(unsafe.Pointer(&info)))
if int32(r0) < 0 {
if r0&0x1fff0000 == 0x00070000 {
r0 &= 0xffff
}
// S_OK: 0
return nil, syscall.Errno(r0)
}
return &Winsize{
Rows: uint16(info.Window.Bottom - info.Window.Top + 1),
Cols: uint16(info.Window.Right - info.Window.Left + 1),
}, nil
}
================================================
FILE: gossh/pty/ztypes_386.go
================================================
//go:build 386
// +build 386
// Created by cgo -godefs - DO NOT EDIT
// cgo -godefs types.go
package pty
type (
_C_int int32
_C_uint uint32
)
================================================
FILE: gossh/pty/ztypes_amd64.go
================================================
//go:build amd64
// +build amd64
// Created by cgo -godefs - DO NOT EDIT
// cgo -godefs types.go
package pty
type (
_C_int int32
_C_uint uint32
)
================================================
FILE: gossh/pty/ztypes_arm.go
================================================
//go:build arm
// +build arm
// Created by cgo -godefs - DO NOT EDIT
// cgo -godefs types.go
package pty
type (
_C_int int32
_C_uint uint32
)
================================================
FILE: gossh/pty/ztypes_arm64.go
================================================
//go:build arm64
// +build arm64
// Created by cgo -godefs - DO NOT EDIT
// cgo -godefs types.go
package pty
type (
_C_int int32
_C_uint uint32
)
================================================
FILE: gossh/pty/ztypes_dragonfly_amd64.go
================================================
//go:build amd64 && dragonfly
// +build amd64,dragonfly
// Created by cgo -godefs - DO NOT EDIT
// cgo -godefs types_dragonfly.go
package pty
const (
_C_SPECNAMELEN = 0x3f
)
type fiodgnameArg struct {
Name *byte
Len uint32
Pad_cgo_0 [4]byte
}
================================================
FILE: gossh/pty/ztypes_freebsd_386.go
================================================
//go:build 386 && freebsd
// +build 386,freebsd
// Created by cgo -godefs - DO NOT EDIT
// cgo -godefs types_freebsd.go
package pty
const (
_C_SPECNAMELEN = 0x3f
)
type fiodgnameArg struct {
Len int32
Buf *byte
}
================================================
FILE: gossh/pty/ztypes_freebsd_amd64.go
================================================
//go:build amd64 && freebsd
// +build amd64,freebsd
// Created by cgo -godefs - DO NOT EDIT
// cgo -godefs types_freebsd.go
package pty
const (
_C_SPECNAMELEN = 0x3f
)
type fiodgnameArg struct {
Len int32
Pad_cgo_0 [4]byte
Buf *byte
}
================================================
FILE: gossh/pty/ztypes_freebsd_arm.go
================================================
//go:build arm && freebsd
// +build arm,freebsd
// Created by cgo -godefs - DO NOT EDIT
// cgo -godefs types_freebsd.go
package pty
const (
_C_SPECNAMELEN = 0x3f
)
type fiodgnameArg struct {
Len int32
Buf *byte
}
================================================
FILE: gossh/pty/ztypes_freebsd_arm64.go
================================================
//go:build arm64 && freebsd
// +build arm64,freebsd
// Code generated by cmd/cgo -godefs; DO NOT EDIT.
// cgo -godefs types_freebsd.go
package pty
const (
_C_SPECNAMELEN = 0xff
)
type fiodgnameArg struct {
Len int32
Buf *byte
}
================================================
FILE: gossh/pty/ztypes_freebsd_ppc64.go
================================================
// Created by cgo -godefs - DO NOT EDIT
// cgo -godefs types_freebsd.go
package pty
const (
_C_SPECNAMELEN = 0x3f
)
type fiodgnameArg struct {
Len int32
Pad_cgo_0 [4]byte
Buf *byte
}
================================================
FILE: gossh/pty/ztypes_freebsd_riscv64.go
================================================
// Code generated by cmd/cgo -godefs; DO NOT EDIT.
// cgo -godefs types_freebsd.go
package pty
const (
_C_SPECNAMELEN = 0x3f
)
type fiodgnameArg struct {
Len int32
Buf *byte
}
================================================
FILE: gossh/pty/ztypes_loong64.go
================================================
//go:build loong64
// +build loong64
// Created by cgo -godefs - DO NOT EDIT
// cgo -godefs types.go
package pty
type (
_C_int int32
_C_uint uint32
)
================================================
FILE: gossh/pty/ztypes_mipsx.go
================================================
//go:build (mips || mipsle || mips64 || mips64le) && linux
// +build mips mipsle mips64 mips64le
// +build linux
// Created by cgo -godefs - DO NOT EDIT
// cgo -godefs types.go
package pty
type (
_C_int int32
_C_uint uint32
)
================================================
FILE: gossh/pty/ztypes_netbsd_32bit_int.go
================================================
//go:build (386 || amd64 || arm || arm64) && netbsd
// +build 386 amd64 arm arm64
// +build netbsd
package pty
type ptmget struct {
Cfd int32
Sfd int32
Cn [1024]int8
Sn [1024]int8
}
var (
ioctl_TIOCPTSNAME = 0x48087448
ioctl_TIOCGRANTPT = 0x20007447
)
================================================
FILE: gossh/pty/ztypes_openbsd_32bit_int.go
================================================
//go:build (386 || amd64 || arm || arm64 || mips64) && openbsd
// +build 386 amd64 arm arm64 mips64
// +build openbsd
package pty
type ptmget struct {
Cfd int32
Sfd int32
Cn [16]int8
Sn [16]int8
}
var ioctl_PTMGET = 0x40287401
================================================
FILE: gossh/pty/ztypes_ppc.go
================================================
// Created by cgo -godefs - DO NOT EDIT
// cgo -godefs types.go
package pty
type (
_C_int int32
_C_uint uint32
)
================================================
FILE: gossh/pty/ztypes_ppc64.go
================================================
//go:build ppc64
// +build ppc64
// Created by cgo -godefs - DO NOT EDIT
// cgo -godefs types.go
package pty
type (
_C_int int32
_C_uint uint32
)
================================================
FILE: gossh/pty/ztypes_ppc64le.go
================================================
//go:build ppc64le
// +build ppc64le
// Created by cgo -godefs - DO NOT EDIT
// cgo -godefs types.go
package pty
type (
_C_int int32
_C_uint uint32
)
================================================
FILE: gossh/pty/ztypes_riscvx.go
================================================
//go:build riscv || riscv64
// +build riscv riscv64
// Code generated by cmd/cgo -godefs; DO NOT EDIT.
// cgo -godefs types.go
package pty
type (
_C_int int32
_C_uint uint32
)
================================================
FILE: gossh/pty/ztypes_s390x.go
================================================
//go:build s390x
// +build s390x
// Created by cgo -godefs - DO NOT EDIT
// cgo -godefs types.go
package pty
type (
_C_int int32
_C_uint uint32
)
================================================
FILE: gossh/pty/ztypes_sparcx.go
================================================
//go:build sparc || sparc64
// +build sparc sparc64
// Code generated by cmd/cgo -godefs; DO NOT EDIT.
// cgo -godefs types.go
package pty
type (
_C_int int32
_C_uint uint32
)
================================================
FILE: gossh/readme.md
================================================
###
```bash
gofmt -w -r 'interface{} -> any' .
```
### MySQL init
```bash
yum -y install mysql-server
systemctl restart mysqld.service
cat >> /etc/my.cnf << "EOF"
##=========================================================================
[mysqld]
bind-address=0.0.0.0
skip-name-resolve
##=========================================================================
EOF
mysql -u root --skip-password
```
```mysql
ALTER USER 'root'@'localhost' IDENTIFIED BY 'mypwd';
CREATE USER 'root'@'%' identified by 'mypwd';
GRANT ALL ON *.* TO 'root'@'%';
ALTER USER `root`@`%` PASSWORD EXPIRE NEVER;
GRANT Grant Option ON *.* TO `root`@`%`;
flush privileges;
```
### PgSQL init
```bash
yum -y install postgresql-server
postgresql-setup --initdb
cat >> /var/lib/pgsql/data/pg_hba.conf << "EOF"
##=========================================================================
host all all 0.0.0.0/0 scram-sha-256
##=========================================================================
EOF
sed -i "s/^#*listen_addresses = .*$/listen_addresses = '*'/" /var/lib/pgsql/data/postgresql.conf
systemctl restart postgresql.service
su - postgres
psql
postgres=# \password
```
Sqlite init
```shell
go get github.com/glebarez/sqlite
go get github.com/glebarez/go-sqlite
go get modernc.org/sqlite/lib
```nginx
location ^~ /myapp/webssh/ {
proxy_pass http://127.0.0.1:8899;
proxy_pass_header Server;
proxy_http_version 1.1;
proxy_redirect off;
#proxy_set_header Host $http_host;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Scheme $scheme;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_connect_timeout 60s;
proxy_read_timeout 120s;
proxy_send_timeout 120s;
client_body_timeout 60s;
}
```
================================================
FILE: gossh/sftp/allocator.go
================================================
package sftp
import (
"sync"
)
type allocator struct {
sync.Mutex
available [][]byte
// map key is the request order
used map[uint32][][]byte
}
func newAllocator() *allocator {
return &allocator{
// micro optimization: initialize available pages with an initial capacity
available: make([][]byte, 0, SftpServerWorkerCount*2),
used: make(map[uint32][][]byte),
}
}
// GetPage returns a previously allocated and unused []byte or create a new one.
// The slice have a fixed size = maxMsgLength, this value is suitable for both
// receiving new packets and reading the files to serve
func (a *allocator) GetPage(requestOrderID uint32) []byte {
a.Lock()
defer a.Unlock()
var result []byte
// get an available page and remove it from the available ones.
if len(a.available) > 0 {
truncLength := len(a.available) - 1
result = a.available[truncLength]
a.available[truncLength] = nil // clear out the internal pointer
a.available = a.available[:truncLength] // truncate the slice
}
// no preallocated slice found, just allocate a new one
if result == nil {
result = make([]byte, maxMsgLength)
}
// put result in used pages
a.used[requestOrderID] = append(a.used[requestOrderID], result)
return result
}
// ReleasePages marks unused all pages in use for the given requestID
func (a *allocator) ReleasePages(requestOrderID uint32) {
a.Lock()
defer a.Unlock()
if used := a.used[requestOrderID]; len(used) > 0 {
a.available = append(a.available, used...)
}
delete(a.used, requestOrderID)
}
// Free removes all the used and available pages.
// Call this method when the allocator is not needed anymore
func (a *allocator) Free() {
a.Lock()
defer a.Unlock()
a.available = nil
a.used = make(map[uint32][][]byte)
}
func (a *allocator) countUsedPages() int {
a.Lock()
defer a.Unlock()
num := 0
for _, p := range a.used {
num += len(p)
}
return num
}
func (a *allocator) countAvailablePages() int {
a.Lock()
defer a.Unlock()
return len(a.available)
}
func (a *allocator) isRequestOrderIDUsed(requestOrderID uint32) bool {
a.Lock()
defer a.Unlock()
_, ok := a.used[requestOrderID]
return ok
}
================================================
FILE: gossh/sftp/attrs.go
================================================
package sftp
// ssh_FXP_ATTRS support
// see https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt#section-5
import (
"os"
"time"
)
const (
sshFileXferAttrSize = 0x00000001
sshFileXferAttrUIDGID = 0x00000002
sshFileXferAttrPermissions = 0x00000004
sshFileXferAttrACmodTime = 0x00000008
sshFileXferAttrExtended = 0x80000000
sshFileXferAttrAll = sshFileXferAttrSize | sshFileXferAttrUIDGID | sshFileXferAttrPermissions |
sshFileXferAttrACmodTime | sshFileXferAttrExtended
)
// fileInfo is an artificial type designed to satisfy os.FileInfo.
type fileInfo struct {
name string
stat *FileStat
}
// Name returns the base name of the file.
func (fi *fileInfo) Name() string { return fi.name }
// Size returns the length in bytes for regular files; system-dependent for others.
func (fi *fileInfo) Size() int64 { return int64(fi.stat.Size) }
// Mode returns file mode bits.
func (fi *fileInfo) Mode() os.FileMode { return fi.stat.FileMode() }
// ModTime returns the last modification time of the file.
func (fi *fileInfo) ModTime() time.Time { return fi.stat.ModTime() }
// IsDir returns true if the file is a directory.
func (fi *fileInfo) IsDir() bool { return fi.Mode().IsDir() }
func (fi *fileInfo) Sys() any { return fi.stat }
// FileStat holds the original unmarshalled values from a call to READDIR or
// *STAT. It is exported for the purposes of accessing the raw values via
// os.FileInfo.Sys(). It is also used server side to store the unmarshalled
// values for SetStat.
type FileStat struct {
Size uint64
Mode uint32
Mtime uint32
Atime uint32
UID uint32
GID uint32
Extended []StatExtended
}
// ModTime returns the Mtime SFTP file attribute converted to a time.Time
func (fs *FileStat) ModTime() time.Time {
return time.Unix(int64(fs.Mtime), 0)
}
// AccessTime returns the Atime SFTP file attribute converted to a time.Time
func (fs *FileStat) AccessTime() time.Time {
return time.Unix(int64(fs.Atime), 0)
}
// FileMode returns the Mode SFTP file attribute converted to an os.FileMode
func (fs *FileStat) FileMode() os.FileMode {
return toFileMode(fs.Mode)
}
// StatExtended contains additional, extended information for a FileStat.
type StatExtended struct {
ExtType string
ExtData string
}
func fileInfoFromStat(stat *FileStat, name string) os.FileInfo {
return &fileInfo{
name: name,
stat: stat,
}
}
// FileInfoUidGid extends os.FileInfo and adds callbacks for Uid and Gid retrieval,
// as an alternative to *syscall.Stat_t objects on unix systems.
type FileInfoUidGid interface {
os.FileInfo
Uid() uint32
Gid() uint32
}
// FileInfoUidGid extends os.FileInfo and adds a callbacks for extended data retrieval.
type FileInfoExtendedData interface {
os.FileInfo
Extended() []StatExtended
}
func fileStatFromInfo(fi os.FileInfo) (uint32, *FileStat) {
mtime := fi.ModTime().Unix()
atime := mtime
var flags uint32 = sshFileXferAttrSize |
sshFileXferAttrPermissions |
sshFileXferAttrACmodTime
fileStat := &FileStat{
Size: uint64(fi.Size()),
Mode: fromFileMode(fi.Mode()),
Mtime: uint32(mtime),
Atime: uint32(atime),
}
// os specific file stat decoding
fileStatFromInfoOs(fi, &flags, fileStat)
// The call above will include the sshFileXferAttrUIDGID in case
// the os.FileInfo can be casted to *syscall.Stat_t on unix.
// If fi implements FileInfoUidGid, retrieve Uid, Gid from it instead.
if fiExt, ok := fi.(FileInfoUidGid); ok {
flags |= sshFileXferAttrUIDGID
fileStat.UID = fiExt.Uid()
fileStat.GID = fiExt.Gid()
}
// if fi implements FileInfoExtendedData, retrieve extended data from it
if fiExt, ok := fi.(FileInfoExtendedData); ok {
fileStat.Extended = fiExt.Extended()
if len(fileStat.Extended) > 0 {
flags |= sshFileXferAttrExtended
}
}
return flags, fileStat
}
================================================
FILE: gossh/sftp/attrs_stubs.go
================================================
//go:build plan9 || windows || android
// +build plan9 windows android
package sftp
import (
"os"
)
func fileStatFromInfoOs(fi os.FileInfo, flags *uint32, fileStat *FileStat) {
// todo
}
================================================
FILE: gossh/sftp/attrs_unix.go
================================================
//go:build darwin || dragonfly || freebsd || (!android && linux) || netbsd || openbsd || solaris || aix || js || zos
// +build darwin dragonfly freebsd !android,linux netbsd openbsd solaris aix js zos
package sftp
import (
"os"
"syscall"
)
func fileStatFromInfoOs(fi os.FileInfo, flags *uint32, fileStat *FileStat) {
if statt, ok := fi.Sys().(*syscall.Stat_t); ok {
*flags |= sshFileXferAttrUIDGID
fileStat.UID = statt.Uid
fileStat.GID = statt.Gid
}
}
================================================
FILE: gossh/sftp/client.go
================================================
package sftp
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"math"
"os"
"path"
"sync"
"sync/atomic"
"syscall"
"time"
"gossh/crypto/ssh"
"gossh/sftp/internal/sftpfs"
)
var (
// ErrInternalInconsistency indicates the packets sent and the data queued to be
// written to the file don't match up. It is an unusual error and usually is
// caused by bad behavior server side or connection issues. The error is
// limited in scope to the call where it happened, the client object is still
// OK to use as long as the connection is still open.
ErrInternalInconsistency = errors.New("internal inconsistency")
// InternalInconsistency alias for ErrInternalInconsistency.
//
// Deprecated: please use ErrInternalInconsistency
InternalInconsistency = ErrInternalInconsistency
)
// A ClientOption is a function which applies configuration to a Client.
type ClientOption func(*Client) error
// MaxPacketChecked sets the maximum size of the payload, measured in bytes.
// This option only accepts sizes servers should support, ie. <= 32768 bytes.
//
// If you get the error "failed to send packet header: EOF" when copying a
// large file, try lowering this number.
//
// The default packet size is 32768 bytes.
func MaxPacketChecked(size int) ClientOption {
return func(c *Client) error {
if size < 1 {
return errors.New("size must be greater or equal to 1")
}
if size > 32768 {
return errors.New("sizes larger than 32KB might not work with all servers")
}
c.maxPacket = size
return nil
}
}
// MaxPacketUnchecked sets the maximum size of the payload, measured in bytes.
// It accepts sizes larger than the 32768 bytes all servers should support.
// Only use a setting higher than 32768 if your application always connects to
// the same server or after sufficiently broad testing.
//
// If you get the error "failed to send packet header: EOF" when copying a
// large file, try lowering this number.
//
// The default packet size is 32768 bytes.
func MaxPacketUnchecked(size int) ClientOption {
return func(c *Client) error {
if size < 1 {
return errors.New("size must be greater or equal to 1")
}
c.maxPacket = size
return nil
}
}
// MaxPacket sets the maximum size of the payload, measured in bytes.
// This option only accepts sizes servers should support, ie. <= 32768 bytes.
// This is a synonym for MaxPacketChecked that provides backward compatibility.
//
// If you get the error "failed to send packet header: EOF" when copying a
// large file, try lowering this number.
//
// The default packet size is 32768 bytes.
func MaxPacket(size int) ClientOption {
return MaxPacketChecked(size)
}
// MaxConcurrentRequestsPerFile sets the maximum concurrent requests allowed for a single file.
//
// The default maximum concurrent requests is 64.
func MaxConcurrentRequestsPerFile(n int) ClientOption {
return func(c *Client) error {
if n < 1 {
return errors.New("n must be greater or equal to 1")
}
c.maxConcurrentRequests = n
return nil
}
}
// UseConcurrentWrites allows the Client to perform concurrent Writes.
//
// Using concurrency while doing writes, requires special consideration.
// A write to a later offset in a file after an error,
// could end up with a file length longer than what was successfully written.
//
// When using this option, if you receive an error during `io.Copy` or `io.WriteTo`,
// you may need to `Truncate` the target Writer to avoid “holes” in the data written.
func UseConcurrentWrites(value bool) ClientOption {
return func(c *Client) error {
c.useConcurrentWrites = value
return nil
}
}
// UseConcurrentReads allows the Client to perform concurrent Reads.
//
// Concurrent reads are generally safe to use and not using them will degrade
// performance, so this option is enabled by default.
//
// When enabled, WriteTo will use Stat/Fstat to get the file size and determines
// how many concurrent workers to use.
// Some "read once" servers will delete the file if they receive a stat call on an
// open file and then the download will fail.
// Disabling concurrent reads you will be able to download files from these servers.
// If concurrent reads are disabled, the UseFstat option is ignored.
func UseConcurrentReads(value bool) ClientOption {
return func(c *Client) error {
c.disableConcurrentReads = !value
return nil
}
}
// UseFstat sets whether to use Fstat or Stat when File.WriteTo is called
// (usually when copying files).
// Some servers limit the amount of open files and calling Stat after opening
// the file will throw an error From the server. Setting this flag will call
// Fstat instead of Stat which is suppose to be called on an open file handle.
//
// It has been found that that with IBM Sterling SFTP servers which have
// "extractability" level set to 1 which means only 1 file can be opened at
// any given time.
//
// If the server you are working with still has an issue with both Stat and
// Fstat calls you can always open a file and read it until the end.
//
// Another reason to read the file until its end and Fstat doesn't work is
// that in some servers, reading a full file will automatically delete the
// file as some of these mainframes map the file to a message in a queue.
// Once the file has been read it will get deleted.
func UseFstat(value bool) ClientOption {
return func(c *Client) error {
c.useFstat = value
return nil
}
}
// Client represents an SFTP session on a *ssh.ClientConn SSH connection.
// Multiple Clients can be active on a single SSH connection, and a Client
// may be called concurrently from multiple Goroutines.
//
// Client implements the github.com/kr/sftpfs.FileSystem interface.
type Client struct {
clientConn
ext map[string]string // Extensions (name -> data).
maxPacket int // max packet size read or written.
maxConcurrentRequests int
nextid uint32
// write concurrency is… error prone.
// Default behavior should be to not use it.
useConcurrentWrites bool
useFstat bool
disableConcurrentReads bool
}
// NewClient creates a new SFTP client on conn, using zero or more option
// functions.
func NewClient(conn *ssh.Client, opts ...ClientOption) (*Client, error) {
s, err := conn.NewSession()
if err != nil {
return nil, err
}
if err := s.RequestSubsystem("sftp"); err != nil {
return nil, err
}
pw, err := s.StdinPipe()
if err != nil {
return nil, err
}
pr, err := s.StdoutPipe()
if err != nil {
return nil, err
}
return NewClientPipe(pr, pw, opts...)
}
// NewClientPipe creates a new SFTP client given a Reader and a WriteCloser.
// This can be used for connecting to an SFTP server over TCP/TLS or by using
// the system's ssh client program (e.g. via exec.Command).
func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Client, error) {
sftp := &Client{
clientConn: clientConn{
conn: conn{
Reader: rd,
WriteCloser: wr,
},
inflight: make(map[uint32]chan<- result),
closed: make(chan struct{}),
},
ext: make(map[string]string),
maxPacket: 1 << 15,
maxConcurrentRequests: 64,
}
for _, opt := range opts {
if err := opt(sftp); err != nil {
wr.Close()
return nil, err
}
}
if err := sftp.sendInit(); err != nil {
wr.Close()
return nil, fmt.Errorf("error sending init packet to server: %w", err)
}
if err := sftp.recvVersion(); err != nil {
wr.Close()
return nil, fmt.Errorf("error receiving version packet from server: %w", err)
}
sftp.clientConn.wg.Add(1)
go func() {
defer sftp.clientConn.wg.Done()
if err := sftp.clientConn.recv(); err != nil {
sftp.clientConn.broadcastErr(err)
}
}()
return sftp, nil
}
// Create creates the named file mode 0666 (before umask), truncating it if it
// already exists. If successful, methods on the returned File can be used for
// I/O; the associated file descriptor has mode O_RDWR. If you need more
// control over the flags/mode used to open the file see client.OpenFile.
//
// Note that some SFTP servers (eg. AWS Transfer) do not support opening files
// read/write at the same time. For those services you will need to use
// `client.OpenFile(os.O_WRONLY|os.O_CREATE|os.O_TRUNC)`.
func (c *Client) Create(path string) (*File, error) {
return c.open(path, toPflags(os.O_RDWR|os.O_CREATE|os.O_TRUNC))
}
const sftpProtocolVersion = 3 // https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt
func (c *Client) sendInit() error {
return c.clientConn.conn.sendPacket(&sshFxInitPacket{
Version: sftpProtocolVersion, // https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt
})
}
// returns the next value of c.nextid
func (c *Client) nextID() uint32 {
return atomic.AddUint32(&c.nextid, 1)
}
func (c *Client) recvVersion() error {
typ, data, err := c.recvPacket(0)
if err != nil {
if err == io.EOF {
return fmt.Errorf("server unexpectedly closed connection: %w", io.ErrUnexpectedEOF)
}
return err
}
if typ != sshFxpVersion {
return &unexpectedPacketErr{sshFxpVersion, typ}
}
version, data, err := unmarshalUint32Safe(data)
if err != nil {
return err
}
if version != sftpProtocolVersion {
return &unexpectedVersionErr{sftpProtocolVersion, version}
}
for len(data) > 0 {
var ext extensionPair
ext, data, err = unmarshalExtensionPair(data)
if err != nil {
return err
}
c.ext[ext.Name] = ext.Data
}
return nil
}
// HasExtension checks whether the server supports a named extension.
//
// The first return value is the extension data reported by the server
// (typically a version number).
func (c *Client) HasExtension(name string) (string, bool) {
data, ok := c.ext[name]
return data, ok
}
// Walk returns a new Walker rooted at root.
func (c *Client) Walk(root string) *sftpfs.Walker {
return sftpfs.WalkFS(root, c)
}
// ReadDir reads the directory named by p
// and returns a list of directory entries.
func (c *Client) ReadDir(p string) ([]os.FileInfo, error) {
return c.ReadDirContext(context.Background(), p)
}
// ReadDirContext reads the directory named by p
// and returns a list of directory entries.
// The passed context can be used to cancel the operation
// returning all entries listed up to the cancellation.
func (c *Client) ReadDirContext(ctx context.Context, p string) ([]os.FileInfo, error) {
handle, err := c.opendir(ctx, p)
if err != nil {
return nil, err
}
defer c.close(handle) // this has to defer earlier than the lock below
var entries []os.FileInfo
var done = false
for !done {
id := c.nextID()
typ, data, err1 := c.sendPacket(ctx, nil, &sshFxpReaddirPacket{
ID: id,
Handle: handle,
})
if err1 != nil {
err = err1
done = true
break
}
switch typ {
case sshFxpName:
sid, data := unmarshalUint32(data)
if sid != id {
return nil, &unexpectedIDErr{id, sid}
}
count, data := unmarshalUint32(data)
for i := uint32(0); i < count; i++ {
var filename string
filename, data = unmarshalString(data)
_, data = unmarshalString(data) // discard longname
var attr *FileStat
attr, data, err = unmarshalAttrs(data)
if err != nil {
return nil, err
}
if filename == "." || filename == ".." {
continue
}
entries = append(entries, fileInfoFromStat(attr, path.Base(filename)))
}
case sshFxpStatus:
// TODO(dfc) scope warning!
err = normaliseError(unmarshalStatus(id, data))
done = true
default:
return nil, unimplementedPacketErr(typ)
}
}
if err == io.EOF {
err = nil
}
return entries, err
}
func (c *Client) opendir(ctx context.Context, path string) (string, error) {
id := c.nextID()
typ, data, err := c.sendPacket(ctx, nil, &sshFxpOpendirPacket{
ID: id,
Path: path,
})
if err != nil {
return "", err
}
switch typ {
case sshFxpHandle:
sid, data := unmarshalUint32(data)
if sid != id {
return "", &unexpectedIDErr{id, sid}
}
handle, _ := unmarshalString(data)
return handle, nil
case sshFxpStatus:
return "", normaliseError(unmarshalStatus(id, data))
default:
return "", unimplementedPacketErr(typ)
}
}
// Stat returns a FileInfo structure describing the file specified by path 'p'.
// If 'p' is a symbolic link, the returned FileInfo structure describes the referent file.
func (c *Client) Stat(p string) (os.FileInfo, error) {
fs, err := c.stat(p)
if err != nil {
return nil, err
}
return fileInfoFromStat(fs, path.Base(p)), nil
}
// Lstat returns a FileInfo structure describing the file specified by path 'p'.
// If 'p' is a symbolic link, the returned FileInfo structure describes the symbolic link.
func (c *Client) Lstat(p string) (os.FileInfo, error) {
id := c.nextID()
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpLstatPacket{
ID: id,
Path: p,
})
if err != nil {
return nil, err
}
switch typ {
case sshFxpAttrs:
sid, data := unmarshalUint32(data)
if sid != id {
return nil, &unexpectedIDErr{id, sid}
}
attr, _, err := unmarshalAttrs(data)
if err != nil {
// avoid returning a valid value from fileInfoFromStats if err != nil.
return nil, err
}
return fileInfoFromStat(attr, path.Base(p)), nil
case sshFxpStatus:
return nil, normaliseError(unmarshalStatus(id, data))
default:
return nil, unimplementedPacketErr(typ)
}
}
// ReadLink reads the target of a symbolic link.
func (c *Client) ReadLink(p string) (string, error) {
id := c.nextID()
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpReadlinkPacket{
ID: id,
Path: p,
})
if err != nil {
return "", err
}
switch typ {
case sshFxpName:
sid, data := unmarshalUint32(data)
if sid != id {
return "", &unexpectedIDErr{id, sid}
}
count, data := unmarshalUint32(data)
if count != 1 {
return "", unexpectedCount(1, count)
}
filename, _ := unmarshalString(data) // ignore dummy attributes
return filename, nil
case sshFxpStatus:
return "", normaliseError(unmarshalStatus(id, data))
default:
return "", unimplementedPacketErr(typ)
}
}
// Link creates a hard link at 'newname', pointing at the same inode as 'oldname'
func (c *Client) Link(oldname, newname string) error {
id := c.nextID()
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpHardlinkPacket{
ID: id,
Oldpath: oldname,
Newpath: newname,
})
if err != nil {
return err
}
switch typ {
case sshFxpStatus:
return normaliseError(unmarshalStatus(id, data))
default:
return unimplementedPacketErr(typ)
}
}
// Symlink creates a symbolic link at 'newname', pointing at target 'oldname'
func (c *Client) Symlink(oldname, newname string) error {
id := c.nextID()
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpSymlinkPacket{
ID: id,
Linkpath: newname,
Targetpath: oldname,
})
if err != nil {
return err
}
switch typ {
case sshFxpStatus:
return normaliseError(unmarshalStatus(id, data))
default:
return unimplementedPacketErr(typ)
}
}
func (c *Client) fsetstat(handle string, flags uint32, attrs any) error {
id := c.nextID()
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpFsetstatPacket{
ID: id,
Handle: handle,
Flags: flags,
Attrs: attrs,
})
if err != nil {
return err
}
switch typ {
case sshFxpStatus:
return normaliseError(unmarshalStatus(id, data))
default:
return unimplementedPacketErr(typ)
}
}
// setstat is a convience wrapper to allow for changing of various parts of the file descriptor.
func (c *Client) setstat(path string, flags uint32, attrs any) error {
id := c.nextID()
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpSetstatPacket{
ID: id,
Path: path,
Flags: flags,
Attrs: attrs,
})
if err != nil {
return err
}
switch typ {
case sshFxpStatus:
return normaliseError(unmarshalStatus(id, data))
default:
return unimplementedPacketErr(typ)
}
}
// Chtimes changes the access and modification times of the named file.
func (c *Client) Chtimes(path string, atime time.Time, mtime time.Time) error {
type times struct {
Atime uint32
Mtime uint32
}
attrs := times{uint32(atime.Unix()), uint32(mtime.Unix())}
return c.setstat(path, sshFileXferAttrACmodTime, attrs)
}
// Chown changes the user and group owners of the named file.
func (c *Client) Chown(path string, uid, gid int) error {
type owner struct {
UID uint32
GID uint32
}
attrs := owner{uint32(uid), uint32(gid)}
return c.setstat(path, sshFileXferAttrUIDGID, attrs)
}
// Chmod changes the permissions of the named file.
//
// Chmod does not apply a umask, because even retrieving the umask is not
// possible in a portable way without causing a race condition. Callers
// should mask off umask bits, if desired.
func (c *Client) Chmod(path string, mode os.FileMode) error {
return c.setstat(path, sshFileXferAttrPermissions, toChmodPerm(mode))
}
// Truncate sets the size of the named file. Although it may be safely assumed
// that if the size is less than its current size it will be truncated to fit,
// the SFTP protocol does not specify what behavior the server should do when setting
// size greater than the current size.
func (c *Client) Truncate(path string, size int64) error {
return c.setstat(path, sshFileXferAttrSize, uint64(size))
}
// Open opens the named file for reading. If successful, methods on the
// returned file can be used for reading; the associated file descriptor
// has mode O_RDONLY.
func (c *Client) Open(path string) (*File, error) {
return c.open(path, toPflags(os.O_RDONLY))
}
// OpenFile is the generalized open call; most users will use Open or
// Create instead. It opens the named file with specified flag (O_RDONLY
// etc.). If successful, methods on the returned File can be used for I/O.
func (c *Client) OpenFile(path string, f int) (*File, error) {
return c.open(path, toPflags(f))
}
func (c *Client) open(path string, pflags uint32) (*File, error) {
id := c.nextID()
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpOpenPacket{
ID: id,
Path: path,
Pflags: pflags,
})
if err != nil {
return nil, err
}
switch typ {
case sshFxpHandle:
sid, data := unmarshalUint32(data)
if sid != id {
return nil, &unexpectedIDErr{id, sid}
}
handle, _ := unmarshalString(data)
return &File{c: c, path: path, handle: handle}, nil
case sshFxpStatus:
return nil, normaliseError(unmarshalStatus(id, data))
default:
return nil, unimplementedPacketErr(typ)
}
}
// close closes a handle handle previously returned in the response
// to SSH_FXP_OPEN or SSH_FXP_OPENDIR. The handle becomes invalid
// immediately after this request has been sent.
func (c *Client) close(handle string) error {
id := c.nextID()
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpClosePacket{
ID: id,
Handle: handle,
})
if err != nil {
return err
}
switch typ {
case sshFxpStatus:
return normaliseError(unmarshalStatus(id, data))
default:
return unimplementedPacketErr(typ)
}
}
func (c *Client) stat(path string) (*FileStat, error) {
id := c.nextID()
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpStatPacket{
ID: id,
Path: path,
})
if err != nil {
return nil, err
}
switch typ {
case sshFxpAttrs:
sid, data := unmarshalUint32(data)
if sid != id {
return nil, &unexpectedIDErr{id, sid}
}
attr, _, err := unmarshalAttrs(data)
return attr, err
case sshFxpStatus:
return nil, normaliseError(unmarshalStatus(id, data))
default:
return nil, unimplementedPacketErr(typ)
}
}
func (c *Client) fstat(handle string) (*FileStat, error) {
id := c.nextID()
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpFstatPacket{
ID: id,
Handle: handle,
})
if err != nil {
return nil, err
}
switch typ {
case sshFxpAttrs:
sid, data := unmarshalUint32(data)
if sid != id {
return nil, &unexpectedIDErr{id, sid}
}
attr, _, err := unmarshalAttrs(data)
return attr, err
case sshFxpStatus:
return nil, normaliseError(unmarshalStatus(id, data))
default:
return nil, unimplementedPacketErr(typ)
}
}
// StatVFS retrieves VFS statistics from a remote host.
//
// It implements the statvfs@openssh.com SSH_FXP_EXTENDED feature
// from http://www.opensource.apple.com/source/OpenSSH/OpenSSH-175/openssh/PROTOCOL?txt.
func (c *Client) StatVFS(path string) (*StatVFS, error) {
// send the StatVFS packet to the server
id := c.nextID()
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpStatvfsPacket{
ID: id,
Path: path,
})
if err != nil {
return nil, err
}
switch typ {
// server responded with valid data
case sshFxpExtendedReply:
var response StatVFS
err = binary.Read(bytes.NewReader(data), binary.BigEndian, &response)
if err != nil {
return nil, errors.New("can not parse reply")
}
return &response, nil
// the resquest failed
case sshFxpStatus:
return nil, normaliseError(unmarshalStatus(id, data))
default:
return nil, unimplementedPacketErr(typ)
}
}
// Join joins any number of path elements into a single path, adding a
// separating slash if necessary. The result is Cleaned; in particular, all
// empty strings are ignored.
func (c *Client) Join(elem ...string) string { return path.Join(elem...) }
// Remove removes the specified file or directory. An error will be returned if no
// file or directory with the specified path exists, or if the specified directory
// is not empty.
func (c *Client) Remove(path string) error {
err := c.removeFile(path)
// some servers, *cough* osx *cough*, return EPERM, not ENODIR.
// serv-u returns ssh_FX_FILE_IS_A_DIRECTORY
// EPERM is converted to os.ErrPermission so it is not a StatusError
if err, ok := err.(*StatusError); ok {
switch err.Code {
case sshFxFailure, sshFxFileIsADirectory:
return c.RemoveDirectory(path)
}
}
if os.IsPermission(err) {
return c.RemoveDirectory(path)
}
return err
}
func (c *Client) removeFile(path string) error {
id := c.nextID()
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpRemovePacket{
ID: id,
Filename: path,
})
if err != nil {
return err
}
switch typ {
case sshFxpStatus:
return normaliseError(unmarshalStatus(id, data))
default:
return unimplementedPacketErr(typ)
}
}
// RemoveDirectory removes a directory path.
func (c *Client) RemoveDirectory(path string) error {
id := c.nextID()
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpRmdirPacket{
ID: id,
Path: path,
})
if err != nil {
return err
}
switch typ {
case sshFxpStatus:
return normaliseError(unmarshalStatus(id, data))
default:
return unimplementedPacketErr(typ)
}
}
// Rename renames a file.
func (c *Client) Rename(oldname, newname string) error {
id := c.nextID()
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpRenamePacket{
ID: id,
Oldpath: oldname,
Newpath: newname,
})
if err != nil {
return err
}
switch typ {
case sshFxpStatus:
return normaliseError(unmarshalStatus(id, data))
default:
return unimplementedPacketErr(typ)
}
}
// PosixRename renames a file using the posix-rename@openssh.com extension
// which will replace newname if it already exists.
func (c *Client) PosixRename(oldname, newname string) error {
id := c.nextID()
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpPosixRenamePacket{
ID: id,
Oldpath: oldname,
Newpath: newname,
})
if err != nil {
return err
}
switch typ {
case sshFxpStatus:
return normaliseError(unmarshalStatus(id, data))
default:
return unimplementedPacketErr(typ)
}
}
// RealPath can be used to have the server canonicalize any given path name to an absolute path.
//
// This is useful for converting path names containing ".." components,
// or relative pathnames without a leading slash into absolute paths.
func (c *Client) RealPath(path string) (string, error) {
id := c.nextID()
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpRealpathPacket{
ID: id,
Path: path,
})
if err != nil {
return "", err
}
switch typ {
case sshFxpName:
sid, data := unmarshalUint32(data)
if sid != id {
return "", &unexpectedIDErr{id, sid}
}
count, data := unmarshalUint32(data)
if count != 1 {
return "", unexpectedCount(1, count)
}
filename, _ := unmarshalString(data) // ignore attributes
return filename, nil
case sshFxpStatus:
return "", normaliseError(unmarshalStatus(id, data))
default:
return "", unimplementedPacketErr(typ)
}
}
// Getwd returns the current working directory of the server. Operations
// involving relative paths will be based at this location.
func (c *Client) Getwd() (string, error) {
return c.RealPath(".")
}
// Mkdir creates the specified directory. An error will be returned if a file or
// directory with the specified path already exists, or if the directory's
// parent folder does not exist (the method cannot create complete paths).
func (c *Client) Mkdir(path string) error {
id := c.nextID()
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpMkdirPacket{
ID: id,
Path: path,
})
if err != nil {
return err
}
switch typ {
case sshFxpStatus:
return normaliseError(unmarshalStatus(id, data))
default:
return unimplementedPacketErr(typ)
}
}
// MkdirAll creates a directory named path, along with any necessary parents,
// and returns nil, or else returns an error.
// If path is already a directory, MkdirAll does nothing and returns nil.
// If path contains a regular file, an error is returned
func (c *Client) MkdirAll(path string) error {
// Most of this code mimics https://golang.org/src/os/path.go?s=514:561#L13
// Fast path: if we can tell whether path is a directory or file, stop with success or error.
dir, err := c.Stat(path)
if err == nil {
if dir.IsDir() {
return nil
}
return &os.PathError{Op: "mkdir", Path: path, Err: syscall.ENOTDIR}
}
// Slow path: make sure parent exists and then call Mkdir for path.
i := len(path)
for i > 0 && path[i-1] == '/' { // Skip trailing path separator.
i--
}
j := i
for j > 0 && path[j-1] != '/' { // Scan backward over element.
j--
}
if j > 1 {
// Create parent
err = c.MkdirAll(path[0 : j-1])
if err != nil {
return err
}
}
// Parent now exists; invoke Mkdir and use its result.
err = c.Mkdir(path)
if err != nil {
// Handle arguments like "foo/." by
// double-checking that directory doesn't exist.
dir, err1 := c.Lstat(path)
if err1 == nil && dir.IsDir() {
return nil
}
return err
}
return nil
}
// RemoveAll delete files recursively in the directory and Recursively delete subdirectories.
// An error will be returned if no file or directory with the specified path exists
func (c *Client) RemoveAll(path string) error {
// Get the file/directory information
fi, err := c.Stat(path)
if err != nil {
return err
}
if fi.IsDir() {
// Delete files recursively in the directory
files, err := c.ReadDir(path)
if err != nil {
return err
}
for _, file := range files {
if file.IsDir() {
// Recursively delete subdirectories
err = c.RemoveAll(path + "/" + file.Name())
if err != nil {
return err
}
} else {
// Delete individual files
err = c.Remove(path + "/" + file.Name())
if err != nil {
return err
}
}
}
}
return c.Remove(path)
}
// File represents a remote file.
type File struct {
c *Client
path string
mu sync.RWMutex
handle string
offset int64 // current offset within remote file
}
// Close closes the File, rendering it unusable for I/O. It returns an
// error, if any.
func (f *File) Close() error {
f.mu.Lock()
defer f.mu.Unlock()
if f.handle == "" {
return os.ErrClosed
}
// The design principle here is that when `openssh-portable/sftp-server.c` is doing `handle_close`,
// it will unconditionally mark the handle as unused,
// so we need to also unconditionally mark this handle as invalid.
// By invalidating our local copy of the handle,
// we ensure that there cannot be any erroneous use-after-close requests sent after Close.
handle := f.handle
f.handle = ""
return f.c.close(handle)
}
// Name returns the name of the file as presented to Open or Create.
func (f *File) Name() string {
return f.path
}
// Read reads up to len(b) bytes from the File. It returns the number of bytes
// read and an error, if any. Read follows io.Reader semantics, so when Read
// encounters an error or EOF condition after successfully reading n > 0 bytes,
// it returns the number of bytes read.
//
// To maximise throughput for transferring the entire file (especially
// over high latency links) it is recommended to use WriteTo rather
// than calling Read multiple times. io.Copy will do this
// automatically.
func (f *File) Read(b []byte) (int, error) {
f.mu.Lock()
defer f.mu.Unlock()
n, err := f.readAt(b, f.offset)
f.offset += int64(n)
return n, err
}
// readChunkAt attempts to read the whole entire length of the buffer from the file starting at the offset.
// It will continue progressively reading into the buffer until it fills the whole buffer, or an error occurs.
func (f *File) readChunkAt(ch chan result, b []byte, off int64) (n int, err error) {
for err == nil && n < len(b) {
id := f.c.nextID()
typ, data, err := f.c.sendPacket(context.Background(), ch, &sshFxpReadPacket{
ID: id,
Handle: f.handle,
Offset: uint64(off) + uint64(n),
Len: uint32(len(b) - n),
})
if err != nil {
return n, err
}
switch typ {
case sshFxpStatus:
return n, normaliseError(unmarshalStatus(id, data))
case sshFxpData:
sid, data := unmarshalUint32(data)
if id != sid {
return n, &unexpectedIDErr{id, sid}
}
l, data := unmarshalUint32(data)
n += copy(b[n:], data[:l])
default:
return n, unimplementedPacketErr(typ)
}
}
return
}
func (f *File) readAtSequential(b []byte, off int64) (read int, err error) {
for read < len(b) {
rb := b[read:]
if len(rb) > f.c.maxPacket {
rb = rb[:f.c.maxPacket]
}
n, err := f.readChunkAt(nil, rb, off+int64(read))
if n < 0 {
panic("sftp.File: returned negative count from readChunkAt")
}
if n > 0 {
read += n
}
if err != nil {
return read, err
}
}
return read, nil
}
// ReadAt reads up to len(b) byte from the File at a given offset `off`. It returns
// the number of bytes read and an error, if any. ReadAt follows io.ReaderAt semantics,
// so the file offset is not altered during the read.
func (f *File) ReadAt(b []byte, off int64) (int, error) {
f.mu.RLock()
defer f.mu.RUnlock()
return f.readAt(b, off)
}
// readAt must be called while holding either the Read or Write mutex in File.
// This code is concurrent safe with itself, but not with Close.
func (f *File) readAt(b []byte, off int64) (int, error) {
if f.handle == "" {
return 0, os.ErrClosed
}
if len(b) <= f.c.maxPacket {
// This should be able to be serviced with 1/2 requests.
// So, just do it directly.
return f.readChunkAt(nil, b, off)
}
if f.c.disableConcurrentReads {
return f.readAtSequential(b, off)
}
// Split the read into multiple maxPacket-sized concurrent reads bounded by maxConcurrentRequests.
// This allows writes with a suitably large buffer to transfer data at a much faster rate
// by overlapping round trip times.
cancel := make(chan struct{})
concurrency := len(b)/f.c.maxPacket + 1
if concurrency > f.c.maxConcurrentRequests || concurrency < 1 {
concurrency = f.c.maxConcurrentRequests
}
resPool := newResChanPool(concurrency)
type work struct {
id uint32
res chan result
b []byte
off int64
}
workCh := make(chan work)
// Slice: cut up the Read into any number of buffers of length <= f.c.maxPacket, and at appropriate offsets.
go func() {
defer close(workCh)
b := b
offset := off
chunkSize := f.c.maxPacket
for len(b) > 0 {
rb := b
if len(rb) > chunkSize {
rb = rb[:chunkSize]
}
id := f.c.nextID()
res := resPool.Get()
f.c.dispatchRequest(res, &sshFxpReadPacket{
ID: id,
Handle: f.handle,
Offset: uint64(offset),
Len: uint32(chunkSize),
})
select {
case workCh <- work{id, res, rb, offset}:
case <-cancel:
return
}
offset += int64(len(rb))
b = b[len(rb):]
}
}()
type rErr struct {
off int64
err error
}
errCh := make(chan rErr)
var wg sync.WaitGroup
wg.Add(concurrency)
for i := 0; i < concurrency; i++ {
// Map_i: each worker gets work, and then performs the Read into its buffer from its respective offset.
go func() {
defer wg.Done()
for packet := range workCh {
var n int
s := <-packet.res
resPool.Put(packet.res)
err := s.err
if err == nil {
switch s.typ {
case sshFxpStatus:
err = normaliseError(unmarshalStatus(packet.id, s.data))
case sshFxpData:
sid, data := unmarshalUint32(s.data)
if packet.id != sid {
err = &unexpectedIDErr{packet.id, sid}
} else {
l, data := unmarshalUint32(data)
n = copy(packet.b, data[:l])
// For normal disk files, it is guaranteed that this will read
// the specified number of bytes, or up to end of file.
// This implies, if we have a short read, that means EOF.
if n < len(packet.b) {
err = io.EOF
}
}
default:
err = unimplementedPacketErr(s.typ)
}
}
if err != nil {
// return the offset as the start + how much we read before the error.
errCh <- rErr{packet.off + int64(n), err}
// DO NOT return.
// We want to ensure that workCh is drained before wg.Wait returns.
}
}
}()
}
// Wait for long tail, before closing results.
go func() {
wg.Wait()
close(errCh)
}()
// Reduce: collect all the results into a relevant return: the earliest offset to return an error.
firstErr := rErr{math.MaxInt64, nil}
for rErr := range errCh {
if rErr.off <= firstErr.off {
firstErr = rErr
}
select {
case <-cancel:
default:
// stop any more work from being distributed. (Just in case.)
close(cancel)
}
}
if firstErr.err != nil {
// firstErr.err != nil if and only if firstErr.off > our starting offset.
return int(firstErr.off - off), firstErr.err
}
// As per spec for io.ReaderAt, we return nil error if and only if we read everything.
return len(b), nil
}
// writeToSequential implements WriteTo, but works sequentially with no parallelism.
func (f *File) writeToSequential(w io.Writer) (written int64, err error) {
b := make([]byte, f.c.maxPacket)
ch := make(chan result, 1) // reusable channel
for {
n, err := f.readChunkAt(ch, b, f.offset)
if n < 0 {
panic("sftp.File: returned negative count from readChunkAt")
}
if n > 0 {
f.offset += int64(n)
m, err := w.Write(b[:n])
written += int64(m)
if err != nil {
return written, err
}
}
if err != nil {
if err == io.EOF {
return written, nil // return nil explicitly.
}
return written, err
}
}
}
// WriteTo writes the file to the given Writer.
// The return value is the number of bytes written.
// Any error encountered during the write is also returned.
//
// This method is preferred over calling Read multiple times
// to maximise throughput for transferring the entire file,
// especially over high latency links.
func (f *File) WriteTo(w io.Writer) (written int64, err error) {
f.mu.Lock()
defer f.mu.Unlock()
if f.handle == "" {
return 0, os.ErrClosed
}
if f.c.disableConcurrentReads {
return f.writeToSequential(w)
}
// For concurrency, we want to guess how many concurrent workers we should use.
var fileStat *FileStat
if f.c.useFstat {
fileStat, err = f.c.fstat(f.handle)
} else {
fileStat, err = f.c.stat(f.path)
}
if err != nil {
return 0, err
}
fileSize := fileStat.Size
if fileSize <= uint64(f.c.maxPacket) || !isRegular(fileStat.Mode) {
// only regular files are guaranteed to return (full read) xor (partial read, next error)
return f.writeToSequential(w)
}
concurrency64 := fileSize/uint64(f.c.maxPacket) + 1 // a bad guess, but better than no guess
if concurrency64 > uint64(f.c.maxConcurrentRequests) || concurrency64 < 1 {
concurrency64 = uint64(f.c.maxConcurrentRequests)
}
// Now that concurrency64 is saturated to an int value, we know this assignment cannot possibly overflow.
concurrency := int(concurrency64)
chunkSize := f.c.maxPacket
pool := newBufPool(concurrency, chunkSize)
resPool := newResChanPool(concurrency)
cancel := make(chan struct{})
var wg sync.WaitGroup
defer func() {
// Once the writing Reduce phase has ended, all the feed work needs to unconditionally stop.
close(cancel)
// We want to wait until all outstanding goroutines with an `f` or `f.c` reference have completed.
// Just to be sure we don’t orphan any goroutines any hanging references.
wg.Wait()
}()
type writeWork struct {
b []byte
off int64
err error
next chan writeWork
}
writeCh := make(chan writeWork)
type readWork struct {
id uint32
res chan result
off int64
cur, next chan writeWork
}
readCh := make(chan readWork)
// Slice: hand out chunks of work on demand, with a `cur` and `next` channel built-in for sequencing.
go func() {
defer close(readCh)
off := f.offset
cur := writeCh
for {
id := f.c.nextID()
res := resPool.Get()
next := make(chan writeWork)
readWork := readWork{
id: id,
res: res,
off: off,
cur: cur,
next: next,
}
f.c.dispatchRequest(res, &sshFxpReadPacket{
ID: id,
Handle: f.handle,
Offset: uint64(off),
Len: uint32(chunkSize),
})
select {
case readCh <- readWork:
case <-cancel:
return
}
off += int64(chunkSize)
cur = next
}
}()
wg.Add(concurrency)
for i := 0; i < concurrency; i++ {
// Map_i: each worker gets readWork, and does the Read into a buffer at the given offset.
go func() {
defer wg.Done()
for readWork := range readCh {
var b []byte
var n int
s := <-readWork.res
resPool.Put(readWork.res)
err := s.err
if err == nil {
switch s.typ {
case sshFxpStatus:
err = normaliseError(unmarshalStatus(readWork.id, s.data))
case sshFxpData:
sid, data := unmarshalUint32(s.data)
if readWork.id != sid {
err = &unexpectedIDErr{readWork.id, sid}
} else {
l, data := unmarshalUint32(data)
b = pool.Get()[:l]
n = copy(b, data[:l])
b = b[:n]
}
default:
err = unimplementedPacketErr(s.typ)
}
}
writeWork := writeWork{
b: b,
off: readWork.off,
err: err,
next: readWork.next,
}
select {
case readWork.cur <- writeWork:
case <-cancel:
}
// DO NOT return.
// We want to ensure that readCh is drained before wg.Wait returns.
}
}()
}
// Reduce: serialize the results from the reads into sequential writes.
cur := writeCh
for {
packet, ok := <-cur
if !ok {
return written, errors.New("sftp.File.WriteTo: unexpectedly closed channel")
}
// Because writes are serialized, this will always be the last successfully read byte.
f.offset = packet.off + int64(len(packet.b))
if len(packet.b) > 0 {
n, err := w.Write(packet.b)
written += int64(n)
if err != nil {
return written, err
}
}
if packet.err != nil {
if packet.err == io.EOF {
return written, nil
}
return written, packet.err
}
pool.Put(packet.b)
cur = packet.next
}
}
// Stat returns the FileInfo structure describing file. If there is an
// error.
func (f *File) Stat() (os.FileInfo, error) {
f.mu.RLock()
defer f.mu.RUnlock()
if f.handle == "" {
return nil, os.ErrClosed
}
return f.stat()
}
func (f *File) stat() (os.FileInfo, error) {
fs, err := f.c.fstat(f.handle)
if err != nil {
return nil, err
}
return fileInfoFromStat(fs, path.Base(f.path)), nil
}
// Write writes len(b) bytes to the File. It returns the number of bytes
// written and an error, if any. Write returns a non-nil error when n !=
// len(b).
//
// To maximise throughput for transferring the entire file (especially
// over high latency links) it is recommended to use ReadFrom rather
// than calling Write multiple times. io.Copy will do this
// automatically.
func (f *File) Write(b []byte) (int, error) {
f.mu.Lock()
defer f.mu.Unlock()
if f.handle == "" {
return 0, os.ErrClosed
}
n, err := f.writeAt(b, f.offset)
f.offset += int64(n)
return n, err
}
func (f *File) writeChunkAt(ch chan result, b []byte, off int64) (int, error) {
typ, data, err := f.c.sendPacket(context.Background(), ch, &sshFxpWritePacket{
ID: f.c.nextID(),
Handle: f.handle,
Offset: uint64(off),
Length: uint32(len(b)),
Data: b,
})
if err != nil {
return 0, err
}
switch typ {
case sshFxpStatus:
id, _ := unmarshalUint32(data)
err := normaliseError(unmarshalStatus(id, data))
if err != nil {
return 0, err
}
default:
return 0, unimplementedPacketErr(typ)
}
return len(b), nil
}
// writeAtConcurrent implements WriterAt, but works concurrently rather than sequentially.
func (f *File) writeAtConcurrent(b []byte, off int64) (int, error) {
// Split the write into multiple maxPacket sized concurrent writes
// bounded by maxConcurrentRequests. This allows writes with a suitably
// large buffer to transfer data at a much faster rate due to
// overlapping round trip times.
cancel := make(chan struct{})
type work struct {
id uint32
res chan result
off int64
}
workCh := make(chan work)
concurrency := len(b)/f.c.maxPacket + 1
if concurrency > f.c.maxConcurrentRequests || concurrency < 1 {
concurrency = f.c.maxConcurrentRequests
}
pool := newResChanPool(concurrency)
// Slice: cut up the Read into any number of buffers of length <= f.c.maxPacket, and at appropriate offsets.
go func() {
defer close(workCh)
var read int
chunkSize := f.c.maxPacket
for read < len(b) {
wb := b[read:]
if len(wb) > chunkSize {
wb = wb[:chunkSize]
}
id := f.c.nextID()
res := pool.Get()
off := off + int64(read)
f.c.dispatchRequest(res, &sshFxpWritePacket{
ID: id,
Handle: f.handle,
Offset: uint64(off),
Length: uint32(len(wb)),
Data: wb,
})
select {
case workCh <- work{id, res, off}:
case <-cancel:
return
}
read += len(wb)
}
}()
type wErr struct {
off int64
err error
}
errCh := make(chan wErr)
var wg sync.WaitGroup
wg.Add(concurrency)
for i := 0; i < concurrency; i++ {
// Map_i: each worker gets work, and does the Write from each buffer to its respective offset.
go func() {
defer wg.Done()
for work := range workCh {
s := <-work.res
pool.Put(work.res)
err := s.err
if err == nil {
switch s.typ {
case sshFxpStatus:
err = normaliseError(unmarshalStatus(work.id, s.data))
default:
err = unimplementedPacketErr(s.typ)
}
}
if err != nil {
errCh <- wErr{work.off, err}
}
}
}()
}
// Wait for long tail, before closing results.
go func() {
wg.Wait()
close(errCh)
}()
// Reduce: collect all the results into a relevant return: the earliest offset to return an error.
firstErr := wErr{math.MaxInt64, nil}
for wErr := range errCh {
if wErr.off <= firstErr.off {
firstErr = wErr
}
select {
case <-cancel:
default:
// stop any more work from being distributed. (Just in case.)
close(cancel)
}
}
if firstErr.err != nil {
// firstErr.err != nil if and only if firstErr.off >= our starting offset.
return int(firstErr.off - off), firstErr.err
}
return len(b), nil
}
// WriteAt writes up to len(b) byte to the File at a given offset `off`. It returns
// the number of bytes written and an error, if any. WriteAt follows io.WriterAt semantics,
// so the file offset is not altered during the write.
func (f *File) WriteAt(b []byte, off int64) (written int, err error) {
f.mu.RLock()
defer f.mu.RUnlock()
if f.handle == "" {
return 0, os.ErrClosed
}
return f.writeAt(b, off)
}
// writeAt must be called while holding either the Read or Write mutex in File.
// This code is concurrent safe with itself, but not with Close.
func (f *File) writeAt(b []byte, off int64) (written int, err error) {
if len(b) <= f.c.maxPacket {
// We can do this in one write.
return f.writeChunkAt(nil, b, off)
}
if f.c.useConcurrentWrites {
return f.writeAtConcurrent(b, off)
}
ch := make(chan result, 1) // reusable channel
chunkSize := f.c.maxPacket
for written < len(b) {
wb := b[written:]
if len(wb) > chunkSize {
wb = wb[:chunkSize]
}
n, err := f.writeChunkAt(ch, wb, off+int64(written))
if n > 0 {
written += n
}
if err != nil {
return written, err
}
}
return len(b), nil
}
// ReadFromWithConcurrency implements ReaderFrom,
// but uses the given concurrency to issue multiple requests at the same time.
//
// Giving a concurrency of less than one will default to the Client’s max concurrency.
//
// Otherwise, the given concurrency will be capped by the Client's max concurrency.
func (f *File) ReadFromWithConcurrency(r io.Reader, concurrency int) (read int64, err error) {
f.mu.Lock()
defer f.mu.Unlock()
return f.readFromWithConcurrency(r, concurrency)
}
func (f *File) readFromWithConcurrency(r io.Reader, concurrency int) (read int64, err error) {
if f.handle == "" {
return 0, os.ErrClosed
}
// Split the write into multiple maxPacket sized concurrent writes.
// This allows writes with a suitably large reader
// to transfer data at a much faster rate due to overlapping round trip times.
cancel := make(chan struct{})
type work struct {
id uint32
res chan result
off int64
}
workCh := make(chan work)
type rwErr struct {
off int64
err error
}
errCh := make(chan rwErr)
if concurrency > f.c.maxConcurrentRequests || concurrency < 1 {
concurrency = f.c.maxConcurrentRequests
}
pool := newResChanPool(concurrency)
// Slice: cut up the Read into any number of buffers of length <= f.c.maxPacket, and at appropriate offsets.
go func() {
defer close(workCh)
b := make([]byte, f.c.maxPacket)
off := f.offset
for {
n, err := r.Read(b)
if n > 0 {
read += int64(n)
id := f.c.nextID()
res := pool.Get()
f.c.dispatchRequest(res, &sshFxpWritePacket{
ID: id,
Handle: f.handle,
Offset: uint64(off),
Length: uint32(n),
Data: b[:n],
})
select {
case workCh <- work{id, res, off}:
case <-cancel:
return
}
off += int64(n)
}
if err != nil {
if err != io.EOF {
errCh <- rwErr{off, err}
}
return
}
}
}()
var wg sync.WaitGroup
wg.Add(concurrency)
for i := 0; i < concurrency; i++ {
// Map_i: each worker gets work, and does the Write from each buffer to its respective offset.
go func() {
defer wg.Done()
for work := range workCh {
s := <-work.res
pool.Put(work.res)
err := s.err
if err == nil {
switch s.typ {
case sshFxpStatus:
err = normaliseError(unmarshalStatus(work.id, s.data))
default:
err = unimplementedPacketErr(s.typ)
}
}
if err != nil {
errCh <- rwErr{work.off, err}
// DO NOT return.
// We want to ensure that workCh is drained before wg.Wait returns.
}
}
}()
}
// Wait for long tail, before closing results.
go func() {
wg.Wait()
close(errCh)
}()
// Reduce: Collect all the results into a relevant return: the earliest offset to return an error.
firstErr := rwErr{math.MaxInt64, nil}
for rwErr := range errCh {
if rwErr.off <= firstErr.off {
firstErr = rwErr
}
select {
case <-cancel:
default:
// stop any more work from being distributed.
close(cancel)
}
}
if firstErr.err != nil {
// firstErr.err != nil if and only if firstErr.off is a valid offset.
//
// firstErr.off will then be the lesser of:
// * the offset of the first error from writing,
// * the last successfully read offset.
//
// This could be less than the last successfully written offset,
// which is the whole reason for the UseConcurrentWrites() ClientOption.
//
// Callers are responsible for truncating any SFTP files to a safe length.
f.offset = firstErr.off
// ReadFrom is defined to return the read bytes, regardless of any writer errors.
return read, firstErr.err
}
f.offset += read
return read, nil
}
// ReadFrom reads data from r until EOF and writes it to the file. The return
// value is the number of bytes read. Any error except io.EOF encountered
// during the read is also returned.
//
// This method is preferred over calling Write multiple times
// to maximise throughput for transferring the entire file,
// especially over high-latency links.
func (f *File) ReadFrom(r io.Reader) (int64, error) {
f.mu.Lock()
defer f.mu.Unlock()
if f.handle == "" {
return 0, os.ErrClosed
}
if f.c.useConcurrentWrites {
var remain int64
switch r := r.(type) {
case interface{ Len() int }:
remain = int64(r.Len())
case interface{ Size() int64 }:
remain = r.Size()
case *io.LimitedReader:
remain = r.N
case interface{ Stat() (os.FileInfo, error) }:
info, err := r.Stat()
if err == nil {
remain = info.Size()
}
}
if remain < 0 {
// We can strongly assert that we want default max concurrency here.
return f.readFromWithConcurrency(r, f.c.maxConcurrentRequests)
}
if remain > int64(f.c.maxPacket) {
// Otherwise, only use concurrency, if it would be at least two packets.
// This is the best reasonable guess we can make.
concurrency64 := remain/int64(f.c.maxPacket) + 1
// We need to cap this value to an `int` size value to avoid overflow on 32-bit machines.
// So, we may as well pre-cap it to `f.c.maxConcurrentRequests`.
if concurrency64 > int64(f.c.maxConcurrentRequests) {
concurrency64 = int64(f.c.maxConcurrentRequests)
}
return f.readFromWithConcurrency(r, int(concurrency64))
}
}
ch := make(chan result, 1) // reusable channel
b := make([]byte, f.c.maxPacket)
var read int64
for {
n, err := r.Read(b)
if n < 0 {
panic("sftp.File: reader returned negative count from Read")
}
if n > 0 {
read += int64(n)
m, err2 := f.writeChunkAt(ch, b[:n], f.offset)
f.offset += int64(m)
if err == nil {
err = err2
}
}
if err != nil {
if err == io.EOF {
return read, nil // return nil explicitly.
}
return read, err
}
}
}
// Seek implements io.Seeker by setting the client offset for the next Read or
// Write. It returns the next offset read. Seeking before or after the end of
// the file is undefined. Seeking relative to the end calls Stat.
func (f *File) Seek(offset int64, whence int) (int64, error) {
f.mu.Lock()
defer f.mu.Unlock()
if f.handle == "" {
return 0, os.ErrClosed
}
switch whence {
case io.SeekStart:
case io.SeekCurrent:
offset += f.offset
case io.SeekEnd:
fi, err := f.stat()
if err != nil {
return f.offset, err
}
offset += fi.Size()
default:
return f.offset, unimplementedSeekWhence(whence)
}
if offset < 0 {
return f.offset, os.ErrInvalid
}
f.offset = offset
return f.offset, nil
}
// Chown changes the uid/gid of the current file.
func (f *File) Chown(uid, gid int) error {
f.mu.RLock()
defer f.mu.RUnlock()
if f.handle == "" {
return os.ErrClosed
}
return f.c.fsetstat(f.handle, sshFileXferAttrUIDGID, &FileStat{
UID: uint32(uid),
GID: uint32(gid),
})
}
// Chmod changes the permissions of the current file.
//
// See Client.Chmod for details.
func (f *File) Chmod(mode os.FileMode) error {
f.mu.RLock()
defer f.mu.RUnlock()
if f.handle == "" {
return os.ErrClosed
}
return f.c.fsetstat(f.handle, sshFileXferAttrPermissions, toChmodPerm(mode))
}
// Truncate sets the size of the current file. Although it may be safely assumed
// that if the size is less than its current size it will be truncated to fit,
// the SFTP protocol does not specify what behavior the server should do when setting
// size greater than the current size.
// We send a SSH_FXP_FSETSTAT here since we have a file handle
func (f *File) Truncate(size int64) error {
f.mu.RLock()
defer f.mu.RUnlock()
if f.handle == "" {
return os.ErrClosed
}
return f.c.fsetstat(f.handle, sshFileXferAttrSize, uint64(size))
}
// Sync requests a flush of the contents of a File to stable storage.
//
// Sync requires the server to support the fsync@openssh.com extension.
func (f *File) Sync() error {
f.mu.Lock()
defer f.mu.Unlock()
if f.handle == "" {
return os.ErrClosed
}
id := f.c.nextID()
typ, data, err := f.c.sendPacket(context.Background(), nil, &sshFxpFsyncPacket{
ID: id,
Handle: f.handle,
})
switch {
case err != nil:
return err
case typ == sshFxpStatus:
return normaliseError(unmarshalStatus(id, data))
default:
return &unexpectedPacketErr{want: sshFxpStatus, got: typ}
}
}
// normaliseError normalises an error into a more standard form that can be
// checked against stdlib errors like io.EOF or os.ErrNotExist.
func normaliseError(err error) error {
switch err := err.(type) {
case *StatusError:
switch err.Code {
case sshFxEOF:
return io.EOF
case sshFxNoSuchFile:
return os.ErrNotExist
case sshFxPermissionDenied:
return os.ErrPermission
case sshFxOk:
return nil
default:
return err
}
default:
return err
}
}
// flags converts the flags passed to OpenFile into ssh flags.
// Unsupported flags are ignored.
func toPflags(f int) uint32 {
var out uint32
switch f & os.O_WRONLY {
case os.O_WRONLY:
out |= sshFxfWrite
case os.O_RDONLY:
out |= sshFxfRead
}
if f&os.O_RDWR == os.O_RDWR {
out |= sshFxfRead | sshFxfWrite
}
if f&os.O_APPEND == os.O_APPEND {
out |= sshFxfAppend
}
if f&os.O_CREATE == os.O_CREATE {
out |= sshFxfCreat
}
if f&os.O_TRUNC == os.O_TRUNC {
out |= sshFxfTrunc
}
if f&os.O_EXCL == os.O_EXCL {
out |= sshFxfExcl
}
return out
}
// toChmodPerm converts Go permission bits to POSIX permission bits.
//
// This differs from fromFileMode in that we preserve the POSIX versions of
// setuid, setgid and sticky in m, because we've historically supported those
// bits, and we mask off any non-permission bits.
func toChmodPerm(m os.FileMode) (perm uint32) {
const mask = os.ModePerm | s_ISUID | s_ISGID | s_ISVTX
perm = uint32(m & mask)
if m&os.ModeSetuid != 0 {
perm |= s_ISUID
}
if m&os.ModeSetgid != 0 {
perm |= s_ISGID
}
if m&os.ModeSticky != 0 {
perm |= s_ISVTX
}
return perm
}
================================================
FILE: gossh/sftp/conn.go
================================================
package sftp
import (
"context"
"encoding"
"fmt"
"io"
"sync"
)
// conn implements a bidirectional channel on which client and server
// connections are multiplexed.
type conn struct {
io.Reader
io.WriteCloser
// this is the same allocator used in packet manager
alloc *allocator
sync.Mutex // used to serialise writes to sendPacket
}
// the orderID is used in server mode if the allocator is enabled.
// For the client mode just pass 0.
// It returns io.EOF if the connection is closed and
// there are no more packets to read.
func (c *conn) recvPacket(orderID uint32) (uint8, []byte, error) {
return recvPacket(c, c.alloc, orderID)
}
func (c *conn) sendPacket(m encoding.BinaryMarshaler) error {
c.Lock()
defer c.Unlock()
return sendPacket(c, m)
}
func (c *conn) Close() error {
c.Lock()
defer c.Unlock()
return c.WriteCloser.Close()
}
type clientConn struct {
conn
wg sync.WaitGroup
sync.Mutex // protects inflight
inflight map[uint32]chan<- result // outstanding requests
closed chan struct{}
err error
}
// Wait blocks until the conn has shut down, and return the error
// causing the shutdown. It can be called concurrently from multiple
// goroutines.
func (c *clientConn) Wait() error {
<-c.closed
return c.err
}
// Close closes the SFTP session.
func (c *clientConn) Close() error {
defer c.wg.Wait()
return c.conn.Close()
}
// recv continuously reads from the server and forwards responses to the
// appropriate channel.
func (c *clientConn) recv() error {
defer c.conn.Close()
for {
typ, data, err := c.recvPacket(0)
if err != nil {
return err
}
sid, _, err := unmarshalUint32Safe(data)
if err != nil {
return err
}
ch, ok := c.getChannel(sid)
if !ok {
// This is an unexpected occurrence. Send the error
// back to all listeners so that they terminate
// gracefully.
return fmt.Errorf("sid not found: %d", sid)
}
ch <- result{typ: typ, data: data}
}
}
func (c *clientConn) putChannel(ch chan<- result, sid uint32) bool {
c.Lock()
defer c.Unlock()
select {
case <-c.closed:
// already closed with broadcastErr, return error on chan.
ch <- result{err: ErrSSHFxConnectionLost}
return false
default:
}
c.inflight[sid] = ch
return true
}
func (c *clientConn) getChannel(sid uint32) (chan<- result, bool) {
c.Lock()
defer c.Unlock()
ch, ok := c.inflight[sid]
delete(c.inflight, sid)
return ch, ok
}
// result captures the result of receiving the a packet from the server
type result struct {
typ byte
data []byte
err error
}
type idmarshaler interface {
id() uint32
encoding.BinaryMarshaler
}
func (c *clientConn) sendPacket(ctx context.Context, ch chan result, p idmarshaler) (byte, []byte, error) {
if cap(ch) < 1 {
ch = make(chan result, 1)
}
c.dispatchRequest(ch, p)
select {
case <-ctx.Done():
return 0, nil, ctx.Err()
case s := <-ch:
return s.typ, s.data, s.err
}
}
// dispatchRequest should ideally only be called by race-detection tests outside of this file,
// where you have to ensure two packets are in flight sequentially after each other.
func (c *clientConn) dispatchRequest(ch chan<- result, p idmarshaler) {
sid := p.id()
if !c.putChannel(ch, sid) {
// already closed.
return
}
if err := c.conn.sendPacket(p); err != nil {
if ch, ok := c.getChannel(sid); ok {
ch <- result{err: err}
}
}
}
// broadcastErr sends an error to all goroutines waiting for a response.
func (c *clientConn) broadcastErr(err error) {
c.Lock()
defer c.Unlock()
bcastRes := result{err: ErrSSHFxConnectionLost}
for sid, ch := range c.inflight {
ch <- bcastRes
// Replace the chan in inflight,
// we have hijacked this chan,
// and this guarantees always-only-once sending.
c.inflight[sid] = make(chan<- result, 1)
}
c.err = err
close(c.closed)
}
type serverConn struct {
conn
}
func (s *serverConn) sendError(id uint32, err error) error {
return s.sendPacket(statusFromError(id, err))
}
================================================
FILE: gossh/sftp/debug.go
================================================
//go:build debug
// +build debug
package sftp
import "log"
func debug(fmt string, args ...any) {
log.Printf(fmt, args...)
}
================================================
FILE: gossh/sftp/fuzz.go
================================================
//go:build gofuzz
// +build gofuzz
package sftp
import "bytes"
type sinkfuzz struct{}
func (*sinkfuzz) Close() error { return nil }
func (*sinkfuzz) Write(p []byte) (int, error) { return len(p), nil }
var devnull = &sinkfuzz{}
// To run: go-fuzz-build && go-fuzz
func Fuzz(data []byte) int {
c, err := NewClientPipe(bytes.NewReader(data), devnull)
if err != nil {
return 0
}
c.Close()
return 1
}
================================================
FILE: gossh/sftp/internal/encoding/ssh/filexfer/attrs.go
================================================
package sshfx
// Attributes related flags.
const (
AttrSize = 1 << iota // SSH_FILEXFER_ATTR_SIZE
AttrUIDGID // SSH_FILEXFER_ATTR_UIDGID
AttrPermissions // SSH_FILEXFER_ATTR_PERMISSIONS
AttrACModTime // SSH_FILEXFER_ACMODTIME
AttrExtended = 1 << 31 // SSH_FILEXFER_ATTR_EXTENDED
)
// Attributes defines the file attributes type defined in draft-ietf-secsh-filexfer-02
//
// Defined in: https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt#section-5
type Attributes struct {
Flags uint32
// AttrSize
Size uint64
// AttrUIDGID
UID uint32
GID uint32
// AttrPermissions
Permissions FileMode
// AttrACmodTime
ATime uint32
MTime uint32
// AttrExtended
ExtendedAttributes []ExtendedAttribute
}
// GetSize returns the Size field and a bool that is true if and only if the value is valid/defined.
func (a *Attributes) GetSize() (size uint64, ok bool) {
return a.Size, a.Flags&AttrSize != 0
}
// SetSize is a convenience function that sets the Size field,
// and marks the field as valid/defined in Flags.
func (a *Attributes) SetSize(size uint64) {
a.Flags |= AttrSize
a.Size = size
}
// GetUIDGID returns the UID and GID fields and a bool that is true if and only if the values are valid/defined.
func (a *Attributes) GetUIDGID() (uid, gid uint32, ok bool) {
return a.UID, a.GID, a.Flags&AttrUIDGID != 0
}
// SetUIDGID is a convenience function that sets the UID and GID fields,
// and marks the fields as valid/defined in Flags.
func (a *Attributes) SetUIDGID(uid, gid uint32) {
a.Flags |= AttrUIDGID
a.UID = uid
a.GID = gid
}
// GetPermissions returns the Permissions field and a bool that is true if and only if the value is valid/defined.
func (a *Attributes) GetPermissions() (perms FileMode, ok bool) {
return a.Permissions, a.Flags&AttrPermissions != 0
}
// SetPermissions is a convenience function that sets the Permissions field,
// and marks the field as valid/defined in Flags.
func (a *Attributes) SetPermissions(perms FileMode) {
a.Flags |= AttrPermissions
a.Permissions = perms
}
// GetACModTime returns the ATime and MTime fields and a bool that is true if and only if the values are valid/defined.
func (a *Attributes) GetACModTime() (atime, mtime uint32, ok bool) {
return a.ATime, a.MTime, a.Flags&AttrACModTime != 0
}
// SetACModTime is a convenience function that sets the ATime and MTime fields,
// and marks the fields as valid/defined in Flags.
func (a *Attributes) SetACModTime(atime, mtime uint32) {
a.Flags |= AttrACModTime
a.ATime = atime
a.MTime = mtime
}
// Len returns the number of bytes a would marshal into.
func (a *Attributes) Len() int {
length := 4
if a.Flags&AttrSize != 0 {
length += 8
}
if a.Flags&AttrUIDGID != 0 {
length += 4 + 4
}
if a.Flags&AttrPermissions != 0 {
length += 4
}
if a.Flags&AttrACModTime != 0 {
length += 4 + 4
}
if a.Flags&AttrExtended != 0 {
length += 4
for _, ext := range a.ExtendedAttributes {
length += ext.Len()
}
}
return length
}
// MarshalInto marshals e onto the end of the given Buffer.
func (a *Attributes) MarshalInto(buf *Buffer) {
buf.AppendUint32(a.Flags)
if a.Flags&AttrSize != 0 {
buf.AppendUint64(a.Size)
}
if a.Flags&AttrUIDGID != 0 {
buf.AppendUint32(a.UID)
buf.AppendUint32(a.GID)
}
if a.Flags&AttrPermissions != 0 {
buf.AppendUint32(uint32(a.Permissions))
}
if a.Flags&AttrACModTime != 0 {
buf.AppendUint32(a.ATime)
buf.AppendUint32(a.MTime)
}
if a.Flags&AttrExtended != 0 {
buf.AppendUint32(uint32(len(a.ExtendedAttributes)))
for _, ext := range a.ExtendedAttributes {
ext.MarshalInto(buf)
}
}
}
// MarshalBinary returns a as the binary encoding of a.
func (a *Attributes) MarshalBinary() ([]byte, error) {
buf := NewBuffer(make([]byte, 0, a.Len()))
a.MarshalInto(buf)
return buf.Bytes(), nil
}
// UnmarshalFrom unmarshals an Attributes from the given Buffer into e.
//
// NOTE: The values of fields not covered in the a.Flags are explicitly undefined.
func (a *Attributes) UnmarshalFrom(buf *Buffer) (err error) {
flags := buf.ConsumeUint32()
return a.XXX_UnmarshalByFlags(flags, buf)
}
// XXX_UnmarshalByFlags uses the pre-existing a.Flags field to determine which fields to decode.
// DO NOT USE THIS: it is an anti-corruption function to implement existing internal usage in pkg/sftp.
// This function is not a part of any compatibility promise.
func (a *Attributes) XXX_UnmarshalByFlags(flags uint32, buf *Buffer) (err error) {
a.Flags = flags
// Short-circuit dummy attributes.
if a.Flags == 0 {
return buf.Err
}
if a.Flags&AttrSize != 0 {
a.Size = buf.ConsumeUint64()
}
if a.Flags&AttrUIDGID != 0 {
a.UID = buf.ConsumeUint32()
a.GID = buf.ConsumeUint32()
}
if a.Flags&AttrPermissions != 0 {
a.Permissions = FileMode(buf.ConsumeUint32())
}
if a.Flags&AttrACModTime != 0 {
a.ATime = buf.ConsumeUint32()
a.MTime = buf.ConsumeUint32()
}
if a.Flags&AttrExtended != 0 {
count := buf.ConsumeCount()
a.ExtendedAttributes = make([]ExtendedAttribute, count)
for i := range a.ExtendedAttributes {
a.ExtendedAttributes[i].UnmarshalFrom(buf)
}
}
return buf.Err
}
// UnmarshalBinary decodes the binary encoding of Attributes into e.
func (a *Attributes) UnmarshalBinary(data []byte) error {
return a.UnmarshalFrom(NewBuffer(data))
}
// ExtendedAttribute defines the extended file attribute type defined in draft-ietf-secsh-filexfer-02
//
// Defined in: https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt#section-5
type ExtendedAttribute struct {
Type string
Data string
}
// Len returns the number of bytes e would marshal into.
func (e *ExtendedAttribute) Len() int {
return 4 + len(e.Type) + 4 + len(e.Data)
}
// MarshalInto marshals e onto the end of the given Buffer.
func (e *ExtendedAttribute) MarshalInto(buf *Buffer) {
buf.AppendString(e.Type)
buf.AppendString(e.Data)
}
// MarshalBinary returns e as the binary encoding of e.
func (e *ExtendedAttribute) MarshalBinary() ([]byte, error) {
buf := NewBuffer(make([]byte, 0, e.Len()))
e.MarshalInto(buf)
return buf.Bytes(), nil
}
// UnmarshalFrom unmarshals an ExtendedAattribute from the given Buffer into e.
func (e *ExtendedAttribute) UnmarshalFrom(buf *Buffer) (err error) {
*e = ExtendedAttribute{
Type: buf.ConsumeString(),
Data: buf.ConsumeString(),
}
return buf.Err
}
// UnmarshalBinary decodes the binary encoding of ExtendedAttribute into e.
func (e *ExtendedAttribute) UnmarshalBinary(data []byte) error {
return e.UnmarshalFrom(NewBuffer(data))
}
// NameEntry implements the SSH_FXP_NAME repeated data type from draft-ietf-secsh-filexfer-02
//
// This type is incompatible with versions 4 or higher.
type NameEntry struct {
Filename string
Longname string
Attrs Attributes
}
// Len returns the number of bytes e would marshal into.
func (e *NameEntry) Len() int {
return 4 + len(e.Filename) + 4 + len(e.Longname) + e.Attrs.Len()
}
// MarshalInto marshals e onto the end of the given Buffer.
func (e *NameEntry) MarshalInto(buf *Buffer) {
buf.AppendString(e.Filename)
buf.AppendString(e.Longname)
e.Attrs.MarshalInto(buf)
}
// MarshalBinary returns e as the binary encoding of e.
func (e *NameEntry) MarshalBinary() ([]byte, error) {
buf := NewBuffer(make([]byte, 0, e.Len()))
e.MarshalInto(buf)
return buf.Bytes(), nil
}
// UnmarshalFrom unmarshals an NameEntry from the given Buffer into e.
//
// NOTE: The values of fields not covered in the a.Flags are explicitly undefined.
func (e *NameEntry) UnmarshalFrom(buf *Buffer) (err error) {
*e = NameEntry{
Filename: buf.ConsumeString(),
Longname: buf.ConsumeString(),
}
return e.Attrs.UnmarshalFrom(buf)
}
// UnmarshalBinary decodes the binary encoding of NameEntry into e.
func (e *NameEntry) UnmarshalBinary(data []byte) error {
return e.UnmarshalFrom(NewBuffer(data))
}
================================================
FILE: gossh/sftp/internal/encoding/ssh/filexfer/buffer.go
================================================
package sshfx
import (
"encoding/binary"
"errors"
)
// Various encoding errors.
var (
ErrShortPacket = errors.New("packet too short")
ErrLongPacket = errors.New("packet too long")
)
// Buffer wraps up the various encoding details of the SSH format.
//
// Data types are encoded as per section 4 from https://tools.ietf.org/html/draft-ietf-secsh-architecture-09#page-8
type Buffer struct {
b []byte
off int
Err error
}
// NewBuffer creates and initializes a new buffer using buf as its initial contents.
// The new buffer takes ownership of buf, and the caller should not use buf after this call.
//
// In most cases, new(Buffer) (or just declaring a Buffer variable) is sufficient to initialize a Buffer.
func NewBuffer(buf []byte) *Buffer {
return &Buffer{
b: buf,
}
}
// NewMarshalBuffer creates a new Buffer ready to start marshaling a Packet into.
// It preallocates enough space for uint32(length), uint8(type), uint32(request-id) and size more bytes.
func NewMarshalBuffer(size int) *Buffer {
return NewBuffer(make([]byte, 4+1+4+size))
}
// Bytes returns a slice of length b.Len() holding the unconsumed bytes in the Buffer.
// The slice is valid for use only until the next buffer modification
// (that is, only until the next call to an Append or Consume method).
func (b *Buffer) Bytes() []byte {
return b.b[b.off:]
}
// Len returns the number of unconsumed bytes in the buffer.
func (b *Buffer) Len() int { return len(b.b) - b.off }
// Cap returns the capacity of the buffer’s underlying byte slice,
// that is, the total space allocated for the buffer’s data.
func (b *Buffer) Cap() int { return cap(b.b) }
// Reset resets the buffer to be empty, but it retains the underlying storage for use by future Appends.
func (b *Buffer) Reset() {
*b = Buffer{
b: b.b[:0],
}
}
// StartPacket resets and initializes the buffer to be ready to start marshaling a packet into.
// It truncates the buffer, reserves space for uint32(length), then appends the given packetType and requestID.
func (b *Buffer) StartPacket(packetType PacketType, requestID uint32) {
*b = Buffer{
b: append(b.b[:0], make([]byte, 4)...),
}
b.AppendUint8(uint8(packetType))
b.AppendUint32(requestID)
}
// Packet finalizes the packet started from StartPacket.
// It is expected that this will end the ownership of the underlying byte-slice,
// and so the returned byte-slices may be reused the same as any other byte-slice,
// the caller should not use this buffer after this call.
//
// It writes the packet body length into the first four bytes of the buffer in network byte order (big endian).
// The packet body length is the length of this buffer less the 4-byte length itself, plus the length of payload.
//
// It is assumed that no Consume methods have been called on this buffer,
// and so it returns the whole underlying slice.
func (b *Buffer) Packet(payload []byte) (header, payloadPassThru []byte, err error) {
b.PutLength(len(b.b) - 4 + len(payload))
return b.b, payload, nil
}
// ConsumeUint8 consumes a single byte from the buffer.
// If the buffer does not have enough data, it will set Err to ErrShortPacket.
func (b *Buffer) ConsumeUint8() uint8 {
if b.Err != nil {
return 0
}
if b.Len() < 1 {
b.off = len(b.b)
b.Err = ErrShortPacket
return 0
}
var v uint8
v, b.off = b.b[b.off], b.off+1
return v
}
// AppendUint8 appends a single byte into the buffer.
func (b *Buffer) AppendUint8(v uint8) {
b.b = append(b.b, v)
}
// ConsumeBool consumes a single byte from the buffer, and returns true if that byte is non-zero.
// If the buffer does not have enough data, it will set Err to ErrShortPacket.
func (b *Buffer) ConsumeBool() bool {
return b.ConsumeUint8() != 0
}
// AppendBool appends a single bool into the buffer.
// It encodes it as a single byte, with false as 0, and true as 1.
func (b *Buffer) AppendBool(v bool) {
if v {
b.AppendUint8(1)
} else {
b.AppendUint8(0)
}
}
// ConsumeUint16 consumes a single uint16 from the buffer, in network byte order (big-endian).
// If the buffer does not have enough data, it will set Err to ErrShortPacket.
func (b *Buffer) ConsumeUint16() uint16 {
if b.Err != nil {
return 0
}
if b.Len() < 2 {
b.off = len(b.b)
b.Err = ErrShortPacket
return 0
}
v := binary.BigEndian.Uint16(b.b[b.off:])
b.off += 2
return v
}
// AppendUint16 appends single uint16 into the buffer, in network byte order (big-endian).
func (b *Buffer) AppendUint16(v uint16) {
b.b = append(b.b,
byte(v>>8),
byte(v>>0),
)
}
// unmarshalUint32 is used internally to read the packet length.
// It is unsafe, and so not exported.
// Even within this package, its use should be avoided.
func unmarshalUint32(b []byte) uint32 {
return binary.BigEndian.Uint32(b[:4])
}
// ConsumeUint32 consumes a single uint32 from the buffer, in network byte order (big-endian).
// If the buffer does not have enough data, it will set Err to ErrShortPacket.
func (b *Buffer) ConsumeUint32() uint32 {
if b.Err != nil {
return 0
}
if b.Len() < 4 {
b.off = len(b.b)
b.Err = ErrShortPacket
return 0
}
v := binary.BigEndian.Uint32(b.b[b.off:])
b.off += 4
return v
}
// AppendUint32 appends a single uint32 into the buffer, in network byte order (big-endian).
func (b *Buffer) AppendUint32(v uint32) {
b.b = append(b.b,
byte(v>>24),
byte(v>>16),
byte(v>>8),
byte(v>>0),
)
}
// ConsumeCount consumes a single uint32 count from the buffer, in network byte order (big-endian) as an int.
// If the buffer does not have enough data, it will set Err to ErrShortPacket.
func (b *Buffer) ConsumeCount() int {
return int(b.ConsumeUint32())
}
// AppendCount appends a single int length as a uint32 into the buffer, in network byte order (big-endian).
func (b *Buffer) AppendCount(v int) {
b.AppendUint32(uint32(v))
}
// ConsumeUint64 consumes a single uint64 from the buffer, in network byte order (big-endian).
// If the buffer does not have enough data, it will set Err to ErrShortPacket.
func (b *Buffer) ConsumeUint64() uint64 {
if b.Err != nil {
return 0
}
if b.Len() < 8 {
b.off = len(b.b)
b.Err = ErrShortPacket
return 0
}
v := binary.BigEndian.Uint64(b.b[b.off:])
b.off += 8
return v
}
// AppendUint64 appends a single uint64 into the buffer, in network byte order (big-endian).
func (b *Buffer) AppendUint64(v uint64) {
b.b = append(b.b,
byte(v>>56),
byte(v>>48),
byte(v>>40),
byte(v>>32),
byte(v>>24),
byte(v>>16),
byte(v>>8),
byte(v>>0),
)
}
// ConsumeInt64 consumes a single int64 from the buffer, in network byte order (big-endian) with two’s complement.
// If the buffer does not have enough data, it will set Err to ErrShortPacket.
func (b *Buffer) ConsumeInt64() int64 {
return int64(b.ConsumeUint64())
}
// AppendInt64 appends a single int64 into the buffer, in network byte order (big-endian) with two’s complement.
func (b *Buffer) AppendInt64(v int64) {
b.AppendUint64(uint64(v))
}
// ConsumeByteSlice consumes a single string of raw binary data from the buffer.
// A string is a uint32 length, followed by that number of raw bytes.
// If the buffer does not have enough data, or defines a length larger than available, it will set Err to ErrShortPacket.
//
// The returned slice aliases the buffer contents, and is valid only as long as the buffer is not reused
// (that is, only until the next call to Reset, PutLength, StartPacket, or UnmarshalBinary).
//
// In no case will any Consume calls return overlapping slice aliases,
// and Append calls are guaranteed to not disturb this slice alias.
func (b *Buffer) ConsumeByteSlice() []byte {
length := int(b.ConsumeUint32())
if b.Err != nil {
return nil
}
if b.Len() < length || length < 0 {
b.off = len(b.b)
b.Err = ErrShortPacket
return nil
}
v := b.b[b.off:]
if len(v) > length || cap(v) > length {
v = v[:length:length]
}
b.off += int(length)
return v
}
// ConsumeByteSliceCopy consumes a single string of raw binary data as a copy from the buffer.
// A string is a uint32 length, followed by that number of raw bytes.
// If the buffer does not have enough data, or defines a length larger than available, it will set Err to ErrShortPacket.
//
// The returned slice does not alias any buffer contents,
// and will therefore be valid even if the buffer is later reused.
//
// If hint has sufficient capacity to hold the data, it will be reused and overwritten,
// otherwise a new backing slice will be allocated and returned.
func (b *Buffer) ConsumeByteSliceCopy(hint []byte) []byte {
data := b.ConsumeByteSlice()
if grow := len(data) - len(hint); grow > 0 {
hint = append(hint, make([]byte, grow)...)
}
n := copy(hint, data)
hint = hint[:n]
return hint
}
// AppendByteSlice appends a single string of raw binary data into the buffer.
// A string is a uint32 length, followed by that number of raw bytes.
func (b *Buffer) AppendByteSlice(v []byte) {
b.AppendUint32(uint32(len(v)))
b.b = append(b.b, v...)
}
// ConsumeString consumes a single string of binary data from the buffer.
// A string is a uint32 length, followed by that number of raw bytes.
// If the buffer does not have enough data, or defines a length larger than available, it will set Err to ErrShortPacket.
//
// NOTE: Go implicitly assumes that strings contain UTF-8 encoded data.
// All caveats on using arbitrary binary data in Go strings applies.
func (b *Buffer) ConsumeString() string {
return string(b.ConsumeByteSlice())
}
// AppendString appends a single string of binary data into the buffer.
// A string is a uint32 length, followed by that number of raw bytes.
func (b *Buffer) AppendString(v string) {
b.AppendByteSlice([]byte(v))
}
// PutLength writes the given size into the first four bytes of the buffer in network byte order (big endian).
func (b *Buffer) PutLength(size int) {
if len(b.b) < 4 {
b.b = append(b.b, make([]byte, 4-len(b.b))...)
}
binary.BigEndian.PutUint32(b.b, uint32(size))
}
// MarshalBinary returns a clone of the full internal buffer.
func (b *Buffer) MarshalBinary() ([]byte, error) {
clone := make([]byte, len(b.b))
n := copy(clone, b.b)
return clone[:n], nil
}
// UnmarshalBinary sets the internal buffer of b to be a clone of data, and zeros the internal offset.
func (b *Buffer) UnmarshalBinary(data []byte) error {
if grow := len(data) - len(b.b); grow > 0 {
b.b = append(b.b, make([]byte, grow)...)
}
n := copy(b.b, data)
b.b = b.b[:n]
b.off = 0
return nil
}
================================================
FILE: gossh/sftp/internal/encoding/ssh/filexfer/extended_packets.go
================================================
package sshfx
import (
"encoding"
"sync"
)
// ExtendedData aliases the untyped interface composition of encoding.BinaryMarshaler and encoding.BinaryUnmarshaler.
type ExtendedData = interface {
encoding.BinaryMarshaler
encoding.BinaryUnmarshaler
}
// ExtendedDataConstructor defines a function that returns a new(ArbitraryExtendedPacket).
type ExtendedDataConstructor func() ExtendedData
var extendedPacketTypes = struct {
mu sync.RWMutex
constructors map[string]ExtendedDataConstructor
}{
constructors: make(map[string]ExtendedDataConstructor),
}
// RegisterExtendedPacketType defines a specific ExtendedDataConstructor for the given extension string.
func RegisterExtendedPacketType(extension string, constructor ExtendedDataConstructor) {
extendedPacketTypes.mu.Lock()
defer extendedPacketTypes.mu.Unlock()
if _, exist := extendedPacketTypes.constructors[extension]; exist {
panic("encoding/ssh/filexfer: multiple registration of extended packet type " + extension)
}
extendedPacketTypes.constructors[extension] = constructor
}
func newExtendedPacket(extension string) ExtendedData {
extendedPacketTypes.mu.RLock()
defer extendedPacketTypes.mu.RUnlock()
if f := extendedPacketTypes.constructors[extension]; f != nil {
return f()
}
return new(Buffer)
}
// ExtendedPacket defines the SSH_FXP_CLOSE packet.
type ExtendedPacket struct {
ExtendedRequest string
Data ExtendedData
}
// Type returns the SSH_FXP_xy value associated with this packet type.
func (p *ExtendedPacket) Type() PacketType {
return PacketTypeExtended
}
// MarshalPacket returns p as a two-part binary encoding of p.
//
// The Data is marshaled into binary, and returned as the payload.
func (p *ExtendedPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
if buf.Cap() < 9 {
size := 4 + len(p.ExtendedRequest) // string(extended-request)
buf = NewMarshalBuffer(size)
}
buf.StartPacket(PacketTypeExtended, reqid)
buf.AppendString(p.ExtendedRequest)
if p.Data != nil {
payload, err = p.Data.MarshalBinary()
if err != nil {
return nil, nil, err
}
}
return buf.Packet(payload)
}
// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
// It is assumed that the uint32(request-id) has already been consumed.
//
// If p.Data is nil, and the extension has been registered, a new type will be made from the registration.
// If the extension has not been registered, then a new Buffer will be allocated.
// Then the request-specific-data will be unmarshaled from the rest of the buffer.
func (p *ExtendedPacket) UnmarshalPacketBody(buf *Buffer) (err error) {
p.ExtendedRequest = buf.ConsumeString()
if buf.Err != nil {
return buf.Err
}
if p.Data == nil {
p.Data = newExtendedPacket(p.ExtendedRequest)
}
return p.Data.UnmarshalBinary(buf.Bytes())
}
// ExtendedReplyPacket defines the SSH_FXP_CLOSE packet.
type ExtendedReplyPacket struct {
Data ExtendedData
}
// Type returns the SSH_FXP_xy value associated with this packet type.
func (p *ExtendedReplyPacket) Type() PacketType {
return PacketTypeExtendedReply
}
// MarshalPacket returns p as a two-part binary encoding of p.
//
// The Data is marshaled into binary, and returned as the payload.
func (p *ExtendedReplyPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
if buf.Cap() < 9 {
buf = NewMarshalBuffer(0)
}
buf.StartPacket(PacketTypeExtendedReply, reqid)
if p.Data != nil {
payload, err = p.Data.MarshalBinary()
if err != nil {
return nil, nil, err
}
}
return buf.Packet(payload)
}
// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
// It is assumed that the uint32(request-id) has already been consumed.
//
// If p.Data is nil, and there is request-specific-data,
// then the request-specific-data will be wrapped in a Buffer and assigned to p.Data.
func (p *ExtendedReplyPacket) UnmarshalPacketBody(buf *Buffer) (err error) {
if p.Data == nil {
p.Data = new(Buffer)
}
return p.Data.UnmarshalBinary(buf.Bytes())
}
================================================
FILE: gossh/sftp/internal/encoding/ssh/filexfer/extensions.go
================================================
package sshfx
// ExtensionPair defines the extension-pair type defined in draft-ietf-secsh-filexfer-13.
// This type is backwards-compatible with how draft-ietf-secsh-filexfer-02 defines extensions.
//
// Defined in: https://tools.ietf.org/html/draft-ietf-secsh-filexfer-13#section-4.2
type ExtensionPair struct {
Name string
Data string
}
// Len returns the number of bytes e would marshal into.
func (e *ExtensionPair) Len() int {
return 4 + len(e.Name) + 4 + len(e.Data)
}
// MarshalInto marshals e onto the end of the given Buffer.
func (e *ExtensionPair) MarshalInto(buf *Buffer) {
buf.AppendString(e.Name)
buf.AppendString(e.Data)
}
// MarshalBinary returns e as the binary encoding of e.
func (e *ExtensionPair) MarshalBinary() ([]byte, error) {
buf := NewBuffer(make([]byte, 0, e.Len()))
e.MarshalInto(buf)
return buf.Bytes(), nil
}
// UnmarshalFrom unmarshals an ExtensionPair from the given Buffer into e.
func (e *ExtensionPair) UnmarshalFrom(buf *Buffer) (err error) {
*e = ExtensionPair{
Name: buf.ConsumeString(),
Data: buf.ConsumeString(),
}
return buf.Err
}
// UnmarshalBinary decodes the binary encoding of ExtensionPair into e.
func (e *ExtensionPair) UnmarshalBinary(data []byte) error {
return e.UnmarshalFrom(NewBuffer(data))
}
================================================
FILE: gossh/sftp/internal/encoding/ssh/filexfer/filexfer.go
================================================
// Package sshfx implements the wire encoding for secsh-filexfer as described in https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt
package sshfx
// PacketMarshaller narrowly defines packets that will only be transmitted.
//
// ExtendedPacket types will often only implement this interface,
// since decoding the whole packet body of an ExtendedPacket can only be done dependent on the ExtendedRequest field.
type PacketMarshaller interface {
// MarshalPacket is the primary intended way to encode a packet.
// The request-id for the packet is set from reqid.
//
// An optional buffer may be given in b.
// If the buffer has a minimum capacity, it shall be truncated and used to marshal the header into.
// The minimum capacity for the packet must be a constant expression, and should be at least 9.
//
// It shall return the main body of the encoded packet in header,
// and may optionally return an additional payload to be written immediately after the header.
//
// It shall encode in the first 4-bytes of the header the proper length of the rest of the header+payload.
MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error)
}
// Packet defines the behavior of a full generic SFTP packet.
//
// InitPacket, and VersionPacket are not generic SFTP packets, and instead implement (Un)MarshalBinary.
//
// ExtendedPacket types should not iplement this interface,
// since decoding the whole packet body of an ExtendedPacket can only be done dependent on the ExtendedRequest field.
type Packet interface {
PacketMarshaller
// Type returns the SSH_FXP_xy value associated with the specific packet.
Type() PacketType
// UnmarshalPacketBody decodes a packet body from the given Buffer.
// It is assumed that the common header values of the length, type and request-id have already been consumed.
//
// Implementations should not alias the given Buffer,
// instead they can consider prepopulating an internal buffer as a hint,
// and copying into that buffer if it has sufficient length.
UnmarshalPacketBody(buf *Buffer) error
}
// ComposePacket converts returns from MarshalPacket into an equivalent call to MarshalBinary.
func ComposePacket(header, payload []byte, err error) ([]byte, error) {
return append(header, payload...), err
}
// Default length values,
// Defined in draft-ietf-secsh-filexfer-02 section 3.
const (
DefaultMaxPacketLength = 34000
DefaultMaxDataLength = 32768
)
================================================
FILE: gossh/sftp/internal/encoding/ssh/filexfer/fx.go
================================================
package sshfx
import (
"fmt"
)
// Status defines the SFTP error codes used in SSH_FXP_STATUS response packets.
type Status uint32
// Defines the various SSH_FX_* values.
const (
// see draft-ietf-secsh-filexfer-02
// https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt#section-7
StatusOK = Status(iota)
StatusEOF
StatusNoSuchFile
StatusPermissionDenied
StatusFailure
StatusBadMessage
StatusNoConnection
StatusConnectionLost
StatusOPUnsupported
// https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-03.txt#section-7
StatusV4InvalidHandle
StatusV4NoSuchPath
StatusV4FileAlreadyExists
StatusV4WriteProtect
// https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-04.txt#section-7
StatusV4NoMedia
// https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-05.txt#section-7
StatusV5NoSpaceOnFilesystem
StatusV5QuotaExceeded
StatusV5UnknownPrincipal
StatusV5LockConflict
// https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-06.txt#section-8
StatusV6DirNotEmpty
StatusV6NotADirectory
StatusV6InvalidFilename
StatusV6LinkLoop
// https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-07.txt#section-8
StatusV6CannotDelete
StatusV6InvalidParameter
StatusV6FileIsADirectory
StatusV6ByteRangeLockConflict
StatusV6ByteRangeLockRefused
StatusV6DeletePending
// https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-08.txt#section-8.1
StatusV6FileCorrupt
// https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-10.txt#section-9.1
StatusV6OwnerInvalid
StatusV6GroupInvalid
// https://tools.ietf.org/html/draft-ietf-secsh-filexfer-13#section-9.1
StatusV6NoMatchingByteRangeLock
)
func (s Status) Error() string {
return s.String()
}
// Is returns true if the target is the same Status code,
// or target is a StatusPacket with the same Status code.
func (s Status) Is(target error) bool {
if target, ok := target.(*StatusPacket); ok {
return target.StatusCode == s
}
return s == target
}
func (s Status) String() string {
switch s {
case StatusOK:
return "SSH_FX_OK"
case StatusEOF:
return "SSH_FX_EOF"
case StatusNoSuchFile:
return "SSH_FX_NO_SUCH_FILE"
case StatusPermissionDenied:
return "SSH_FX_PERMISSION_DENIED"
case StatusFailure:
return "SSH_FX_FAILURE"
case StatusBadMessage:
return "SSH_FX_BAD_MESSAGE"
case StatusNoConnection:
return "SSH_FX_NO_CONNECTION"
case StatusConnectionLost:
return "SSH_FX_CONNECTION_LOST"
case StatusOPUnsupported:
return "SSH_FX_OP_UNSUPPORTED"
case StatusV4InvalidHandle:
return "SSH_FX_INVALID_HANDLE"
case StatusV4NoSuchPath:
return "SSH_FX_NO_SUCH_PATH"
case StatusV4FileAlreadyExists:
return "SSH_FX_FILE_ALREADY_EXISTS"
case StatusV4WriteProtect:
return "SSH_FX_WRITE_PROTECT"
case StatusV4NoMedia:
return "SSH_FX_NO_MEDIA"
case StatusV5NoSpaceOnFilesystem:
return "SSH_FX_NO_SPACE_ON_FILESYSTEM"
case StatusV5QuotaExceeded:
return "SSH_FX_QUOTA_EXCEEDED"
case StatusV5UnknownPrincipal:
return "SSH_FX_UNKNOWN_PRINCIPAL"
case StatusV5LockConflict:
return "SSH_FX_LOCK_CONFLICT"
case StatusV6DirNotEmpty:
return "SSH_FX_DIR_NOT_EMPTY"
case StatusV6NotADirectory:
return "SSH_FX_NOT_A_DIRECTORY"
case StatusV6InvalidFilename:
return "SSH_FX_INVALID_FILENAME"
case StatusV6LinkLoop:
return "SSH_FX_LINK_LOOP"
case StatusV6CannotDelete:
return "SSH_FX_CANNOT_DELETE"
case StatusV6InvalidParameter:
return "SSH_FX_INVALID_PARAMETER"
case StatusV6FileIsADirectory:
return "SSH_FX_FILE_IS_A_DIRECTORY"
case StatusV6ByteRangeLockConflict:
return "SSH_FX_BYTE_RANGE_LOCK_CONFLICT"
case StatusV6ByteRangeLockRefused:
return "SSH_FX_BYTE_RANGE_LOCK_REFUSED"
case StatusV6DeletePending:
return "SSH_FX_DELETE_PENDING"
case StatusV6FileCorrupt:
return "SSH_FX_FILE_CORRUPT"
case StatusV6OwnerInvalid:
return "SSH_FX_OWNER_INVALID"
case StatusV6GroupInvalid:
return "SSH_FX_GROUP_INVALID"
case StatusV6NoMatchingByteRangeLock:
return "SSH_FX_NO_MATCHING_BYTE_RANGE_LOCK"
default:
return fmt.Sprintf("SSH_FX_UNKNOWN(%d)", s)
}
}
================================================
FILE: gossh/sftp/internal/encoding/ssh/filexfer/fxp.go
================================================
package sshfx
import (
"fmt"
)
// PacketType defines the various SFTP packet types.
type PacketType uint8
// Request packet types.
const (
// https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt#section-3
PacketTypeInit = PacketType(iota + 1)
PacketTypeVersion
PacketTypeOpen
PacketTypeClose
PacketTypeRead
PacketTypeWrite
PacketTypeLStat
PacketTypeFStat
PacketTypeSetstat
PacketTypeFSetstat
PacketTypeOpenDir
PacketTypeReadDir
PacketTypeRemove
PacketTypeMkdir
PacketTypeRmdir
PacketTypeRealPath
PacketTypeStat
PacketTypeRename
PacketTypeReadLink
PacketTypeSymlink
// https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-07.txt#section-3.3
PacketTypeV6Link
// https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-08.txt#section-3.3
PacketTypeV6Block
PacketTypeV6Unblock
)
// Response packet types.
const (
// https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt#section-3
PacketTypeStatus = PacketType(iota + 101)
PacketTypeHandle
PacketTypeData
PacketTypeName
PacketTypeAttrs
)
// Extended packet types.
const (
// https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt#section-3
PacketTypeExtended = PacketType(iota + 200)
PacketTypeExtendedReply
)
func (f PacketType) String() string {
switch f {
case PacketTypeInit:
return "SSH_FXP_INIT"
case PacketTypeVersion:
return "SSH_FXP_VERSION"
case PacketTypeOpen:
return "SSH_FXP_OPEN"
case PacketTypeClose:
return "SSH_FXP_CLOSE"
case PacketTypeRead:
return "SSH_FXP_READ"
case PacketTypeWrite:
return "SSH_FXP_WRITE"
case PacketTypeLStat:
return "SSH_FXP_LSTAT"
case PacketTypeFStat:
return "SSH_FXP_FSTAT"
case PacketTypeSetstat:
return "SSH_FXP_SETSTAT"
case PacketTypeFSetstat:
return "SSH_FXP_FSETSTAT"
case PacketTypeOpenDir:
return "SSH_FXP_OPENDIR"
case PacketTypeReadDir:
return "SSH_FXP_READDIR"
case PacketTypeRemove:
return "SSH_FXP_REMOVE"
case PacketTypeMkdir:
return "SSH_FXP_MKDIR"
case PacketTypeRmdir:
return "SSH_FXP_RMDIR"
case PacketTypeRealPath:
return "SSH_FXP_REALPATH"
case PacketTypeStat:
return "SSH_FXP_STAT"
case PacketTypeRename:
return "SSH_FXP_RENAME"
case PacketTypeReadLink:
return "SSH_FXP_READLINK"
case PacketTypeSymlink:
return "SSH_FXP_SYMLINK"
case PacketTypeV6Link:
return "SSH_FXP_LINK"
case PacketTypeV6Block:
return "SSH_FXP_BLOCK"
case PacketTypeV6Unblock:
return "SSH_FXP_UNBLOCK"
case PacketTypeStatus:
return "SSH_FXP_STATUS"
case PacketTypeHandle:
return "SSH_FXP_HANDLE"
case PacketTypeData:
return "SSH_FXP_DATA"
case PacketTypeName:
return "SSH_FXP_NAME"
case PacketTypeAttrs:
return "SSH_FXP_ATTRS"
case PacketTypeExtended:
return "SSH_FXP_EXTENDED"
case PacketTypeExtendedReply:
return "SSH_FXP_EXTENDED_REPLY"
default:
return fmt.Sprintf("SSH_FXP_UNKNOWN(%d)", f)
}
}
func newPacketFromType(typ PacketType) (Packet, error) {
switch typ {
case PacketTypeOpen:
return new(OpenPacket), nil
case PacketTypeClose:
return new(ClosePacket), nil
case PacketTypeRead:
return new(ReadPacket), nil
case PacketTypeWrite:
return new(WritePacket), nil
case PacketTypeLStat:
return new(LStatPacket), nil
case PacketTypeFStat:
return new(FStatPacket), nil
case PacketTypeSetstat:
return new(SetstatPacket), nil
case PacketTypeFSetstat:
return new(FSetstatPacket), nil
case PacketTypeOpenDir:
return new(OpenDirPacket), nil
case PacketTypeReadDir:
return new(ReadDirPacket), nil
case PacketTypeRemove:
return new(RemovePacket), nil
case PacketTypeMkdir:
return new(MkdirPacket), nil
case PacketTypeRmdir:
return new(RmdirPacket), nil
case PacketTypeRealPath:
return new(RealPathPacket), nil
case PacketTypeStat:
return new(StatPacket), nil
case PacketTypeRename:
return new(RenamePacket), nil
case PacketTypeReadLink:
return new(ReadLinkPacket), nil
case PacketTypeSymlink:
return new(SymlinkPacket), nil
case PacketTypeExtended:
return new(ExtendedPacket), nil
default:
return nil, fmt.Errorf("unexpected request packet type: %v", typ)
}
}
================================================
FILE: gossh/sftp/internal/encoding/ssh/filexfer/handle_packets.go
================================================
package sshfx
// ClosePacket defines the SSH_FXP_CLOSE packet.
type ClosePacket struct {
Handle string
}
// Type returns the SSH_FXP_xy value associated with this packet type.
func (p *ClosePacket) Type() PacketType {
return PacketTypeClose
}
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *ClosePacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
if buf.Cap() < 9 {
size := 4 + len(p.Handle) // string(handle)
buf = NewMarshalBuffer(size)
}
buf.StartPacket(PacketTypeClose, reqid)
buf.AppendString(p.Handle)
return buf.Packet(payload)
}
// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
// It is assumed that the uint32(request-id) has already been consumed.
func (p *ClosePacket) UnmarshalPacketBody(buf *Buffer) (err error) {
*p = ClosePacket{
Handle: buf.ConsumeString(),
}
return buf.Err
}
// ReadPacket defines the SSH_FXP_READ packet.
type ReadPacket struct {
Handle string
Offset uint64
Length uint32
}
// Type returns the SSH_FXP_xy value associated with this packet type.
func (p *ReadPacket) Type() PacketType {
return PacketTypeRead
}
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *ReadPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
if buf.Cap() < 9 {
// string(handle) + uint64(offset) + uint32(len)
size := 4 + len(p.Handle) + 8 + 4
buf = NewMarshalBuffer(size)
}
buf.StartPacket(PacketTypeRead, reqid)
buf.AppendString(p.Handle)
buf.AppendUint64(p.Offset)
buf.AppendUint32(p.Length)
return buf.Packet(payload)
}
// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
// It is assumed that the uint32(request-id) has already been consumed.
func (p *ReadPacket) UnmarshalPacketBody(buf *Buffer) (err error) {
*p = ReadPacket{
Handle: buf.ConsumeString(),
Offset: buf.ConsumeUint64(),
Length: buf.ConsumeUint32(),
}
return buf.Err
}
// WritePacket defines the SSH_FXP_WRITE packet.
type WritePacket struct {
Handle string
Offset uint64
Data []byte
}
// Type returns the SSH_FXP_xy value associated with this packet type.
func (p *WritePacket) Type() PacketType {
return PacketTypeWrite
}
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *WritePacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
if buf.Cap() < 9 {
// string(handle) + uint64(offset) + uint32(len(data)); data content in payload
size := 4 + len(p.Handle) + 8 + 4
buf = NewMarshalBuffer(size)
}
buf.StartPacket(PacketTypeWrite, reqid)
buf.AppendString(p.Handle)
buf.AppendUint64(p.Offset)
buf.AppendUint32(uint32(len(p.Data)))
return buf.Packet(p.Data)
}
// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
// It is assumed that the uint32(request-id) has already been consumed.
//
// If p.Data is already populated, and of sufficient length to hold the data,
// then this will copy the data into that byte slice.
//
// If p.Data has a length insufficient to hold the data,
// then this will make a new slice of sufficient length, and copy the data into that.
//
// This means this _does not_ alias any of the data buffer that is passed in.
func (p *WritePacket) UnmarshalPacketBody(buf *Buffer) (err error) {
*p = WritePacket{
Handle: buf.ConsumeString(),
Offset: buf.ConsumeUint64(),
Data: buf.ConsumeByteSliceCopy(p.Data),
}
return buf.Err
}
// FStatPacket defines the SSH_FXP_FSTAT packet.
type FStatPacket struct {
Handle string
}
// Type returns the SSH_FXP_xy value associated with this packet type.
func (p *FStatPacket) Type() PacketType {
return PacketTypeFStat
}
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *FStatPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
if buf.Cap() < 9 {
size := 4 + len(p.Handle) // string(handle)
buf = NewMarshalBuffer(size)
}
buf.StartPacket(PacketTypeFStat, reqid)
buf.AppendString(p.Handle)
return buf.Packet(payload)
}
// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
// It is assumed that the uint32(request-id) has already been consumed.
func (p *FStatPacket) UnmarshalPacketBody(buf *Buffer) (err error) {
*p = FStatPacket{
Handle: buf.ConsumeString(),
}
return buf.Err
}
// FSetstatPacket defines the SSH_FXP_FSETSTAT packet.
type FSetstatPacket struct {
Handle string
Attrs Attributes
}
// Type returns the SSH_FXP_xy value associated with this packet type.
func (p *FSetstatPacket) Type() PacketType {
return PacketTypeFSetstat
}
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *FSetstatPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
if buf.Cap() < 9 {
size := 4 + len(p.Handle) + p.Attrs.Len() // string(handle) + ATTRS(attrs)
buf = NewMarshalBuffer(size)
}
buf.StartPacket(PacketTypeFSetstat, reqid)
buf.AppendString(p.Handle)
p.Attrs.MarshalInto(buf)
return buf.Packet(payload)
}
// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
// It is assumed that the uint32(request-id) has already been consumed.
func (p *FSetstatPacket) UnmarshalPacketBody(buf *Buffer) (err error) {
*p = FSetstatPacket{
Handle: buf.ConsumeString(),
}
return p.Attrs.UnmarshalFrom(buf)
}
// ReadDirPacket defines the SSH_FXP_READDIR packet.
type ReadDirPacket struct {
Handle string
}
// Type returns the SSH_FXP_xy value associated with this packet type.
func (p *ReadDirPacket) Type() PacketType {
return PacketTypeReadDir
}
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *ReadDirPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
if buf.Cap() < 9 {
size := 4 + len(p.Handle) // string(handle)
buf = NewMarshalBuffer(size)
}
buf.StartPacket(PacketTypeReadDir, reqid)
buf.AppendString(p.Handle)
return buf.Packet(payload)
}
// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
// It is assumed that the uint32(request-id) has already been consumed.
func (p *ReadDirPacket) UnmarshalPacketBody(buf *Buffer) (err error) {
*p = ReadDirPacket{
Handle: buf.ConsumeString(),
}
return buf.Err
}
================================================
FILE: gossh/sftp/internal/encoding/ssh/filexfer/init_packets.go
================================================
package sshfx
// InitPacket defines the SSH_FXP_INIT packet.
type InitPacket struct {
Version uint32
Extensions []*ExtensionPair
}
// MarshalBinary returns p as the binary encoding of p.
func (p *InitPacket) MarshalBinary() ([]byte, error) {
size := 1 + 4 // byte(type) + uint32(version)
for _, ext := range p.Extensions {
size += ext.Len()
}
b := NewBuffer(make([]byte, 4, 4+size))
b.AppendUint8(uint8(PacketTypeInit))
b.AppendUint32(p.Version)
for _, ext := range p.Extensions {
ext.MarshalInto(b)
}
b.PutLength(size)
return b.Bytes(), nil
}
// UnmarshalBinary unmarshals a full raw packet out of the given data.
// It is assumed that the uint32(length) has already been consumed to receive the data.
// It is also assumed that the uint8(type) has already been consumed to which packet to unmarshal into.
func (p *InitPacket) UnmarshalBinary(data []byte) (err error) {
buf := NewBuffer(data)
*p = InitPacket{
Version: buf.ConsumeUint32(),
}
for buf.Len() > 0 {
var ext ExtensionPair
if err := ext.UnmarshalFrom(buf); err != nil {
return err
}
p.Extensions = append(p.Extensions, &ext)
}
return buf.Err
}
// VersionPacket defines the SSH_FXP_VERSION packet.
type VersionPacket struct {
Version uint32
Extensions []*ExtensionPair
}
// MarshalBinary returns p as the binary encoding of p.
func (p *VersionPacket) MarshalBinary() ([]byte, error) {
size := 1 + 4 // byte(type) + uint32(version)
for _, ext := range p.Extensions {
size += ext.Len()
}
b := NewBuffer(make([]byte, 4, 4+size))
b.AppendUint8(uint8(PacketTypeVersion))
b.AppendUint32(p.Version)
for _, ext := range p.Extensions {
ext.MarshalInto(b)
}
b.PutLength(size)
return b.Bytes(), nil
}
// UnmarshalBinary unmarshals a full raw packet out of the given data.
// It is assumed that the uint32(length) has already been consumed to receive the data.
// It is also assumed that the uint8(type) has already been consumed to which packet to unmarshal into.
func (p *VersionPacket) UnmarshalBinary(data []byte) (err error) {
buf := NewBuffer(data)
*p = VersionPacket{
Version: buf.ConsumeUint32(),
}
for buf.Len() > 0 {
var ext ExtensionPair
if err := ext.UnmarshalFrom(buf); err != nil {
return err
}
p.Extensions = append(p.Extensions, &ext)
}
return nil
}
================================================
FILE: gossh/sftp/internal/encoding/ssh/filexfer/open_packets.go
================================================
package sshfx
// SSH_FXF_* flags.
const (
FlagRead = 1 << iota // SSH_FXF_READ
FlagWrite // SSH_FXF_WRITE
FlagAppend // SSH_FXF_APPEND
FlagCreate // SSH_FXF_CREAT
FlagTruncate // SSH_FXF_TRUNC
FlagExclusive // SSH_FXF_EXCL
)
// OpenPacket defines the SSH_FXP_OPEN packet.
type OpenPacket struct {
Filename string
PFlags uint32
Attrs Attributes
}
// Type returns the SSH_FXP_xy value associated with this packet type.
func (p *OpenPacket) Type() PacketType {
return PacketTypeOpen
}
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *OpenPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
if buf.Cap() < 9 {
// string(filename) + uint32(pflags) + ATTRS(attrs)
size := 4 + len(p.Filename) + 4 + p.Attrs.Len()
buf = NewMarshalBuffer(size)
}
buf.StartPacket(PacketTypeOpen, reqid)
buf.AppendString(p.Filename)
buf.AppendUint32(p.PFlags)
p.Attrs.MarshalInto(buf)
return buf.Packet(payload)
}
// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
// It is assumed that the uint32(request-id) has already been consumed.
func (p *OpenPacket) UnmarshalPacketBody(buf *Buffer) (err error) {
*p = OpenPacket{
Filename: buf.ConsumeString(),
PFlags: buf.ConsumeUint32(),
}
return p.Attrs.UnmarshalFrom(buf)
}
// OpenDirPacket defines the SSH_FXP_OPENDIR packet.
type OpenDirPacket struct {
Path string
}
// Type returns the SSH_FXP_xy value associated with this packet type.
func (p *OpenDirPacket) Type() PacketType {
return PacketTypeOpenDir
}
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *OpenDirPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
if buf.Cap() < 9 {
size := 4 + len(p.Path) // string(path)
buf = NewMarshalBuffer(size)
}
buf.StartPacket(PacketTypeOpenDir, reqid)
buf.AppendString(p.Path)
return buf.Packet(payload)
}
// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
// It is assumed that the uint32(request-id) has already been consumed.
func (p *OpenDirPacket) UnmarshalPacketBody(buf *Buffer) (err error) {
*p = OpenDirPacket{
Path: buf.ConsumeString(),
}
return buf.Err
}
================================================
FILE: gossh/sftp/internal/encoding/ssh/filexfer/openssh/fsync.go
================================================
package openssh
import (
sshfx "gossh/sftp/internal/encoding/ssh/filexfer"
)
const extensionFSync = "fsync@openssh.com"
// RegisterExtensionFSync registers the "fsync@openssh.com" extended packet with the encoding/ssh/filexfer package.
func RegisterExtensionFSync() {
sshfx.RegisterExtendedPacketType(extensionFSync, func() sshfx.ExtendedData {
return new(FSyncExtendedPacket)
})
}
// ExtensionFSync returns an ExtensionPair suitable to append into an sshfx.InitPacket or sshfx.VersionPacket.
func ExtensionFSync() *sshfx.ExtensionPair {
return &sshfx.ExtensionPair{
Name: extensionFSync,
Data: "1",
}
}
// FSyncExtendedPacket defines the fsync@openssh.com extend packet.
type FSyncExtendedPacket struct {
Handle string
}
// Type returns the SSH_FXP_EXTENDED packet type.
func (ep *FSyncExtendedPacket) Type() sshfx.PacketType {
return sshfx.PacketTypeExtended
}
// MarshalPacket returns ep as a two-part binary encoding of the full extended packet.
func (ep *FSyncExtendedPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
p := &sshfx.ExtendedPacket{
ExtendedRequest: extensionFSync,
Data: ep,
}
return p.MarshalPacket(reqid, b)
}
// MarshalInto encodes ep into the binary encoding of the fsync@openssh.com extended packet-specific data.
func (ep *FSyncExtendedPacket) MarshalInto(buf *sshfx.Buffer) {
buf.AppendString(ep.Handle)
}
// MarshalBinary encodes ep into the binary encoding of the fsync@openssh.com extended packet-specific data.
//
// NOTE: This _only_ encodes the packet-specific data, it does not encode the full extended packet.
func (ep *FSyncExtendedPacket) MarshalBinary() ([]byte, error) {
// string(handle)
size := 4 + len(ep.Handle)
buf := sshfx.NewBuffer(make([]byte, 0, size))
ep.MarshalInto(buf)
return buf.Bytes(), nil
}
// UnmarshalFrom decodes the fsync@openssh.com extended packet-specific data from buf.
func (ep *FSyncExtendedPacket) UnmarshalFrom(buf *sshfx.Buffer) (err error) {
*ep = FSyncExtendedPacket{
Handle: buf.ConsumeString(),
}
return buf.Err
}
// UnmarshalBinary decodes the fsync@openssh.com extended packet-specific data into ep.
func (ep *FSyncExtendedPacket) UnmarshalBinary(data []byte) (err error) {
return ep.UnmarshalFrom(sshfx.NewBuffer(data))
}
================================================
FILE: gossh/sftp/internal/encoding/ssh/filexfer/openssh/hardlink.go
================================================
package openssh
import (
sshfx "gossh/sftp/internal/encoding/ssh/filexfer"
)
const extensionHardlink = "hardlink@openssh.com"
// RegisterExtensionHardlink registers the "hardlink@openssh.com" extended packet with the encoding/ssh/filexfer package.
func RegisterExtensionHardlink() {
sshfx.RegisterExtendedPacketType(extensionHardlink, func() sshfx.ExtendedData {
return new(HardlinkExtendedPacket)
})
}
// ExtensionHardlink returns an ExtensionPair suitable to append into an sshfx.InitPacket or sshfx.VersionPacket.
func ExtensionHardlink() *sshfx.ExtensionPair {
return &sshfx.ExtensionPair{
Name: extensionHardlink,
Data: "1",
}
}
// HardlinkExtendedPacket defines the hardlink@openssh.com extend packet.
type HardlinkExtendedPacket struct {
OldPath string
NewPath string
}
// Type returns the SSH_FXP_EXTENDED packet type.
func (ep *HardlinkExtendedPacket) Type() sshfx.PacketType {
return sshfx.PacketTypeExtended
}
// MarshalPacket returns ep as a two-part binary encoding of the full extended packet.
func (ep *HardlinkExtendedPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
p := &sshfx.ExtendedPacket{
ExtendedRequest: extensionHardlink,
Data: ep,
}
return p.MarshalPacket(reqid, b)
}
// MarshalInto encodes ep into the binary encoding of the hardlink@openssh.com extended packet-specific data.
func (ep *HardlinkExtendedPacket) MarshalInto(buf *sshfx.Buffer) {
buf.AppendString(ep.OldPath)
buf.AppendString(ep.NewPath)
}
// MarshalBinary encodes ep into the binary encoding of the hardlink@openssh.com extended packet-specific data.
//
// NOTE: This _only_ encodes the packet-specific data, it does not encode the full extended packet.
func (ep *HardlinkExtendedPacket) MarshalBinary() ([]byte, error) {
// string(oldpath) + string(newpath)
size := 4 + len(ep.OldPath) + 4 + len(ep.NewPath)
buf := sshfx.NewBuffer(make([]byte, 0, size))
ep.MarshalInto(buf)
return buf.Bytes(), nil
}
// UnmarshalFrom decodes the hardlink@openssh.com extended packet-specific data from buf.
func (ep *HardlinkExtendedPacket) UnmarshalFrom(buf *sshfx.Buffer) (err error) {
*ep = HardlinkExtendedPacket{
OldPath: buf.ConsumeString(),
NewPath: buf.ConsumeString(),
}
return buf.Err
}
// UnmarshalBinary decodes the hardlink@openssh.com extended packet-specific data into ep.
func (ep *HardlinkExtendedPacket) UnmarshalBinary(data []byte) (err error) {
return ep.UnmarshalFrom(sshfx.NewBuffer(data))
}
================================================
FILE: gossh/sftp/internal/encoding/ssh/filexfer/openssh/openssh.go
================================================
// Package openssh implements the openssh secsh-filexfer extensions as described in https://github.com/openssh/openssh-portable/blob/master/PROTOCOL
package openssh
================================================
FILE: gossh/sftp/internal/encoding/ssh/filexfer/openssh/posix-rename.go
================================================
package openssh
import (
sshfx "gossh/sftp/internal/encoding/ssh/filexfer"
)
const extensionPOSIXRename = "posix-rename@openssh.com"
// RegisterExtensionPOSIXRename registers the "posix-rename@openssh.com" extended packet with the encoding/ssh/filexfer package.
func RegisterExtensionPOSIXRename() {
sshfx.RegisterExtendedPacketType(extensionPOSIXRename, func() sshfx.ExtendedData {
return new(POSIXRenameExtendedPacket)
})
}
// ExtensionPOSIXRename returns an ExtensionPair suitable to append into an sshfx.InitPacket or sshfx.VersionPacket.
func ExtensionPOSIXRename() *sshfx.ExtensionPair {
return &sshfx.ExtensionPair{
Name: extensionPOSIXRename,
Data: "1",
}
}
// POSIXRenameExtendedPacket defines the posix-rename@openssh.com extend packet.
type POSIXRenameExtendedPacket struct {
OldPath string
NewPath string
}
// Type returns the SSH_FXP_EXTENDED packet type.
func (ep *POSIXRenameExtendedPacket) Type() sshfx.PacketType {
return sshfx.PacketTypeExtended
}
// MarshalPacket returns ep as a two-part binary encoding of the full extended packet.
func (ep *POSIXRenameExtendedPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
p := &sshfx.ExtendedPacket{
ExtendedRequest: extensionPOSIXRename,
Data: ep,
}
return p.MarshalPacket(reqid, b)
}
// MarshalInto encodes ep into the binary encoding of the hardlink@openssh.com extended packet-specific data.
func (ep *POSIXRenameExtendedPacket) MarshalInto(buf *sshfx.Buffer) {
buf.AppendString(ep.OldPath)
buf.AppendString(ep.NewPath)
}
// MarshalBinary encodes ep into the binary encoding of the hardlink@openssh.com extended packet-specific data.
//
// NOTE: This _only_ encodes the packet-specific data, it does not encode the full extended packet.
func (ep *POSIXRenameExtendedPacket) MarshalBinary() ([]byte, error) {
// string(oldpath) + string(newpath)
size := 4 + len(ep.OldPath) + 4 + len(ep.NewPath)
buf := sshfx.NewBuffer(make([]byte, 0, size))
ep.MarshalInto(buf)
return buf.Bytes(), nil
}
// UnmarshalFrom decodes the hardlink@openssh.com extended packet-specific data from buf.
func (ep *POSIXRenameExtendedPacket) UnmarshalFrom(buf *sshfx.Buffer) (err error) {
*ep = POSIXRenameExtendedPacket{
OldPath: buf.ConsumeString(),
NewPath: buf.ConsumeString(),
}
return buf.Err
}
// UnmarshalBinary decodes the hardlink@openssh.com extended packet-specific data into ep.
func (ep *POSIXRenameExtendedPacket) UnmarshalBinary(data []byte) (err error) {
return ep.UnmarshalFrom(sshfx.NewBuffer(data))
}
================================================
FILE: gossh/sftp/internal/encoding/ssh/filexfer/openssh/statvfs.go
================================================
package openssh
import (
sshfx "gossh/sftp/internal/encoding/ssh/filexfer"
)
const extensionStatVFS = "statvfs@openssh.com"
// RegisterExtensionStatVFS registers the "statvfs@openssh.com" extended packet with the encoding/ssh/filexfer package.
func RegisterExtensionStatVFS() {
sshfx.RegisterExtendedPacketType(extensionStatVFS, func() sshfx.ExtendedData {
return new(StatVFSExtendedPacket)
})
}
// ExtensionStatVFS returns an ExtensionPair suitable to append into an sshfx.InitPacket or sshfx.VersionPacket.
func ExtensionStatVFS() *sshfx.ExtensionPair {
return &sshfx.ExtensionPair{
Name: extensionStatVFS,
Data: "2",
}
}
// StatVFSExtendedPacket defines the statvfs@openssh.com extend packet.
type StatVFSExtendedPacket struct {
Path string
}
// Type returns the SSH_FXP_EXTENDED packet type.
func (ep *StatVFSExtendedPacket) Type() sshfx.PacketType {
return sshfx.PacketTypeExtended
}
// MarshalPacket returns ep as a two-part binary encoding of the full extended packet.
func (ep *StatVFSExtendedPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
p := &sshfx.ExtendedPacket{
ExtendedRequest: extensionStatVFS,
Data: ep,
}
return p.MarshalPacket(reqid, b)
}
// MarshalInto encodes ep into the binary encoding of the statvfs@openssh.com extended packet-specific data.
func (ep *StatVFSExtendedPacket) MarshalInto(buf *sshfx.Buffer) {
buf.AppendString(ep.Path)
}
// MarshalBinary encodes ep into the binary encoding of the statvfs@openssh.com extended packet-specific data.
//
// NOTE: This _only_ encodes the packet-specific data, it does not encode the full extended packet.
func (ep *StatVFSExtendedPacket) MarshalBinary() ([]byte, error) {
size := 4 + len(ep.Path) // string(path)
buf := sshfx.NewBuffer(make([]byte, 0, size))
ep.MarshalInto(buf)
return buf.Bytes(), nil
}
// UnmarshalFrom decodes the statvfs@openssh.com extended packet-specific data into ep.
func (ep *StatVFSExtendedPacket) UnmarshalFrom(buf *sshfx.Buffer) (err error) {
*ep = StatVFSExtendedPacket{
Path: buf.ConsumeString(),
}
return buf.Err
}
// UnmarshalBinary decodes the statvfs@openssh.com extended packet-specific data into ep.
func (ep *StatVFSExtendedPacket) UnmarshalBinary(data []byte) (err error) {
return ep.UnmarshalFrom(sshfx.NewBuffer(data))
}
const extensionFStatVFS = "fstatvfs@openssh.com"
// RegisterExtensionFStatVFS registers the "fstatvfs@openssh.com" extended packet with the encoding/ssh/filexfer package.
func RegisterExtensionFStatVFS() {
sshfx.RegisterExtendedPacketType(extensionFStatVFS, func() sshfx.ExtendedData {
return new(FStatVFSExtendedPacket)
})
}
// ExtensionFStatVFS returns an ExtensionPair suitable to append into an sshfx.InitPacket or sshfx.VersionPacket.
func ExtensionFStatVFS() *sshfx.ExtensionPair {
return &sshfx.ExtensionPair{
Name: extensionFStatVFS,
Data: "2",
}
}
// FStatVFSExtendedPacket defines the fstatvfs@openssh.com extend packet.
type FStatVFSExtendedPacket struct {
Path string
}
// Type returns the SSH_FXP_EXTENDED packet type.
func (ep *FStatVFSExtendedPacket) Type() sshfx.PacketType {
return sshfx.PacketTypeExtended
}
// MarshalPacket returns ep as a two-part binary encoding of the full extended packet.
func (ep *FStatVFSExtendedPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
p := &sshfx.ExtendedPacket{
ExtendedRequest: extensionFStatVFS,
Data: ep,
}
return p.MarshalPacket(reqid, b)
}
// MarshalInto encodes ep into the binary encoding of the statvfs@openssh.com extended packet-specific data.
func (ep *FStatVFSExtendedPacket) MarshalInto(buf *sshfx.Buffer) {
buf.AppendString(ep.Path)
}
// MarshalBinary encodes ep into the binary encoding of the statvfs@openssh.com extended packet-specific data.
//
// NOTE: This _only_ encodes the packet-specific data, it does not encode the full extended packet.
func (ep *FStatVFSExtendedPacket) MarshalBinary() ([]byte, error) {
size := 4 + len(ep.Path) // string(path)
buf := sshfx.NewBuffer(make([]byte, 0, size))
ep.MarshalInto(buf)
return buf.Bytes(), nil
}
// UnmarshalFrom decodes the statvfs@openssh.com extended packet-specific data into ep.
func (ep *FStatVFSExtendedPacket) UnmarshalFrom(buf *sshfx.Buffer) (err error) {
*ep = FStatVFSExtendedPacket{
Path: buf.ConsumeString(),
}
return buf.Err
}
// UnmarshalBinary decodes the statvfs@openssh.com extended packet-specific data into ep.
func (ep *FStatVFSExtendedPacket) UnmarshalBinary(data []byte) (err error) {
return ep.UnmarshalFrom(sshfx.NewBuffer(data))
}
// The values for the MountFlags field.
// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL
const (
MountFlagsReadOnly = 0x1 // SSH_FXE_STATVFS_ST_RDONLY
MountFlagsNoSUID = 0x2 // SSH_FXE_STATVFS_ST_NOSUID
)
// StatVFSExtendedReplyPacket defines the extended reply packet for statvfs@openssh.com and fstatvfs@openssh.com requests.
type StatVFSExtendedReplyPacket struct {
BlockSize uint64 /* f_bsize: file system block size */
FragmentSize uint64 /* f_frsize: fundamental sftpfs block size / fagment size */
Blocks uint64 /* f_blocks: number of blocks (unit f_frsize) */
BlocksFree uint64 /* f_bfree: free blocks in filesystem */
BlocksAvail uint64 /* f_bavail: free blocks for non-root */
Files uint64 /* f_files: total file inodes */
FilesFree uint64 /* f_ffree: free file inodes */
FilesAvail uint64 /* f_favail: free file inodes for to non-root */
FilesystemID uint64 /* f_fsid: file system id */
MountFlags uint64 /* f_flag: bit mask of mount flag values */
MaxNameLength uint64 /* f_namemax: maximum filename length */
}
// Type returns the SSH_FXP_EXTENDED_REPLY packet type.
func (ep *StatVFSExtendedReplyPacket) Type() sshfx.PacketType {
return sshfx.PacketTypeExtendedReply
}
// MarshalPacket returns ep as a two-part binary encoding of the full extended reply packet.
func (ep *StatVFSExtendedReplyPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
p := &sshfx.ExtendedReplyPacket{
Data: ep,
}
return p.MarshalPacket(reqid, b)
}
// UnmarshalPacketBody returns ep as a two-part binary encoding of the full extended reply packet.
func (ep *StatVFSExtendedReplyPacket) UnmarshalPacketBody(buf *sshfx.Buffer) (err error) {
p := &sshfx.ExtendedReplyPacket{
Data: ep,
}
return p.UnmarshalPacketBody(buf)
}
// MarshalInto encodes ep into the binary encoding of the (f)statvfs@openssh.com extended reply packet-specific data.
func (ep *StatVFSExtendedReplyPacket) MarshalInto(buf *sshfx.Buffer) {
buf.AppendUint64(ep.BlockSize)
buf.AppendUint64(ep.FragmentSize)
buf.AppendUint64(ep.Blocks)
buf.AppendUint64(ep.BlocksFree)
buf.AppendUint64(ep.BlocksAvail)
buf.AppendUint64(ep.Files)
buf.AppendUint64(ep.FilesFree)
buf.AppendUint64(ep.FilesAvail)
buf.AppendUint64(ep.FilesystemID)
buf.AppendUint64(ep.MountFlags)
buf.AppendUint64(ep.MaxNameLength)
}
// MarshalBinary encodes ep into the binary encoding of the (f)statvfs@openssh.com extended reply packet-specific data.
//
// NOTE: This _only_ encodes the packet-specific data, it does not encode the full extended reply packet.
func (ep *StatVFSExtendedReplyPacket) MarshalBinary() ([]byte, error) {
size := 11 * 8 // 11 × uint64(various)
b := sshfx.NewBuffer(make([]byte, 0, size))
ep.MarshalInto(b)
return b.Bytes(), nil
}
// UnmarshalFrom decodes the fstatvfs@openssh.com extended reply packet-specific data into ep.
func (ep *StatVFSExtendedReplyPacket) UnmarshalFrom(buf *sshfx.Buffer) (err error) {
*ep = StatVFSExtendedReplyPacket{
BlockSize: buf.ConsumeUint64(),
FragmentSize: buf.ConsumeUint64(),
Blocks: buf.ConsumeUint64(),
BlocksFree: buf.ConsumeUint64(),
BlocksAvail: buf.ConsumeUint64(),
Files: buf.ConsumeUint64(),
FilesFree: buf.ConsumeUint64(),
FilesAvail: buf.ConsumeUint64(),
FilesystemID: buf.ConsumeUint64(),
MountFlags: buf.ConsumeUint64(),
MaxNameLength: buf.ConsumeUint64(),
}
return buf.Err
}
// UnmarshalBinary decodes the fstatvfs@openssh.com extended reply packet-specific data into ep.
func (ep *StatVFSExtendedReplyPacket) UnmarshalBinary(data []byte) (err error) {
return ep.UnmarshalFrom(sshfx.NewBuffer(data))
}
================================================
FILE: gossh/sftp/internal/encoding/ssh/filexfer/packets.go
================================================
package sshfx
import (
"errors"
"io"
)
// smallBufferSize is an initial allocation minimal capacity.
const smallBufferSize = 64
// RawPacket implements the general packet format from draft-ietf-secsh-filexfer-02
//
// RawPacket is intended for use in clients receiving responses,
// where a response will be expected to be of a limited number of types,
// and unmarshaling unknown/unexpected response packets is unnecessary.
//
// For servers expecting to receive arbitrary request packet types,
// use RequestPacket.
//
// Defined in https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt#section-3
type RawPacket struct {
PacketType PacketType
RequestID uint32
Data Buffer
}
// Type returns the Type field defining the SSH_FXP_xy type for this packet.
func (p *RawPacket) Type() PacketType {
return p.PacketType
}
// Reset clears the pointers and reference-semantic variables of RawPacket,
// releasing underlying resources, and making them and the RawPacket suitable to be reused,
// so long as no other references have been kept.
func (p *RawPacket) Reset() {
p.Data = Buffer{}
}
// MarshalPacket returns p as a two-part binary encoding of p.
//
// The internal p.RequestID is overridden by the reqid argument.
func (p *RawPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
if buf.Cap() < 9 {
buf = NewMarshalBuffer(0)
}
buf.StartPacket(p.PacketType, reqid)
return buf.Packet(p.Data.Bytes())
}
// MarshalBinary returns p as the binary encoding of p.
//
// This is a convenience implementation primarily intended for tests,
// because it is inefficient with allocations.
func (p *RawPacket) MarshalBinary() ([]byte, error) {
return ComposePacket(p.MarshalPacket(p.RequestID, nil))
}
// UnmarshalFrom decodes a RawPacket from the given Buffer into p.
//
// The Data field will alias the passed in Buffer,
// so the buffer passed in should not be reused before RawPacket.Reset().
func (p *RawPacket) UnmarshalFrom(buf *Buffer) error {
*p = RawPacket{
PacketType: PacketType(buf.ConsumeUint8()),
RequestID: buf.ConsumeUint32(),
}
p.Data = *buf
return buf.Err
}
// UnmarshalBinary decodes a full raw packet out of the given data.
// It is assumed that the uint32(length) has already been consumed to receive the data.
//
// This is a convenience implementation primarily intended for tests,
// because this must clone the given data byte slice,
// as Data is not allowed to alias any part of the data byte slice.
func (p *RawPacket) UnmarshalBinary(data []byte) error {
clone := make([]byte, len(data))
n := copy(clone, data)
return p.UnmarshalFrom(NewBuffer(clone[:n]))
}
// readPacket reads a uint32 length-prefixed binary data packet from r.
// using the given byte slice as a backing array.
//
// If the packet length read from r is bigger than maxPacketLength,
// or greater than math.MaxInt32 on a 32-bit implementation,
// then a `ErrLongPacket` error will be returned.
//
// If the given byte slice is insufficient to hold the packet,
// then it will be extended to fill the packet size.
func readPacket(r io.Reader, b []byte, maxPacketLength uint32) ([]byte, error) {
if cap(b) < 4 {
// We will need allocate our own buffer just for reading the packet length.
// However, we don’t really want to allocate an extremely narrow buffer (4-bytes),
// and cause unnecessary allocation churn from both length reads and small packet reads,
// so we use smallBufferSize from the bytes package as a reasonable guess.
// But if callers really do want to force narrow throw-away allocation of every packet body,
// they can do so with a buffer of capacity 4.
b = make([]byte, smallBufferSize)
}
if _, err := io.ReadFull(r, b[:4]); err != nil {
return nil, err
}
length := unmarshalUint32(b)
if int(length) < 5 {
// Must have at least uint8(type) and uint32(request-id)
if int(length) < 0 {
// Only possible when strconv.IntSize == 32,
// the packet length is longer than math.MaxInt32,
// and thus longer than any possible slice.
return nil, ErrLongPacket
}
return nil, ErrShortPacket
}
if length > maxPacketLength {
return nil, ErrLongPacket
}
if int(length) > cap(b) {
// We know int(length) must be positive, because of tests above.
b = make([]byte, length)
}
n, err := io.ReadFull(r, b[:length])
return b[:n], err
}
// ReadFrom provides a simple functional packet reader,
// using the given byte slice as a backing array.
//
// To protect against potential denial of service attacks,
// if the read packet length is longer than maxPacketLength,
// then no packet data will be read, and ErrLongPacket will be returned.
// (On 32-bit int architectures, all packets >= 2^31 in length
// will return ErrLongPacket regardless of maxPacketLength.)
//
// If the read packet length is longer than cap(b),
// then a throw-away slice will allocated to meet the exact packet length.
// This can be used to limit the length of reused buffers,
// while still allowing reception of occasional large packets.
//
// The Data field may alias the passed in byte slice,
// so the byte slice passed in should not be reused before RawPacket.Reset().
func (p *RawPacket) ReadFrom(r io.Reader, b []byte, maxPacketLength uint32) error {
b, err := readPacket(r, b, maxPacketLength)
if err != nil {
return err
}
return p.UnmarshalFrom(NewBuffer(b))
}
// RequestPacket implements the general packet format from draft-ietf-secsh-filexfer-02
// but also automatically decode/encodes valid request packets (2 < type < 100 || type == 200).
//
// RequestPacket is intended for use in servers receiving requests,
// where any arbitrary request may be received, and so decoding them automatically
// is useful.
//
// For clients expecting to receive specific response packet types,
// where automatic unmarshaling of the packet body does not make sense,
// use RawPacket.
//
// Defined in https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt#section-3
type RequestPacket struct {
RequestID uint32
Request Packet
}
// Type returns the SSH_FXP_xy value associated with the underlying packet.
func (p *RequestPacket) Type() PacketType {
return p.Request.Type()
}
// Reset clears the pointers and reference-semantic variables in RequestPacket,
// releasing underlying resources, and making them and the RequestPacket suitable to be reused,
// so long as no other references have been kept.
func (p *RequestPacket) Reset() {
p.Request = nil
}
// MarshalPacket returns p as a two-part binary encoding of p.
//
// The internal p.RequestID is overridden by the reqid argument.
func (p *RequestPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
if p.Request == nil {
return nil, nil, errors.New("empty request packet")
}
return p.Request.MarshalPacket(reqid, b)
}
// MarshalBinary returns p as the binary encoding of p.
//
// This is a convenience implementation primarily intended for tests,
// because it is inefficient with allocations.
func (p *RequestPacket) MarshalBinary() ([]byte, error) {
return ComposePacket(p.MarshalPacket(p.RequestID, nil))
}
// UnmarshalFrom decodes a RequestPacket from the given Buffer into p.
//
// The Request field may alias the passed in Buffer, (e.g. SSH_FXP_WRITE),
// so the buffer passed in should not be reused before RequestPacket.Reset().
func (p *RequestPacket) UnmarshalFrom(buf *Buffer) error {
typ := PacketType(buf.ConsumeUint8())
if buf.Err != nil {
return buf.Err
}
req, err := newPacketFromType(typ)
if err != nil {
return err
}
*p = RequestPacket{
RequestID: buf.ConsumeUint32(),
Request: req,
}
return p.Request.UnmarshalPacketBody(buf)
}
// UnmarshalBinary decodes a full request packet out of the given data.
// It is assumed that the uint32(length) has already been consumed to receive the data.
//
// This is a convenience implementation primarily intended for tests,
// because this must clone the given data byte slice,
// as Request is not allowed to alias any part of the data byte slice.
func (p *RequestPacket) UnmarshalBinary(data []byte) error {
clone := make([]byte, len(data))
n := copy(clone, data)
return p.UnmarshalFrom(NewBuffer(clone[:n]))
}
// ReadFrom provides a simple functional packet reader,
// using the given byte slice as a backing array.
//
// To protect against potential denial of service attacks,
// if the read packet length is longer than maxPacketLength,
// then no packet data will be read, and ErrLongPacket will be returned.
// (On 32-bit int architectures, all packets >= 2^31 in length
// will return ErrLongPacket regardless of maxPacketLength.)
//
// If the read packet length is longer than cap(b),
// then a throw-away slice will allocated to meet the exact packet length.
// This can be used to limit the length of reused buffers,
// while still allowing reception of occasional large packets.
//
// The Request field may alias the passed in byte slice,
// so the byte slice passed in should not be reused before RawPacket.Reset().
func (p *RequestPacket) ReadFrom(r io.Reader, b []byte, maxPacketLength uint32) error {
b, err := readPacket(r, b, maxPacketLength)
if err != nil {
return err
}
return p.UnmarshalFrom(NewBuffer(b))
}
================================================
FILE: gossh/sftp/internal/encoding/ssh/filexfer/path_packets.go
================================================
package sshfx
// LStatPacket defines the SSH_FXP_LSTAT packet.
type LStatPacket struct {
Path string
}
// Type returns the SSH_FXP_xy value associated with this packet type.
func (p *LStatPacket) Type() PacketType {
return PacketTypeLStat
}
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *LStatPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
if buf.Cap() < 9 {
size := 4 + len(p.Path) // string(path)
buf = NewMarshalBuffer(size)
}
buf.StartPacket(PacketTypeLStat, reqid)
buf.AppendString(p.Path)
return buf.Packet(payload)
}
// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
// It is assumed that the uint32(request-id) has already been consumed.
func (p *LStatPacket) UnmarshalPacketBody(buf *Buffer) (err error) {
*p = LStatPacket{
Path: buf.ConsumeString(),
}
return buf.Err
}
// SetstatPacket defines the SSH_FXP_SETSTAT packet.
type SetstatPacket struct {
Path string
Attrs Attributes
}
// Type returns the SSH_FXP_xy value associated with this packet type.
func (p *SetstatPacket) Type() PacketType {
return PacketTypeSetstat
}
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *SetstatPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
if buf.Cap() < 9 {
size := 4 + len(p.Path) + p.Attrs.Len() // string(path) + ATTRS(attrs)
buf = NewMarshalBuffer(size)
}
buf.StartPacket(PacketTypeSetstat, reqid)
buf.AppendString(p.Path)
p.Attrs.MarshalInto(buf)
return buf.Packet(payload)
}
// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
// It is assumed that the uint32(request-id) has already been consumed.
func (p *SetstatPacket) UnmarshalPacketBody(buf *Buffer) (err error) {
*p = SetstatPacket{
Path: buf.ConsumeString(),
}
return p.Attrs.UnmarshalFrom(buf)
}
// RemovePacket defines the SSH_FXP_REMOVE packet.
type RemovePacket struct {
Path string
}
// Type returns the SSH_FXP_xy value associated with this packet type.
func (p *RemovePacket) Type() PacketType {
return PacketTypeRemove
}
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *RemovePacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
if buf.Cap() < 9 {
size := 4 + len(p.Path) // string(path)
buf = NewMarshalBuffer(size)
}
buf.StartPacket(PacketTypeRemove, reqid)
buf.AppendString(p.Path)
return buf.Packet(payload)
}
// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
// It is assumed that the uint32(request-id) has already been consumed.
func (p *RemovePacket) UnmarshalPacketBody(buf *Buffer) (err error) {
*p = RemovePacket{
Path: buf.ConsumeString(),
}
return buf.Err
}
// MkdirPacket defines the SSH_FXP_MKDIR packet.
type MkdirPacket struct {
Path string
Attrs Attributes
}
// Type returns the SSH_FXP_xy value associated with this packet type.
func (p *MkdirPacket) Type() PacketType {
return PacketTypeMkdir
}
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *MkdirPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
if buf.Cap() < 9 {
size := 4 + len(p.Path) + p.Attrs.Len() // string(path) + ATTRS(attrs)
buf = NewMarshalBuffer(size)
}
buf.StartPacket(PacketTypeMkdir, reqid)
buf.AppendString(p.Path)
p.Attrs.MarshalInto(buf)
return buf.Packet(payload)
}
// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
// It is assumed that the uint32(request-id) has already been consumed.
func (p *MkdirPacket) UnmarshalPacketBody(buf *Buffer) (err error) {
*p = MkdirPacket{
Path: buf.ConsumeString(),
}
return p.Attrs.UnmarshalFrom(buf)
}
// RmdirPacket defines the SSH_FXP_RMDIR packet.
type RmdirPacket struct {
Path string
}
// Type returns the SSH_FXP_xy value associated with this packet type.
func (p *RmdirPacket) Type() PacketType {
return PacketTypeRmdir
}
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *RmdirPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
if buf.Cap() < 9 {
size := 4 + len(p.Path) // string(path)
buf = NewMarshalBuffer(size)
}
buf.StartPacket(PacketTypeRmdir, reqid)
buf.AppendString(p.Path)
return buf.Packet(payload)
}
// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
// It is assumed that the uint32(request-id) has already been consumed.
func (p *RmdirPacket) UnmarshalPacketBody(buf *Buffer) (err error) {
*p = RmdirPacket{
Path: buf.ConsumeString(),
}
return buf.Err
}
// RealPathPacket defines the SSH_FXP_REALPATH packet.
type RealPathPacket struct {
Path string
}
// Type returns the SSH_FXP_xy value associated with this packet type.
func (p *RealPathPacket) Type() PacketType {
return PacketTypeRealPath
}
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *RealPathPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
if buf.Cap() < 9 {
size := 4 + len(p.Path) // string(path)
buf = NewMarshalBuffer(size)
}
buf.StartPacket(PacketTypeRealPath, reqid)
buf.AppendString(p.Path)
return buf.Packet(payload)
}
// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
// It is assumed that the uint32(request-id) has already been consumed.
func (p *RealPathPacket) UnmarshalPacketBody(buf *Buffer) (err error) {
*p = RealPathPacket{
Path: buf.ConsumeString(),
}
return buf.Err
}
// StatPacket defines the SSH_FXP_STAT packet.
type StatPacket struct {
Path string
}
// Type returns the SSH_FXP_xy value associated with this packet type.
func (p *StatPacket) Type() PacketType {
return PacketTypeStat
}
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *StatPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
if buf.Cap() < 9 {
size := 4 + len(p.Path) // string(path)
buf = NewMarshalBuffer(size)
}
buf.StartPacket(PacketTypeStat, reqid)
buf.AppendString(p.Path)
return buf.Packet(payload)
}
// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
// It is assumed that the uint32(request-id) has already been consumed.
func (p *StatPacket) UnmarshalPacketBody(buf *Buffer) (err error) {
*p = StatPacket{
Path: buf.ConsumeString(),
}
return buf.Err
}
// RenamePacket defines the SSH_FXP_RENAME packet.
type RenamePacket struct {
OldPath string
NewPath string
}
// Type returns the SSH_FXP_xy value associated with this packet type.
func (p *RenamePacket) Type() PacketType {
return PacketTypeRename
}
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *RenamePacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
if buf.Cap() < 9 {
// string(oldpath) + string(newpath)
size := 4 + len(p.OldPath) + 4 + len(p.NewPath)
buf = NewMarshalBuffer(size)
}
buf.StartPacket(PacketTypeRename, reqid)
buf.AppendString(p.OldPath)
buf.AppendString(p.NewPath)
return buf.Packet(payload)
}
// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
// It is assumed that the uint32(request-id) has already been consumed.
func (p *RenamePacket) UnmarshalPacketBody(buf *Buffer) (err error) {
*p = RenamePacket{
OldPath: buf.ConsumeString(),
NewPath: buf.ConsumeString(),
}
return buf.Err
}
// ReadLinkPacket defines the SSH_FXP_READLINK packet.
type ReadLinkPacket struct {
Path string
}
// Type returns the SSH_FXP_xy value associated with this packet type.
func (p *ReadLinkPacket) Type() PacketType {
return PacketTypeReadLink
}
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *ReadLinkPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
if buf.Cap() < 9 {
size := 4 + len(p.Path) // string(path)
buf = NewMarshalBuffer(size)
}
buf.StartPacket(PacketTypeReadLink, reqid)
buf.AppendString(p.Path)
return buf.Packet(payload)
}
// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
// It is assumed that the uint32(request-id) has already been consumed.
func (p *ReadLinkPacket) UnmarshalPacketBody(buf *Buffer) (err error) {
*p = ReadLinkPacket{
Path: buf.ConsumeString(),
}
return buf.Err
}
// SymlinkPacket defines the SSH_FXP_SYMLINK packet.
//
// The order of the arguments to the SSH_FXP_SYMLINK method was inadvertently reversed.
// Unfortunately, the reversal was not noticed until the server was widely deployed.
// Covered in Section 4.1 of https://github.com/openssh/openssh-portable/blob/master/PROTOCOL
type SymlinkPacket struct {
LinkPath string
TargetPath string
}
// Type returns the SSH_FXP_xy value associated with this packet type.
func (p *SymlinkPacket) Type() PacketType {
return PacketTypeSymlink
}
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *SymlinkPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
if buf.Cap() < 9 {
// string(targetpath) + string(linkpath)
size := 4 + len(p.TargetPath) + 4 + len(p.LinkPath)
buf = NewMarshalBuffer(size)
}
buf.StartPacket(PacketTypeSymlink, reqid)
// Arguments were inadvertently reversed.
buf.AppendString(p.TargetPath)
buf.AppendString(p.LinkPath)
return buf.Packet(payload)
}
// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
// It is assumed that the uint32(request-id) has already been consumed.
func (p *SymlinkPacket) UnmarshalPacketBody(buf *Buffer) (err error) {
*p = SymlinkPacket{
// Arguments were inadvertently reversed.
TargetPath: buf.ConsumeString(),
LinkPath: buf.ConsumeString(),
}
return buf.Err
}
================================================
FILE: gossh/sftp/internal/encoding/ssh/filexfer/permissions.go
================================================
package sshfx
// FileMode represents a file’s mode and permission bits.
// The bits are defined according to POSIX standards,
// and may not apply to the OS being built for.
type FileMode uint32
// Permission flags, defined here to avoid potential inconsistencies in individual OS implementations.
const (
ModePerm FileMode = 0o0777 // S_IRWXU | S_IRWXG | S_IRWXO
ModeUserRead FileMode = 0o0400 // S_IRUSR
ModeUserWrite FileMode = 0o0200 // S_IWUSR
ModeUserExec FileMode = 0o0100 // S_IXUSR
ModeGroupRead FileMode = 0o0040 // S_IRGRP
ModeGroupWrite FileMode = 0o0020 // S_IWGRP
ModeGroupExec FileMode = 0o0010 // S_IXGRP
ModeOtherRead FileMode = 0o0004 // S_IROTH
ModeOtherWrite FileMode = 0o0002 // S_IWOTH
ModeOtherExec FileMode = 0o0001 // S_IXOTH
ModeSetUID FileMode = 0o4000 // S_ISUID
ModeSetGID FileMode = 0o2000 // S_ISGID
ModeSticky FileMode = 0o1000 // S_ISVTX
ModeType FileMode = 0xF000 // S_IFMT
ModeNamedPipe FileMode = 0x1000 // S_IFIFO
ModeCharDevice FileMode = 0x2000 // S_IFCHR
ModeDir FileMode = 0x4000 // S_IFDIR
ModeDevice FileMode = 0x6000 // S_IFBLK
ModeRegular FileMode = 0x8000 // S_IFREG
ModeSymlink FileMode = 0xA000 // S_IFLNK
ModeSocket FileMode = 0xC000 // S_IFSOCK
)
// IsDir reports whether m describes a directory.
// That is, it tests for m.Type() == ModeDir.
func (m FileMode) IsDir() bool {
return (m & ModeType) == ModeDir
}
// IsRegular reports whether m describes a regular file.
// That is, it tests for m.Type() == ModeRegular
func (m FileMode) IsRegular() bool {
return (m & ModeType) == ModeRegular
}
// Perm returns the POSIX permission bits in m (m & ModePerm).
func (m FileMode) Perm() FileMode {
return (m & ModePerm)
}
// Type returns the type bits in m (m & ModeType).
func (m FileMode) Type() FileMode {
return (m & ModeType)
}
// String returns a `-rwxrwxrwx` style string representing the `ls -l` POSIX permissions string.
func (m FileMode) String() string {
var buf [10]byte
switch m.Type() {
case ModeRegular:
buf[0] = '-'
case ModeDir:
buf[0] = 'd'
case ModeSymlink:
buf[0] = 'l'
case ModeDevice:
buf[0] = 'b'
case ModeCharDevice:
buf[0] = 'c'
case ModeNamedPipe:
buf[0] = 'p'
case ModeSocket:
buf[0] = 's'
default:
buf[0] = '?'
}
const rwx = "rwxrwxrwx"
for i, c := range rwx {
if m&(1<= 0; i-- {
path := w.fs.Join(w.cur.path, list[i].Name())
w.stack = append(w.stack, item{path, list[i], nil})
}
}
}
if len(w.stack) == 0 {
return false
}
i := len(w.stack) - 1
w.cur = w.stack[i]
w.stack = w.stack[:i]
w.descend = true
return true
}
// Path returns the path to the most recent file or directory
// visited by a call to Step. It contains the argument to Walk
// as a prefix; that is, if Walk is called with "dir", which is
// a directory containing the file "a", Path will return "dir/a".
func (w *Walker) Path() string {
return w.cur.path
}
// Stat returns info for the most recent file or directory
// visited by a call to Step.
func (w *Walker) Stat() os.FileInfo {
return w.cur.info
}
// Err returns the error, if any, for the most recent attempt
// by Step to visit a file or directory. If a directory has
// an error, w will not descend into that directory.
func (w *Walker) Err() error {
return w.cur.err
}
// SkipDir causes the currently visited directory to be skipped.
// If w is not on a directory, SkipDir has no effect.
func (w *Walker) SkipDir() {
w.descend = false
}
================================================
FILE: gossh/sftp/ls_formatting.go
================================================
package sftp
import (
"errors"
"fmt"
"os"
"os/user"
"strconv"
"time"
sshfx "gossh/sftp/internal/encoding/ssh/filexfer"
)
func lsFormatID(id uint32) string {
return strconv.FormatUint(uint64(id), 10)
}
type osIDLookup struct{}
func (osIDLookup) Filelist(*Request) (ListerAt, error) {
return nil, errors.New("unimplemented stub")
}
func (osIDLookup) LookupUserName(uid string) string {
u, err := user.LookupId(uid)
if err != nil {
return uid
}
return u.Username
}
func (osIDLookup) LookupGroupName(gid string) string {
g, err := user.LookupGroupId(gid)
if err != nil {
return gid
}
return g.Name
}
// runLs formats the FileInfo as per `ls -l` style, which is in the 'longname' field of a SSH_FXP_NAME entry.
// This is a fairly simple implementation, just enough to look close to openssh in simple cases.
func runLs(idLookup NameLookupFileLister, dirent os.FileInfo) string {
// example from openssh sftp server:
// crw-rw-rw- 1 root wheel 0 Jul 31 20:52 ttyvd
// format:
// {directory / char device / etc}{rwxrwxrwx} {number of links} owner group size month day [time (this year) | year (otherwise)] name
symPerms := sshfx.FileMode(fromFileMode(dirent.Mode())).String()
var numLinks uint64 = 1
uid, gid := "0", "0"
switch sys := dirent.Sys().(type) {
case *sshfx.Attributes:
uid = lsFormatID(sys.UID)
gid = lsFormatID(sys.GID)
case *FileStat:
uid = lsFormatID(sys.UID)
gid = lsFormatID(sys.GID)
default:
if fiExt, ok := dirent.(FileInfoUidGid); ok {
uid = lsFormatID(fiExt.Uid())
gid = lsFormatID(fiExt.Gid())
break
}
numLinks, uid, gid = lsLinksUIDGID(dirent)
}
if idLookup != nil {
uid, gid = idLookup.LookupUserName(uid), idLookup.LookupGroupName(gid)
}
mtime := dirent.ModTime()
date := mtime.Format("Jan 2")
var yearOrTime string
if mtime.Before(time.Now().AddDate(0, -6, 0)) {
yearOrTime = mtime.Format("2006")
} else {
yearOrTime = mtime.Format("15:04")
}
return fmt.Sprintf("%s %4d %-8s %-8s %8d %s %5s %s", symPerms, numLinks, uid, gid, dirent.Size(), date, yearOrTime, dirent.Name())
}
================================================
FILE: gossh/sftp/ls_plan9.go
================================================
//go:build plan9
// +build plan9
package sftp
import (
"os"
"syscall"
)
func lsLinksUIDGID(fi os.FileInfo) (numLinks uint64, uid, gid string) {
numLinks = 1
uid, gid = "0", "0"
switch sys := fi.Sys().(type) {
case *syscall.Dir:
uid = sys.Uid
gid = sys.Gid
}
return numLinks, uid, gid
}
================================================
FILE: gossh/sftp/ls_stub.go
================================================
//go:build windows || android
// +build windows android
package sftp
import (
"os"
)
func lsLinksUIDGID(fi os.FileInfo) (numLinks uint64, uid, gid string) {
return 1, "0", "0"
}
================================================
FILE: gossh/sftp/ls_unix.go
================================================
//go:build aix || darwin || dragonfly || freebsd || (!android && linux) || netbsd || openbsd || solaris || js || zos
// +build aix darwin dragonfly freebsd !android,linux netbsd openbsd solaris js zos
package sftp
import (
"os"
"syscall"
)
func lsLinksUIDGID(fi os.FileInfo) (numLinks uint64, uid, gid string) {
numLinks = 1
uid, gid = "0", "0"
switch sys := fi.Sys().(type) {
case *syscall.Stat_t:
numLinks = uint64(sys.Nlink)
uid = lsFormatID(sys.Uid)
gid = lsFormatID(sys.Gid)
default:
}
return numLinks, uid, gid
}
================================================
FILE: gossh/sftp/match.go
================================================
package sftp
import (
"path"
"strings"
)
// ErrBadPattern indicates a globbing pattern was malformed.
var ErrBadPattern = path.ErrBadPattern
// Match reports whether name matches the shell pattern.
//
// This is an alias for path.Match from the standard library,
// offered so that callers need not import the path package.
// For details, see https://golang.org/pkg/path/#Match.
func Match(pattern, name string) (matched bool, err error) {
return path.Match(pattern, name)
}
// detect if byte(char) is path separator
func isPathSeparator(c byte) bool {
return c == '/'
}
// Split splits the path p immediately following the final slash,
// separating it into a directory and file name component.
//
// This is an alias for path.Split from the standard library,
// offered so that callers need not import the path package.
// For details, see https://golang.org/pkg/path/#Split.
func Split(p string) (dir, file string) {
return path.Split(p)
}
// Glob returns the names of all files matching pattern or nil
// if there is no matching file. The syntax of patterns is the same
// as in Match. The pattern may describe hierarchical names such as
// /usr/*/bin/ed.
//
// Glob ignores file system errors such as I/O errors reading directories.
// The only possible returned error is ErrBadPattern, when pattern
// is malformed.
func (c *Client) Glob(pattern string) (matches []string, err error) {
if !hasMeta(pattern) {
file, err := c.Lstat(pattern)
if err != nil {
return nil, nil
}
dir, _ := Split(pattern)
dir = cleanGlobPath(dir)
return []string{Join(dir, file.Name())}, nil
}
dir, file := Split(pattern)
dir = cleanGlobPath(dir)
if !hasMeta(dir) {
return c.glob(dir, file, nil)
}
// Prevent infinite recursion. See issue 15879.
if dir == pattern {
return nil, ErrBadPattern
}
var m []string
m, err = c.Glob(dir)
if err != nil {
return
}
for _, d := range m {
matches, err = c.glob(d, file, matches)
if err != nil {
return
}
}
return
}
// cleanGlobPath prepares path for glob matching.
func cleanGlobPath(path string) string {
switch path {
case "":
return "."
case "/":
return path
default:
return path[0 : len(path)-1] // chop off trailing separator
}
}
// glob searches for files matching pattern in the directory dir
// and appends them to matches. If the directory cannot be
// opened, it returns the existing matches. New matches are
// added in lexicographical order.
func (c *Client) glob(dir, pattern string, matches []string) (m []string, e error) {
m = matches
fi, err := c.Stat(dir)
if err != nil {
return
}
if !fi.IsDir() {
return
}
names, err := c.ReadDir(dir)
if err != nil {
return
}
//sort.Strings(names)
for _, n := range names {
matched, err := Match(pattern, n.Name())
if err != nil {
return m, err
}
if matched {
m = append(m, Join(dir, n.Name()))
}
}
return
}
// Join joins any number of path elements into a single path, separating
// them with slashes.
//
// This is an alias for path.Join from the standard library,
// offered so that callers need not import the path package.
// For details, see https://golang.org/pkg/path/#Join.
func Join(elem ...string) string {
return path.Join(elem...)
}
// hasMeta reports whether path contains any of the magic characters
// recognized by Match.
func hasMeta(path string) bool {
return strings.ContainsAny(path, "\\*?[")
}
================================================
FILE: gossh/sftp/packet-manager.go
================================================
package sftp
import (
"encoding"
"sort"
"sync"
)
// The goal of the packetManager is to keep the outgoing packets in the same
// order as the incoming as is requires by section 7 of the RFC.
type packetManager struct {
requests chan orderedPacket
responses chan orderedPacket
fini chan struct{}
incoming orderedPackets
outgoing orderedPackets
sender packetSender // connection object
working *sync.WaitGroup
packetCount uint32
// it is not nil if the allocator is enabled
alloc *allocator
}
type packetSender interface {
sendPacket(encoding.BinaryMarshaler) error
}
func newPktMgr(sender packetSender) *packetManager {
s := &packetManager{
requests: make(chan orderedPacket, SftpServerWorkerCount),
responses: make(chan orderedPacket, SftpServerWorkerCount),
fini: make(chan struct{}),
incoming: make([]orderedPacket, 0, SftpServerWorkerCount),
outgoing: make([]orderedPacket, 0, SftpServerWorkerCount),
sender: sender,
working: &sync.WaitGroup{},
}
go s.controller()
return s
}
// // packet ordering
func (s *packetManager) newOrderID() uint32 {
s.packetCount++
return s.packetCount
}
// returns the next orderID without incrementing it.
// This is used before receiving a new packet, with the allocator enabled, to associate
// the slice allocated for the received packet with the orderID that will be used to mark
// the allocated slices for reuse once the request is served
func (s *packetManager) getNextOrderID() uint32 {
return s.packetCount + 1
}
type orderedRequest struct {
requestPacket
orderid uint32
}
func (s *packetManager) newOrderedRequest(p requestPacket) orderedRequest {
return orderedRequest{requestPacket: p, orderid: s.newOrderID()}
}
func (p orderedRequest) orderID() uint32 { return p.orderid }
func (p orderedRequest) setOrderID(oid uint32) { p.orderid = oid }
type orderedResponse struct {
responsePacket
orderid uint32
}
func (s *packetManager) newOrderedResponse(p responsePacket, id uint32,
) orderedResponse {
return orderedResponse{responsePacket: p, orderid: id}
}
func (p orderedResponse) orderID() uint32 { return p.orderid }
func (p orderedResponse) setOrderID(oid uint32) { p.orderid = oid }
type orderedPacket interface {
id() uint32
orderID() uint32
}
type orderedPackets []orderedPacket
func (o orderedPackets) Sort() {
sort.Slice(o, func(i, j int) bool {
return o[i].orderID() < o[j].orderID()
})
}
// // packet registry
// register incoming packets to be handled
func (s *packetManager) incomingPacket(pkt orderedRequest) {
s.working.Add(1)
s.requests <- pkt
}
// register outgoing packets as being ready
func (s *packetManager) readyPacket(pkt orderedResponse) {
s.responses <- pkt
s.working.Done()
}
// shut down packetManager controller
func (s *packetManager) close() {
// pause until current packets are processed
s.working.Wait()
close(s.fini)
}
// Passed a worker function, returns a channel for incoming packets.
// Keep process packet responses in the order they are received while
// maximizing throughput of file transfers.
func (s *packetManager) workerChan(runWorker func(chan orderedRequest),
) chan orderedRequest {
// multiple workers for faster read/writes
rwChan := make(chan orderedRequest, SftpServerWorkerCount)
for i := 0; i < SftpServerWorkerCount; i++ {
runWorker(rwChan)
}
// single worker to enforce sequential processing of everything else
cmdChan := make(chan orderedRequest)
runWorker(cmdChan)
pktChan := make(chan orderedRequest, SftpServerWorkerCount)
go func() {
for pkt := range pktChan {
switch pkt.requestPacket.(type) {
case *sshFxpReadPacket, *sshFxpWritePacket:
s.incomingPacket(pkt)
rwChan <- pkt
continue
case *sshFxpClosePacket:
// wait for reads/writes to finish when file is closed
// incomingPacket() call must occur after this
s.working.Wait()
}
s.incomingPacket(pkt)
// all non-RW use sequential cmdChan
cmdChan <- pkt
}
close(rwChan)
close(cmdChan)
s.close()
}()
return pktChan
}
// process packets
func (s *packetManager) controller() {
for {
select {
case pkt := <-s.requests:
debug("incoming id (oid): %v (%v)", pkt.id(), pkt.orderID())
s.incoming = append(s.incoming, pkt)
s.incoming.Sort()
case pkt := <-s.responses:
debug("outgoing id (oid): %v (%v)", pkt.id(), pkt.orderID())
s.outgoing = append(s.outgoing, pkt)
s.outgoing.Sort()
case <-s.fini:
return
}
s.maybeSendPackets()
}
}
// send as many packets as are ready
func (s *packetManager) maybeSendPackets() {
for {
if len(s.outgoing) == 0 || len(s.incoming) == 0 {
debug("break! -- outgoing: %v; incoming: %v",
len(s.outgoing), len(s.incoming))
break
}
out := s.outgoing[0]
in := s.incoming[0]
// debug("incoming: %v", ids(s.incoming))
// debug("outgoing: %v", ids(s.outgoing))
if in.orderID() == out.orderID() {
debug("Sending packet: %v", out.id())
s.sender.sendPacket(out.(encoding.BinaryMarshaler))
if s.alloc != nil {
// mark for reuse the slices allocated for this request
s.alloc.ReleasePages(in.orderID())
}
// pop off heads
copy(s.incoming, s.incoming[1:]) // shift left
s.incoming[len(s.incoming)-1] = nil // clear last
s.incoming = s.incoming[:len(s.incoming)-1] // remove last
copy(s.outgoing, s.outgoing[1:]) // shift left
s.outgoing[len(s.outgoing)-1] = nil // clear last
s.outgoing = s.outgoing[:len(s.outgoing)-1] // remove last
} else {
break
}
}
}
// func oids(o []orderedPacket) []uint32 {
// res := make([]uint32, 0, len(o))
// for _, v := range o {
// res = append(res, v.orderId())
// }
// return res
// }
// func ids(o []orderedPacket) []uint32 {
// res := make([]uint32, 0, len(o))
// for _, v := range o {
// res = append(res, v.id())
// }
// return res
// }
================================================
FILE: gossh/sftp/packet-typing.go
================================================
package sftp
import (
"encoding"
"fmt"
)
// all incoming packets
type requestPacket interface {
encoding.BinaryUnmarshaler
id() uint32
}
type responsePacket interface {
encoding.BinaryMarshaler
id() uint32
}
// interfaces to group types
type hasPath interface {
requestPacket
getPath() string
}
type hasHandle interface {
requestPacket
getHandle() string
}
type notReadOnly interface {
notReadOnly()
}
// // define types by adding methods
// hasPath
func (p *sshFxpLstatPacket) getPath() string { return p.Path }
func (p *sshFxpStatPacket) getPath() string { return p.Path }
func (p *sshFxpRmdirPacket) getPath() string { return p.Path }
func (p *sshFxpReadlinkPacket) getPath() string { return p.Path }
func (p *sshFxpRealpathPacket) getPath() string { return p.Path }
func (p *sshFxpMkdirPacket) getPath() string { return p.Path }
func (p *sshFxpSetstatPacket) getPath() string { return p.Path }
func (p *sshFxpStatvfsPacket) getPath() string { return p.Path }
func (p *sshFxpRemovePacket) getPath() string { return p.Filename }
func (p *sshFxpRenamePacket) getPath() string { return p.Oldpath }
func (p *sshFxpSymlinkPacket) getPath() string { return p.Targetpath }
func (p *sshFxpOpendirPacket) getPath() string { return p.Path }
func (p *sshFxpOpenPacket) getPath() string { return p.Path }
func (p *sshFxpExtendedPacketPosixRename) getPath() string { return p.Oldpath }
func (p *sshFxpExtendedPacketHardlink) getPath() string { return p.Oldpath }
// getHandle
func (p *sshFxpFstatPacket) getHandle() string { return p.Handle }
func (p *sshFxpFsetstatPacket) getHandle() string { return p.Handle }
func (p *sshFxpReadPacket) getHandle() string { return p.Handle }
func (p *sshFxpWritePacket) getHandle() string { return p.Handle }
func (p *sshFxpReaddirPacket) getHandle() string { return p.Handle }
func (p *sshFxpClosePacket) getHandle() string { return p.Handle }
// notReadOnly
func (p *sshFxpWritePacket) notReadOnly() {}
func (p *sshFxpSetstatPacket) notReadOnly() {}
func (p *sshFxpFsetstatPacket) notReadOnly() {}
func (p *sshFxpRemovePacket) notReadOnly() {}
func (p *sshFxpMkdirPacket) notReadOnly() {}
func (p *sshFxpRmdirPacket) notReadOnly() {}
func (p *sshFxpRenamePacket) notReadOnly() {}
func (p *sshFxpSymlinkPacket) notReadOnly() {}
func (p *sshFxpExtendedPacketPosixRename) notReadOnly() {}
func (p *sshFxpExtendedPacketHardlink) notReadOnly() {}
// some packets with ID are missing id()
func (p *sshFxpDataPacket) id() uint32 { return p.ID }
func (p *sshFxpStatusPacket) id() uint32 { return p.ID }
func (p *sshFxpStatResponse) id() uint32 { return p.ID }
func (p *sshFxpNamePacket) id() uint32 { return p.ID }
func (p *sshFxpHandlePacket) id() uint32 { return p.ID }
func (p *StatVFS) id() uint32 { return p.ID }
func (p *sshFxVersionPacket) id() uint32 { return 0 }
// take raw incoming packet data and build packet objects
func makePacket(p rxPacket) (requestPacket, error) {
var pkt requestPacket
switch p.pktType {
case sshFxpInit:
pkt = &sshFxInitPacket{}
case sshFxpLstat:
pkt = &sshFxpLstatPacket{}
case sshFxpOpen:
pkt = &sshFxpOpenPacket{}
case sshFxpClose:
pkt = &sshFxpClosePacket{}
case sshFxpRead:
pkt = &sshFxpReadPacket{}
case sshFxpWrite:
pkt = &sshFxpWritePacket{}
case sshFxpFstat:
pkt = &sshFxpFstatPacket{}
case sshFxpSetstat:
pkt = &sshFxpSetstatPacket{}
case sshFxpFsetstat:
pkt = &sshFxpFsetstatPacket{}
case sshFxpOpendir:
pkt = &sshFxpOpendirPacket{}
case sshFxpReaddir:
pkt = &sshFxpReaddirPacket{}
case sshFxpRemove:
pkt = &sshFxpRemovePacket{}
case sshFxpMkdir:
pkt = &sshFxpMkdirPacket{}
case sshFxpRmdir:
pkt = &sshFxpRmdirPacket{}
case sshFxpRealpath:
pkt = &sshFxpRealpathPacket{}
case sshFxpStat:
pkt = &sshFxpStatPacket{}
case sshFxpRename:
pkt = &sshFxpRenamePacket{}
case sshFxpReadlink:
pkt = &sshFxpReadlinkPacket{}
case sshFxpSymlink:
pkt = &sshFxpSymlinkPacket{}
case sshFxpExtended:
pkt = &sshFxpExtendedPacket{}
default:
return nil, fmt.Errorf("unhandled packet type: %s", p.pktType)
}
if err := pkt.UnmarshalBinary(p.pktBytes); err != nil {
// Return partially unpacked packet to allow callers to return
// error messages appropriately with necessary id() method.
return pkt, err
}
return pkt, nil
}
================================================
FILE: gossh/sftp/packet.go
================================================
package sftp
import (
"bytes"
"encoding"
"encoding/binary"
"errors"
"fmt"
"io"
"os"
"reflect"
)
var (
errLongPacket = errors.New("packet too long")
errShortPacket = errors.New("packet too short")
errUnknownExtendedPacket = errors.New("unknown extended packet")
)
const (
maxMsgLength = 256 * 1024
debugDumpTxPacket = false
debugDumpRxPacket = false
debugDumpTxPacketBytes = false
debugDumpRxPacketBytes = false
)
func marshalUint32(b []byte, v uint32) []byte {
return append(b, byte(v>>24), byte(v>>16), byte(v>>8), byte(v))
}
func marshalUint64(b []byte, v uint64) []byte {
return marshalUint32(marshalUint32(b, uint32(v>>32)), uint32(v))
}
func marshalString(b []byte, v string) []byte {
return append(marshalUint32(b, uint32(len(v))), v...)
}
func marshalFileInfo(b []byte, fi os.FileInfo) []byte {
// attributes variable struct, and also variable per protocol version
// spec version 3 attributes:
// uint32 flags
// uint64 size present only if flag SSH_FILEXFER_ATTR_SIZE
// uint32 uid present only if flag SSH_FILEXFER_ATTR_UIDGID
// uint32 gid present only if flag SSH_FILEXFER_ATTR_UIDGID
// uint32 permissions present only if flag SSH_FILEXFER_ATTR_PERMISSIONS
// uint32 atime present only if flag SSH_FILEXFER_ACMODTIME
// uint32 mtime present only if flag SSH_FILEXFER_ACMODTIME
// uint32 extended_count present only if flag SSH_FILEXFER_ATTR_EXTENDED
// string extended_type
// string extended_data
// ... more extended data (extended_type - extended_data pairs),
// so that number of pairs equals extended_count
flags, fileStat := fileStatFromInfo(fi)
b = marshalUint32(b, flags)
return marshalFileStat(b, flags, fileStat)
}
func marshalFileStat(b []byte, flags uint32, fileStat *FileStat) []byte {
if flags&sshFileXferAttrSize != 0 {
b = marshalUint64(b, fileStat.Size)
}
if flags&sshFileXferAttrUIDGID != 0 {
b = marshalUint32(b, fileStat.UID)
b = marshalUint32(b, fileStat.GID)
}
if flags&sshFileXferAttrPermissions != 0 {
b = marshalUint32(b, fileStat.Mode)
}
if flags&sshFileXferAttrACmodTime != 0 {
b = marshalUint32(b, fileStat.Atime)
b = marshalUint32(b, fileStat.Mtime)
}
if flags&sshFileXferAttrExtended != 0 {
b = marshalUint32(b, uint32(len(fileStat.Extended)))
for _, attr := range fileStat.Extended {
b = marshalString(b, attr.ExtType)
b = marshalString(b, attr.ExtData)
}
}
return b
}
func marshalStatus(b []byte, err StatusError) []byte {
b = marshalUint32(b, err.Code)
b = marshalString(b, err.msg)
b = marshalString(b, err.lang)
return b
}
func marshal(b []byte, v any) []byte {
switch v := v.(type) {
case nil:
return b
case uint8:
return append(b, v)
case uint32:
return marshalUint32(b, v)
case uint64:
return marshalUint64(b, v)
case string:
return marshalString(b, v)
case []byte:
return append(b, v...)
case os.FileInfo:
return marshalFileInfo(b, v)
default:
switch d := reflect.ValueOf(v); d.Kind() {
case reflect.Struct:
for i, n := 0, d.NumField(); i < n; i++ {
b = marshal(b, d.Field(i).Interface())
}
return b
case reflect.Slice:
for i, n := 0, d.Len(); i < n; i++ {
b = marshal(b, d.Index(i).Interface())
}
return b
default:
panic(fmt.Sprintf("marshal(%#v): cannot handle type %T", v, v))
}
}
}
func unmarshalUint32(b []byte) (uint32, []byte) {
v := uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24
return v, b[4:]
}
func unmarshalUint32Safe(b []byte) (uint32, []byte, error) {
var v uint32
if len(b) < 4 {
return 0, nil, errShortPacket
}
v, b = unmarshalUint32(b)
return v, b, nil
}
func unmarshalUint64(b []byte) (uint64, []byte) {
h, b := unmarshalUint32(b)
l, b := unmarshalUint32(b)
return uint64(h)<<32 | uint64(l), b
}
func unmarshalUint64Safe(b []byte) (uint64, []byte, error) {
var v uint64
if len(b) < 8 {
return 0, nil, errShortPacket
}
v, b = unmarshalUint64(b)
return v, b, nil
}
func unmarshalString(b []byte) (string, []byte) {
n, b := unmarshalUint32(b)
return string(b[:n]), b[n:]
}
func unmarshalStringSafe(b []byte) (string, []byte, error) {
n, b, err := unmarshalUint32Safe(b)
if err != nil {
return "", nil, err
}
if int64(n) > int64(len(b)) {
return "", nil, errShortPacket
}
return string(b[:n]), b[n:], nil
}
func unmarshalAttrs(b []byte) (*FileStat, []byte, error) {
flags, b, err := unmarshalUint32Safe(b)
if err != nil {
return nil, b, err
}
return unmarshalFileStat(flags, b)
}
func unmarshalFileStat(flags uint32, b []byte) (*FileStat, []byte, error) {
var fs FileStat
var err error
if flags&sshFileXferAttrSize == sshFileXferAttrSize {
fs.Size, b, err = unmarshalUint64Safe(b)
if err != nil {
return nil, b, err
}
}
if flags&sshFileXferAttrUIDGID == sshFileXferAttrUIDGID {
fs.UID, b, err = unmarshalUint32Safe(b)
if err != nil {
return nil, b, err
}
fs.GID, b, err = unmarshalUint32Safe(b)
if err != nil {
return nil, b, err
}
}
if flags&sshFileXferAttrPermissions == sshFileXferAttrPermissions {
fs.Mode, b, err = unmarshalUint32Safe(b)
if err != nil {
return nil, b, err
}
}
if flags&sshFileXferAttrACmodTime == sshFileXferAttrACmodTime {
fs.Atime, b, err = unmarshalUint32Safe(b)
if err != nil {
return nil, b, err
}
fs.Mtime, b, err = unmarshalUint32Safe(b)
if err != nil {
return nil, b, err
}
}
if flags&sshFileXferAttrExtended == sshFileXferAttrExtended {
var count uint32
count, b, err = unmarshalUint32Safe(b)
if err != nil {
return nil, b, err
}
ext := make([]StatExtended, count)
for i := uint32(0); i < count; i++ {
var typ string
var data string
typ, b, err = unmarshalStringSafe(b)
if err != nil {
return nil, b, err
}
data, b, err = unmarshalStringSafe(b)
if err != nil {
return nil, b, err
}
ext[i] = StatExtended{
ExtType: typ,
ExtData: data,
}
}
fs.Extended = ext
}
return &fs, b, nil
}
func unmarshalStatus(id uint32, data []byte) error {
sid, data := unmarshalUint32(data)
if sid != id {
return &unexpectedIDErr{id, sid}
}
code, data := unmarshalUint32(data)
msg, data, _ := unmarshalStringSafe(data)
lang, _, _ := unmarshalStringSafe(data)
return &StatusError{
Code: code,
msg: msg,
lang: lang,
}
}
type packetMarshaler interface {
marshalPacket() (header, payload []byte, err error)
}
func marshalPacket(m encoding.BinaryMarshaler) (header, payload []byte, err error) {
if m, ok := m.(packetMarshaler); ok {
return m.marshalPacket()
}
header, err = m.MarshalBinary()
return
}
// sendPacket marshals p according to RFC 4234.
func sendPacket(w io.Writer, m encoding.BinaryMarshaler) error {
header, payload, err := marshalPacket(m)
if err != nil {
return fmt.Errorf("binary marshaller failed: %w", err)
}
length := len(header) + len(payload) - 4 // subtract the uint32(length) from the start
if debugDumpTxPacketBytes {
debug("send packet: %s %d bytes %x%x", fxp(header[4]), length, header[5:], payload)
} else if debugDumpTxPacket {
debug("send packet: %s %d bytes", fxp(header[4]), length)
}
binary.BigEndian.PutUint32(header[:4], uint32(length))
if _, err := w.Write(header); err != nil {
return fmt.Errorf("failed to send packet: %w", err)
}
if len(payload) > 0 {
if _, err := w.Write(payload); err != nil {
return fmt.Errorf("failed to send packet payload: %w", err)
}
}
return nil
}
func recvPacket(r io.Reader, alloc *allocator, orderID uint32) (uint8, []byte, error) {
var b []byte
if alloc != nil {
b = alloc.GetPage(orderID)
} else {
b = make([]byte, 4)
}
if _, err := io.ReadFull(r, b[:4]); err != nil {
return 0, nil, err
}
length, _ := unmarshalUint32(b)
if length > maxMsgLength {
debug("recv packet %d bytes too long", length)
return 0, nil, errLongPacket
}
if length == 0 {
debug("recv packet of 0 bytes too short")
return 0, nil, errShortPacket
}
if alloc == nil {
b = make([]byte, length)
}
if _, err := io.ReadFull(r, b[:length]); err != nil {
// ReadFull only returns EOF if it has read no bytes.
// In this case, that means a partial packet, and thus unexpected.
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
debug("recv packet %d bytes: err %v", length, err)
return 0, nil, err
}
if debugDumpRxPacketBytes {
debug("recv packet: %s %d bytes %x", fxp(b[0]), length, b[1:length])
} else if debugDumpRxPacket {
debug("recv packet: %s %d bytes", fxp(b[0]), length)
}
return b[0], b[1:length], nil
}
type extensionPair struct {
Name string
Data string
}
func unmarshalExtensionPair(b []byte) (extensionPair, []byte, error) {
var ep extensionPair
var err error
ep.Name, b, err = unmarshalStringSafe(b)
if err != nil {
return ep, b, err
}
ep.Data, b, err = unmarshalStringSafe(b)
return ep, b, err
}
// Here starts the definition of packets along with their MarshalBinary
// implementations.
// Manually writing the marshalling logic wins us a lot of time and
// allocation.
type sshFxInitPacket struct {
Version uint32
Extensions []extensionPair
}
func (p *sshFxInitPacket) MarshalBinary() ([]byte, error) {
l := 4 + 1 + 4 // uint32(length) + byte(type) + uint32(version)
for _, e := range p.Extensions {
l += 4 + len(e.Name) + 4 + len(e.Data)
}
b := make([]byte, 4, l)
b = append(b, sshFxpInit)
b = marshalUint32(b, p.Version)
for _, e := range p.Extensions {
b = marshalString(b, e.Name)
b = marshalString(b, e.Data)
}
return b, nil
}
func (p *sshFxInitPacket) UnmarshalBinary(b []byte) error {
var err error
if p.Version, b, err = unmarshalUint32Safe(b); err != nil {
return err
}
for len(b) > 0 {
var ep extensionPair
ep, b, err = unmarshalExtensionPair(b)
if err != nil {
return err
}
p.Extensions = append(p.Extensions, ep)
}
return nil
}
type sshFxVersionPacket struct {
Version uint32
Extensions []sshExtensionPair
}
type sshExtensionPair struct {
Name, Data string
}
func (p *sshFxVersionPacket) MarshalBinary() ([]byte, error) {
l := 4 + 1 + 4 // uint32(length) + byte(type) + uint32(version)
for _, e := range p.Extensions {
l += 4 + len(e.Name) + 4 + len(e.Data)
}
b := make([]byte, 4, l)
b = append(b, sshFxpVersion)
b = marshalUint32(b, p.Version)
for _, e := range p.Extensions {
b = marshalString(b, e.Name)
b = marshalString(b, e.Data)
}
return b, nil
}
func marshalIDStringPacket(packetType byte, id uint32, str string) ([]byte, error) {
l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
4 + len(str)
b := make([]byte, 4, l)
b = append(b, packetType)
b = marshalUint32(b, id)
b = marshalString(b, str)
return b, nil
}
func unmarshalIDString(b []byte, id *uint32, str *string) error {
var err error
*id, b, err = unmarshalUint32Safe(b)
if err != nil {
return err
}
*str, _, err = unmarshalStringSafe(b)
return err
}
type sshFxpReaddirPacket struct {
ID uint32
Handle string
}
func (p *sshFxpReaddirPacket) id() uint32 { return p.ID }
func (p *sshFxpReaddirPacket) MarshalBinary() ([]byte, error) {
return marshalIDStringPacket(sshFxpReaddir, p.ID, p.Handle)
}
func (p *sshFxpReaddirPacket) UnmarshalBinary(b []byte) error {
return unmarshalIDString(b, &p.ID, &p.Handle)
}
type sshFxpOpendirPacket struct {
ID uint32
Path string
}
func (p *sshFxpOpendirPacket) id() uint32 { return p.ID }
func (p *sshFxpOpendirPacket) MarshalBinary() ([]byte, error) {
return marshalIDStringPacket(sshFxpOpendir, p.ID, p.Path)
}
func (p *sshFxpOpendirPacket) UnmarshalBinary(b []byte) error {
return unmarshalIDString(b, &p.ID, &p.Path)
}
type sshFxpLstatPacket struct {
ID uint32
Path string
}
func (p *sshFxpLstatPacket) id() uint32 { return p.ID }
func (p *sshFxpLstatPacket) MarshalBinary() ([]byte, error) {
return marshalIDStringPacket(sshFxpLstat, p.ID, p.Path)
}
func (p *sshFxpLstatPacket) UnmarshalBinary(b []byte) error {
return unmarshalIDString(b, &p.ID, &p.Path)
}
type sshFxpStatPacket struct {
ID uint32
Path string
}
func (p *sshFxpStatPacket) id() uint32 { return p.ID }
func (p *sshFxpStatPacket) MarshalBinary() ([]byte, error) {
return marshalIDStringPacket(sshFxpStat, p.ID, p.Path)
}
func (p *sshFxpStatPacket) UnmarshalBinary(b []byte) error {
return unmarshalIDString(b, &p.ID, &p.Path)
}
type sshFxpFstatPacket struct {
ID uint32
Handle string
}
func (p *sshFxpFstatPacket) id() uint32 { return p.ID }
func (p *sshFxpFstatPacket) MarshalBinary() ([]byte, error) {
return marshalIDStringPacket(sshFxpFstat, p.ID, p.Handle)
}
func (p *sshFxpFstatPacket) UnmarshalBinary(b []byte) error {
return unmarshalIDString(b, &p.ID, &p.Handle)
}
type sshFxpClosePacket struct {
ID uint32
Handle string
}
func (p *sshFxpClosePacket) id() uint32 { return p.ID }
func (p *sshFxpClosePacket) MarshalBinary() ([]byte, error) {
return marshalIDStringPacket(sshFxpClose, p.ID, p.Handle)
}
func (p *sshFxpClosePacket) UnmarshalBinary(b []byte) error {
return unmarshalIDString(b, &p.ID, &p.Handle)
}
type sshFxpRemovePacket struct {
ID uint32
Filename string
}
func (p *sshFxpRemovePacket) id() uint32 { return p.ID }
func (p *sshFxpRemovePacket) MarshalBinary() ([]byte, error) {
return marshalIDStringPacket(sshFxpRemove, p.ID, p.Filename)
}
func (p *sshFxpRemovePacket) UnmarshalBinary(b []byte) error {
return unmarshalIDString(b, &p.ID, &p.Filename)
}
type sshFxpRmdirPacket struct {
ID uint32
Path string
}
func (p *sshFxpRmdirPacket) id() uint32 { return p.ID }
func (p *sshFxpRmdirPacket) MarshalBinary() ([]byte, error) {
return marshalIDStringPacket(sshFxpRmdir, p.ID, p.Path)
}
func (p *sshFxpRmdirPacket) UnmarshalBinary(b []byte) error {
return unmarshalIDString(b, &p.ID, &p.Path)
}
type sshFxpSymlinkPacket struct {
ID uint32
// The order of the arguments to the SSH_FXP_SYMLINK method was inadvertently reversed.
// Unfortunately, the reversal was not noticed until the server was widely deployed.
// Covered in Section 4.1 of https://github.com/openssh/openssh-portable/blob/master/PROTOCOL
Targetpath string
Linkpath string
}
func (p *sshFxpSymlinkPacket) id() uint32 { return p.ID }
func (p *sshFxpSymlinkPacket) MarshalBinary() ([]byte, error) {
l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
4 + len(p.Targetpath) +
4 + len(p.Linkpath)
b := make([]byte, 4, l)
b = append(b, sshFxpSymlink)
b = marshalUint32(b, p.ID)
b = marshalString(b, p.Targetpath)
b = marshalString(b, p.Linkpath)
return b, nil
}
func (p *sshFxpSymlinkPacket) UnmarshalBinary(b []byte) error {
var err error
if p.ID, b, err = unmarshalUint32Safe(b); err != nil {
return err
} else if p.Targetpath, b, err = unmarshalStringSafe(b); err != nil {
return err
} else if p.Linkpath, _, err = unmarshalStringSafe(b); err != nil {
return err
}
return nil
}
type sshFxpHardlinkPacket struct {
ID uint32
Oldpath string
Newpath string
}
func (p *sshFxpHardlinkPacket) id() uint32 { return p.ID }
func (p *sshFxpHardlinkPacket) MarshalBinary() ([]byte, error) {
const ext = "hardlink@openssh.com"
l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
4 + len(ext) +
4 + len(p.Oldpath) +
4 + len(p.Newpath)
b := make([]byte, 4, l)
b = append(b, sshFxpExtended)
b = marshalUint32(b, p.ID)
b = marshalString(b, ext)
b = marshalString(b, p.Oldpath)
b = marshalString(b, p.Newpath)
return b, nil
}
type sshFxpReadlinkPacket struct {
ID uint32
Path string
}
func (p *sshFxpReadlinkPacket) id() uint32 { return p.ID }
func (p *sshFxpReadlinkPacket) MarshalBinary() ([]byte, error) {
return marshalIDStringPacket(sshFxpReadlink, p.ID, p.Path)
}
func (p *sshFxpReadlinkPacket) UnmarshalBinary(b []byte) error {
return unmarshalIDString(b, &p.ID, &p.Path)
}
type sshFxpRealpathPacket struct {
ID uint32
Path string
}
func (p *sshFxpRealpathPacket) id() uint32 { return p.ID }
func (p *sshFxpRealpathPacket) MarshalBinary() ([]byte, error) {
return marshalIDStringPacket(sshFxpRealpath, p.ID, p.Path)
}
func (p *sshFxpRealpathPacket) UnmarshalBinary(b []byte) error {
return unmarshalIDString(b, &p.ID, &p.Path)
}
type sshFxpNameAttr struct {
Name string
LongName string
Attrs []any
}
func (p *sshFxpNameAttr) MarshalBinary() ([]byte, error) {
var b []byte
b = marshalString(b, p.Name)
b = marshalString(b, p.LongName)
for _, attr := range p.Attrs {
b = marshal(b, attr)
}
return b, nil
}
type sshFxpNamePacket struct {
ID uint32
NameAttrs []*sshFxpNameAttr
}
func (p *sshFxpNamePacket) marshalPacket() ([]byte, []byte, error) {
l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
4
b := make([]byte, 4, l)
b = append(b, sshFxpName)
b = marshalUint32(b, p.ID)
b = marshalUint32(b, uint32(len(p.NameAttrs)))
var payload []byte
for _, na := range p.NameAttrs {
ab, err := na.MarshalBinary()
if err != nil {
return nil, nil, err
}
payload = append(payload, ab...)
}
return b, payload, nil
}
func (p *sshFxpNamePacket) MarshalBinary() ([]byte, error) {
header, payload, err := p.marshalPacket()
return append(header, payload...), err
}
type sshFxpOpenPacket struct {
ID uint32
Path string
Pflags uint32
Flags uint32
Attrs any
}
func (p *sshFxpOpenPacket) id() uint32 { return p.ID }
func (p *sshFxpOpenPacket) marshalPacket() ([]byte, []byte, error) {
l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
4 + len(p.Path) +
4 + 4
b := make([]byte, 4, l)
b = append(b, sshFxpOpen)
b = marshalUint32(b, p.ID)
b = marshalString(b, p.Path)
b = marshalUint32(b, p.Pflags)
b = marshalUint32(b, p.Flags)
switch attrs := p.Attrs.(type) {
case []byte:
return b, attrs, nil // may as well short-ciruit this case.
case os.FileInfo:
_, fs := fileStatFromInfo(attrs) // we throw away the flags, and override with those in packet.
return b, marshalFileStat(nil, p.Flags, fs), nil
case *FileStat:
return b, marshalFileStat(nil, p.Flags, attrs), nil
}
return b, marshal(nil, p.Attrs), nil
}
func (p *sshFxpOpenPacket) MarshalBinary() ([]byte, error) {
header, payload, err := p.marshalPacket()
return append(header, payload...), err
}
func (p *sshFxpOpenPacket) UnmarshalBinary(b []byte) error {
var err error
if p.ID, b, err = unmarshalUint32Safe(b); err != nil {
return err
} else if p.Path, b, err = unmarshalStringSafe(b); err != nil {
return err
} else if p.Pflags, b, err = unmarshalUint32Safe(b); err != nil {
return err
} else if p.Flags, b, err = unmarshalUint32Safe(b); err != nil {
return err
}
p.Attrs = b
return nil
}
func (p *sshFxpOpenPacket) unmarshalFileStat(flags uint32) (*FileStat, error) {
switch attrs := p.Attrs.(type) {
case *FileStat:
return attrs, nil
case []byte:
fs, _, err := unmarshalFileStat(flags, attrs)
return fs, err
default:
return nil, fmt.Errorf("invalid type in unmarshalFileStat: %T", attrs)
}
}
type sshFxpReadPacket struct {
ID uint32
Len uint32
Offset uint64
Handle string
}
func (p *sshFxpReadPacket) id() uint32 { return p.ID }
func (p *sshFxpReadPacket) MarshalBinary() ([]byte, error) {
l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
4 + len(p.Handle) +
8 + 4 // uint64 + uint32
b := make([]byte, 4, l)
b = append(b, sshFxpRead)
b = marshalUint32(b, p.ID)
b = marshalString(b, p.Handle)
b = marshalUint64(b, p.Offset)
b = marshalUint32(b, p.Len)
return b, nil
}
func (p *sshFxpReadPacket) UnmarshalBinary(b []byte) error {
var err error
if p.ID, b, err = unmarshalUint32Safe(b); err != nil {
return err
} else if p.Handle, b, err = unmarshalStringSafe(b); err != nil {
return err
} else if p.Offset, b, err = unmarshalUint64Safe(b); err != nil {
return err
} else if p.Len, _, err = unmarshalUint32Safe(b); err != nil {
return err
}
return nil
}
// We need allocate bigger slices with extra capacity to avoid a re-allocation in sshFxpDataPacket.MarshalBinary
// So, we need: uint32(length) + byte(type) + uint32(id) + uint32(data_length)
const dataHeaderLen = 4 + 1 + 4 + 4
func (p *sshFxpReadPacket) getDataSlice(alloc *allocator, orderID uint32) []byte {
dataLen := p.Len
if dataLen > maxTxPacket {
dataLen = maxTxPacket
}
if alloc != nil {
// GetPage returns a slice with capacity = maxMsgLength this is enough to avoid new allocations in
// sshFxpDataPacket.MarshalBinary
return alloc.GetPage(orderID)[:dataLen]
}
// allocate with extra space for the header
return make([]byte, dataLen, dataLen+dataHeaderLen)
}
type sshFxpRenamePacket struct {
ID uint32
Oldpath string
Newpath string
}
func (p *sshFxpRenamePacket) id() uint32 { return p.ID }
func (p *sshFxpRenamePacket) MarshalBinary() ([]byte, error) {
l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
4 + len(p.Oldpath) +
4 + len(p.Newpath)
b := make([]byte, 4, l)
b = append(b, sshFxpRename)
b = marshalUint32(b, p.ID)
b = marshalString(b, p.Oldpath)
b = marshalString(b, p.Newpath)
return b, nil
}
func (p *sshFxpRenamePacket) UnmarshalBinary(b []byte) error {
var err error
if p.ID, b, err = unmarshalUint32Safe(b); err != nil {
return err
} else if p.Oldpath, b, err = unmarshalStringSafe(b); err != nil {
return err
} else if p.Newpath, _, err = unmarshalStringSafe(b); err != nil {
return err
}
return nil
}
type sshFxpPosixRenamePacket struct {
ID uint32
Oldpath string
Newpath string
}
func (p *sshFxpPosixRenamePacket) id() uint32 { return p.ID }
func (p *sshFxpPosixRenamePacket) MarshalBinary() ([]byte, error) {
const ext = "posix-rename@openssh.com"
l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
4 + len(ext) +
4 + len(p.Oldpath) +
4 + len(p.Newpath)
b := make([]byte, 4, l)
b = append(b, sshFxpExtended)
b = marshalUint32(b, p.ID)
b = marshalString(b, ext)
b = marshalString(b, p.Oldpath)
b = marshalString(b, p.Newpath)
return b, nil
}
type sshFxpWritePacket struct {
ID uint32
Length uint32
Offset uint64
Handle string
Data []byte
}
func (p *sshFxpWritePacket) id() uint32 { return p.ID }
func (p *sshFxpWritePacket) marshalPacket() ([]byte, []byte, error) {
l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
4 + len(p.Handle) +
8 + // uint64
4
b := make([]byte, 4, l)
b = append(b, sshFxpWrite)
b = marshalUint32(b, p.ID)
b = marshalString(b, p.Handle)
b = marshalUint64(b, p.Offset)
b = marshalUint32(b, p.Length)
return b, p.Data, nil
}
func (p *sshFxpWritePacket) MarshalBinary() ([]byte, error) {
header, payload, err := p.marshalPacket()
return append(header, payload...), err
}
func (p *sshFxpWritePacket) UnmarshalBinary(b []byte) error {
var err error
if p.ID, b, err = unmarshalUint32Safe(b); err != nil {
return err
} else if p.Handle, b, err = unmarshalStringSafe(b); err != nil {
return err
} else if p.Offset, b, err = unmarshalUint64Safe(b); err != nil {
return err
} else if p.Length, b, err = unmarshalUint32Safe(b); err != nil {
return err
} else if uint32(len(b)) < p.Length {
return errShortPacket
}
p.Data = b[:p.Length]
return nil
}
type sshFxpMkdirPacket struct {
ID uint32
Flags uint32 // ignored
Path string
}
func (p *sshFxpMkdirPacket) id() uint32 { return p.ID }
func (p *sshFxpMkdirPacket) MarshalBinary() ([]byte, error) {
l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
4 + len(p.Path) +
4 // uint32
b := make([]byte, 4, l)
b = append(b, sshFxpMkdir)
b = marshalUint32(b, p.ID)
b = marshalString(b, p.Path)
b = marshalUint32(b, p.Flags)
return b, nil
}
func (p *sshFxpMkdirPacket) UnmarshalBinary(b []byte) error {
var err error
if p.ID, b, err = unmarshalUint32Safe(b); err != nil {
return err
} else if p.Path, b, err = unmarshalStringSafe(b); err != nil {
return err
} else if p.Flags, _, err = unmarshalUint32Safe(b); err != nil {
return err
}
return nil
}
type sshFxpSetstatPacket struct {
ID uint32
Flags uint32
Path string
Attrs any
}
type sshFxpFsetstatPacket struct {
ID uint32
Flags uint32
Handle string
Attrs any
}
func (p *sshFxpSetstatPacket) id() uint32 { return p.ID }
func (p *sshFxpFsetstatPacket) id() uint32 { return p.ID }
func (p *sshFxpSetstatPacket) marshalPacket() ([]byte, []byte, error) {
l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
4 + len(p.Path) +
4 // uint32
b := make([]byte, 4, l)
b = append(b, sshFxpSetstat)
b = marshalUint32(b, p.ID)
b = marshalString(b, p.Path)
b = marshalUint32(b, p.Flags)
switch attrs := p.Attrs.(type) {
case []byte:
return b, attrs, nil // may as well short-ciruit this case.
case os.FileInfo:
_, fs := fileStatFromInfo(attrs) // we throw away the flags, and override with those in packet.
return b, marshalFileStat(nil, p.Flags, fs), nil
case *FileStat:
return b, marshalFileStat(nil, p.Flags, attrs), nil
}
return b, marshal(nil, p.Attrs), nil
}
func (p *sshFxpSetstatPacket) MarshalBinary() ([]byte, error) {
header, payload, err := p.marshalPacket()
return append(header, payload...), err
}
func (p *sshFxpFsetstatPacket) marshalPacket() ([]byte, []byte, error) {
l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
4 + len(p.Handle) +
4 // uint32
b := make([]byte, 4, l)
b = append(b, sshFxpFsetstat)
b = marshalUint32(b, p.ID)
b = marshalString(b, p.Handle)
b = marshalUint32(b, p.Flags)
switch attrs := p.Attrs.(type) {
case []byte:
return b, attrs, nil // may as well short-ciruit this case.
case os.FileInfo:
_, fs := fileStatFromInfo(attrs) // we throw away the flags, and override with those in packet.
return b, marshalFileStat(nil, p.Flags, fs), nil
case *FileStat:
return b, marshalFileStat(nil, p.Flags, attrs), nil
}
return b, marshal(nil, p.Attrs), nil
}
func (p *sshFxpFsetstatPacket) MarshalBinary() ([]byte, error) {
header, payload, err := p.marshalPacket()
return append(header, payload...), err
}
func (p *sshFxpSetstatPacket) UnmarshalBinary(b []byte) error {
var err error
if p.ID, b, err = unmarshalUint32Safe(b); err != nil {
return err
} else if p.Path, b, err = unmarshalStringSafe(b); err != nil {
return err
} else if p.Flags, b, err = unmarshalUint32Safe(b); err != nil {
return err
}
p.Attrs = b
return nil
}
func (p *sshFxpSetstatPacket) unmarshalFileStat(flags uint32) (*FileStat, error) {
switch attrs := p.Attrs.(type) {
case *FileStat:
return attrs, nil
case []byte:
fs, _, err := unmarshalFileStat(flags, attrs)
return fs, err
default:
return nil, fmt.Errorf("invalid type in unmarshalFileStat: %T", attrs)
}
}
func (p *sshFxpFsetstatPacket) UnmarshalBinary(b []byte) error {
var err error
if p.ID, b, err = unmarshalUint32Safe(b); err != nil {
return err
} else if p.Handle, b, err = unmarshalStringSafe(b); err != nil {
return err
} else if p.Flags, b, err = unmarshalUint32Safe(b); err != nil {
return err
}
p.Attrs = b
return nil
}
func (p *sshFxpFsetstatPacket) unmarshalFileStat(flags uint32) (*FileStat, error) {
switch attrs := p.Attrs.(type) {
case *FileStat:
return attrs, nil
case []byte:
fs, _, err := unmarshalFileStat(flags, attrs)
return fs, err
default:
return nil, fmt.Errorf("invalid type in unmarshalFileStat: %T", attrs)
}
}
type sshFxpHandlePacket struct {
ID uint32
Handle string
}
func (p *sshFxpHandlePacket) MarshalBinary() ([]byte, error) {
l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
4 + len(p.Handle)
b := make([]byte, 4, l)
b = append(b, sshFxpHandle)
b = marshalUint32(b, p.ID)
b = marshalString(b, p.Handle)
return b, nil
}
type sshFxpStatusPacket struct {
ID uint32
StatusError
}
func (p *sshFxpStatusPacket) MarshalBinary() ([]byte, error) {
l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
4 +
4 + len(p.StatusError.msg) +
4 + len(p.StatusError.lang)
b := make([]byte, 4, l)
b = append(b, sshFxpStatus)
b = marshalUint32(b, p.ID)
b = marshalStatus(b, p.StatusError)
return b, nil
}
type sshFxpDataPacket struct {
ID uint32
Length uint32
Data []byte
}
func (p *sshFxpDataPacket) marshalPacket() ([]byte, []byte, error) {
l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
4
b := make([]byte, 4, l)
b = append(b, sshFxpData)
b = marshalUint32(b, p.ID)
b = marshalUint32(b, p.Length)
return b, p.Data, nil
}
// MarshalBinary encodes the receiver into a binary form and returns the result.
// To avoid a new allocation the Data slice must have a capacity >= Length + 9
//
// This is hand-coded rather than just append(header, payload...),
// in order to try and reuse the r.Data backing store in the packet.
func (p *sshFxpDataPacket) MarshalBinary() ([]byte, error) {
b := append(p.Data, make([]byte, dataHeaderLen)...)
copy(b[dataHeaderLen:], p.Data[:p.Length])
// b[0:4] will be overwritten with the length in sendPacket
b[4] = sshFxpData
binary.BigEndian.PutUint32(b[5:9], p.ID)
binary.BigEndian.PutUint32(b[9:13], p.Length)
return b, nil
}
func (p *sshFxpDataPacket) UnmarshalBinary(b []byte) error {
var err error
if p.ID, b, err = unmarshalUint32Safe(b); err != nil {
return err
} else if p.Length, b, err = unmarshalUint32Safe(b); err != nil {
return err
} else if uint32(len(b)) < p.Length {
return errShortPacket
}
p.Data = b[:p.Length]
return nil
}
type sshFxpStatvfsPacket struct {
ID uint32
Path string
}
func (p *sshFxpStatvfsPacket) id() uint32 { return p.ID }
func (p *sshFxpStatvfsPacket) MarshalBinary() ([]byte, error) {
const ext = "statvfs@openssh.com"
l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
4 + len(ext) +
4 + len(p.Path)
b := make([]byte, 4, l)
b = append(b, sshFxpExtended)
b = marshalUint32(b, p.ID)
b = marshalString(b, ext)
b = marshalString(b, p.Path)
return b, nil
}
// A StatVFS contains statistics about a filesystem.
type StatVFS struct {
ID uint32
Bsize uint64 /* file system block size */
Frsize uint64 /* fundamental sftpfs block size */
Blocks uint64 /* number of blocks (unit f_frsize) */
Bfree uint64 /* free blocks in file system */
Bavail uint64 /* free blocks for non-root */
Files uint64 /* total file inodes */
Ffree uint64 /* free file inodes */
Favail uint64 /* free file inodes for to non-root */
Fsid uint64 /* file system id */
Flag uint64 /* bit mask of f_flag values */
Namemax uint64 /* maximum filename length */
}
// TotalSpace calculates the amount of total space in a filesystem.
func (p *StatVFS) TotalSpace() uint64 {
return p.Frsize * p.Blocks
}
// FreeSpace calculates the amount of free space in a filesystem.
func (p *StatVFS) FreeSpace() uint64 {
return p.Frsize * p.Bfree
}
// marshalPacket converts to ssh_FXP_EXTENDED_REPLY packet binary format
func (p *StatVFS) marshalPacket() ([]byte, []byte, error) {
header := []byte{0, 0, 0, 0, sshFxpExtendedReply}
var buf bytes.Buffer
err := binary.Write(&buf, binary.BigEndian, p)
return header, buf.Bytes(), err
}
// MarshalBinary encodes the StatVFS as an SSH_FXP_EXTENDED_REPLY packet.
func (p *StatVFS) MarshalBinary() ([]byte, error) {
header, payload, err := p.marshalPacket()
return append(header, payload...), err
}
type sshFxpFsyncPacket struct {
ID uint32
Handle string
}
func (p *sshFxpFsyncPacket) id() uint32 { return p.ID }
func (p *sshFxpFsyncPacket) MarshalBinary() ([]byte, error) {
const ext = "fsync@openssh.com"
l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
4 + len(ext) +
4 + len(p.Handle)
b := make([]byte, 4, l)
b = append(b, sshFxpExtended)
b = marshalUint32(b, p.ID)
b = marshalString(b, ext)
b = marshalString(b, p.Handle)
return b, nil
}
type sshFxpExtendedPacket struct {
ID uint32
ExtendedRequest string
SpecificPacket interface {
serverRespondablePacket
readonly() bool
}
}
func (p *sshFxpExtendedPacket) id() uint32 { return p.ID }
func (p *sshFxpExtendedPacket) readonly() bool {
if p.SpecificPacket == nil {
return true
}
return p.SpecificPacket.readonly()
}
func (p *sshFxpExtendedPacket) respond(svr *Server) responsePacket {
if p.SpecificPacket == nil {
return statusFromError(p.ID, nil)
}
return p.SpecificPacket.respond(svr)
}
func (p *sshFxpExtendedPacket) UnmarshalBinary(b []byte) error {
var err error
bOrig := b
if p.ID, b, err = unmarshalUint32Safe(b); err != nil {
return err
} else if p.ExtendedRequest, _, err = unmarshalStringSafe(b); err != nil {
return err
}
// specific unmarshalling
switch p.ExtendedRequest {
case "statvfs@openssh.com":
p.SpecificPacket = &sshFxpExtendedPacketStatVFS{}
case "posix-rename@openssh.com":
p.SpecificPacket = &sshFxpExtendedPacketPosixRename{}
case "hardlink@openssh.com":
p.SpecificPacket = &sshFxpExtendedPacketHardlink{}
default:
return fmt.Errorf("packet type %v: %w", p.SpecificPacket, errUnknownExtendedPacket)
}
return p.SpecificPacket.UnmarshalBinary(bOrig)
}
type sshFxpExtendedPacketStatVFS struct {
ID uint32
ExtendedRequest string
Path string
}
func (p *sshFxpExtendedPacketStatVFS) id() uint32 { return p.ID }
func (p *sshFxpExtendedPacketStatVFS) readonly() bool { return true }
func (p *sshFxpExtendedPacketStatVFS) UnmarshalBinary(b []byte) error {
var err error
if p.ID, b, err = unmarshalUint32Safe(b); err != nil {
return err
} else if p.ExtendedRequest, b, err = unmarshalStringSafe(b); err != nil {
return err
} else if p.Path, _, err = unmarshalStringSafe(b); err != nil {
return err
}
return nil
}
type sshFxpExtendedPacketPosixRename struct {
ID uint32
ExtendedRequest string
Oldpath string
Newpath string
}
func (p *sshFxpExtendedPacketPosixRename) id() uint32 { return p.ID }
func (p *sshFxpExtendedPacketPosixRename) readonly() bool { return false }
func (p *sshFxpExtendedPacketPosixRename) UnmarshalBinary(b []byte) error {
var err error
if p.ID, b, err = unmarshalUint32Safe(b); err != nil {
return err
} else if p.ExtendedRequest, b, err = unmarshalStringSafe(b); err != nil {
return err
} else if p.Oldpath, b, err = unmarshalStringSafe(b); err != nil {
return err
} else if p.Newpath, _, err = unmarshalStringSafe(b); err != nil {
return err
}
return nil
}
func (p *sshFxpExtendedPacketPosixRename) respond(s *Server) responsePacket {
err := os.Rename(s.toLocalPath(p.Oldpath), s.toLocalPath(p.Newpath))
return statusFromError(p.ID, err)
}
type sshFxpExtendedPacketHardlink struct {
ID uint32
ExtendedRequest string
Oldpath string
Newpath string
}
// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL
func (p *sshFxpExtendedPacketHardlink) id() uint32 { return p.ID }
func (p *sshFxpExtendedPacketHardlink) readonly() bool { return true }
func (p *sshFxpExtendedPacketHardlink) UnmarshalBinary(b []byte) error {
var err error
if p.ID, b, err = unmarshalUint32Safe(b); err != nil {
return err
} else if p.ExtendedRequest, b, err = unmarshalStringSafe(b); err != nil {
return err
} else if p.Oldpath, b, err = unmarshalStringSafe(b); err != nil {
return err
} else if p.Newpath, _, err = unmarshalStringSafe(b); err != nil {
return err
}
return nil
}
func (p *sshFxpExtendedPacketHardlink) respond(s *Server) responsePacket {
err := os.Link(s.toLocalPath(p.Oldpath), s.toLocalPath(p.Newpath))
return statusFromError(p.ID, err)
}
================================================
FILE: gossh/sftp/pool.go
================================================
package sftp
// bufPool provides a pool of byte-slices to be reused in various parts of the package.
// It is safe to use concurrently through a pointer.
type bufPool struct {
ch chan []byte
blen int
}
func newBufPool(depth, bufLen int) *bufPool {
return &bufPool{
ch: make(chan []byte, depth),
blen: bufLen,
}
}
func (p *bufPool) Get() []byte {
if p.blen <= 0 {
panic("bufPool: new buffer creation length must be greater than zero")
}
for {
select {
case b := <-p.ch:
if cap(b) < p.blen {
// just in case: throw away any buffer with insufficient capacity.
continue
}
return b[:p.blen]
default:
return make([]byte, p.blen)
}
}
}
func (p *bufPool) Put(b []byte) {
if p == nil {
// functional default: no reuse.
return
}
if cap(b) < p.blen || cap(b) > p.blen*2 {
// DO NOT reuse buffers with insufficient capacity.
// This could cause panics when resizing to p.blen.
// DO NOT reuse buffers with excessive capacity.
// This could cause memory leaks.
return
}
select {
case p.ch <- b:
default:
}
}
type resChanPool chan chan result
func newResChanPool(depth int) resChanPool {
return make(chan chan result, depth)
}
func (p resChanPool) Get() chan result {
select {
case ch := <-p:
return ch
default:
return make(chan result, 1)
}
}
func (p resChanPool) Put(ch chan result) {
select {
case p <- ch:
default:
}
}
================================================
FILE: gossh/sftp/release.go
================================================
//go:build !debug
// +build !debug
package sftp
func debug(fmt string, args ...any) {}
================================================
FILE: gossh/sftp/request-attrs.go
================================================
package sftp
// Methods on the Request object to make working with the Flags bitmasks and
// Attr(ibutes) byte blob easier. Use Pflags() when working with an Open/Write
// request and AttrFlags() and Attributes() when working with SetStat requests.
// FileOpenFlags defines Open and Write Flags. Correlate directly with with os.OpenFile flags
// (https://golang.org/pkg/os/#pkg-constants).
type FileOpenFlags struct {
Read, Write, Append, Creat, Trunc, Excl bool
}
func newFileOpenFlags(flags uint32) FileOpenFlags {
return FileOpenFlags{
Read: flags&sshFxfRead != 0,
Write: flags&sshFxfWrite != 0,
Append: flags&sshFxfAppend != 0,
Creat: flags&sshFxfCreat != 0,
Trunc: flags&sshFxfTrunc != 0,
Excl: flags&sshFxfExcl != 0,
}
}
// Pflags converts the bitmap/uint32 from SFTP Open packet pflag values,
// into a FileOpenFlags struct with booleans set for flags set in bitmap.
func (r *Request) Pflags() FileOpenFlags {
return newFileOpenFlags(r.Flags)
}
// FileAttrFlags that indicate whether SFTP file attributes were passed. When a flag is
// true the corresponding attribute should be available from the FileStat
// object returned by Attributes method. Used with SetStat.
type FileAttrFlags struct {
Size, UidGid, Permissions, Acmodtime bool
}
func newFileAttrFlags(flags uint32) FileAttrFlags {
return FileAttrFlags{
Size: (flags & sshFileXferAttrSize) != 0,
UidGid: (flags & sshFileXferAttrUIDGID) != 0,
Permissions: (flags & sshFileXferAttrPermissions) != 0,
Acmodtime: (flags & sshFileXferAttrACmodTime) != 0,
}
}
// AttrFlags returns a FileAttrFlags boolean struct based on the
// bitmap/uint32 file attribute flags from the SFTP packaet.
func (r *Request) AttrFlags() FileAttrFlags {
return newFileAttrFlags(r.Flags)
}
// Attributes parses file attributes byte blob and return them in a
// FileStat object.
func (r *Request) Attributes() *FileStat {
fs, _, _ := unmarshalFileStat(r.Flags, r.Attrs)
return fs
}
================================================
FILE: gossh/sftp/request-errors.go
================================================
package sftp
type fxerr uint32
// Error types that match the SFTP's SSH_FXP_STATUS codes. Gives you more
// direct control of the errors being sent vs. letting the library work them
// out from the standard os/io errors.
const (
ErrSSHFxOk = fxerr(sshFxOk)
ErrSSHFxEOF = fxerr(sshFxEOF)
ErrSSHFxNoSuchFile = fxerr(sshFxNoSuchFile)
ErrSSHFxPermissionDenied = fxerr(sshFxPermissionDenied)
ErrSSHFxFailure = fxerr(sshFxFailure)
ErrSSHFxBadMessage = fxerr(sshFxBadMessage)
ErrSSHFxNoConnection = fxerr(sshFxNoConnection)
ErrSSHFxConnectionLost = fxerr(sshFxConnectionLost)
ErrSSHFxOpUnsupported = fxerr(sshFxOPUnsupported)
)
// Deprecated error types, these are aliases for the new ones, please use the new ones directly
const (
ErrSshFxOk = ErrSSHFxOk
ErrSshFxEof = ErrSSHFxEOF
ErrSshFxNoSuchFile = ErrSSHFxNoSuchFile
ErrSshFxPermissionDenied = ErrSSHFxPermissionDenied
ErrSshFxFailure = ErrSSHFxFailure
ErrSshFxBadMessage = ErrSSHFxBadMessage
ErrSshFxNoConnection = ErrSSHFxNoConnection
ErrSshFxConnectionLost = ErrSSHFxConnectionLost
ErrSshFxOpUnsupported = ErrSSHFxOpUnsupported
)
func (e fxerr) Error() string {
switch e {
case ErrSSHFxOk:
return "OK"
case ErrSSHFxEOF:
return "EOF"
case ErrSSHFxNoSuchFile:
return "no such file"
case ErrSSHFxPermissionDenied:
return "permission denied"
case ErrSSHFxBadMessage:
return "bad message"
case ErrSSHFxNoConnection:
return "no connection"
case ErrSSHFxConnectionLost:
return "connection lost"
case ErrSSHFxOpUnsupported:
return "operation unsupported"
default:
return "failure"
}
}
================================================
FILE: gossh/sftp/request-example.go
================================================
package sftp
// This serves as an example of how to implement the request server handler as
// well as a dummy backend for testing. It implements an in-memory backend that
// works as a very simple filesystem with simple flat key-value lookup system.
import (
"errors"
"io"
"os"
"path"
"sort"
"strings"
"sync"
"syscall"
"time"
)
const maxSymlinkFollows = 5
var errTooManySymlinks = errors.New("too many symbolic links")
// InMemHandler returns a Hanlders object with the test handlers.
func InMemHandler() Handlers {
root := &root{
rootFile: &memFile{name: "/", modtime: time.Now(), isdir: true},
files: make(map[string]*memFile),
}
return Handlers{root, root, root, root}
}
// Example Handlers
func (fs *root) Fileread(r *Request) (io.ReaderAt, error) {
flags := r.Pflags()
if !flags.Read {
// sanity check
return nil, os.ErrInvalid
}
return fs.OpenFile(r)
}
func (fs *root) Filewrite(r *Request) (io.WriterAt, error) {
flags := r.Pflags()
if !flags.Write {
// sanity check
return nil, os.ErrInvalid
}
return fs.OpenFile(r)
}
func (fs *root) OpenFile(r *Request) (WriterAtReaderAt, error) {
if fs.mockErr != nil {
return nil, fs.mockErr
}
_ = r.WithContext(r.Context()) // initialize context for deadlock testing
fs.mu.Lock()
defer fs.mu.Unlock()
return fs.openfile(r.Filepath, r.Flags)
}
func (fs *root) putfile(pathname string, file *memFile) error {
pathname, err := fs.canonName(pathname)
if err != nil {
return err
}
if !strings.HasPrefix(pathname, "/") {
return os.ErrInvalid
}
if _, err := fs.lfetch(pathname); err != os.ErrNotExist {
return os.ErrExist
}
file.name = pathname
fs.files[pathname] = file
return nil
}
func (fs *root) openfile(pathname string, flags uint32) (*memFile, error) {
pflags := newFileOpenFlags(flags)
file, err := fs.fetch(pathname)
if err == os.ErrNotExist {
if !pflags.Creat {
return nil, os.ErrNotExist
}
var count int
// You can create files through dangling symlinks.
link, err := fs.lfetch(pathname)
for err == nil && link.symlink != "" {
if pflags.Excl {
// unless you also passed in O_EXCL
return nil, os.ErrInvalid
}
if count++; count > maxSymlinkFollows {
return nil, errTooManySymlinks
}
pathname = link.symlink
link, err = fs.lfetch(pathname)
}
file := &memFile{
modtime: time.Now(),
}
if err := fs.putfile(pathname, file); err != nil {
return nil, err
}
return file, nil
}
if err != nil {
return nil, err
}
if pflags.Creat && pflags.Excl {
return nil, os.ErrExist
}
if file.IsDir() {
return nil, os.ErrInvalid
}
if pflags.Trunc {
if err := file.Truncate(0); err != nil {
return nil, err
}
}
return file, nil
}
func (fs *root) Filecmd(r *Request) error {
if fs.mockErr != nil {
return fs.mockErr
}
_ = r.WithContext(r.Context()) // initialize context for deadlock testing
fs.mu.Lock()
defer fs.mu.Unlock()
switch r.Method {
case "Setstat":
file, err := fs.openfile(r.Filepath, sshFxfWrite)
if err != nil {
return err
}
if r.AttrFlags().Size {
return file.Truncate(int64(r.Attributes().Size))
}
return nil
case "Rename":
// SFTP-v2: "It is an error if there already exists a file with the name specified by newpath."
// This varies from the POSIX specification, which allows limited replacement of target files.
if fs.exists(r.Target) {
return os.ErrExist
}
return fs.rename(r.Filepath, r.Target)
case "Rmdir":
return fs.rmdir(r.Filepath)
case "Remove":
// IEEE 1003.1 remove explicitly can unlink files and remove empty directories.
// We use instead here the semantics of unlink, which is allowed to be restricted against directories.
return fs.unlink(r.Filepath)
case "Mkdir":
return fs.mkdir(r.Filepath)
case "Link":
return fs.link(r.Filepath, r.Target)
case "Symlink":
// NOTE: r.Filepath is the target, and r.Target is the linkpath.
return fs.symlink(r.Filepath, r.Target)
}
return errors.New("unsupported")
}
func (fs *root) rename(oldpath, newpath string) error {
file, err := fs.lfetch(oldpath)
if err != nil {
return err
}
newpath, err = fs.canonName(newpath)
if err != nil {
return err
}
if !strings.HasPrefix(newpath, "/") {
return os.ErrInvalid
}
target, err := fs.lfetch(newpath)
if err != os.ErrNotExist {
if target == file {
// IEEE 1003.1: if oldpath and newpath are the same directory entry,
// then return no error, and perform no further action.
return nil
}
switch {
case file.IsDir():
// IEEE 1003.1: if oldpath is a directory, and newpath exists,
// then newpath must be a directory, and empty.
// It is to be removed prior to rename.
if err := fs.rmdir(newpath); err != nil {
return err
}
case target.IsDir():
// IEEE 1003.1: if oldpath is not a directory, and newpath exists,
// then newpath may not be a directory.
return syscall.EISDIR
}
}
fs.files[newpath] = file
if file.IsDir() {
dirprefix := file.name + "/"
for name, file := range fs.files {
if strings.HasPrefix(name, dirprefix) {
newname := path.Join(newpath, strings.TrimPrefix(name, dirprefix))
fs.files[newname] = file
file.name = newname
delete(fs.files, name)
}
}
}
file.name = newpath
delete(fs.files, oldpath)
return nil
}
func (fs *root) PosixRename(r *Request) error {
if fs.mockErr != nil {
return fs.mockErr
}
_ = r.WithContext(r.Context()) // initialize context for deadlock testing
fs.mu.Lock()
defer fs.mu.Unlock()
return fs.rename(r.Filepath, r.Target)
}
func (fs *root) StatVFS(r *Request) (*StatVFS, error) {
if fs.mockErr != nil {
return nil, fs.mockErr
}
return getStatVFSForPath(r.Filepath)
}
func (fs *root) mkdir(pathname string) error {
dir := &memFile{
modtime: time.Now(),
isdir: true,
}
return fs.putfile(pathname, dir)
}
func (fs *root) rmdir(pathname string) error {
// IEEE 1003.1: If pathname is a symlink, then rmdir should fail with ENOTDIR.
dir, err := fs.lfetch(pathname)
if err != nil {
return err
}
if !dir.IsDir() {
return syscall.ENOTDIR
}
// use the dir‘s internal name not the pathname we passed in.
// the dir.name is always the canonical name of a directory.
pathname = dir.name
for name := range fs.files {
if path.Dir(name) == pathname {
return errors.New("directory not empty")
}
}
delete(fs.files, pathname)
return nil
}
func (fs *root) link(oldpath, newpath string) error {
file, err := fs.lfetch(oldpath)
if err != nil {
return err
}
if file.IsDir() {
return errors.New("hard link not allowed for directory")
}
return fs.putfile(newpath, file)
}
// symlink() creates a symbolic link named `linkpath` which contains the string `target`.
// NOTE! This would be called with `symlink(req.Filepath, req.Target)` due to different semantics.
func (fs *root) symlink(target, linkpath string) error {
link := &memFile{
modtime: time.Now(),
symlink: target,
}
return fs.putfile(linkpath, link)
}
func (fs *root) unlink(pathname string) error {
// does not follow symlinks!
file, err := fs.lfetch(pathname)
if err != nil {
return err
}
if file.IsDir() {
// IEEE 1003.1: implementations may opt out of allowing the unlinking of directories.
// SFTP-v2: SSH_FXP_REMOVE may not remove directories.
return os.ErrInvalid
}
// DO NOT use the file’s internal name.
// because of hard-links files cannot have a single canonical name.
delete(fs.files, pathname)
return nil
}
type listerat []os.FileInfo
// Modeled after strings.Reader's ReadAt() implementation
func (f listerat) ListAt(ls []os.FileInfo, offset int64) (int, error) {
var n int
if offset >= int64(len(f)) {
return 0, io.EOF
}
n = copy(ls, f[offset:])
if n < len(ls) {
return n, io.EOF
}
return n, nil
}
func (fs *root) Filelist(r *Request) (ListerAt, error) {
if fs.mockErr != nil {
return nil, fs.mockErr
}
_ = r.WithContext(r.Context()) // initialize context for deadlock testing
fs.mu.Lock()
defer fs.mu.Unlock()
switch r.Method {
case "List":
files, err := fs.readdir(r.Filepath)
if err != nil {
return nil, err
}
return listerat(files), nil
case "Stat":
file, err := fs.fetch(r.Filepath)
if err != nil {
return nil, err
}
return listerat{file}, nil
}
return nil, errors.New("unsupported")
}
func (fs *root) readdir(pathname string) ([]os.FileInfo, error) {
dir, err := fs.fetch(pathname)
if err != nil {
return nil, err
}
if !dir.IsDir() {
return nil, syscall.ENOTDIR
}
var files []os.FileInfo
for name, file := range fs.files {
if path.Dir(name) == dir.name {
files = append(files, file)
}
}
sort.Slice(files, func(i, j int) bool { return files[i].Name() < files[j].Name() })
return files, nil
}
func (fs *root) Readlink(pathname string) (string, error) {
file, err := fs.lfetch(pathname)
if err != nil {
return "", err
}
if file.symlink == "" {
return "", os.ErrInvalid
}
return file.symlink, nil
}
// implements LstatFileLister interface
func (fs *root) Lstat(r *Request) (ListerAt, error) {
if fs.mockErr != nil {
return nil, fs.mockErr
}
_ = r.WithContext(r.Context()) // initialize context for deadlock testing
fs.mu.Lock()
defer fs.mu.Unlock()
file, err := fs.lfetch(r.Filepath)
if err != nil {
return nil, err
}
return listerat{file}, nil
}
// In memory file-system-y thing that the Hanlders live on
type root struct {
rootFile *memFile
mockErr error
mu sync.Mutex
files map[string]*memFile
}
// Set a mocked error that the next handler call will return.
// Set to nil to reset for no error.
func (fs *root) returnErr(err error) {
fs.mockErr = err
}
func (fs *root) lfetch(path string) (*memFile, error) {
if path == "/" {
return fs.rootFile, nil
}
file, ok := fs.files[path]
if file == nil {
if ok {
delete(fs.files, path)
}
return nil, os.ErrNotExist
}
return file, nil
}
// canonName returns the “canonical” name of a file, that is:
// if the directory of the pathname is a symlink, it follows that symlink to the valid directory name.
// this is relatively easy, since `dir.name` will be the only valid canonical path for a directory.
func (fs *root) canonName(pathname string) (string, error) {
dirname, filename := path.Dir(pathname), path.Base(pathname)
dir, err := fs.fetch(dirname)
if err != nil {
return "", err
}
if !dir.IsDir() {
return "", syscall.ENOTDIR
}
return path.Join(dir.name, filename), nil
}
func (fs *root) exists(path string) bool {
path, err := fs.canonName(path)
if err != nil {
return false
}
_, err = fs.lfetch(path)
return err != os.ErrNotExist
}
func (fs *root) fetch(pathname string) (*memFile, error) {
file, err := fs.lfetch(pathname)
if err != nil {
return nil, err
}
var count int
for file.symlink != "" {
if count++; count > maxSymlinkFollows {
return nil, errTooManySymlinks
}
linkTarget := file.symlink
if !path.IsAbs(linkTarget) {
linkTarget = path.Join(path.Dir(file.name), linkTarget)
}
file, err = fs.lfetch(linkTarget)
if err != nil {
return nil, err
}
}
return file, nil
}
// Implements os.FileInfo, io.ReaderAt and io.WriterAt interfaces.
// These are the 3 interfaces necessary for the Handlers.
// Implements the optional interface TransferError.
type memFile struct {
name string
modtime time.Time
symlink string
isdir bool
mu sync.RWMutex
content []byte
err error
}
// These are helper functions, they must be called while holding the memFile.mu mutex
func (f *memFile) size() int64 { return int64(len(f.content)) }
func (f *memFile) grow(n int64) { f.content = append(f.content, make([]byte, n)...) }
// Have memFile fulfill os.FileInfo interface
func (f *memFile) Name() string { return path.Base(f.name) }
func (f *memFile) Size() int64 {
f.mu.Lock()
defer f.mu.Unlock()
return f.size()
}
func (f *memFile) Mode() os.FileMode {
if f.isdir {
return os.FileMode(0755) | os.ModeDir
}
if f.symlink != "" {
return os.FileMode(0777) | os.ModeSymlink
}
return os.FileMode(0644)
}
func (f *memFile) ModTime() time.Time { return f.modtime }
func (f *memFile) IsDir() bool { return f.isdir }
func (f *memFile) Sys() any {
return fakeFileInfoSys()
}
func (f *memFile) ReadAt(b []byte, off int64) (int, error) {
f.mu.Lock()
defer f.mu.Unlock()
if f.err != nil {
return 0, f.err
}
if off < 0 {
return 0, errors.New("memFile.ReadAt: negative offset")
}
if off >= f.size() {
return 0, io.EOF
}
n := copy(b, f.content[off:])
if n < len(b) {
return n, io.EOF
}
return n, nil
}
func (f *memFile) WriteAt(b []byte, off int64) (int, error) {
// fmt.Println(string(p), off)
// mimic write delays, should be optional
time.Sleep(time.Microsecond * time.Duration(len(b)))
f.mu.Lock()
defer f.mu.Unlock()
if f.err != nil {
return 0, f.err
}
grow := int64(len(b)) + off - f.size()
if grow > 0 {
f.grow(grow)
}
return copy(f.content[off:], b), nil
}
func (f *memFile) Truncate(size int64) error {
f.mu.Lock()
defer f.mu.Unlock()
if f.err != nil {
return f.err
}
grow := size - f.size()
if grow <= 0 {
f.content = f.content[:size]
} else {
f.grow(grow)
}
return nil
}
func (f *memFile) TransferError(err error) {
f.mu.Lock()
defer f.mu.Unlock()
f.err = err
}
================================================
FILE: gossh/sftp/request-interfaces.go
================================================
package sftp
import (
"io"
"os"
)
// WriterAtReaderAt defines the interface to return when a file is to
// be opened for reading and writing
type WriterAtReaderAt interface {
io.WriterAt
io.ReaderAt
}
// Interfaces are differentiated based on required returned values.
// All input arguments are to be pulled from Request (the only arg).
// The Handler interfaces all take the Request object as its only argument.
// All the data you should need to handle the call are in the Request object.
// The request.Method attribute is initially the most important one as it
// determines which Handler gets called.
// FileReader should return an io.ReaderAt for the filepath
// Note in cases of an error, the error text will be sent to the client.
// Called for Methods: Get
type FileReader interface {
Fileread(*Request) (io.ReaderAt, error)
}
// FileWriter should return an io.WriterAt for the filepath.
//
// The request server code will call Close() on the returned io.WriterAt
// object if an io.Closer type assertion succeeds.
// Note in cases of an error, the error text will be sent to the client.
// Note when receiving an Append flag it is important to not open files using
// O_APPEND if you plan to use WriteAt, as they conflict.
// Called for Methods: Put, Open
type FileWriter interface {
Filewrite(*Request) (io.WriterAt, error)
}
// OpenFileWriter is a FileWriter that implements the generic OpenFile method.
// You need to implement this optional interface if you want to be able
// to read and write from/to the same handle.
// Called for Methods: Open
type OpenFileWriter interface {
FileWriter
OpenFile(*Request) (WriterAtReaderAt, error)
}
// FileCmder should return an error
// Note in cases of an error, the error text will be sent to the client.
// Called for Methods: Setstat, Rename, Rmdir, Mkdir, Link, Symlink, Remove
type FileCmder interface {
Filecmd(*Request) error
}
// PosixRenameFileCmder is a FileCmder that implements the PosixRename method.
// If this interface is implemented PosixRename requests will call it
// otherwise they will be handled in the same way as Rename
type PosixRenameFileCmder interface {
FileCmder
PosixRename(*Request) error
}
// StatVFSFileCmder is a FileCmder that implements the StatVFS method.
// You need to implement this interface if you want to handle statvfs requests.
// Please also be sure that the statvfs@openssh.com extension is enabled
type StatVFSFileCmder interface {
FileCmder
StatVFS(*Request) (*StatVFS, error)
}
// FileLister should return an object that fulfils the ListerAt interface
// Note in cases of an error, the error text will be sent to the client.
// Called for Methods: List, Stat, Readlink
//
// Since Filelist returns an os.FileInfo, this can make it non-ideal for implementing Readlink.
// This is because the Name receiver method defined by that interface defines that it should only return the base name.
// However, Readlink is required to be capable of returning essentially any arbitrary valid path relative or absolute.
// In order to implement this more expressive requirement, implement [ReadlinkFileLister] which will then be used instead.
type FileLister interface {
Filelist(*Request) (ListerAt, error)
}
// LstatFileLister is a FileLister that implements the Lstat method.
// If this interface is implemented Lstat requests will call it
// otherwise they will be handled in the same way as Stat
type LstatFileLister interface {
FileLister
Lstat(*Request) (ListerAt, error)
}
// RealPathFileLister is a FileLister that implements the Realpath method.
// The built-in RealPath implementation does not resolve symbolic links.
// By implementing this interface you can customize the returned path
// and, for example, resolve symbolinc links if needed for your use case.
// You have to return an absolute POSIX path.
//
// Up to v1.13.5 the signature for the RealPath method was:
//
// # RealPath(string) string
//
// we have added a legacyRealPathFileLister that implements the old method
// to ensure that your code does not break.
// You should use the new method signature to avoid future issues
type RealPathFileLister interface {
FileLister
RealPath(string) (string, error)
}
// ReadlinkFileLister is a FileLister that implements the Readlink method.
// By implementing the Readlink method, it is possible to return any arbitrary valid path relative or absolute.
// This allows giving a better response than via the default FileLister (which is limited to os.FileInfo, whose Name method should only return the base name of a file)
type ReadlinkFileLister interface {
FileLister
Readlink(string) (string, error)
}
// This interface is here for backward compatibility only
type legacyRealPathFileLister interface {
FileLister
RealPath(string) string
}
// NameLookupFileLister is a FileLister that implmeents the LookupUsername and LookupGroupName methods.
// If this interface is implemented, then longname ls formatting will use these to convert usernames and groupnames.
type NameLookupFileLister interface {
FileLister
LookupUserName(string) string
LookupGroupName(string) string
}
// ListerAt does for file lists what io.ReaderAt does for files, i.e. a []os.FileInfo buffer is passed to the ListAt function
// and the entries that are populated in the buffer will be passed to the client.
//
// ListAt should return the number of entries copied and an io.EOF error if at end of list.
// This is testable by comparing how many you copied to how many could be copied (eg. n < len(ls) below).
// The copy() builtin is best for the copying.
//
// Uid and gid information will on unix systems be retrieved from [os.FileInfo.Sys]
// if this function returns a [syscall.Stat_t] when called on a populated entry.
// Alternatively, if the entry implements [FileInfoUidGid], it will be used for uid and gid information.
//
// If a populated entry implements [FileInfoExtendedData], extended attributes will also be returned to the client.
//
// Note in cases of an error, the error text will be sent to the client.
type ListerAt interface {
ListAt([]os.FileInfo, int64) (int, error)
}
// TransferError is an optional interface that readerAt and writerAt
// can implement to be notified about the error causing Serve() to exit
// with the request still open
type TransferError interface {
TransferError(err error)
}
================================================
FILE: gossh/sftp/request-plan9.go
================================================
//go:build plan9
// +build plan9
package sftp
import (
"syscall"
)
func fakeFileInfoSys() any {
return &syscall.Dir{}
}
func testOsSys(sys any) error {
return nil
}
================================================
FILE: gossh/sftp/request-server.go
================================================
package sftp
import (
"context"
"errors"
"io"
"path"
"path/filepath"
"strconv"
"sync"
)
var maxTxPacket uint32 = 1 << 15
// Handlers contains the 4 SFTP server request handlers.
type Handlers struct {
FileGet FileReader
FilePut FileWriter
FileCmd FileCmder
FileList FileLister
}
// RequestServer abstracts the sftp protocol with an http request-like protocol
type RequestServer struct {
Handlers Handlers
*serverConn
pktMgr *packetManager
startDirectory string
mu sync.RWMutex
handleCount int
openRequests map[string]*Request
}
// A RequestServerOption is a function which applies configuration to a RequestServer.
type RequestServerOption func(*RequestServer)
// WithRSAllocator enable the allocator.
// After processing a packet we keep in memory the allocated slices
// and we reuse them for new packets.
// The allocator is experimental
func WithRSAllocator() RequestServerOption {
return func(rs *RequestServer) {
alloc := newAllocator()
rs.pktMgr.alloc = alloc
rs.conn.alloc = alloc
}
}
// WithStartDirectory sets a start directory to use as base for relative paths.
// If unset the default is "/"
func WithStartDirectory(startDirectory string) RequestServerOption {
return func(rs *RequestServer) {
rs.startDirectory = cleanPath(startDirectory)
}
}
// NewRequestServer creates/allocates/returns new RequestServer.
// Normally there will be one server per user-session.
func NewRequestServer(rwc io.ReadWriteCloser, h Handlers, options ...RequestServerOption) *RequestServer {
svrConn := &serverConn{
conn: conn{
Reader: rwc,
WriteCloser: rwc,
},
}
rs := &RequestServer{
Handlers: h,
serverConn: svrConn,
pktMgr: newPktMgr(svrConn),
startDirectory: "/",
openRequests: make(map[string]*Request),
}
for _, o := range options {
o(rs)
}
return rs
}
// New Open packet/Request
func (rs *RequestServer) nextRequest(r *Request) string {
rs.mu.Lock()
defer rs.mu.Unlock()
rs.handleCount++
r.handle = strconv.Itoa(rs.handleCount)
rs.openRequests[r.handle] = r
return r.handle
}
// Returns Request from openRequests, bool is false if it is missing.
//
// The Requests in openRequests work essentially as open file descriptors that
// you can do different things with. What you are doing with it are denoted by
// the first packet of that type (read/write/etc).
func (rs *RequestServer) getRequest(handle string) (*Request, bool) {
rs.mu.RLock()
defer rs.mu.RUnlock()
r, ok := rs.openRequests[handle]
return r, ok
}
// Close the Request and clear from openRequests map
func (rs *RequestServer) closeRequest(handle string) error {
rs.mu.Lock()
defer rs.mu.Unlock()
if r, ok := rs.openRequests[handle]; ok {
delete(rs.openRequests, handle)
return r.close()
}
return EBADF
}
// Close the read/write/closer to trigger exiting the main server loop
func (rs *RequestServer) Close() error { return rs.conn.Close() }
func (rs *RequestServer) serveLoop(pktChan chan<- orderedRequest) error {
defer close(pktChan) // shuts down sftpServerWorkers
var err error
var pkt requestPacket
var pktType uint8
var pktBytes []byte
for {
pktType, pktBytes, err = rs.serverConn.recvPacket(rs.pktMgr.getNextOrderID())
if err != nil {
// we don't care about releasing allocated pages here, the server will quit and the allocator freed
return err
}
pkt, err = makePacket(rxPacket{fxp(pktType), pktBytes})
if err != nil {
switch {
case errors.Is(err, errUnknownExtendedPacket):
// do nothing
default:
debug("makePacket err: %v", err)
rs.conn.Close() // shuts down recvPacket
return err
}
}
pktChan <- rs.pktMgr.newOrderedRequest(pkt)
}
}
// Serve requests for user session
func (rs *RequestServer) Serve() error {
defer func() {
if rs.pktMgr.alloc != nil {
rs.pktMgr.alloc.Free()
}
}()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var wg sync.WaitGroup
runWorker := func(ch chan orderedRequest) {
wg.Add(1)
go func() {
defer wg.Done()
if err := rs.packetWorker(ctx, ch); err != nil {
rs.conn.Close() // shuts down recvPacket
}
}()
}
pktChan := rs.pktMgr.workerChan(runWorker)
err := rs.serveLoop(pktChan)
wg.Wait() // wait for all workers to exit
rs.mu.Lock()
defer rs.mu.Unlock()
// make sure all open requests are properly closed
// (eg. possible on dropped connections, client crashes, etc.)
for handle, req := range rs.openRequests {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
req.transferError(err)
delete(rs.openRequests, handle)
req.close()
}
return err
}
func (rs *RequestServer) packetWorker(ctx context.Context, pktChan chan orderedRequest) error {
for pkt := range pktChan {
orderID := pkt.orderID()
if epkt, ok := pkt.requestPacket.(*sshFxpExtendedPacket); ok {
if epkt.SpecificPacket != nil {
pkt.requestPacket = epkt.SpecificPacket
}
}
var rpkt responsePacket
switch pkt := pkt.requestPacket.(type) {
case *sshFxInitPacket:
rpkt = &sshFxVersionPacket{Version: sftpProtocolVersion, Extensions: sftpExtensions}
case *sshFxpClosePacket:
handle := pkt.getHandle()
rpkt = statusFromError(pkt.ID, rs.closeRequest(handle))
case *sshFxpRealpathPacket:
var realPath string
var err error
switch pather := rs.Handlers.FileList.(type) {
case RealPathFileLister:
realPath, err = pather.RealPath(pkt.getPath())
case legacyRealPathFileLister:
realPath = pather.RealPath(pkt.getPath())
default:
realPath = cleanPathWithBase(rs.startDirectory, pkt.getPath())
}
if err != nil {
rpkt = statusFromError(pkt.ID, err)
} else {
rpkt = cleanPacketPath(pkt, realPath)
}
case *sshFxpOpendirPacket:
request := requestFromPacket(ctx, pkt, rs.startDirectory)
handle := rs.nextRequest(request)
rpkt = request.opendir(rs.Handlers, pkt)
if _, ok := rpkt.(*sshFxpHandlePacket); !ok {
// if we return an error we have to remove the handle from the active ones
rs.closeRequest(handle)
}
case *sshFxpOpenPacket:
request := requestFromPacket(ctx, pkt, rs.startDirectory)
handle := rs.nextRequest(request)
rpkt = request.open(rs.Handlers, pkt)
if _, ok := rpkt.(*sshFxpHandlePacket); !ok {
// if we return an error we have to remove the handle from the active ones
rs.closeRequest(handle)
}
case *sshFxpFstatPacket:
handle := pkt.getHandle()
request, ok := rs.getRequest(handle)
if !ok {
rpkt = statusFromError(pkt.ID, EBADF)
} else {
request = &Request{
Method: "Stat",
Filepath: cleanPathWithBase(rs.startDirectory, request.Filepath),
}
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
}
case *sshFxpFsetstatPacket:
handle := pkt.getHandle()
request, ok := rs.getRequest(handle)
if !ok {
rpkt = statusFromError(pkt.ID, EBADF)
} else {
request = &Request{
Method: "Setstat",
Filepath: cleanPathWithBase(rs.startDirectory, request.Filepath),
}
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
}
case *sshFxpExtendedPacketPosixRename:
request := &Request{
Method: "PosixRename",
Filepath: cleanPathWithBase(rs.startDirectory, pkt.Oldpath),
Target: cleanPathWithBase(rs.startDirectory, pkt.Newpath),
}
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
case *sshFxpExtendedPacketStatVFS:
request := &Request{
Method: "StatVFS",
Filepath: cleanPathWithBase(rs.startDirectory, pkt.Path),
}
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
case hasHandle:
handle := pkt.getHandle()
request, ok := rs.getRequest(handle)
if !ok {
rpkt = statusFromError(pkt.id(), EBADF)
} else {
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
}
case hasPath:
request := requestFromPacket(ctx, pkt, rs.startDirectory)
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
request.close()
default:
rpkt = statusFromError(pkt.id(), ErrSSHFxOpUnsupported)
}
rs.pktMgr.readyPacket(
rs.pktMgr.newOrderedResponse(rpkt, orderID))
}
return nil
}
// clean and return name packet for file
func cleanPacketPath(pkt *sshFxpRealpathPacket, realPath string) responsePacket {
return &sshFxpNamePacket{
ID: pkt.id(),
NameAttrs: []*sshFxpNameAttr{
{
Name: realPath,
LongName: realPath,
Attrs: emptyFileStat,
},
},
}
}
// Makes sure we have a clean POSIX (/) absolute path to work with
func cleanPath(p string) string {
return cleanPathWithBase("/", p)
}
func cleanPathWithBase(base, p string) string {
p = filepath.ToSlash(filepath.Clean(p))
if !path.IsAbs(p) {
return path.Join(base, p)
}
return p
}
================================================
FILE: gossh/sftp/request-unix.go
================================================
//go:build !windows && !plan9
// +build !windows,!plan9
package sftp
import (
"errors"
"syscall"
)
func fakeFileInfoSys() any {
return &syscall.Stat_t{Uid: 65534, Gid: 65534}
}
func testOsSys(sys any) error {
fstat := sys.(*FileStat)
if fstat.UID != uint32(65534) {
return errors.New("Uid failed to match")
}
if fstat.GID != uint32(65534) {
return errors.New("Gid failed to match")
}
return nil
}
================================================
FILE: gossh/sftp/request.go
================================================
package sftp
import (
"context"
"errors"
"fmt"
"io"
"os"
"strings"
"sync"
"syscall"
)
// MaxFilelist is the max number of files to return in a readdir batch.
var MaxFilelist int64 = 100
// state encapsulates the reader/writer/readdir from handlers.
type state struct {
mu sync.RWMutex
writerAt io.WriterAt
readerAt io.ReaderAt
writerAtReaderAt WriterAtReaderAt
listerAt ListerAt
lsoffset int64
}
// copy returns a shallow copy the state.
// This is broken out to specific fields,
// because we have to copy around the mutex in state.
func (s *state) copy() state {
s.mu.RLock()
defer s.mu.RUnlock()
return state{
writerAt: s.writerAt,
readerAt: s.readerAt,
writerAtReaderAt: s.writerAtReaderAt,
listerAt: s.listerAt,
lsoffset: s.lsoffset,
}
}
func (s *state) setReaderAt(rd io.ReaderAt) {
s.mu.Lock()
defer s.mu.Unlock()
s.readerAt = rd
}
func (s *state) getReaderAt() io.ReaderAt {
s.mu.RLock()
defer s.mu.RUnlock()
return s.readerAt
}
func (s *state) setWriterAt(rd io.WriterAt) {
s.mu.Lock()
defer s.mu.Unlock()
s.writerAt = rd
}
func (s *state) getWriterAt() io.WriterAt {
s.mu.RLock()
defer s.mu.RUnlock()
return s.writerAt
}
func (s *state) setWriterAtReaderAt(rw WriterAtReaderAt) {
s.mu.Lock()
defer s.mu.Unlock()
s.writerAtReaderAt = rw
}
func (s *state) getWriterAtReaderAt() WriterAtReaderAt {
s.mu.RLock()
defer s.mu.RUnlock()
return s.writerAtReaderAt
}
func (s *state) getAllReaderWriters() (io.ReaderAt, io.WriterAt, WriterAtReaderAt) {
s.mu.RLock()
defer s.mu.RUnlock()
return s.readerAt, s.writerAt, s.writerAtReaderAt
}
// Returns current offset for file list
func (s *state) lsNext() int64 {
s.mu.RLock()
defer s.mu.RUnlock()
return s.lsoffset
}
// Increases next offset
func (s *state) lsInc(offset int64) {
s.mu.Lock()
defer s.mu.Unlock()
s.lsoffset += offset
}
// manage file read/write state
func (s *state) setListerAt(la ListerAt) {
s.mu.Lock()
defer s.mu.Unlock()
s.listerAt = la
}
func (s *state) getListerAt() ListerAt {
s.mu.RLock()
defer s.mu.RUnlock()
return s.listerAt
}
// Request contains the data and state for the incoming service request.
type Request struct {
// Get, Put, Setstat, Stat, Rename, Remove
// Rmdir, Mkdir, List, Readlink, Link, Symlink
Method string
Filepath string
Flags uint32
Attrs []byte // convert to sub-struct
Target string // for renames and sym-links
handle string
// reader/writer/readdir from handlers
state
// context lasts duration of request
ctx context.Context
cancelCtx context.CancelFunc
}
// NewRequest creates a new Request object.
func NewRequest(method, path string) *Request {
return &Request{
Method: method,
Filepath: cleanPath(path),
}
}
// copy returns a shallow copy of existing request.
// This is broken out to specific fields,
// because we have to copy around the mutex in state.
func (r *Request) copy() *Request {
return &Request{
Method: r.Method,
Filepath: r.Filepath,
Flags: r.Flags,
Attrs: r.Attrs,
Target: r.Target,
handle: r.handle,
state: r.state.copy(),
ctx: r.ctx,
cancelCtx: r.cancelCtx,
}
}
// New Request initialized based on packet data
func requestFromPacket(ctx context.Context, pkt hasPath, baseDir string) *Request {
request := &Request{
Method: requestMethod(pkt),
Filepath: cleanPathWithBase(baseDir, pkt.getPath()),
}
request.ctx, request.cancelCtx = context.WithCancel(ctx)
switch p := pkt.(type) {
case *sshFxpOpenPacket:
request.Flags = p.Pflags
request.Attrs = p.Attrs.([]byte)
case *sshFxpSetstatPacket:
request.Flags = p.Flags
request.Attrs = p.Attrs.([]byte)
case *sshFxpRenamePacket:
request.Target = cleanPathWithBase(baseDir, p.Newpath)
case *sshFxpSymlinkPacket:
// NOTE: given a POSIX compliant signature: symlink(target, linkpath string)
// this makes Request.Target the linkpath, and Request.Filepath the target.
request.Target = cleanPathWithBase(baseDir, p.Linkpath)
request.Filepath = p.Targetpath
case *sshFxpExtendedPacketHardlink:
request.Target = cleanPathWithBase(baseDir, p.Newpath)
}
return request
}
// Context returns the request's context. To change the context,
// use WithContext.
//
// The returned context is always non-nil; it defaults to the
// background context.
//
// For incoming server requests, the context is canceled when the
// request is complete or the client's connection closes.
func (r *Request) Context() context.Context {
if r.ctx != nil {
return r.ctx
}
return context.Background()
}
// WithContext returns a copy of r with its context changed to ctx.
// The provided ctx must be non-nil.
func (r *Request) WithContext(ctx context.Context) *Request {
if ctx == nil {
panic("nil context")
}
r2 := r.copy()
r2.ctx = ctx
r2.cancelCtx = nil
return r2
}
// Close reader/writer if possible
func (r *Request) close() error {
defer func() {
if r.cancelCtx != nil {
r.cancelCtx()
}
}()
rd, wr, rw := r.getAllReaderWriters()
var err error
// Close errors on a Writer are far more likely to be the important one.
// As they can be information that there was a loss of data.
if c, ok := wr.(io.Closer); ok {
if err2 := c.Close(); err == nil {
// update error if it is still nil
err = err2
}
}
if c, ok := rw.(io.Closer); ok {
if err2 := c.Close(); err == nil {
// update error if it is still nil
err = err2
r.setWriterAtReaderAt(nil)
}
}
if c, ok := rd.(io.Closer); ok {
if err2 := c.Close(); err == nil {
// update error if it is still nil
err = err2
}
}
return err
}
// Notify transfer error if any
func (r *Request) transferError(err error) {
if err == nil {
return
}
rd, wr, rw := r.getAllReaderWriters()
if t, ok := wr.(TransferError); ok {
t.TransferError(err)
}
if t, ok := rw.(TransferError); ok {
t.TransferError(err)
}
if t, ok := rd.(TransferError); ok {
t.TransferError(err)
}
}
// called from worker to handle packet/request
func (r *Request) call(handlers Handlers, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket {
switch r.Method {
case "Get":
return fileget(handlers.FileGet, r, pkt, alloc, orderID)
case "Put":
return fileput(handlers.FilePut, r, pkt, alloc, orderID)
case "Open":
return fileputget(handlers.FilePut, r, pkt, alloc, orderID)
case "Setstat", "Rename", "Rmdir", "Mkdir", "Link", "Symlink", "Remove", "PosixRename", "StatVFS":
return filecmd(handlers.FileCmd, r, pkt)
case "List":
return filelist(handlers.FileList, r, pkt)
case "Stat", "Lstat":
return filestat(handlers.FileList, r, pkt)
case "Readlink":
if readlinkFileLister, ok := handlers.FileList.(ReadlinkFileLister); ok {
return readlink(readlinkFileLister, r, pkt)
}
return filestat(handlers.FileList, r, pkt)
default:
return statusFromError(pkt.id(), fmt.Errorf("unexpected method: %s", r.Method))
}
}
// Additional initialization for Open packets
func (r *Request) open(h Handlers, pkt requestPacket) responsePacket {
flags := r.Pflags()
id := pkt.id()
switch {
case flags.Write, flags.Append, flags.Creat, flags.Trunc:
if flags.Read {
if openFileWriter, ok := h.FilePut.(OpenFileWriter); ok {
r.Method = "Open"
rw, err := openFileWriter.OpenFile(r)
if err != nil {
return statusFromError(id, err)
}
r.setWriterAtReaderAt(rw)
return &sshFxpHandlePacket{
ID: id,
Handle: r.handle,
}
}
}
r.Method = "Put"
wr, err := h.FilePut.Filewrite(r)
if err != nil {
return statusFromError(id, err)
}
r.setWriterAt(wr)
case flags.Read:
r.Method = "Get"
rd, err := h.FileGet.Fileread(r)
if err != nil {
return statusFromError(id, err)
}
r.setReaderAt(rd)
default:
return statusFromError(id, errors.New("bad file flags"))
}
return &sshFxpHandlePacket{
ID: id,
Handle: r.handle,
}
}
func (r *Request) opendir(h Handlers, pkt requestPacket) responsePacket {
r.Method = "List"
la, err := h.FileList.Filelist(r)
if err != nil {
return statusFromError(pkt.id(), wrapPathError(r.Filepath, err))
}
r.setListerAt(la)
return &sshFxpHandlePacket{
ID: pkt.id(),
Handle: r.handle,
}
}
// wrap FileReader handler
func fileget(h FileReader, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket {
rd := r.getReaderAt()
if rd == nil {
return statusFromError(pkt.id(), errors.New("unexpected read packet"))
}
data, offset, _ := packetData(pkt, alloc, orderID)
n, err := rd.ReadAt(data, offset)
// only return EOF error if no data left to read
if err != nil && (err != io.EOF || n == 0) {
return statusFromError(pkt.id(), err)
}
return &sshFxpDataPacket{
ID: pkt.id(),
Length: uint32(n),
Data: data[:n],
}
}
// wrap FileWriter handler
func fileput(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket {
wr := r.getWriterAt()
if wr == nil {
return statusFromError(pkt.id(), errors.New("unexpected write packet"))
}
data, offset, _ := packetData(pkt, alloc, orderID)
_, err := wr.WriteAt(data, offset)
return statusFromError(pkt.id(), err)
}
// wrap OpenFileWriter handler
func fileputget(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket {
rw := r.getWriterAtReaderAt()
if rw == nil {
return statusFromError(pkt.id(), errors.New("unexpected write and read packet"))
}
switch p := pkt.(type) {
case *sshFxpReadPacket:
data, offset := p.getDataSlice(alloc, orderID), int64(p.Offset)
n, err := rw.ReadAt(data, offset)
// only return EOF error if no data left to read
if err != nil && (err != io.EOF || n == 0) {
return statusFromError(pkt.id(), err)
}
return &sshFxpDataPacket{
ID: pkt.id(),
Length: uint32(n),
Data: data[:n],
}
case *sshFxpWritePacket:
data, offset := p.Data, int64(p.Offset)
_, err := rw.WriteAt(data, offset)
return statusFromError(pkt.id(), err)
default:
return statusFromError(pkt.id(), errors.New("unexpected packet type for read or write"))
}
}
// file data for additional read/write packets
func packetData(p requestPacket, alloc *allocator, orderID uint32) (data []byte, offset int64, length uint32) {
switch p := p.(type) {
case *sshFxpReadPacket:
return p.getDataSlice(alloc, orderID), int64(p.Offset), p.Len
case *sshFxpWritePacket:
return p.Data, int64(p.Offset), p.Length
}
return
}
// wrap FileCmder handler
func filecmd(h FileCmder, r *Request, pkt requestPacket) responsePacket {
switch p := pkt.(type) {
case *sshFxpFsetstatPacket:
r.Flags = p.Flags
r.Attrs = p.Attrs.([]byte)
}
switch r.Method {
case "PosixRename":
if posixRenamer, ok := h.(PosixRenameFileCmder); ok {
err := posixRenamer.PosixRename(r)
return statusFromError(pkt.id(), err)
}
// PosixRenameFileCmder not implemented handle this request as a Rename
r.Method = "Rename"
err := h.Filecmd(r)
return statusFromError(pkt.id(), err)
case "StatVFS":
if statVFSCmdr, ok := h.(StatVFSFileCmder); ok {
stat, err := statVFSCmdr.StatVFS(r)
if err != nil {
return statusFromError(pkt.id(), err)
}
stat.ID = pkt.id()
return stat
}
return statusFromError(pkt.id(), ErrSSHFxOpUnsupported)
}
err := h.Filecmd(r)
return statusFromError(pkt.id(), err)
}
// wrap FileLister handler
func filelist(h FileLister, r *Request, pkt requestPacket) responsePacket {
lister := r.getListerAt()
if lister == nil {
return statusFromError(pkt.id(), errors.New("unexpected dir packet"))
}
offset := r.lsNext()
finfo := make([]os.FileInfo, MaxFilelist)
n, err := lister.ListAt(finfo, offset)
r.lsInc(int64(n))
// ignore EOF as we only return it when there are no results
finfo = finfo[:n] // avoid need for nil tests below
switch r.Method {
case "List":
if err != nil && (err != io.EOF || n == 0) {
return statusFromError(pkt.id(), err)
}
nameAttrs := make([]*sshFxpNameAttr, 0, len(finfo))
// If the type conversion fails, we get untyped `nil`,
// which is handled by not looking up any names.
idLookup, _ := h.(NameLookupFileLister)
for _, fi := range finfo {
nameAttrs = append(nameAttrs, &sshFxpNameAttr{
Name: fi.Name(),
LongName: runLs(idLookup, fi),
Attrs: []any{fi},
})
}
return &sshFxpNamePacket{
ID: pkt.id(),
NameAttrs: nameAttrs,
}
default:
err = fmt.Errorf("unexpected method: %s", r.Method)
return statusFromError(pkt.id(), err)
}
}
func filestat(h FileLister, r *Request, pkt requestPacket) responsePacket {
var lister ListerAt
var err error
if r.Method == "Lstat" {
if lstatFileLister, ok := h.(LstatFileLister); ok {
lister, err = lstatFileLister.Lstat(r)
} else {
// LstatFileLister not implemented handle this request as a Stat
r.Method = "Stat"
lister, err = h.Filelist(r)
}
} else {
lister, err = h.Filelist(r)
}
if err != nil {
return statusFromError(pkt.id(), err)
}
finfo := make([]os.FileInfo, 1)
n, err := lister.ListAt(finfo, 0)
finfo = finfo[:n] // avoid need for nil tests below
switch r.Method {
case "Stat", "Lstat":
if err != nil && err != io.EOF {
return statusFromError(pkt.id(), err)
}
if n == 0 {
err = &os.PathError{
Op: strings.ToLower(r.Method),
Path: r.Filepath,
Err: syscall.ENOENT,
}
return statusFromError(pkt.id(), err)
}
return &sshFxpStatResponse{
ID: pkt.id(),
info: finfo[0],
}
case "Readlink":
if err != nil && err != io.EOF {
return statusFromError(pkt.id(), err)
}
if n == 0 {
err = &os.PathError{
Op: "readlink",
Path: r.Filepath,
Err: syscall.ENOENT,
}
return statusFromError(pkt.id(), err)
}
filename := finfo[0].Name()
return &sshFxpNamePacket{
ID: pkt.id(),
NameAttrs: []*sshFxpNameAttr{
{
Name: filename,
LongName: filename,
Attrs: emptyFileStat,
},
},
}
default:
err = fmt.Errorf("unexpected method: %s", r.Method)
return statusFromError(pkt.id(), err)
}
}
func readlink(readlinkFileLister ReadlinkFileLister, r *Request, pkt requestPacket) responsePacket {
resolved, err := readlinkFileLister.Readlink(r.Filepath)
if err != nil {
return statusFromError(pkt.id(), err)
}
return &sshFxpNamePacket{
ID: pkt.id(),
NameAttrs: []*sshFxpNameAttr{
{
Name: resolved,
LongName: resolved,
Attrs: emptyFileStat,
},
},
}
}
// init attributes of request object from packet data
func requestMethod(p requestPacket) (method string) {
switch p.(type) {
case *sshFxpReadPacket, *sshFxpWritePacket, *sshFxpOpenPacket:
// set in open() above
case *sshFxpOpendirPacket, *sshFxpReaddirPacket:
// set in opendir() above
case *sshFxpSetstatPacket, *sshFxpFsetstatPacket:
method = "Setstat"
case *sshFxpRenamePacket:
method = "Rename"
case *sshFxpSymlinkPacket:
method = "Symlink"
case *sshFxpRemovePacket:
method = "Remove"
case *sshFxpStatPacket, *sshFxpFstatPacket:
method = "Stat"
case *sshFxpLstatPacket:
method = "Lstat"
case *sshFxpRmdirPacket:
method = "Rmdir"
case *sshFxpReadlinkPacket:
method = "Readlink"
case *sshFxpMkdirPacket:
method = "Mkdir"
case *sshFxpExtendedPacketHardlink:
method = "Link"
}
return method
}
================================================
FILE: gossh/sftp/request_windows.go
================================================
package sftp
import (
"syscall"
)
func fakeFileInfoSys() any {
return syscall.Win32FileAttributeData{}
}
func testOsSys(sys any) error {
return nil
}
================================================
FILE: gossh/sftp/server.go
================================================
package sftp
// sftp server counterpart
import (
"encoding"
"errors"
"fmt"
"io"
"io/ioutil"
"os"
"path/filepath"
"strconv"
"sync"
"syscall"
"time"
)
const (
// SftpServerWorkerCount defines the number of workers for the SFTP server
SftpServerWorkerCount = 8
)
// Server is an SSH File Transfer Protocol (sftp) server.
// This is intended to provide the sftp subsystem to an ssh server daemon.
// This implementation currently supports most of sftp server protocol version 3,
// as specified at https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt
type Server struct {
*serverConn
debugStream io.Writer
readOnly bool
pktMgr *packetManager
openFiles map[string]*os.File
openFilesLock sync.RWMutex
handleCount int
workDir string
}
func (svr *Server) nextHandle(f *os.File) string {
svr.openFilesLock.Lock()
defer svr.openFilesLock.Unlock()
svr.handleCount++
handle := strconv.Itoa(svr.handleCount)
svr.openFiles[handle] = f
return handle
}
func (svr *Server) closeHandle(handle string) error {
svr.openFilesLock.Lock()
defer svr.openFilesLock.Unlock()
if f, ok := svr.openFiles[handle]; ok {
delete(svr.openFiles, handle)
return f.Close()
}
return EBADF
}
func (svr *Server) getHandle(handle string) (*os.File, bool) {
svr.openFilesLock.RLock()
defer svr.openFilesLock.RUnlock()
f, ok := svr.openFiles[handle]
return f, ok
}
type serverRespondablePacket interface {
encoding.BinaryUnmarshaler
id() uint32
respond(svr *Server) responsePacket
}
// NewServer creates a new Server instance around the provided streams, serving
// content from the root of the filesystem. Optionally, ServerOption
// functions may be specified to further configure the Server.
//
// A subsequent call to Serve() is required to begin serving files over SFTP.
func NewServer(rwc io.ReadWriteCloser, options ...ServerOption) (*Server, error) {
svrConn := &serverConn{
conn: conn{
Reader: rwc,
WriteCloser: rwc,
},
}
s := &Server{
serverConn: svrConn,
debugStream: ioutil.Discard,
pktMgr: newPktMgr(svrConn),
openFiles: make(map[string]*os.File),
}
for _, o := range options {
if err := o(s); err != nil {
return nil, err
}
}
return s, nil
}
// A ServerOption is a function which applies configuration to a Server.
type ServerOption func(*Server) error
// WithDebug enables Server debugging output to the supplied io.Writer.
func WithDebug(w io.Writer) ServerOption {
return func(s *Server) error {
s.debugStream = w
return nil
}
}
// ReadOnly configures a Server to serve files in read-only mode.
func ReadOnly() ServerOption {
return func(s *Server) error {
s.readOnly = true
return nil
}
}
// WithAllocator enable the allocator.
// After processing a packet we keep in memory the allocated slices
// and we reuse them for new packets.
// The allocator is experimental
func WithAllocator() ServerOption {
return func(s *Server) error {
alloc := newAllocator()
s.pktMgr.alloc = alloc
s.conn.alloc = alloc
return nil
}
}
// WithServerWorkingDirectory sets a working directory to use as base
// for relative paths.
// If unset the default is current working directory (os.Getwd).
func WithServerWorkingDirectory(workDir string) ServerOption {
return func(s *Server) error {
s.workDir = cleanPath(workDir)
return nil
}
}
type rxPacket struct {
pktType fxp
pktBytes []byte
}
// Up to N parallel servers
func (svr *Server) sftpServerWorker(pktChan chan orderedRequest) error {
for pkt := range pktChan {
// readonly checks
readonly := true
switch pkt := pkt.requestPacket.(type) {
case notReadOnly:
readonly = false
case *sshFxpOpenPacket:
readonly = pkt.readonly()
case *sshFxpExtendedPacket:
readonly = pkt.readonly()
}
// If server is operating read-only and a write operation is requested,
// return permission denied
if !readonly && svr.readOnly {
svr.pktMgr.readyPacket(
svr.pktMgr.newOrderedResponse(statusFromError(pkt.id(), syscall.EPERM), pkt.orderID()),
)
continue
}
if err := handlePacket(svr, pkt); err != nil {
return err
}
}
return nil
}
func handlePacket(s *Server, p orderedRequest) error {
var rpkt responsePacket
orderID := p.orderID()
switch p := p.requestPacket.(type) {
case *sshFxInitPacket:
rpkt = &sshFxVersionPacket{
Version: sftpProtocolVersion,
Extensions: sftpExtensions,
}
case *sshFxpStatPacket:
// stat the requested file
info, err := os.Stat(s.toLocalPath(p.Path))
rpkt = &sshFxpStatResponse{
ID: p.ID,
info: info,
}
if err != nil {
rpkt = statusFromError(p.ID, err)
}
case *sshFxpLstatPacket:
// stat the requested file
info, err := os.Lstat(s.toLocalPath(p.Path))
rpkt = &sshFxpStatResponse{
ID: p.ID,
info: info,
}
if err != nil {
rpkt = statusFromError(p.ID, err)
}
case *sshFxpFstatPacket:
f, ok := s.getHandle(p.Handle)
var err error = EBADF
var info os.FileInfo
if ok {
info, err = f.Stat()
rpkt = &sshFxpStatResponse{
ID: p.ID,
info: info,
}
}
if err != nil {
rpkt = statusFromError(p.ID, err)
}
case *sshFxpMkdirPacket:
// TODO FIXME: ignore flags field
err := os.Mkdir(s.toLocalPath(p.Path), 0o755)
rpkt = statusFromError(p.ID, err)
case *sshFxpRmdirPacket:
err := os.Remove(s.toLocalPath(p.Path))
rpkt = statusFromError(p.ID, err)
case *sshFxpRemovePacket:
err := os.Remove(s.toLocalPath(p.Filename))
rpkt = statusFromError(p.ID, err)
case *sshFxpRenamePacket:
err := os.Rename(s.toLocalPath(p.Oldpath), s.toLocalPath(p.Newpath))
rpkt = statusFromError(p.ID, err)
case *sshFxpSymlinkPacket:
err := os.Symlink(s.toLocalPath(p.Targetpath), s.toLocalPath(p.Linkpath))
rpkt = statusFromError(p.ID, err)
case *sshFxpClosePacket:
rpkt = statusFromError(p.ID, s.closeHandle(p.Handle))
case *sshFxpReadlinkPacket:
f, err := os.Readlink(s.toLocalPath(p.Path))
rpkt = &sshFxpNamePacket{
ID: p.ID,
NameAttrs: []*sshFxpNameAttr{
{
Name: f,
LongName: f,
Attrs: emptyFileStat,
},
},
}
if err != nil {
rpkt = statusFromError(p.ID, err)
}
case *sshFxpRealpathPacket:
f, err := filepath.Abs(s.toLocalPath(p.Path))
f = cleanPath(f)
rpkt = &sshFxpNamePacket{
ID: p.ID,
NameAttrs: []*sshFxpNameAttr{
{
Name: f,
LongName: f,
Attrs: emptyFileStat,
},
},
}
if err != nil {
rpkt = statusFromError(p.ID, err)
}
case *sshFxpOpendirPacket:
lp := s.toLocalPath(p.Path)
if stat, err := os.Stat(lp); err != nil {
rpkt = statusFromError(p.ID, err)
} else if !stat.IsDir() {
rpkt = statusFromError(p.ID, &os.PathError{
Path: lp, Err: syscall.ENOTDIR,
})
} else {
rpkt = (&sshFxpOpenPacket{
ID: p.ID,
Path: p.Path,
Pflags: sshFxfRead,
}).respond(s)
}
case *sshFxpReadPacket:
var err error = EBADF
f, ok := s.getHandle(p.Handle)
if ok {
err = nil
data := p.getDataSlice(s.pktMgr.alloc, orderID)
n, _err := f.ReadAt(data, int64(p.Offset))
if _err != nil && (_err != io.EOF || n == 0) {
err = _err
}
rpkt = &sshFxpDataPacket{
ID: p.ID,
Length: uint32(n),
Data: data[:n],
// do not use data[:n:n] here to clamp the capacity, we allocated extra capacity above to avoid reallocations
}
}
if err != nil {
rpkt = statusFromError(p.ID, err)
}
case *sshFxpWritePacket:
f, ok := s.getHandle(p.Handle)
var err error = EBADF
if ok {
_, err = f.WriteAt(p.Data, int64(p.Offset))
}
rpkt = statusFromError(p.ID, err)
case *sshFxpExtendedPacket:
if p.SpecificPacket == nil {
rpkt = statusFromError(p.ID, ErrSSHFxOpUnsupported)
} else {
rpkt = p.respond(s)
}
case serverRespondablePacket:
rpkt = p.respond(s)
default:
return fmt.Errorf("unexpected packet type %T", p)
}
s.pktMgr.readyPacket(s.pktMgr.newOrderedResponse(rpkt, orderID))
return nil
}
// Serve serves SFTP connections until the streams stop or the SFTP subsystem
// is stopped. It returns nil if the server exits cleanly.
func (svr *Server) Serve() error {
defer func() {
if svr.pktMgr.alloc != nil {
svr.pktMgr.alloc.Free()
}
}()
var wg sync.WaitGroup
runWorker := func(ch chan orderedRequest) {
wg.Add(1)
go func() {
defer wg.Done()
if err := svr.sftpServerWorker(ch); err != nil {
svr.conn.Close() // shuts down recvPacket
}
}()
}
pktChan := svr.pktMgr.workerChan(runWorker)
var err error
var pkt requestPacket
var pktType uint8
var pktBytes []byte
for {
pktType, pktBytes, err = svr.serverConn.recvPacket(svr.pktMgr.getNextOrderID())
if err != nil {
// Check whether the connection terminated cleanly in-between packets.
if err == io.EOF {
err = nil
}
// we don't care about releasing allocated pages here, the server will quit and the allocator freed
break
}
pkt, err = makePacket(rxPacket{fxp(pktType), pktBytes})
if err != nil {
switch {
case errors.Is(err, errUnknownExtendedPacket):
//if err := svr.serverConn.sendError(pkt, ErrSshFxOpUnsupported); err != nil {
// debug("failed to send err packet: %v", err)
// svr.conn.Close() // shuts down recvPacket
// break
//}
default:
debug("makePacket err: %v", err)
svr.conn.Close() // shuts down recvPacket
break
}
}
pktChan <- svr.pktMgr.newOrderedRequest(pkt)
}
close(pktChan) // shuts down sftpServerWorkers
wg.Wait() // wait for all workers to exit
// close any still-open files
for handle, file := range svr.openFiles {
fmt.Fprintf(svr.debugStream, "sftp server file with handle %q left open: %v\n", handle, file.Name())
file.Close()
}
return err // error from recvPacket
}
type ider interface {
id() uint32
}
// The init packet has no ID, so we just return a zero-value ID
func (p *sshFxInitPacket) id() uint32 { return 0 }
type sshFxpStatResponse struct {
ID uint32
info os.FileInfo
}
func (p *sshFxpStatResponse) marshalPacket() ([]byte, []byte, error) {
l := 4 + 1 + 4 // uint32(length) + byte(type) + uint32(id)
b := make([]byte, 4, l)
b = append(b, sshFxpAttrs)
b = marshalUint32(b, p.ID)
var payload []byte
payload = marshalFileInfo(payload, p.info)
return b, payload, nil
}
func (p *sshFxpStatResponse) MarshalBinary() ([]byte, error) {
header, payload, err := p.marshalPacket()
return append(header, payload...), err
}
var emptyFileStat = []any{uint32(0)}
func (p *sshFxpOpenPacket) readonly() bool {
return !p.hasPflags(sshFxfWrite)
}
func (p *sshFxpOpenPacket) hasPflags(flags ...uint32) bool {
for _, f := range flags {
if p.Pflags&f == 0 {
return false
}
}
return true
}
func (p *sshFxpOpenPacket) respond(svr *Server) responsePacket {
var osFlags int
if p.hasPflags(sshFxfRead, sshFxfWrite) {
osFlags |= os.O_RDWR
} else if p.hasPflags(sshFxfWrite) {
osFlags |= os.O_WRONLY
} else if p.hasPflags(sshFxfRead) {
osFlags |= os.O_RDONLY
} else {
// how are they opening?
return statusFromError(p.ID, syscall.EINVAL)
}
// Don't use O_APPEND flag as it conflicts with WriteAt.
// The sshFxfAppend flag is a no-op here as the client sends the offsets.
if p.hasPflags(sshFxfCreat) {
osFlags |= os.O_CREATE
}
if p.hasPflags(sshFxfTrunc) {
osFlags |= os.O_TRUNC
}
if p.hasPflags(sshFxfExcl) {
osFlags |= os.O_EXCL
}
mode := os.FileMode(0o644)
// Like OpenSSH, we only handle permissions here, and only when the file is being created.
// Otherwise, the permissions are ignored.
if p.Flags&sshFileXferAttrPermissions != 0 {
fs, err := p.unmarshalFileStat(p.Flags)
if err != nil {
return statusFromError(p.ID, err)
}
mode = fs.FileMode() & os.ModePerm
}
f, err := os.OpenFile(svr.toLocalPath(p.Path), osFlags, mode)
if err != nil {
return statusFromError(p.ID, err)
}
handle := svr.nextHandle(f)
return &sshFxpHandlePacket{ID: p.ID, Handle: handle}
}
func (p *sshFxpReaddirPacket) respond(svr *Server) responsePacket {
f, ok := svr.getHandle(p.Handle)
if !ok {
return statusFromError(p.ID, EBADF)
}
dirents, err := f.Readdir(128)
if err != nil {
return statusFromError(p.ID, err)
}
idLookup := osIDLookup{}
ret := &sshFxpNamePacket{ID: p.ID}
for _, dirent := range dirents {
ret.NameAttrs = append(ret.NameAttrs, &sshFxpNameAttr{
Name: dirent.Name(),
LongName: runLs(idLookup, dirent),
Attrs: []any{dirent},
})
}
return ret
}
func (p *sshFxpSetstatPacket) respond(svr *Server) responsePacket {
path := svr.toLocalPath(p.Path)
debug("setstat name %q", path)
fs, err := p.unmarshalFileStat(p.Flags)
if err == nil && (p.Flags&sshFileXferAttrSize) != 0 {
err = os.Truncate(path, int64(fs.Size))
}
if err == nil && (p.Flags&sshFileXferAttrPermissions) != 0 {
err = os.Chmod(path, fs.FileMode())
}
if err == nil && (p.Flags&sshFileXferAttrUIDGID) != 0 {
err = os.Chown(path, int(fs.UID), int(fs.GID))
}
if err == nil && (p.Flags&sshFileXferAttrACmodTime) != 0 {
err = os.Chtimes(path, fs.AccessTime(), fs.ModTime())
}
return statusFromError(p.ID, err)
}
func (p *sshFxpFsetstatPacket) respond(svr *Server) responsePacket {
f, ok := svr.getHandle(p.Handle)
if !ok {
return statusFromError(p.ID, EBADF)
}
path := f.Name()
debug("fsetstat name %q", path)
fs, err := p.unmarshalFileStat(p.Flags)
if err == nil && (p.Flags&sshFileXferAttrSize) != 0 {
err = f.Truncate(int64(fs.Size))
}
if err == nil && (p.Flags&sshFileXferAttrPermissions) != 0 {
err = f.Chmod(fs.FileMode())
}
if err == nil && (p.Flags&sshFileXferAttrUIDGID) != 0 {
err = f.Chown(int(fs.UID), int(fs.GID))
}
if err == nil && (p.Flags&sshFileXferAttrACmodTime) != 0 {
type chtimer interface {
Chtimes(atime, mtime time.Time) error
}
switch f := any(f).(type) {
case chtimer:
// future-compatible, for when/if *os.File supports Chtimes.
err = f.Chtimes(fs.AccessTime(), fs.ModTime())
default:
err = os.Chtimes(path, fs.AccessTime(), fs.ModTime())
}
}
return statusFromError(p.ID, err)
}
func statusFromError(id uint32, err error) *sshFxpStatusPacket {
ret := &sshFxpStatusPacket{
ID: id,
StatusError: StatusError{
// sshFXOk = 0
// sshFXEOF = 1
// sshFXNoSuchFile = 2 ENOENT
// sshFXPermissionDenied = 3
// sshFXFailure = 4
// sshFXBadMessage = 5
// sshFXNoConnection = 6
// sshFXConnectionLost = 7
// sshFXOPUnsupported = 8
Code: sshFxOk,
},
}
if err == nil {
return ret
}
debug("statusFromError: error is %T %#v", err, err)
ret.StatusError.Code = sshFxFailure
ret.StatusError.msg = err.Error()
if os.IsNotExist(err) {
ret.StatusError.Code = sshFxNoSuchFile
return ret
}
if code, ok := translateSyscallError(err); ok {
ret.StatusError.Code = code
return ret
}
if errors.Is(err, io.EOF) {
ret.StatusError.Code = sshFxEOF
return ret
}
var e fxerr
if errors.As(err, &e) {
ret.StatusError.Code = uint32(e)
return ret
}
return ret
}
================================================
FILE: gossh/sftp/server_plan9.go
================================================
package sftp
import (
"path"
"path/filepath"
)
func (s *Server) toLocalPath(p string) string {
if s.workDir != "" && !path.IsAbs(p) {
p = path.Join(s.workDir, p)
}
lp := filepath.FromSlash(p)
if path.IsAbs(p) {
tmp := lp[1:]
if filepath.IsAbs(tmp) {
// If the FromSlash without any starting slashes is absolute,
// then we have a filepath encoded with a prefix '/'.
// e.g. "/#s/boot" to "#s/boot"
return tmp
}
}
return lp
}
================================================
FILE: gossh/sftp/server_statvfs_darwin.go
================================================
package sftp
import (
"syscall"
)
func statvfsFromStatfst(stat *syscall.Statfs_t) (*StatVFS, error) {
return &StatVFS{
Bsize: uint64(stat.Bsize),
Frsize: uint64(stat.Bsize), // fragment size is a linux thing; use block size here
Blocks: stat.Blocks,
Bfree: stat.Bfree,
Bavail: stat.Bavail,
Files: stat.Files,
Ffree: stat.Ffree,
Favail: stat.Ffree, // not sure how to calculate Favail
Fsid: uint64(uint64(stat.Fsid.Val[1])<<32 | uint64(stat.Fsid.Val[0])), // endianness?
Flag: uint64(stat.Flags), // assuming POSIX?
Namemax: 1024, // man 2 statfs shows: #define MAXPATHLEN 1024
}, nil
}
================================================
FILE: gossh/sftp/server_statvfs_impl.go
================================================
//go:build darwin || linux
// +build darwin linux
// fill in statvfs structure with OS specific values
// Statfs_t is different per-kernel, and only exists on some unixes (not Solaris for instance)
package sftp
import (
"syscall"
)
func (p *sshFxpExtendedPacketStatVFS) respond(svr *Server) responsePacket {
retPkt, err := getStatVFSForPath(p.Path)
if err != nil {
return statusFromError(p.ID, err)
}
retPkt.ID = p.ID
return retPkt
}
func getStatVFSForPath(name string) (*StatVFS, error) {
var stat syscall.Statfs_t
if err := syscall.Statfs(name, &stat); err != nil {
return nil, err
}
return statvfsFromStatfst(&stat)
}
================================================
FILE: gossh/sftp/server_statvfs_linux.go
================================================
//go:build linux
// +build linux
package sftp
import (
"syscall"
)
func statvfsFromStatfst(stat *syscall.Statfs_t) (*StatVFS, error) {
return &StatVFS{
Bsize: uint64(stat.Bsize),
Frsize: uint64(stat.Frsize),
Blocks: stat.Blocks,
Bfree: stat.Bfree,
Bavail: stat.Bavail,
Files: stat.Files,
Ffree: stat.Ffree,
Favail: stat.Ffree, // not sure how to calculate Favail
Flag: uint64(stat.Flags), // assuming POSIX?
Namemax: uint64(stat.Namelen),
}, nil
}
================================================
FILE: gossh/sftp/server_statvfs_plan9.go
================================================
package sftp
import (
"syscall"
)
func (p *sshFxpExtendedPacketStatVFS) respond(svr *Server) responsePacket {
return statusFromError(p.ID, syscall.EPLAN9)
}
func getStatVFSForPath(name string) (*StatVFS, error) {
return nil, syscall.EPLAN9
}
================================================
FILE: gossh/sftp/server_statvfs_stubs.go
================================================
//go:build !darwin && !linux && !plan9
// +build !darwin,!linux,!plan9
package sftp
import (
"syscall"
)
func (p *sshFxpExtendedPacketStatVFS) respond(svr *Server) responsePacket {
return statusFromError(p.ID, syscall.ENOTSUP)
}
func getStatVFSForPath(name string) (*StatVFS, error) {
return nil, syscall.ENOTSUP
}
================================================
FILE: gossh/sftp/server_unix.go
================================================
//go:build !windows && !plan9
// +build !windows,!plan9
package sftp
import (
"path"
)
func (s *Server) toLocalPath(p string) string {
if s.workDir != "" && !path.IsAbs(p) {
p = path.Join(s.workDir, p)
}
return p
}
================================================
FILE: gossh/sftp/server_windows.go
================================================
package sftp
import (
"path"
"path/filepath"
)
func (s *Server) toLocalPath(p string) string {
if s.workDir != "" && !path.IsAbs(p) {
p = path.Join(s.workDir, p)
}
lp := filepath.FromSlash(p)
if path.IsAbs(p) {
tmp := lp
for len(tmp) > 0 && tmp[0] == '\\' {
tmp = tmp[1:]
}
if filepath.IsAbs(tmp) {
// If the FromSlash without any starting slashes is absolute,
// then we have a filepath encoded with a prefix '/'.
// e.g. "/C:/Windows" to "C:\\Windows"
return tmp
}
tmp += "\\"
if filepath.IsAbs(tmp) {
// If the FromSlash without any starting slashes but with extra end slash is absolute,
// then we have a filepath encoded with a prefix '/' and a dropped '/' at the end.
// e.g. "/C:" to "C:\\"
return tmp
}
}
return lp
}
================================================
FILE: gossh/sftp/sftp.go
================================================
// Package sftp implements the SSH File Transfer Protocol as described in
// https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt
package sftp
import (
"fmt"
)
const (
sshFxpInit = 1
sshFxpVersion = 2
sshFxpOpen = 3
sshFxpClose = 4
sshFxpRead = 5
sshFxpWrite = 6
sshFxpLstat = 7
sshFxpFstat = 8
sshFxpSetstat = 9
sshFxpFsetstat = 10
sshFxpOpendir = 11
sshFxpReaddir = 12
sshFxpRemove = 13
sshFxpMkdir = 14
sshFxpRmdir = 15
sshFxpRealpath = 16
sshFxpStat = 17
sshFxpRename = 18
sshFxpReadlink = 19
sshFxpSymlink = 20
sshFxpStatus = 101
sshFxpHandle = 102
sshFxpData = 103
sshFxpName = 104
sshFxpAttrs = 105
sshFxpExtended = 200
sshFxpExtendedReply = 201
)
const (
sshFxOk = 0
sshFxEOF = 1
sshFxNoSuchFile = 2
sshFxPermissionDenied = 3
sshFxFailure = 4
sshFxBadMessage = 5
sshFxNoConnection = 6
sshFxConnectionLost = 7
sshFxOPUnsupported = 8
// see draft-ietf-secsh-filexfer-13
// https://tools.ietf.org/html/draft-ietf-secsh-filexfer-13#section-9.1
sshFxInvalidHandle = 9
sshFxNoSuchPath = 10
sshFxFileAlreadyExists = 11
sshFxWriteProtect = 12
sshFxNoMedia = 13
sshFxNoSpaceOnFilesystem = 14
sshFxQuotaExceeded = 15
sshFxUnknownPrincipal = 16
sshFxLockConflict = 17
sshFxDirNotEmpty = 18
sshFxNotADirectory = 19
sshFxInvalidFilename = 20
sshFxLinkLoop = 21
sshFxCannotDelete = 22
sshFxInvalidParameter = 23
sshFxFileIsADirectory = 24
sshFxByteRangeLockConflict = 25
sshFxByteRangeLockRefused = 26
sshFxDeletePending = 27
sshFxFileCorrupt = 28
sshFxOwnerInvalid = 29
sshFxGroupInvalid = 30
sshFxNoMatchingByteRangeLock = 31
)
const (
sshFxfRead = 0x00000001
sshFxfWrite = 0x00000002
sshFxfAppend = 0x00000004
sshFxfCreat = 0x00000008
sshFxfTrunc = 0x00000010
sshFxfExcl = 0x00000020
)
var (
// supportedSFTPExtensions defines the supported extensions
supportedSFTPExtensions = []sshExtensionPair{
{"hardlink@openssh.com", "1"},
{"posix-rename@openssh.com", "1"},
{"statvfs@openssh.com", "2"},
}
sftpExtensions = supportedSFTPExtensions
)
type fxp uint8
func (f fxp) String() string {
switch f {
case sshFxpInit:
return "SSH_FXP_INIT"
case sshFxpVersion:
return "SSH_FXP_VERSION"
case sshFxpOpen:
return "SSH_FXP_OPEN"
case sshFxpClose:
return "SSH_FXP_CLOSE"
case sshFxpRead:
return "SSH_FXP_READ"
case sshFxpWrite:
return "SSH_FXP_WRITE"
case sshFxpLstat:
return "SSH_FXP_LSTAT"
case sshFxpFstat:
return "SSH_FXP_FSTAT"
case sshFxpSetstat:
return "SSH_FXP_SETSTAT"
case sshFxpFsetstat:
return "SSH_FXP_FSETSTAT"
case sshFxpOpendir:
return "SSH_FXP_OPENDIR"
case sshFxpReaddir:
return "SSH_FXP_READDIR"
case sshFxpRemove:
return "SSH_FXP_REMOVE"
case sshFxpMkdir:
return "SSH_FXP_MKDIR"
case sshFxpRmdir:
return "SSH_FXP_RMDIR"
case sshFxpRealpath:
return "SSH_FXP_REALPATH"
case sshFxpStat:
return "SSH_FXP_STAT"
case sshFxpRename:
return "SSH_FXP_RENAME"
case sshFxpReadlink:
return "SSH_FXP_READLINK"
case sshFxpSymlink:
return "SSH_FXP_SYMLINK"
case sshFxpStatus:
return "SSH_FXP_STATUS"
case sshFxpHandle:
return "SSH_FXP_HANDLE"
case sshFxpData:
return "SSH_FXP_DATA"
case sshFxpName:
return "SSH_FXP_NAME"
case sshFxpAttrs:
return "SSH_FXP_ATTRS"
case sshFxpExtended:
return "SSH_FXP_EXTENDED"
case sshFxpExtendedReply:
return "SSH_FXP_EXTENDED_REPLY"
default:
return "unknown"
}
}
type fx uint8
func (f fx) String() string {
switch f {
case sshFxOk:
return "SSH_FX_OK"
case sshFxEOF:
return "SSH_FX_EOF"
case sshFxNoSuchFile:
return "SSH_FX_NO_SUCH_FILE"
case sshFxPermissionDenied:
return "SSH_FX_PERMISSION_DENIED"
case sshFxFailure:
return "SSH_FX_FAILURE"
case sshFxBadMessage:
return "SSH_FX_BAD_MESSAGE"
case sshFxNoConnection:
return "SSH_FX_NO_CONNECTION"
case sshFxConnectionLost:
return "SSH_FX_CONNECTION_LOST"
case sshFxOPUnsupported:
return "SSH_FX_OP_UNSUPPORTED"
default:
return "unknown"
}
}
type unexpectedPacketErr struct {
want, got uint8
}
func (u *unexpectedPacketErr) Error() string {
return fmt.Sprintf("sftp: unexpected packet: want %v, got %v", fxp(u.want), fxp(u.got))
}
func unimplementedPacketErr(u uint8) error {
return fmt.Errorf("sftp: unimplemented packet type: got %v", fxp(u))
}
type unexpectedIDErr struct{ want, got uint32 }
func (u *unexpectedIDErr) Error() string {
return fmt.Sprintf("sftp: unexpected id: want %d, got %d", u.want, u.got)
}
func unimplementedSeekWhence(whence int) error {
return fmt.Errorf("sftp: unimplemented seek whence %d", whence)
}
func unexpectedCount(want, got uint32) error {
return fmt.Errorf("sftp: unexpected count: want %d, got %d", want, got)
}
type unexpectedVersionErr struct{ want, got uint32 }
func (u *unexpectedVersionErr) Error() string {
return fmt.Sprintf("sftp: unexpected server version: want %v, got %v", u.want, u.got)
}
// A StatusError is returned when an SFTP operation fails, and provides
// additional information about the failure.
type StatusError struct {
Code uint32
msg, lang string
}
func (s *StatusError) Error() string {
return fmt.Sprintf("sftp: %q (%v)", s.msg, fx(s.Code))
}
// FxCode returns the error code typed to match against the exported codes
func (s *StatusError) FxCode() fxerr {
return fxerr(s.Code)
}
func getSupportedExtensionByName(extensionName string) (sshExtensionPair, error) {
for _, supportedExtension := range supportedSFTPExtensions {
if supportedExtension.Name == extensionName {
return supportedExtension, nil
}
}
return sshExtensionPair{}, fmt.Errorf("unsupported extension: %s", extensionName)
}
// SetSFTPExtensions allows to customize the supported server extensions.
// See the variable supportedSFTPExtensions for supported extensions.
// This method accepts a slice of sshExtensionPair names for example 'hardlink@openssh.com'.
// If an invalid extension is given an error will be returned and nothing will be changed
func SetSFTPExtensions(extensions ...string) error {
tempExtensions := []sshExtensionPair{}
for _, extension := range extensions {
sftpExtension, err := getSupportedExtensionByName(extension)
if err != nil {
return err
}
tempExtensions = append(tempExtensions, sftpExtension)
}
sftpExtensions = tempExtensions
return nil
}
================================================
FILE: gossh/sftp/stat_plan9.go
================================================
package sftp
import (
"os"
"syscall"
)
var EBADF = syscall.NewError("fd out of range or not open")
func wrapPathError(filepath string, err error) error {
if errno, ok := err.(syscall.ErrorString); ok {
return &os.PathError{Path: filepath, Err: errno}
}
return err
}
// translateErrno translates a syscall error number to a SFTP error code.
func translateErrno(errno syscall.ErrorString) uint32 {
switch errno {
case "":
return sshFxOk
case syscall.ENOENT:
return sshFxNoSuchFile
case syscall.EPERM:
return sshFxPermissionDenied
}
return sshFxFailure
}
func translateSyscallError(err error) (uint32, bool) {
switch e := err.(type) {
case syscall.ErrorString:
return translateErrno(e), true
case *os.PathError:
debug("statusFromError,pathError: error is %T %#v", e.Err, e.Err)
if errno, ok := e.Err.(syscall.ErrorString); ok {
return translateErrno(errno), true
}
}
return 0, false
}
// isRegular returns true if the mode describes a regular file.
func isRegular(mode uint32) bool {
return mode&S_IFMT == syscall.S_IFREG
}
// toFileMode converts sftp filemode bits to the os.FileMode specification
func toFileMode(mode uint32) os.FileMode {
var fm = os.FileMode(mode & 0777)
switch mode & S_IFMT {
case syscall.S_IFBLK:
fm |= os.ModeDevice
case syscall.S_IFCHR:
fm |= os.ModeDevice | os.ModeCharDevice
case syscall.S_IFDIR:
fm |= os.ModeDir
case syscall.S_IFIFO:
fm |= os.ModeNamedPipe
case syscall.S_IFLNK:
fm |= os.ModeSymlink
case syscall.S_IFREG:
// nothing to do
case syscall.S_IFSOCK:
fm |= os.ModeSocket
}
return fm
}
// fromFileMode converts from the os.FileMode specification to sftp filemode bits
func fromFileMode(mode os.FileMode) uint32 {
ret := uint32(mode & os.ModePerm)
switch mode & os.ModeType {
case os.ModeDevice | os.ModeCharDevice:
ret |= syscall.S_IFCHR
case os.ModeDevice:
ret |= syscall.S_IFBLK
case os.ModeDir:
ret |= syscall.S_IFDIR
case os.ModeNamedPipe:
ret |= syscall.S_IFIFO
case os.ModeSymlink:
ret |= syscall.S_IFLNK
case 0:
ret |= syscall.S_IFREG
case os.ModeSocket:
ret |= syscall.S_IFSOCK
}
return ret
}
// Plan 9 doesn't have setuid, setgid or sticky, but a Plan 9 client should
// be able to send these bits to a POSIX server.
const (
s_ISUID = 04000
s_ISGID = 02000
s_ISVTX = 01000
)
================================================
FILE: gossh/sftp/stat_posix.go
================================================
//go:build !plan9
// +build !plan9
package sftp
import (
"os"
"syscall"
)
const EBADF = syscall.EBADF
func wrapPathError(filepath string, err error) error {
if errno, ok := err.(syscall.Errno); ok {
return &os.PathError{Path: filepath, Err: errno}
}
return err
}
// translateErrno translates a syscall error number to a SFTP error code.
func translateErrno(errno syscall.Errno) uint32 {
switch errno {
case 0:
return sshFxOk
case syscall.ENOENT:
return sshFxNoSuchFile
case syscall.EACCES, syscall.EPERM:
return sshFxPermissionDenied
}
return sshFxFailure
}
func translateSyscallError(err error) (uint32, bool) {
switch e := err.(type) {
case syscall.Errno:
return translateErrno(e), true
case *os.PathError:
debug("statusFromError,pathError: error is %T %#v", e.Err, e.Err)
if errno, ok := e.Err.(syscall.Errno); ok {
return translateErrno(errno), true
}
}
return 0, false
}
// isRegular returns true if the mode describes a regular file.
func isRegular(mode uint32) bool {
return mode&S_IFMT == syscall.S_IFREG
}
// toFileMode converts sftp filemode bits to the os.FileMode specification
func toFileMode(mode uint32) os.FileMode {
var fm = os.FileMode(mode & 0777)
switch mode & S_IFMT {
case syscall.S_IFBLK:
fm |= os.ModeDevice
case syscall.S_IFCHR:
fm |= os.ModeDevice | os.ModeCharDevice
case syscall.S_IFDIR:
fm |= os.ModeDir
case syscall.S_IFIFO:
fm |= os.ModeNamedPipe
case syscall.S_IFLNK:
fm |= os.ModeSymlink
case syscall.S_IFREG:
// nothing to do
case syscall.S_IFSOCK:
fm |= os.ModeSocket
}
if mode&syscall.S_ISUID != 0 {
fm |= os.ModeSetuid
}
if mode&syscall.S_ISGID != 0 {
fm |= os.ModeSetgid
}
if mode&syscall.S_ISVTX != 0 {
fm |= os.ModeSticky
}
return fm
}
// fromFileMode converts from the os.FileMode specification to sftp filemode bits
func fromFileMode(mode os.FileMode) uint32 {
ret := uint32(mode & os.ModePerm)
switch mode & os.ModeType {
case os.ModeDevice | os.ModeCharDevice:
ret |= syscall.S_IFCHR
case os.ModeDevice:
ret |= syscall.S_IFBLK
case os.ModeDir:
ret |= syscall.S_IFDIR
case os.ModeNamedPipe:
ret |= syscall.S_IFIFO
case os.ModeSymlink:
ret |= syscall.S_IFLNK
case 0:
ret |= syscall.S_IFREG
case os.ModeSocket:
ret |= syscall.S_IFSOCK
}
if mode&os.ModeSetuid != 0 {
ret |= syscall.S_ISUID
}
if mode&os.ModeSetgid != 0 {
ret |= syscall.S_ISGID
}
if mode&os.ModeSticky != 0 {
ret |= syscall.S_ISVTX
}
return ret
}
const (
s_ISUID = syscall.S_ISUID
s_ISGID = syscall.S_ISGID
s_ISVTX = syscall.S_ISVTX
)
================================================
FILE: gossh/sftp/syscall_fixed.go
================================================
//go:build plan9 || windows || (js && wasm)
// +build plan9 windows js,wasm
// Go defines S_IFMT on windows, plan9 and js/wasm as 0x1f000 instead of
// 0xf000. None of the the other S_IFxyz values include the "1" (in 0x1f000)
// which prevents them from matching the bitmask.
package sftp
const S_IFMT = 0xf000
================================================
FILE: gossh/sftp/syscall_good.go
================================================
//go:build !plan9 && !windows && (!js || !wasm)
// +build !plan9
// +build !windows
// +build !js !wasm
package sftp
import "syscall"
const S_IFMT = syscall.S_IFMT
================================================
FILE: gossh/sqlite/ddlmod.go
================================================
package sqlite
import (
"database/sql"
"errors"
"fmt"
"regexp"
"strconv"
"strings"
"gossh/gorm/migrator"
)
var (
sqliteSeparator = "`|\"|'|\t"
uniqueRegexp = regexp.MustCompile(fmt.Sprintf(`^CONSTRAINT [%v]?[\w-]+[%v]? UNIQUE (.*)$`, sqliteSeparator, sqliteSeparator))
indexRegexp = regexp.MustCompile(fmt.Sprintf(`(?is)CREATE(?: UNIQUE)? INDEX [%v]?[\w\d-]+[%v]?(?s:.*?)ON (.*)$`, sqliteSeparator, sqliteSeparator))
tableRegexp = regexp.MustCompile(fmt.Sprintf(`(?is)(CREATE TABLE [%v]?[\w\d-]+[%v]?)(?:\s*\((.*)\))?`, sqliteSeparator, sqliteSeparator))
separatorRegexp = regexp.MustCompile(fmt.Sprintf("[%v]", sqliteSeparator))
columnsRegexp = regexp.MustCompile(fmt.Sprintf(`[(,][%v]?(\w+)[%v]?`, sqliteSeparator, sqliteSeparator))
columnRegexp = regexp.MustCompile(fmt.Sprintf(`^[%v]?([\w\d]+)[%v]?\s+([\w\(\)\d]+)(.*)$`, sqliteSeparator, sqliteSeparator))
defaultValueRegexp = regexp.MustCompile(`(?i) DEFAULT \(?(.+)?\)?( |COLLATE|GENERATED|$)`)
regRealDataType = regexp.MustCompile(`[^\d](\d+)[^\d]?`)
)
func getAllColumns(s string) []string {
allMatches := columnsRegexp.FindAllStringSubmatch(s, -1)
columns := make([]string, 0, len(allMatches))
for _, matches := range allMatches {
if len(matches) > 1 {
columns = append(columns, matches[1])
}
}
return columns
}
type ddl struct {
head string
fields []string
columns []migrator.ColumnType
}
func parseDDL(strs ...string) (*ddl, error) {
var result ddl
for _, str := range strs {
if sections := tableRegexp.FindStringSubmatch(str); len(sections) > 0 {
var (
ddlBody = sections[2]
ddlBodyRunes = []rune(ddlBody)
bracketLevel int
quote rune
buf string
)
ddlBodyRunesLen := len(ddlBodyRunes)
result.head = sections[1]
for idx := 0; idx < ddlBodyRunesLen; idx++ {
var (
next rune = 0
c = ddlBodyRunes[idx]
)
if idx+1 < ddlBodyRunesLen {
next = ddlBodyRunes[idx+1]
}
if sc := string(c); separatorRegexp.MatchString(sc) {
if c == next {
buf += sc // Skip escaped quote
idx++
} else if quote > 0 {
quote = 0
} else {
quote = c
}
} else if quote == 0 {
if c == '(' {
bracketLevel++
} else if c == ')' {
bracketLevel--
} else if bracketLevel == 0 {
if c == ',' {
result.fields = append(result.fields, strings.TrimSpace(buf))
buf = ""
continue
}
}
}
if bracketLevel < 0 {
return nil, errors.New("invalid DDL, unbalanced brackets")
}
buf += string(c)
}
if bracketLevel != 0 {
return nil, errors.New("invalid DDL, unbalanced brackets")
}
if buf != "" {
result.fields = append(result.fields, strings.TrimSpace(buf))
}
for _, f := range result.fields {
fUpper := strings.ToUpper(f)
if strings.HasPrefix(fUpper, "CHECK") {
continue
}
if strings.HasPrefix(fUpper, "CONSTRAINT") {
matches := uniqueRegexp.FindStringSubmatch(f)
if len(matches) > 0 {
if columns := getAllColumns(matches[1]); len(columns) == 1 {
for idx, column := range result.columns {
if column.NameValue.String == columns[0] {
column.UniqueValue = sql.NullBool{Bool: true, Valid: true}
result.columns[idx] = column
break
}
}
}
}
continue
}
if strings.HasPrefix(fUpper, "PRIMARY KEY") {
for _, name := range getAllColumns(f) {
for idx, column := range result.columns {
if column.NameValue.String == name {
column.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true}
result.columns[idx] = column
break
}
}
}
} else if matches := columnRegexp.FindStringSubmatch(f); len(matches) > 0 {
columnType := migrator.ColumnType{
NameValue: sql.NullString{String: matches[1], Valid: true},
DataTypeValue: sql.NullString{String: matches[2], Valid: true},
ColumnTypeValue: sql.NullString{String: matches[2], Valid: true},
PrimaryKeyValue: sql.NullBool{Valid: true},
UniqueValue: sql.NullBool{Valid: true},
NullableValue: sql.NullBool{Bool: true, Valid: true},
DefaultValueValue: sql.NullString{Valid: false},
}
matchUpper := strings.ToUpper(matches[3])
if strings.Contains(matchUpper, " NOT NULL") {
columnType.NullableValue = sql.NullBool{Bool: false, Valid: true}
} else if strings.Contains(matchUpper, " NULL") {
columnType.NullableValue = sql.NullBool{Bool: true, Valid: true}
}
if strings.Contains(matchUpper, " UNIQUE") {
columnType.UniqueValue = sql.NullBool{Bool: true, Valid: true}
}
if strings.Contains(matchUpper, " PRIMARY") {
columnType.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true}
}
if defaultMatches := defaultValueRegexp.FindStringSubmatch(matches[3]); len(defaultMatches) > 1 {
if strings.ToLower(defaultMatches[1]) != "null" {
columnType.DefaultValueValue = sql.NullString{String: strings.Trim(defaultMatches[1], `"`), Valid: true}
}
}
// data type length
matches := regRealDataType.FindAllStringSubmatch(columnType.DataTypeValue.String, -1)
if len(matches) == 1 && len(matches[0]) == 2 {
size, _ := strconv.Atoi(matches[0][1])
columnType.LengthValue = sql.NullInt64{Valid: true, Int64: int64(size)}
columnType.DataTypeValue.String = strings.TrimSuffix(columnType.DataTypeValue.String, matches[0][0])
}
result.columns = append(result.columns, columnType)
}
}
} else if matches := indexRegexp.FindStringSubmatch(str); len(matches) > 0 {
// don't report Unique by UniqueIndex
} else {
return nil, errors.New("invalid DDL")
}
}
return &result, nil
}
func (d *ddl) clone() *ddl {
copied := new(ddl)
*copied = *d
copied.fields = make([]string, len(d.fields))
copy(copied.fields, d.fields)
copied.columns = make([]migrator.ColumnType, len(d.columns))
copy(copied.columns, d.columns)
return copied
}
func (d *ddl) compile() string {
if len(d.fields) == 0 {
return d.head
}
return fmt.Sprintf("%s (%s)", d.head, strings.Join(d.fields, ","))
}
func (d *ddl) renameTable(dst, src string) error {
tableReg, err := regexp.Compile("\\s*('|`|\")?\\b" + regexp.QuoteMeta(src) + "\\b('|`|\")?\\s*")
if err != nil {
return err
}
replaced := tableReg.ReplaceAllString(d.head, fmt.Sprintf(" `%s` ", dst))
if replaced == d.head {
return fmt.Errorf("failed to look up tablename `%s` from DDL head '%s'", src, d.head)
}
d.head = replaced
return nil
}
func (d *ddl) addConstraint(name string, sql string) {
reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]")
for i := 0; i < len(d.fields); i++ {
if reg.MatchString(d.fields[i]) {
d.fields[i] = sql
return
}
}
d.fields = append(d.fields, sql)
}
func (d *ddl) removeConstraint(name string) bool {
reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]")
for i := 0; i < len(d.fields); i++ {
if reg.MatchString(d.fields[i]) {
d.fields = append(d.fields[:i], d.fields[i+1:]...)
return true
}
}
return false
}
func (d *ddl) hasConstraint(name string) bool {
reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]")
for _, f := range d.fields {
if reg.MatchString(f) {
return true
}
}
return false
}
func (d *ddl) getColumns() []string {
res := []string{}
for _, f := range d.fields {
fUpper := strings.ToUpper(f)
if strings.HasPrefix(fUpper, "PRIMARY KEY") ||
strings.HasPrefix(fUpper, "CHECK") ||
strings.HasPrefix(fUpper, "CONSTRAINT") ||
strings.Contains(fUpper, "GENERATED ALWAYS AS") {
continue
}
reg := regexp.MustCompile("^[\"`']?([\\w\\d]+)[\"`']?")
match := reg.FindStringSubmatch(f)
if match != nil {
res = append(res, "`"+match[1]+"`")
}
}
return res
}
func (d *ddl) removeColumn(name string) bool {
reg := regexp.MustCompile("^(`|'|\"| )" + regexp.QuoteMeta(name) + "(`|'|\"| ) .*?$")
for i := 0; i < len(d.fields); i++ {
if reg.MatchString(d.fields[i]) {
d.fields = append(d.fields[:i], d.fields[i+1:]...)
return true
}
}
return false
}
================================================
FILE: gossh/sqlite/errors.go
================================================
package sqlite
import "errors"
var (
ErrConstraintsNotImplemented = errors.New("constraints not implemented on sqlite, consider using DisableForeignKeyConstraintWhenMigrating, more details https://github.com/go-gorm/gorm/wiki/GORM-V2-Release-Note-Draft#all-new-migrator")
)
================================================
FILE: gossh/sqlite/migrator.go
================================================
package sqlite
import (
"database/sql"
"fmt"
"strings"
"gossh/gorm"
"gossh/gorm/clause"
"gossh/gorm/migrator"
"gossh/gorm/schema"
)
type Migrator struct {
migrator.Migrator
}
func (m *Migrator) RunWithoutForeignKey(fc func() error) error {
var enabled int
m.DB.Raw("PRAGMA foreign_keys").Scan(&enabled)
if enabled == 1 {
m.DB.Exec("PRAGMA foreign_keys = OFF")
defer m.DB.Exec("PRAGMA foreign_keys = ON")
}
return fc()
}
func (m Migrator) HasTable(value interface{}) bool {
var count int
m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
return m.DB.Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", stmt.Table).Row().Scan(&count)
})
return count > 0
}
func (m Migrator) DropTable(values ...interface{}) error {
return m.RunWithoutForeignKey(func() error {
values = m.ReorderModels(values, false)
tx := m.DB.Session(&gorm.Session{})
for i := len(values) - 1; i >= 0; i-- {
if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error {
return tx.Exec("DROP TABLE IF EXISTS ?", clause.Table{Name: stmt.Table}).Error
}); err != nil {
return err
}
}
return nil
})
}
func (m Migrator) GetTables() (tableList []string, err error) {
return tableList, m.DB.Raw("SELECT name FROM sqlite_master where type=?", "table").Scan(&tableList).Error
}
func (m Migrator) HasColumn(value interface{}, name string) bool {
var count int
m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(name); field != nil {
name = field.DBName
}
}
if name != "" {
m.DB.Raw(
"SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ?)",
"table", stmt.Table, `%"`+name+`" %`, `%`+name+` %`, "%`"+name+"`%", "%["+name+"]%", "%\t"+name+"\t%",
).Row().Scan(&count)
}
return nil
})
return count > 0
}
func (m Migrator) AlterColumn(value interface{}, name string) error {
return m.RunWithoutForeignKey(func() error {
return m.recreateTable(value, nil, func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
if field := stmt.Schema.LookUpField(name); field != nil {
var sqlArgs []interface{}
for i, f := range ddl.fields {
if matches := columnRegexp.FindStringSubmatch(f); len(matches) > 1 && matches[1] == field.DBName {
ddl.fields[i] = fmt.Sprintf("`%v` ?", field.DBName)
sqlArgs = []interface{}{m.FullDataTypeOf(field)}
// table created by old version might look like `CREATE TABLE ? (? varchar(10) UNIQUE)`.
// FullDataTypeOf doesn't contain UNIQUE, so we need to add unique constraint.
if strings.Contains(strings.ToUpper(matches[3]), " UNIQUE") {
uniName := m.DB.NamingStrategy.UniqueName(stmt.Table, field.DBName)
uni, _ := m.GuessConstraintInterfaceAndTable(stmt, uniName)
if uni != nil {
uniSQL, uniArgs := uni.Build()
ddl.addConstraint(uniName, uniSQL)
sqlArgs = append(sqlArgs, uniArgs...)
}
}
break
}
}
return ddl, sqlArgs, nil
}
return nil, nil, fmt.Errorf("failed to alter field with name %v", name)
})
})
}
// ColumnTypes return columnTypes []gorm.ColumnType and execErr error
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
columnTypes := make([]gorm.ColumnType, 0)
execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
var (
sqls []string
sqlDDL *ddl
)
if err := m.DB.Raw("SELECT sql FROM sqlite_master WHERE type IN ? AND tbl_name = ? AND sql IS NOT NULL order by type = ? desc", []string{"table", "index"}, stmt.Table, "table").Scan(&sqls).Error; err != nil {
return err
}
if sqlDDL, err = parseDDL(sqls...); err != nil {
return err
}
rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows()
if err != nil {
return err
}
defer func() {
err = rows.Close()
}()
var rawColumnTypes []*sql.ColumnType
rawColumnTypes, err = rows.ColumnTypes()
if err != nil {
return err
}
for _, c := range rawColumnTypes {
columnType := migrator.ColumnType{SQLColumnType: c}
for _, column := range sqlDDL.columns {
if column.NameValue.String == c.Name() {
column.SQLColumnType = c
columnType = column
break
}
}
columnTypes = append(columnTypes, columnType)
}
return err
})
return columnTypes, execErr
}
func (m Migrator) DropColumn(value interface{}, name string) error {
return m.recreateTable(value, nil, func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
if field := stmt.Schema.LookUpField(name); field != nil {
name = field.DBName
}
ddl.removeColumn(name)
return ddl, nil, nil
})
}
func (m Migrator) CreateConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
return m.recreateTable(value, &table,
func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
var (
constraintName string
constraintSql string
constraintValues []interface{}
)
if constraint != nil {
constraintName = constraint.GetName()
constraintSql, constraintValues = constraint.Build()
} else {
return nil, nil, nil
}
ddl.addConstraint(constraintName, constraintSql)
return ddl, constraintValues, nil
})
})
}
func (m Migrator) DropConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
if constraint != nil {
name = constraint.GetName()
}
return m.recreateTable(value, &table,
func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
ddl.removeConstraint(name)
return ddl, nil, nil
})
})
}
func (m Migrator) HasConstraint(value interface{}, name string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
if constraint != nil {
name = constraint.GetName()
}
m.DB.Raw(
"SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ?)",
"table", table, `%CONSTRAINT "`+name+`" %`, `%CONSTRAINT `+name+` %`, "%CONSTRAINT `"+name+"`%", "%CONSTRAINT ["+name+"]%", "%CONSTRAINT \t"+name+"\t%",
).Row().Scan(&count)
return nil
})
return count > 0
}
func (m Migrator) CurrentDatabase() (name string) {
var null interface{}
m.DB.Raw("PRAGMA database_list").Row().Scan(&null, &name, &null)
return
}
func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
for _, opt := range opts {
str := stmt.Quote(opt.DBName)
if opt.Expression != "" {
str = opt.Expression
}
if opt.Collate != "" {
str += " COLLATE " + opt.Collate
}
if opt.Sort != "" {
str += " " + opt.Sort
}
results = append(results, clause.Expr{SQL: str})
}
return
}
func (m Migrator) CreateIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if idx := stmt.Schema.LookIndex(name); idx != nil {
opts := m.BuildIndexOptions(idx.Fields, stmt)
values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts}
createIndexSQL := "CREATE "
if idx.Class != "" {
createIndexSQL += idx.Class + " "
}
createIndexSQL += "INDEX ?"
if idx.Type != "" {
createIndexSQL += " USING " + idx.Type
}
createIndexSQL += " ON ??"
if idx.Where != "" {
createIndexSQL += " WHERE " + idx.Where
}
return m.DB.Exec(createIndexSQL, values...).Error
}
}
return fmt.Errorf("failed to create index with name %v", name)
})
}
func (m Migrator) HasIndex(value interface{}, name string) bool {
var count int
m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
}
if name != "" {
m.DB.Raw(
"SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, name,
).Row().Scan(&count)
}
return nil
})
return count > 0
}
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
var sql string
m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, oldName).Row().Scan(&sql)
if sql != "" {
if err := m.DropIndex(value, oldName); err != nil {
return err
}
return m.DB.Exec(strings.Replace(sql, oldName, newName, 1)).Error
}
return fmt.Errorf("failed to find index with name %v", oldName)
})
}
func (m Migrator) DropIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
}
return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error
})
}
type Index struct {
Seq int
Name string
Unique bool
Origin string
Partial bool
}
// GetIndexes return Indexes []gorm.Index and execErr error,
// See the [doc]
//
// [doc]: https://www.sqlite.org/pragma.html#pragma_index_list
func (m Migrator) GetIndexes(value interface{}) ([]gorm.Index, error) {
indexes := make([]gorm.Index, 0)
err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
rst := make([]*Index, 0)
if err := m.DB.Debug().Raw("SELECT * FROM PRAGMA_index_list(?)", stmt.Table).Scan(&rst).Error; err != nil { // alias `PRAGMA index_list(?)`
return err
}
for _, index := range rst {
if index.Origin == "u" { // skip the index was created by a UNIQUE constraint
continue
}
var columns []string
if err := m.DB.Raw("SELECT name FROM PRAGMA_index_info(?)", index.Name).Scan(&columns).Error; err != nil { // alias `PRAGMA index_info(?)`
return err
}
indexes = append(indexes, &migrator.Index{
TableName: stmt.Table,
NameValue: index.Name,
ColumnList: columns,
PrimaryKeyValue: sql.NullBool{Bool: index.Origin == "pk", Valid: true}, // The exceptions are INTEGER PRIMARY KEY
UniqueValue: sql.NullBool{Bool: index.Unique, Valid: true},
})
}
return nil
})
return indexes, err
}
func (m Migrator) getRawDDL(table string) (string, error) {
var createSQL string
m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", table, table).Row().Scan(&createSQL)
if m.DB.Error != nil {
return "", m.DB.Error
}
return createSQL, nil
}
func (m Migrator) recreateTable(
value interface{}, tablePtr *string,
getCreateSQL func(ddl *ddl, stmt *gorm.Statement) (sql *ddl, sqlArgs []interface{}, err error),
) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
table := stmt.Table
if tablePtr != nil {
table = *tablePtr
}
rawDDL, err := m.getRawDDL(table)
if err != nil {
return err
}
originDDL, err := parseDDL(rawDDL)
if err != nil {
return err
}
createDDL, sqlArgs, err := getCreateSQL(originDDL.clone(), stmt)
if err != nil {
return err
}
if createDDL == nil {
return nil
}
newTableName := table + "__temp"
if err := createDDL.renameTable(newTableName, table); err != nil {
return err
}
columns := createDDL.getColumns()
createSQL := createDDL.compile()
return m.DB.Transaction(func(tx *gorm.DB) error {
if err := tx.Exec(createSQL, sqlArgs...).Error; err != nil {
return err
}
queries := []string{
fmt.Sprintf("INSERT INTO `%v`(%v) SELECT %v FROM `%v`", newTableName, strings.Join(columns, ","), strings.Join(columns, ","), table),
fmt.Sprintf("DROP TABLE `%v`", table),
fmt.Sprintf("ALTER TABLE `%v` RENAME TO `%v`", newTableName, table),
}
for _, query := range queries {
if err := tx.Exec(query).Error; err != nil {
return err
}
}
return nil
})
})
}
func init() {
}
================================================
FILE: gossh/sqlite/sqlite.go
================================================
package sqlite
import (
"context"
"database/sql"
"strconv"
"gossh/gorm/callbacks"
gosqlite "github.com/glebarez/go-sqlite"
sqlite3 "modernc.org/sqlite/lib"
"gossh/gorm"
"gossh/gorm/clause"
"gossh/gorm/logger"
"gossh/gorm/migrator"
"gossh/gorm/schema"
)
// DriverName is the default driver name for SQLite.
const DriverName = "sqlite"
type Dialector struct {
DriverName string
DSN string
Conn gorm.ConnPool
}
func Open(dsn string) gorm.Dialector {
return &Dialector{DSN: dsn}
}
func (dialector Dialector) Name() string {
return "sqlite"
}
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
if dialector.DriverName == "" {
dialector.DriverName = DriverName
}
if dialector.Conn != nil {
db.ConnPool = dialector.Conn
} else {
conn, err := sql.Open(dialector.DriverName, dialector.DSN)
if err != nil {
return err
}
db.ConnPool = conn
}
var version string
if err := db.ConnPool.QueryRowContext(context.Background(), "select sqlite_version()").Scan(&version); err != nil {
return err
}
// https://www.sqlite.org/releaselog/3_35_0.html
if compareVersion(version, "3.35.0") >= 0 {
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"},
UpdateClauses: []string{"UPDATE", "SET", "FROM", "WHERE", "RETURNING"},
DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"},
LastInsertIDReversed: true,
})
} else {
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
LastInsertIDReversed: true,
})
}
for k, v := range dialector.ClauseBuilders() {
db.ClauseBuilders[k] = v
}
return
}
func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
return map[string]clause.ClauseBuilder{
"INSERT": func(c clause.Clause, builder clause.Builder) {
if insert, ok := c.Expression.(clause.Insert); ok {
if stmt, ok := builder.(*gorm.Statement); ok {
stmt.WriteString("INSERT ")
if insert.Modifier != "" {
stmt.WriteString(insert.Modifier)
stmt.WriteByte(' ')
}
stmt.WriteString("INTO ")
if insert.Table.Name == "" {
stmt.WriteQuoted(stmt.Table)
} else {
stmt.WriteQuoted(insert.Table)
}
return
}
}
c.Build(builder)
},
"LIMIT": func(c clause.Clause, builder clause.Builder) {
if limit, ok := c.Expression.(clause.Limit); ok {
var lmt = -1
if limit.Limit != nil && *limit.Limit >= 0 {
lmt = *limit.Limit
}
if lmt >= 0 || limit.Offset > 0 {
builder.WriteString("LIMIT ")
builder.WriteString(strconv.Itoa(lmt))
}
if limit.Offset > 0 {
builder.WriteString(" OFFSET ")
builder.WriteString(strconv.Itoa(limit.Offset))
}
}
},
"FOR": func(c clause.Clause, builder clause.Builder) {
if _, ok := c.Expression.(clause.Locking); ok {
// SQLite3 does not support row-level locking.
return
}
c.Build(builder)
},
}
}
func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression {
if field.AutoIncrement {
return clause.Expr{SQL: "NULL"}
}
// doesn't work, will raise error
return clause.Expr{SQL: "DEFAULT"}
}
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
return Migrator{migrator.Migrator{Config: migrator.Config{
DB: db,
Dialector: dialector,
CreateIndexAfterCreateTable: true,
}}}
}
func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
writer.WriteByte('?')
}
func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
var (
underQuoted, selfQuoted bool
continuousBacktick int8
shiftDelimiter int8
)
for _, v := range []byte(str) {
switch v {
case '`':
continuousBacktick++
if continuousBacktick == 2 {
writer.WriteString("``")
continuousBacktick = 0
}
case '.':
if continuousBacktick > 0 || !selfQuoted {
shiftDelimiter = 0
underQuoted = false
continuousBacktick = 0
writer.WriteString("`")
}
writer.WriteByte(v)
continue
default:
if shiftDelimiter-continuousBacktick <= 0 && !underQuoted {
writer.WriteString("`")
underQuoted = true
if selfQuoted = continuousBacktick > 0; selfQuoted {
continuousBacktick -= 1
}
}
for ; continuousBacktick > 0; continuousBacktick -= 1 {
writer.WriteString("``")
}
writer.WriteByte(v)
}
shiftDelimiter++
}
if continuousBacktick > 0 && !selfQuoted {
writer.WriteString("``")
}
writer.WriteString("`")
}
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
return logger.ExplainSQL(sql, nil, `"`, vars...)
}
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
switch field.DataType {
case schema.Bool:
return "numeric"
case schema.Int, schema.Uint:
if field.AutoIncrement {
// doesn't check `PrimaryKey`, to keep backward compatibility
// https://www.sqlite.org/autoinc.html
return "integer PRIMARY KEY AUTOINCREMENT"
} else {
return "integer"
}
case schema.Float:
return "real"
case schema.String:
return "text"
case schema.Time:
// Distinguish between schema.Time and tag time
if val, ok := field.TagSettings["TYPE"]; ok {
return val
} else {
return "datetime"
}
case schema.Bytes:
return "blob"
}
return string(field.DataType)
}
func (dialectopr Dialector) SavePoint(tx *gorm.DB, name string) error {
tx.Exec("SAVEPOINT " + name)
return nil
}
func (dialectopr Dialector) RollbackTo(tx *gorm.DB, name string) error {
tx.Exec("ROLLBACK TO SAVEPOINT " + name)
return nil
}
func (dialector Dialector) Translate(err error) error {
switch terr := err.(type) {
case *gosqlite.Error:
switch terr.Code() {
case sqlite3.SQLITE_CONSTRAINT_UNIQUE:
return gorm.ErrDuplicatedKey
case sqlite3.SQLITE_CONSTRAINT_PRIMARYKEY:
return gorm.ErrDuplicatedKey
case sqlite3.SQLITE_CONSTRAINT_FOREIGNKEY:
return gorm.ErrForeignKeyViolated
}
}
return err
}
func compareVersion(version1, version2 string) int {
n, m := len(version1), len(version2)
i, j := 0, 0
for i < n || j < m {
x := 0
for ; i < n && version1[i] != '.'; i++ {
x = x*10 + int(version1[i]-'0')
}
i++
y := 0
for ; j < m && version2[j] != '.'; j++ {
y = y*10 + int(version2[j]-'0')
}
j++
if x > y {
return 1
}
if x < y {
return -1
}
}
return 0
}
================================================
FILE: gossh/sys/cpu/asm_aix_ppc64.s
================================================
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gc
#include "textflag.h"
//
// System calls for ppc64, AIX are implemented in runtime/syscall_aix.go
//
TEXT ·syscall6(SB),NOSPLIT,$0-88
JMP syscall·syscall6(SB)
TEXT ·rawSyscall6(SB),NOSPLIT,$0-88
JMP syscall·rawSyscall6(SB)
================================================
FILE: gossh/sys/cpu/byteorder.go
================================================
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cpu
import (
"runtime"
)
// byteOrder is a subset of encoding/binary.ByteOrder.
type byteOrder interface {
Uint32([]byte) uint32
Uint64([]byte) uint64
}
type littleEndian struct{}
type bigEndian struct{}
func (littleEndian) Uint32(b []byte) uint32 {
_ = b[3] // bounds check hint to compiler; see golang.org/issue/14808
return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
}
func (littleEndian) Uint64(b []byte) uint64 {
_ = b[7] // bounds check hint to compiler; see golang.org/issue/14808
return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 |
uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56
}
func (bigEndian) Uint32(b []byte) uint32 {
_ = b[3] // bounds check hint to compiler; see golang.org/issue/14808
return uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24
}
func (bigEndian) Uint64(b []byte) uint64 {
_ = b[7] // bounds check hint to compiler; see golang.org/issue/14808
return uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 |
uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56
}
// hostByteOrder returns littleEndian on little-endian machines and
// bigEndian on big-endian machines.
func hostByteOrder() byteOrder {
switch runtime.GOARCH {
case "386", "amd64", "amd64p32",
"alpha",
"arm", "arm64",
"loong64",
"mipsle", "mips64le", "mips64p32le",
"nios2",
"ppc64le",
"riscv", "riscv64",
"sh":
return littleEndian{}
case "armbe", "arm64be",
"m68k",
"mips", "mips64", "mips64p32",
"ppc", "ppc64",
"s390", "s390x",
"shbe",
"sparc", "sparc64":
return bigEndian{}
}
panic("unknown architecture")
}
================================================
FILE: gossh/sys/cpu/cpu.go
================================================
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package cpu implements processor feature detection for
// various CPU architectures.
package cpu
import (
"os"
"strings"
)
// Initialized reports whether the CPU features were initialized.
//
// For some GOOS/GOARCH combinations initialization of the CPU features depends
// on reading an operating specific file, e.g. /proc/self/auxv on linux/arm
// Initialized will report false if reading the file fails.
var Initialized bool
// CacheLinePad is used to pad structs to avoid false sharing.
type CacheLinePad struct{ _ [cacheLineSize]byte }
// X86 contains the supported CPU features of the
// current X86/AMD64 platform. If the current platform
// is not X86/AMD64 then all feature flags are false.
//
// X86 is padded to avoid false sharing. Further the HasAVX
// and HasAVX2 are only set if the OS supports XMM and YMM
// registers in addition to the CPUID feature bit being set.
var X86 struct {
_ CacheLinePad
HasAES bool // AES hardware implementation (AES NI)
HasADX bool // Multi-precision add-carry instruction extensions
HasAVX bool // Advanced vector extension
HasAVX2 bool // Advanced vector extension 2
HasAVX512 bool // Advanced vector extension 512
HasAVX512F bool // Advanced vector extension 512 Foundation Instructions
HasAVX512CD bool // Advanced vector extension 512 Conflict Detection Instructions
HasAVX512ER bool // Advanced vector extension 512 Exponential and Reciprocal Instructions
HasAVX512PF bool // Advanced vector extension 512 Prefetch Instructions
HasAVX512VL bool // Advanced vector extension 512 Vector Length Extensions
HasAVX512BW bool // Advanced vector extension 512 Byte and Word Instructions
HasAVX512DQ bool // Advanced vector extension 512 Doubleword and Quadword Instructions
HasAVX512IFMA bool // Advanced vector extension 512 Integer Fused Multiply Add
HasAVX512VBMI bool // Advanced vector extension 512 Vector Byte Manipulation Instructions
HasAVX5124VNNIW bool // Advanced vector extension 512 Vector Neural Network Instructions Word variable precision
HasAVX5124FMAPS bool // Advanced vector extension 512 Fused Multiply Accumulation Packed Single precision
HasAVX512VPOPCNTDQ bool // Advanced vector extension 512 Double and quad word population count instructions
HasAVX512VPCLMULQDQ bool // Advanced vector extension 512 Vector carry-less multiply operations
HasAVX512VNNI bool // Advanced vector extension 512 Vector Neural Network Instructions
HasAVX512GFNI bool // Advanced vector extension 512 Galois field New Instructions
HasAVX512VAES bool // Advanced vector extension 512 Vector AES instructions
HasAVX512VBMI2 bool // Advanced vector extension 512 Vector Byte Manipulation Instructions 2
HasAVX512BITALG bool // Advanced vector extension 512 Bit Algorithms
HasAVX512BF16 bool // Advanced vector extension 512 BFloat16 Instructions
HasAMXTile bool // Advanced Matrix Extension Tile instructions
HasAMXInt8 bool // Advanced Matrix Extension Int8 instructions
HasAMXBF16 bool // Advanced Matrix Extension BFloat16 instructions
HasBMI1 bool // Bit manipulation instruction set 1
HasBMI2 bool // Bit manipulation instruction set 2
HasCX16 bool // Compare and exchange 16 Bytes
HasERMS bool // Enhanced REP for MOVSB and STOSB
HasFMA bool // Fused-multiply-add instructions
HasOSXSAVE bool // OS supports XSAVE/XRESTOR for saving/restoring XMM registers.
HasPCLMULQDQ bool // PCLMULQDQ instruction - most often used for AES-GCM
HasPOPCNT bool // Hamming weight instruction POPCNT.
HasRDRAND bool // RDRAND instruction (on-chip random number generator)
HasRDSEED bool // RDSEED instruction (on-chip random number generator)
HasSSE2 bool // Streaming SIMD extension 2 (always available on amd64)
HasSSE3 bool // Streaming SIMD extension 3
HasSSSE3 bool // Supplemental streaming SIMD extension 3
HasSSE41 bool // Streaming SIMD extension 4 and 4.1
HasSSE42 bool // Streaming SIMD extension 4 and 4.2
_ CacheLinePad
}
// ARM64 contains the supported CPU features of the
// current ARMv8(aarch64) platform. If the current platform
// is not arm64 then all feature flags are false.
var ARM64 struct {
_ CacheLinePad
HasFP bool // Floating-point instruction set (always available)
HasASIMD bool // Advanced SIMD (always available)
HasEVTSTRM bool // Event stream support
HasAES bool // AES hardware implementation
HasPMULL bool // Polynomial multiplication instruction set
HasSHA1 bool // SHA1 hardware implementation
HasSHA2 bool // SHA2 hardware implementation
HasCRC32 bool // CRC32 hardware implementation
HasATOMICS bool // Atomic memory operation instruction set
HasFPHP bool // Half precision floating-point instruction set
HasASIMDHP bool // Advanced SIMD half precision instruction set
HasCPUID bool // CPUID identification scheme registers
HasASIMDRDM bool // Rounding double multiply add/subtract instruction set
HasJSCVT bool // Javascript conversion from floating-point to integer
HasFCMA bool // Floating-point multiplication and addition of complex numbers
HasLRCPC bool // Release Consistent processor consistent support
HasDCPOP bool // Persistent memory support
HasSHA3 bool // SHA3 hardware implementation
HasSM3 bool // SM3 hardware implementation
HasSM4 bool // SM4 hardware implementation
HasASIMDDP bool // Advanced SIMD double precision instruction set
HasSHA512 bool // SHA512 hardware implementation
HasSVE bool // Scalable Vector Extensions
HasASIMDFHM bool // Advanced SIMD multiplication FP16 to FP32
_ CacheLinePad
}
// ARM contains the supported CPU features of the current ARM (32-bit) platform.
// All feature flags are false if:
// 1. the current platform is not arm, or
// 2. the current operating system is not Linux.
var ARM struct {
_ CacheLinePad
HasSWP bool // SWP instruction support
HasHALF bool // Half-word load and store support
HasTHUMB bool // ARM Thumb instruction set
Has26BIT bool // Address space limited to 26-bits
HasFASTMUL bool // 32-bit operand, 64-bit result multiplication support
HasFPA bool // Floating point arithmetic support
HasVFP bool // Vector floating point support
HasEDSP bool // DSP Extensions support
HasJAVA bool // Java instruction set
HasIWMMXT bool // Intel Wireless MMX technology support
HasCRUNCH bool // MaverickCrunch context switching and handling
HasTHUMBEE bool // Thumb EE instruction set
HasNEON bool // NEON instruction set
HasVFPv3 bool // Vector floating point version 3 support
HasVFPv3D16 bool // Vector floating point version 3 D8-D15
HasTLS bool // Thread local storage support
HasVFPv4 bool // Vector floating point version 4 support
HasIDIVA bool // Integer divide instruction support in ARM mode
HasIDIVT bool // Integer divide instruction support in Thumb mode
HasVFPD32 bool // Vector floating point version 3 D15-D31
HasLPAE bool // Large Physical Address Extensions
HasEVTSTRM bool // Event stream support
HasAES bool // AES hardware implementation
HasPMULL bool // Polynomial multiplication instruction set
HasSHA1 bool // SHA1 hardware implementation
HasSHA2 bool // SHA2 hardware implementation
HasCRC32 bool // CRC32 hardware implementation
_ CacheLinePad
}
// MIPS64X contains the supported CPU features of the current mips64/mips64le
// platforms. If the current platform is not mips64/mips64le or the current
// operating system is not Linux then all feature flags are false.
var MIPS64X struct {
_ CacheLinePad
HasMSA bool // MIPS SIMD architecture
_ CacheLinePad
}
// PPC64 contains the supported CPU features of the current ppc64/ppc64le platforms.
// If the current platform is not ppc64/ppc64le then all feature flags are false.
//
// For ppc64/ppc64le, it is safe to check only for ISA level starting on ISA v3.00,
// since there are no optional categories. There are some exceptions that also
// require kernel support to work (DARN, SCV), so there are feature bits for
// those as well. The struct is padded to avoid false sharing.
var PPC64 struct {
_ CacheLinePad
HasDARN bool // Hardware random number generator (requires kernel enablement)
HasSCV bool // Syscall vectored (requires kernel enablement)
IsPOWER8 bool // ISA v2.07 (POWER8)
IsPOWER9 bool // ISA v3.00 (POWER9), implies IsPOWER8
_ CacheLinePad
}
// S390X contains the supported CPU features of the current IBM Z
// (s390x) platform. If the current platform is not IBM Z then all
// feature flags are false.
//
// S390X is padded to avoid false sharing. Further HasVX is only set
// if the OS supports vector registers in addition to the STFLE
// feature bit being set.
var S390X struct {
_ CacheLinePad
HasZARCH bool // z/Architecture mode is active [mandatory]
HasSTFLE bool // store facility list extended
HasLDISP bool // long (20-bit) displacements
HasEIMM bool // 32-bit immediates
HasDFP bool // decimal floating point
HasETF3EH bool // ETF-3 enhanced
HasMSA bool // message security assist (CPACF)
HasAES bool // KM-AES{128,192,256} functions
HasAESCBC bool // KMC-AES{128,192,256} functions
HasAESCTR bool // KMCTR-AES{128,192,256} functions
HasAESGCM bool // KMA-GCM-AES{128,192,256} functions
HasGHASH bool // KIMD-GHASH function
HasSHA1 bool // K{I,L}MD-SHA-1 functions
HasSHA256 bool // K{I,L}MD-SHA-256 functions
HasSHA512 bool // K{I,L}MD-SHA-512 functions
HasSHA3 bool // K{I,L}MD-SHA3-{224,256,384,512} and K{I,L}MD-SHAKE-{128,256} functions
HasVX bool // vector facility
HasVXE bool // vector-enhancements facility 1
_ CacheLinePad
}
func init() {
archInit()
initOptions()
processOptions()
}
// options contains the cpu debug options that can be used in GODEBUG.
// Options are arch dependent and are added by the arch specific initOptions functions.
// Features that are mandatory for the specific GOARCH should have the Required field set
// (e.g. SSE2 on amd64).
var options []option
// Option names should be lower case. e.g. avx instead of AVX.
type option struct {
Name string
Feature *bool
Specified bool // whether feature value was specified in GODEBUG
Enable bool // whether feature should be enabled
Required bool // whether feature is mandatory and can not be disabled
}
func processOptions() {
env := os.Getenv("GODEBUG")
field:
for env != "" {
field := ""
i := strings.IndexByte(env, ',')
if i < 0 {
field, env = env, ""
} else {
field, env = env[:i], env[i+1:]
}
if len(field) < 4 || field[:4] != "cpu." {
continue
}
i = strings.IndexByte(field, '=')
if i < 0 {
print("GODEBUG sys/cpu: no value specified for \"", field, "\"\n")
continue
}
key, value := field[4:i], field[i+1:] // e.g. "SSE2", "on"
var enable bool
switch value {
case "on":
enable = true
case "off":
enable = false
default:
print("GODEBUG sys/cpu: value \"", value, "\" not supported for cpu option \"", key, "\"\n")
continue field
}
if key == "all" {
for i := range options {
options[i].Specified = true
options[i].Enable = enable || options[i].Required
}
continue field
}
for i := range options {
if options[i].Name == key {
options[i].Specified = true
options[i].Enable = enable
continue field
}
}
print("GODEBUG sys/cpu: unknown cpu feature \"", key, "\"\n")
}
for _, o := range options {
if !o.Specified {
continue
}
if o.Enable && !*o.Feature {
print("GODEBUG sys/cpu: can not enable \"", o.Name, "\", missing CPU support\n")
continue
}
if !o.Enable && o.Required {
print("GODEBUG sys/cpu: can not disable \"", o.Name, "\", required CPU feature\n")
continue
}
*o.Feature = o.Enable
}
}
================================================
FILE: gossh/sys/cpu/cpu_aix.go
================================================
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build aix
package cpu
const (
// getsystemcfg constants
_SC_IMPL = 2
_IMPL_POWER8 = 0x10000
_IMPL_POWER9 = 0x20000
)
func archInit() {
impl := getsystemcfg(_SC_IMPL)
if impl&_IMPL_POWER8 != 0 {
PPC64.IsPOWER8 = true
}
if impl&_IMPL_POWER9 != 0 {
PPC64.IsPOWER8 = true
PPC64.IsPOWER9 = true
}
Initialized = true
}
func getsystemcfg(label int) (n uint64) {
r0, _ := callgetsystemcfg(label)
n = uint64(r0)
return
}
================================================
FILE: gossh/sys/cpu/cpu_arm.go
================================================
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cpu
const cacheLineSize = 32
// HWCAP/HWCAP2 bits.
// These are specific to Linux.
const (
hwcap_SWP = 1 << 0
hwcap_HALF = 1 << 1
hwcap_THUMB = 1 << 2
hwcap_26BIT = 1 << 3
hwcap_FAST_MULT = 1 << 4
hwcap_FPA = 1 << 5
hwcap_VFP = 1 << 6
hwcap_EDSP = 1 << 7
hwcap_JAVA = 1 << 8
hwcap_IWMMXT = 1 << 9
hwcap_CRUNCH = 1 << 10
hwcap_THUMBEE = 1 << 11
hwcap_NEON = 1 << 12
hwcap_VFPv3 = 1 << 13
hwcap_VFPv3D16 = 1 << 14
hwcap_TLS = 1 << 15
hwcap_VFPv4 = 1 << 16
hwcap_IDIVA = 1 << 17
hwcap_IDIVT = 1 << 18
hwcap_VFPD32 = 1 << 19
hwcap_LPAE = 1 << 20
hwcap_EVTSTRM = 1 << 21
hwcap2_AES = 1 << 0
hwcap2_PMULL = 1 << 1
hwcap2_SHA1 = 1 << 2
hwcap2_SHA2 = 1 << 3
hwcap2_CRC32 = 1 << 4
)
func initOptions() {
options = []option{
{Name: "pmull", Feature: &ARM.HasPMULL},
{Name: "sha1", Feature: &ARM.HasSHA1},
{Name: "sha2", Feature: &ARM.HasSHA2},
{Name: "swp", Feature: &ARM.HasSWP},
{Name: "thumb", Feature: &ARM.HasTHUMB},
{Name: "thumbee", Feature: &ARM.HasTHUMBEE},
{Name: "tls", Feature: &ARM.HasTLS},
{Name: "vfp", Feature: &ARM.HasVFP},
{Name: "vfpd32", Feature: &ARM.HasVFPD32},
{Name: "vfpv3", Feature: &ARM.HasVFPv3},
{Name: "vfpv3d16", Feature: &ARM.HasVFPv3D16},
{Name: "vfpv4", Feature: &ARM.HasVFPv4},
{Name: "half", Feature: &ARM.HasHALF},
{Name: "26bit", Feature: &ARM.Has26BIT},
{Name: "fastmul", Feature: &ARM.HasFASTMUL},
{Name: "fpa", Feature: &ARM.HasFPA},
{Name: "edsp", Feature: &ARM.HasEDSP},
{Name: "java", Feature: &ARM.HasJAVA},
{Name: "iwmmxt", Feature: &ARM.HasIWMMXT},
{Name: "crunch", Feature: &ARM.HasCRUNCH},
{Name: "neon", Feature: &ARM.HasNEON},
{Name: "idivt", Feature: &ARM.HasIDIVT},
{Name: "idiva", Feature: &ARM.HasIDIVA},
{Name: "lpae", Feature: &ARM.HasLPAE},
{Name: "evtstrm", Feature: &ARM.HasEVTSTRM},
{Name: "aes", Feature: &ARM.HasAES},
{Name: "crc32", Feature: &ARM.HasCRC32},
}
}
================================================
FILE: gossh/sys/cpu/cpu_arm64.go
================================================
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cpu
import "runtime"
// cacheLineSize is used to prevent false sharing of cache lines.
// We choose 128 because Apple Silicon, a.k.a. M1, has 128-byte cache line size.
// It doesn't cost much and is much more future-proof.
const cacheLineSize = 128
func initOptions() {
options = []option{
{Name: "fp", Feature: &ARM64.HasFP},
{Name: "asimd", Feature: &ARM64.HasASIMD},
{Name: "evstrm", Feature: &ARM64.HasEVTSTRM},
{Name: "aes", Feature: &ARM64.HasAES},
{Name: "fphp", Feature: &ARM64.HasFPHP},
{Name: "jscvt", Feature: &ARM64.HasJSCVT},
{Name: "lrcpc", Feature: &ARM64.HasLRCPC},
{Name: "pmull", Feature: &ARM64.HasPMULL},
{Name: "sha1", Feature: &ARM64.HasSHA1},
{Name: "sha2", Feature: &ARM64.HasSHA2},
{Name: "sha3", Feature: &ARM64.HasSHA3},
{Name: "sha512", Feature: &ARM64.HasSHA512},
{Name: "sm3", Feature: &ARM64.HasSM3},
{Name: "sm4", Feature: &ARM64.HasSM4},
{Name: "sve", Feature: &ARM64.HasSVE},
{Name: "crc32", Feature: &ARM64.HasCRC32},
{Name: "atomics", Feature: &ARM64.HasATOMICS},
{Name: "asimdhp", Feature: &ARM64.HasASIMDHP},
{Name: "cpuid", Feature: &ARM64.HasCPUID},
{Name: "asimrdm", Feature: &ARM64.HasASIMDRDM},
{Name: "fcma", Feature: &ARM64.HasFCMA},
{Name: "dcpop", Feature: &ARM64.HasDCPOP},
{Name: "asimddp", Feature: &ARM64.HasASIMDDP},
{Name: "asimdfhm", Feature: &ARM64.HasASIMDFHM},
}
}
func archInit() {
switch runtime.GOOS {
case "freebsd":
readARM64Registers()
case "linux", "netbsd", "openbsd":
doinit()
default:
// Many platforms don't seem to allow reading these registers.
setMinimalFeatures()
}
}
// setMinimalFeatures fakes the minimal ARM64 features expected by
// TestARM64minimalFeatures.
func setMinimalFeatures() {
ARM64.HasASIMD = true
ARM64.HasFP = true
}
func readARM64Registers() {
Initialized = true
parseARM64SystemRegisters(getisar0(), getisar1(), getpfr0())
}
func parseARM64SystemRegisters(isar0, isar1, pfr0 uint64) {
// ID_AA64ISAR0_EL1
switch extractBits(isar0, 4, 7) {
case 1:
ARM64.HasAES = true
case 2:
ARM64.HasAES = true
ARM64.HasPMULL = true
}
switch extractBits(isar0, 8, 11) {
case 1:
ARM64.HasSHA1 = true
}
switch extractBits(isar0, 12, 15) {
case 1:
ARM64.HasSHA2 = true
case 2:
ARM64.HasSHA2 = true
ARM64.HasSHA512 = true
}
switch extractBits(isar0, 16, 19) {
case 1:
ARM64.HasCRC32 = true
}
switch extractBits(isar0, 20, 23) {
case 2:
ARM64.HasATOMICS = true
}
switch extractBits(isar0, 28, 31) {
case 1:
ARM64.HasASIMDRDM = true
}
switch extractBits(isar0, 32, 35) {
case 1:
ARM64.HasSHA3 = true
}
switch extractBits(isar0, 36, 39) {
case 1:
ARM64.HasSM3 = true
}
switch extractBits(isar0, 40, 43) {
case 1:
ARM64.HasSM4 = true
}
switch extractBits(isar0, 44, 47) {
case 1:
ARM64.HasASIMDDP = true
}
// ID_AA64ISAR1_EL1
switch extractBits(isar1, 0, 3) {
case 1:
ARM64.HasDCPOP = true
}
switch extractBits(isar1, 12, 15) {
case 1:
ARM64.HasJSCVT = true
}
switch extractBits(isar1, 16, 19) {
case 1:
ARM64.HasFCMA = true
}
switch extractBits(isar1, 20, 23) {
case 1:
ARM64.HasLRCPC = true
}
// ID_AA64PFR0_EL1
switch extractBits(pfr0, 16, 19) {
case 0:
ARM64.HasFP = true
case 1:
ARM64.HasFP = true
ARM64.HasFPHP = true
}
switch extractBits(pfr0, 20, 23) {
case 0:
ARM64.HasASIMD = true
case 1:
ARM64.HasASIMD = true
ARM64.HasASIMDHP = true
}
switch extractBits(pfr0, 32, 35) {
case 1:
ARM64.HasSVE = true
}
}
func extractBits(data uint64, start, end uint) uint {
return (uint)(data>>start) & ((1 << (end - start + 1)) - 1)
}
================================================
FILE: gossh/sys/cpu/cpu_arm64.s
================================================
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gc
#include "textflag.h"
// func getisar0() uint64
TEXT ·getisar0(SB),NOSPLIT,$0-8
// get Instruction Set Attributes 0 into x0
// mrs x0, ID_AA64ISAR0_EL1 = d5380600
WORD $0xd5380600
MOVD R0, ret+0(FP)
RET
// func getisar1() uint64
TEXT ·getisar1(SB),NOSPLIT,$0-8
// get Instruction Set Attributes 1 into x0
// mrs x0, ID_AA64ISAR1_EL1 = d5380620
WORD $0xd5380620
MOVD R0, ret+0(FP)
RET
// func getpfr0() uint64
TEXT ·getpfr0(SB),NOSPLIT,$0-8
// get Processor Feature Register 0 into x0
// mrs x0, ID_AA64PFR0_EL1 = d5380400
WORD $0xd5380400
MOVD R0, ret+0(FP)
RET
================================================
FILE: gossh/sys/cpu/cpu_gc_arm64.go
================================================
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gc
package cpu
func getisar0() uint64
func getisar1() uint64
func getpfr0() uint64
================================================
FILE: gossh/sys/cpu/cpu_gc_s390x.go
================================================
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gc
package cpu
// haveAsmFunctions reports whether the other functions in this file can
// be safely called.
func haveAsmFunctions() bool { return true }
// The following feature detection functions are defined in cpu_s390x.s.
// They are likely to be expensive to call so the results should be cached.
func stfle() facilityList
func kmQuery() queryResult
func kmcQuery() queryResult
func kmctrQuery() queryResult
func kmaQuery() queryResult
func kimdQuery() queryResult
func klmdQuery() queryResult
================================================
FILE: gossh/sys/cpu/cpu_gc_x86.go
================================================
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build (386 || amd64 || amd64p32) && gc
package cpu
// cpuid is implemented in cpu_x86.s for gc compiler
// and in cpu_gccgo.c for gccgo.
func cpuid(eaxArg, ecxArg uint32) (eax, ebx, ecx, edx uint32)
// xgetbv with ecx = 0 is implemented in cpu_x86.s for gc compiler
// and in cpu_gccgo.c for gccgo.
func xgetbv() (eax, edx uint32)
================================================
FILE: gossh/sys/cpu/cpu_gccgo_arm64.go
================================================
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gccgo
package cpu
func getisar0() uint64 { return 0 }
func getisar1() uint64 { return 0 }
func getpfr0() uint64 { return 0 }
================================================
FILE: gossh/sys/cpu/cpu_gccgo_s390x.go
================================================
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gccgo
package cpu
// haveAsmFunctions reports whether the other functions in this file can
// be safely called.
func haveAsmFunctions() bool { return false }
// TODO(mundaym): the following feature detection functions are currently
// stubs. See https://golang.org/cl/162887 for how to fix this.
// They are likely to be expensive to call so the results should be cached.
func stfle() facilityList { panic("not implemented for gccgo") }
func kmQuery() queryResult { panic("not implemented for gccgo") }
func kmcQuery() queryResult { panic("not implemented for gccgo") }
func kmctrQuery() queryResult { panic("not implemented for gccgo") }
func kmaQuery() queryResult { panic("not implemented for gccgo") }
func kimdQuery() queryResult { panic("not implemented for gccgo") }
func klmdQuery() queryResult { panic("not implemented for gccgo") }
================================================
FILE: gossh/sys/cpu/cpu_gccgo_x86.c
================================================
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build (386 || amd64 || amd64p32) && gccgo
#include
#include
#include
// Need to wrap __get_cpuid_count because it's declared as static.
int
gccgoGetCpuidCount(uint32_t leaf, uint32_t subleaf,
uint32_t *eax, uint32_t *ebx,
uint32_t *ecx, uint32_t *edx)
{
return __get_cpuid_count(leaf, subleaf, eax, ebx, ecx, edx);
}
#pragma GCC diagnostic ignored "-Wunknown-pragmas"
#pragma GCC push_options
#pragma GCC target("xsave")
#pragma clang attribute push (__attribute__((target("xsave"))), apply_to=function)
// xgetbv reads the contents of an XCR (Extended Control Register)
// specified in the ECX register into registers EDX:EAX.
// Currently, the only supported value for XCR is 0.
void
gccgoXgetbv(uint32_t *eax, uint32_t *edx)
{
uint64_t v = _xgetbv(0);
*eax = v & 0xffffffff;
*edx = v >> 32;
}
#pragma clang attribute pop
#pragma GCC pop_options
================================================
FILE: gossh/sys/cpu/cpu_gccgo_x86.go
================================================
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build (386 || amd64 || amd64p32) && gccgo
package cpu
//extern gccgoGetCpuidCount
func gccgoGetCpuidCount(eaxArg, ecxArg uint32, eax, ebx, ecx, edx *uint32)
func cpuid(eaxArg, ecxArg uint32) (eax, ebx, ecx, edx uint32) {
var a, b, c, d uint32
gccgoGetCpuidCount(eaxArg, ecxArg, &a, &b, &c, &d)
return a, b, c, d
}
//extern gccgoXgetbv
func gccgoXgetbv(eax, edx *uint32)
func xgetbv() (eax, edx uint32) {
var a, d uint32
gccgoXgetbv(&a, &d)
return a, d
}
// gccgo doesn't build on Darwin, per:
// https://github.com/Homebrew/homebrew-core/blob/HEAD/Formula/gcc.rb#L76
func darwinSupportsAVX512() bool {
return false
}
================================================
FILE: gossh/sys/cpu/cpu_linux.go
================================================
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !386 && !amd64 && !amd64p32 && !arm64
package cpu
func archInit() {
if err := readHWCAP(); err != nil {
return
}
doinit()
Initialized = true
}
================================================
FILE: gossh/sys/cpu/cpu_linux_arm.go
================================================
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cpu
func doinit() {
ARM.HasSWP = isSet(hwCap, hwcap_SWP)
ARM.HasHALF = isSet(hwCap, hwcap_HALF)
ARM.HasTHUMB = isSet(hwCap, hwcap_THUMB)
ARM.Has26BIT = isSet(hwCap, hwcap_26BIT)
ARM.HasFASTMUL = isSet(hwCap, hwcap_FAST_MULT)
ARM.HasFPA = isSet(hwCap, hwcap_FPA)
ARM.HasVFP = isSet(hwCap, hwcap_VFP)
ARM.HasEDSP = isSet(hwCap, hwcap_EDSP)
ARM.HasJAVA = isSet(hwCap, hwcap_JAVA)
ARM.HasIWMMXT = isSet(hwCap, hwcap_IWMMXT)
ARM.HasCRUNCH = isSet(hwCap, hwcap_CRUNCH)
ARM.HasTHUMBEE = isSet(hwCap, hwcap_THUMBEE)
ARM.HasNEON = isSet(hwCap, hwcap_NEON)
ARM.HasVFPv3 = isSet(hwCap, hwcap_VFPv3)
ARM.HasVFPv3D16 = isSet(hwCap, hwcap_VFPv3D16)
ARM.HasTLS = isSet(hwCap, hwcap_TLS)
ARM.HasVFPv4 = isSet(hwCap, hwcap_VFPv4)
ARM.HasIDIVA = isSet(hwCap, hwcap_IDIVA)
ARM.HasIDIVT = isSet(hwCap, hwcap_IDIVT)
ARM.HasVFPD32 = isSet(hwCap, hwcap_VFPD32)
ARM.HasLPAE = isSet(hwCap, hwcap_LPAE)
ARM.HasEVTSTRM = isSet(hwCap, hwcap_EVTSTRM)
ARM.HasAES = isSet(hwCap2, hwcap2_AES)
ARM.HasPMULL = isSet(hwCap2, hwcap2_PMULL)
ARM.HasSHA1 = isSet(hwCap2, hwcap2_SHA1)
ARM.HasSHA2 = isSet(hwCap2, hwcap2_SHA2)
ARM.HasCRC32 = isSet(hwCap2, hwcap2_CRC32)
}
func isSet(hwc uint, value uint) bool {
return hwc&value != 0
}
================================================
FILE: gossh/sys/cpu/cpu_linux_arm64.go
================================================
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cpu
import (
"strings"
"syscall"
)
// HWCAP/HWCAP2 bits. These are exposed by Linux.
const (
hwcap_FP = 1 << 0
hwcap_ASIMD = 1 << 1
hwcap_EVTSTRM = 1 << 2
hwcap_AES = 1 << 3
hwcap_PMULL = 1 << 4
hwcap_SHA1 = 1 << 5
hwcap_SHA2 = 1 << 6
hwcap_CRC32 = 1 << 7
hwcap_ATOMICS = 1 << 8
hwcap_FPHP = 1 << 9
hwcap_ASIMDHP = 1 << 10
hwcap_CPUID = 1 << 11
hwcap_ASIMDRDM = 1 << 12
hwcap_JSCVT = 1 << 13
hwcap_FCMA = 1 << 14
hwcap_LRCPC = 1 << 15
hwcap_DCPOP = 1 << 16
hwcap_SHA3 = 1 << 17
hwcap_SM3 = 1 << 18
hwcap_SM4 = 1 << 19
hwcap_ASIMDDP = 1 << 20
hwcap_SHA512 = 1 << 21
hwcap_SVE = 1 << 22
hwcap_ASIMDFHM = 1 << 23
)
// linuxKernelCanEmulateCPUID reports whether we're running
// on Linux 4.11+. Ideally we'd like to ask the question about
// whether the current kernel contains
// https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/commit/?id=77c97b4ee21290f5f083173d957843b615abbff2
// but the version number will have to do.
func linuxKernelCanEmulateCPUID() bool {
var un syscall.Utsname
syscall.Uname(&un)
var sb strings.Builder
for _, b := range un.Release[:] {
if b == 0 {
break
}
sb.WriteByte(byte(b))
}
major, minor, _, ok := parseRelease(sb.String())
return ok && (major > 4 || major == 4 && minor >= 11)
}
func doinit() {
if err := readHWCAP(); err != nil {
// We failed to read /proc/self/auxv. This can happen if the binary has
// been given extra capabilities(7) with /bin/setcap.
//
// When this happens, we have two options. If the Linux kernel is new
// enough (4.11+), we can read the arm64 registers directly which'll
// trap into the kernel and then return back to userspace.
//
// But on older kernels, such as Linux 4.4.180 as used on many Synology
// devices, calling readARM64Registers (specifically getisar0) will
// cause a SIGILL and we'll die. So for older kernels, parse /proc/cpuinfo
// instead.
//
// See golang/go#57336.
if linuxKernelCanEmulateCPUID() {
readARM64Registers()
} else {
readLinuxProcCPUInfo()
}
return
}
// HWCAP feature bits
ARM64.HasFP = isSet(hwCap, hwcap_FP)
ARM64.HasASIMD = isSet(hwCap, hwcap_ASIMD)
ARM64.HasEVTSTRM = isSet(hwCap, hwcap_EVTSTRM)
ARM64.HasAES = isSet(hwCap, hwcap_AES)
ARM64.HasPMULL = isSet(hwCap, hwcap_PMULL)
ARM64.HasSHA1 = isSet(hwCap, hwcap_SHA1)
ARM64.HasSHA2 = isSet(hwCap, hwcap_SHA2)
ARM64.HasCRC32 = isSet(hwCap, hwcap_CRC32)
ARM64.HasATOMICS = isSet(hwCap, hwcap_ATOMICS)
ARM64.HasFPHP = isSet(hwCap, hwcap_FPHP)
ARM64.HasASIMDHP = isSet(hwCap, hwcap_ASIMDHP)
ARM64.HasCPUID = isSet(hwCap, hwcap_CPUID)
ARM64.HasASIMDRDM = isSet(hwCap, hwcap_ASIMDRDM)
ARM64.HasJSCVT = isSet(hwCap, hwcap_JSCVT)
ARM64.HasFCMA = isSet(hwCap, hwcap_FCMA)
ARM64.HasLRCPC = isSet(hwCap, hwcap_LRCPC)
ARM64.HasDCPOP = isSet(hwCap, hwcap_DCPOP)
ARM64.HasSHA3 = isSet(hwCap, hwcap_SHA3)
ARM64.HasSM3 = isSet(hwCap, hwcap_SM3)
ARM64.HasSM4 = isSet(hwCap, hwcap_SM4)
ARM64.HasASIMDDP = isSet(hwCap, hwcap_ASIMDDP)
ARM64.HasSHA512 = isSet(hwCap, hwcap_SHA512)
ARM64.HasSVE = isSet(hwCap, hwcap_SVE)
ARM64.HasASIMDFHM = isSet(hwCap, hwcap_ASIMDFHM)
}
func isSet(hwc uint, value uint) bool {
return hwc&value != 0
}
================================================
FILE: gossh/sys/cpu/cpu_linux_mips64x.go
================================================
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build linux && (mips64 || mips64le)
package cpu
// HWCAP bits. These are exposed by the Linux kernel 5.4.
const (
// CPU features
hwcap_MIPS_MSA = 1 << 1
)
func doinit() {
// HWCAP feature bits
MIPS64X.HasMSA = isSet(hwCap, hwcap_MIPS_MSA)
}
func isSet(hwc uint, value uint) bool {
return hwc&value != 0
}
================================================
FILE: gossh/sys/cpu/cpu_linux_noinit.go
================================================
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build linux && !arm && !arm64 && !mips64 && !mips64le && !ppc64 && !ppc64le && !s390x
package cpu
func doinit() {}
================================================
FILE: gossh/sys/cpu/cpu_linux_ppc64x.go
================================================
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build linux && (ppc64 || ppc64le)
package cpu
// HWCAP/HWCAP2 bits. These are exposed by the kernel.
const (
// ISA Level
_PPC_FEATURE2_ARCH_2_07 = 0x80000000
_PPC_FEATURE2_ARCH_3_00 = 0x00800000
// CPU features
_PPC_FEATURE2_DARN = 0x00200000
_PPC_FEATURE2_SCV = 0x00100000
)
func doinit() {
// HWCAP2 feature bits
PPC64.IsPOWER8 = isSet(hwCap2, _PPC_FEATURE2_ARCH_2_07)
PPC64.IsPOWER9 = isSet(hwCap2, _PPC_FEATURE2_ARCH_3_00)
PPC64.HasDARN = isSet(hwCap2, _PPC_FEATURE2_DARN)
PPC64.HasSCV = isSet(hwCap2, _PPC_FEATURE2_SCV)
}
func isSet(hwc uint, value uint) bool {
return hwc&value != 0
}
================================================
FILE: gossh/sys/cpu/cpu_linux_s390x.go
================================================
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cpu
const (
// bit mask values from /usr/include/bits/hwcap.h
hwcap_ZARCH = 2
hwcap_STFLE = 4
hwcap_MSA = 8
hwcap_LDISP = 16
hwcap_EIMM = 32
hwcap_DFP = 64
hwcap_ETF3EH = 256
hwcap_VX = 2048
hwcap_VXE = 8192
)
func initS390Xbase() {
// test HWCAP bit vector
has := func(featureMask uint) bool {
return hwCap&featureMask == featureMask
}
// mandatory
S390X.HasZARCH = has(hwcap_ZARCH)
// optional
S390X.HasSTFLE = has(hwcap_STFLE)
S390X.HasLDISP = has(hwcap_LDISP)
S390X.HasEIMM = has(hwcap_EIMM)
S390X.HasETF3EH = has(hwcap_ETF3EH)
S390X.HasDFP = has(hwcap_DFP)
S390X.HasMSA = has(hwcap_MSA)
S390X.HasVX = has(hwcap_VX)
if S390X.HasVX {
S390X.HasVXE = has(hwcap_VXE)
}
}
================================================
FILE: gossh/sys/cpu/cpu_loong64.go
================================================
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build loong64
package cpu
const cacheLineSize = 64
func initOptions() {
}
================================================
FILE: gossh/sys/cpu/cpu_mips64x.go
================================================
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build mips64 || mips64le
package cpu
const cacheLineSize = 32
func initOptions() {
options = []option{
{Name: "msa", Feature: &MIPS64X.HasMSA},
}
}
================================================
FILE: gossh/sys/cpu/cpu_mipsx.go
================================================
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build mips || mipsle
package cpu
const cacheLineSize = 32
func initOptions() {}
================================================
FILE: gossh/sys/cpu/cpu_netbsd_arm64.go
================================================
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cpu
import (
"syscall"
"unsafe"
)
// Minimal copy of functionality from x/sys/unix so the cpu package can call
// sysctl without depending on x/sys/unix.
const (
_CTL_QUERY = -2
_SYSCTL_VERS_1 = 0x1000000
)
var _zero uintptr
func sysctl(mib []int32, old *byte, oldlen *uintptr, new *byte, newlen uintptr) (err error) {
var _p0 unsafe.Pointer
if len(mib) > 0 {
_p0 = unsafe.Pointer(&mib[0])
} else {
_p0 = unsafe.Pointer(&_zero)
}
_, _, errno := syscall.Syscall6(
syscall.SYS___SYSCTL,
uintptr(_p0),
uintptr(len(mib)),
uintptr(unsafe.Pointer(old)),
uintptr(unsafe.Pointer(oldlen)),
uintptr(unsafe.Pointer(new)),
uintptr(newlen))
if errno != 0 {
return errno
}
return nil
}
type sysctlNode struct {
Flags uint32
Num int32
Name [32]int8
Ver uint32
__rsvd uint32
Un [16]byte
_sysctl_size [8]byte
_sysctl_func [8]byte
_sysctl_parent [8]byte
_sysctl_desc [8]byte
}
func sysctlNodes(mib []int32) ([]sysctlNode, error) {
var olen uintptr
// Get a list of all sysctl nodes below the given MIB by performing
// a sysctl for the given MIB with CTL_QUERY appended.
mib = append(mib, _CTL_QUERY)
qnode := sysctlNode{Flags: _SYSCTL_VERS_1}
qp := (*byte)(unsafe.Pointer(&qnode))
sz := unsafe.Sizeof(qnode)
if err := sysctl(mib, nil, &olen, qp, sz); err != nil {
return nil, err
}
// Now that we know the size, get the actual nodes.
nodes := make([]sysctlNode, olen/sz)
np := (*byte)(unsafe.Pointer(&nodes[0]))
if err := sysctl(mib, np, &olen, qp, sz); err != nil {
return nil, err
}
return nodes, nil
}
func nametomib(name string) ([]int32, error) {
// Split name into components.
var parts []string
last := 0
for i := 0; i < len(name); i++ {
if name[i] == '.' {
parts = append(parts, name[last:i])
last = i + 1
}
}
parts = append(parts, name[last:])
mib := []int32{}
// Discover the nodes and construct the MIB OID.
for partno, part := range parts {
nodes, err := sysctlNodes(mib)
if err != nil {
return nil, err
}
for _, node := range nodes {
n := make([]byte, 0)
for i := range node.Name {
if node.Name[i] != 0 {
n = append(n, byte(node.Name[i]))
}
}
if string(n) == part {
mib = append(mib, int32(node.Num))
break
}
}
if len(mib) != partno+1 {
return nil, err
}
}
return mib, nil
}
// aarch64SysctlCPUID is struct aarch64_sysctl_cpu_id from NetBSD's
type aarch64SysctlCPUID struct {
midr uint64 /* Main ID Register */
revidr uint64 /* Revision ID Register */
mpidr uint64 /* Multiprocessor Affinity Register */
aa64dfr0 uint64 /* A64 Debug Feature Register 0 */
aa64dfr1 uint64 /* A64 Debug Feature Register 1 */
aa64isar0 uint64 /* A64 Instruction Set Attribute Register 0 */
aa64isar1 uint64 /* A64 Instruction Set Attribute Register 1 */
aa64mmfr0 uint64 /* A64 Memory Model Feature Register 0 */
aa64mmfr1 uint64 /* A64 Memory Model Feature Register 1 */
aa64mmfr2 uint64 /* A64 Memory Model Feature Register 2 */
aa64pfr0 uint64 /* A64 Processor Feature Register 0 */
aa64pfr1 uint64 /* A64 Processor Feature Register 1 */
aa64zfr0 uint64 /* A64 SVE Feature ID Register 0 */
mvfr0 uint32 /* Media and VFP Feature Register 0 */
mvfr1 uint32 /* Media and VFP Feature Register 1 */
mvfr2 uint32 /* Media and VFP Feature Register 2 */
pad uint32
clidr uint64 /* Cache Level ID Register */
ctr uint64 /* Cache Type Register */
}
func sysctlCPUID(name string) (*aarch64SysctlCPUID, error) {
mib, err := nametomib(name)
if err != nil {
return nil, err
}
out := aarch64SysctlCPUID{}
n := unsafe.Sizeof(out)
_, _, errno := syscall.Syscall6(
syscall.SYS___SYSCTL,
uintptr(unsafe.Pointer(&mib[0])),
uintptr(len(mib)),
uintptr(unsafe.Pointer(&out)),
uintptr(unsafe.Pointer(&n)),
uintptr(0),
uintptr(0))
if errno != 0 {
return nil, errno
}
return &out, nil
}
func doinit() {
cpuid, err := sysctlCPUID("machdep.cpu0.cpu_id")
if err != nil {
setMinimalFeatures()
return
}
parseARM64SystemRegisters(cpuid.aa64isar0, cpuid.aa64isar1, cpuid.aa64pfr0)
Initialized = true
}
================================================
FILE: gossh/sys/cpu/cpu_openbsd_arm64.go
================================================
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cpu
import (
"syscall"
"unsafe"
)
// Minimal copy of functionality from x/sys/unix so the cpu package can call
// sysctl without depending on x/sys/unix.
const (
// From OpenBSD's sys/sysctl.h.
_CTL_MACHDEP = 7
// From OpenBSD's machine/cpu.h.
_CPU_ID_AA64ISAR0 = 2
_CPU_ID_AA64ISAR1 = 3
)
// Implemented in the runtime package (runtime/sys_openbsd3.go)
func syscall_syscall6(fn, a1, a2, a3, a4, a5, a6 uintptr) (r1, r2 uintptr, err syscall.Errno)
//go:linkname syscall_syscall6 syscall.syscall6
func sysctl(mib []uint32, old *byte, oldlen *uintptr, new *byte, newlen uintptr) (err error) {
_, _, errno := syscall_syscall6(libc_sysctl_trampoline_addr, uintptr(unsafe.Pointer(&mib[0])), uintptr(len(mib)), uintptr(unsafe.Pointer(old)), uintptr(unsafe.Pointer(oldlen)), uintptr(unsafe.Pointer(new)), uintptr(newlen))
if errno != 0 {
return errno
}
return nil
}
var libc_sysctl_trampoline_addr uintptr
//go:cgo_import_dynamic libc_sysctl sysctl "libc.so"
func sysctlUint64(mib []uint32) (uint64, bool) {
var out uint64
nout := unsafe.Sizeof(out)
if err := sysctl(mib, (*byte)(unsafe.Pointer(&out)), &nout, nil, 0); err != nil {
return 0, false
}
return out, true
}
func doinit() {
setMinimalFeatures()
// Get ID_AA64ISAR0 and ID_AA64ISAR1 from sysctl.
isar0, ok := sysctlUint64([]uint32{_CTL_MACHDEP, _CPU_ID_AA64ISAR0})
if !ok {
return
}
isar1, ok := sysctlUint64([]uint32{_CTL_MACHDEP, _CPU_ID_AA64ISAR1})
if !ok {
return
}
parseARM64SystemRegisters(isar0, isar1, 0)
Initialized = true
}
================================================
FILE: gossh/sys/cpu/cpu_openbsd_arm64.s
================================================
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
#include "textflag.h"
TEXT libc_sysctl_trampoline<>(SB),NOSPLIT,$0-0
JMP libc_sysctl(SB)
GLOBL ·libc_sysctl_trampoline_addr(SB), RODATA, $8
DATA ·libc_sysctl_trampoline_addr(SB)/8, $libc_sysctl_trampoline<>(SB)
================================================
FILE: gossh/sys/cpu/cpu_other_arm.go
================================================
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !linux && arm
package cpu
func archInit() {}
================================================
FILE: gossh/sys/cpu/cpu_other_arm64.go
================================================
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !linux && !netbsd && !openbsd && arm64
package cpu
func doinit() {}
================================================
FILE: gossh/sys/cpu/cpu_other_mips64x.go
================================================
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !linux && (mips64 || mips64le)
package cpu
func archInit() {
Initialized = true
}
================================================
FILE: gossh/sys/cpu/cpu_other_ppc64x.go
================================================
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !aix && !linux && (ppc64 || ppc64le)
package cpu
func archInit() {
PPC64.IsPOWER8 = true
Initialized = true
}
================================================
FILE: gossh/sys/cpu/cpu_other_riscv64.go
================================================
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !linux && riscv64
package cpu
func archInit() {
Initialized = true
}
================================================
FILE: gossh/sys/cpu/cpu_ppc64x.go
================================================
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build ppc64 || ppc64le
package cpu
const cacheLineSize = 128
func initOptions() {
options = []option{
{Name: "darn", Feature: &PPC64.HasDARN},
{Name: "scv", Feature: &PPC64.HasSCV},
}
}
================================================
FILE: gossh/sys/cpu/cpu_riscv64.go
================================================
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build riscv64
package cpu
const cacheLineSize = 64
func initOptions() {}
================================================
FILE: gossh/sys/cpu/cpu_s390x.go
================================================
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cpu
const cacheLineSize = 256
func initOptions() {
options = []option{
{Name: "zarch", Feature: &S390X.HasZARCH, Required: true},
{Name: "stfle", Feature: &S390X.HasSTFLE, Required: true},
{Name: "ldisp", Feature: &S390X.HasLDISP, Required: true},
{Name: "eimm", Feature: &S390X.HasEIMM, Required: true},
{Name: "dfp", Feature: &S390X.HasDFP},
{Name: "etf3eh", Feature: &S390X.HasETF3EH},
{Name: "msa", Feature: &S390X.HasMSA},
{Name: "aes", Feature: &S390X.HasAES},
{Name: "aescbc", Feature: &S390X.HasAESCBC},
{Name: "aesctr", Feature: &S390X.HasAESCTR},
{Name: "aesgcm", Feature: &S390X.HasAESGCM},
{Name: "ghash", Feature: &S390X.HasGHASH},
{Name: "sha1", Feature: &S390X.HasSHA1},
{Name: "sha256", Feature: &S390X.HasSHA256},
{Name: "sha3", Feature: &S390X.HasSHA3},
{Name: "sha512", Feature: &S390X.HasSHA512},
{Name: "vx", Feature: &S390X.HasVX},
{Name: "vxe", Feature: &S390X.HasVXE},
}
}
// bitIsSet reports whether the bit at index is set. The bit index
// is in big endian order, so bit index 0 is the leftmost bit.
func bitIsSet(bits []uint64, index uint) bool {
return bits[index/64]&((1<<63)>>(index%64)) != 0
}
// facility is a bit index for the named facility.
type facility uint8
const (
// mandatory facilities
zarch facility = 1 // z architecture mode is active
stflef facility = 7 // store-facility-list-extended
ldisp facility = 18 // long-displacement
eimm facility = 21 // extended-immediate
// miscellaneous facilities
dfp facility = 42 // decimal-floating-point
etf3eh facility = 30 // extended-translation 3 enhancement
// cryptography facilities
msa facility = 17 // message-security-assist
msa3 facility = 76 // message-security-assist extension 3
msa4 facility = 77 // message-security-assist extension 4
msa5 facility = 57 // message-security-assist extension 5
msa8 facility = 146 // message-security-assist extension 8
msa9 facility = 155 // message-security-assist extension 9
// vector facilities
vx facility = 129 // vector facility
vxe facility = 135 // vector-enhancements 1
vxe2 facility = 148 // vector-enhancements 2
)
// facilityList contains the result of an STFLE call.
// Bits are numbered in big endian order so the
// leftmost bit (the MSB) is at index 0.
type facilityList struct {
bits [4]uint64
}
// Has reports whether the given facilities are present.
func (s *facilityList) Has(fs ...facility) bool {
if len(fs) == 0 {
panic("no facility bits provided")
}
for _, f := range fs {
if !bitIsSet(s.bits[:], uint(f)) {
return false
}
}
return true
}
// function is the code for the named cryptographic function.
type function uint8
const (
// KM{,A,C,CTR} function codes
aes128 function = 18 // AES-128
aes192 function = 19 // AES-192
aes256 function = 20 // AES-256
// K{I,L}MD function codes
sha1 function = 1 // SHA-1
sha256 function = 2 // SHA-256
sha512 function = 3 // SHA-512
sha3_224 function = 32 // SHA3-224
sha3_256 function = 33 // SHA3-256
sha3_384 function = 34 // SHA3-384
sha3_512 function = 35 // SHA3-512
shake128 function = 36 // SHAKE-128
shake256 function = 37 // SHAKE-256
// KLMD function codes
ghash function = 65 // GHASH
)
// queryResult contains the result of a Query function
// call. Bits are numbered in big endian order so the
// leftmost bit (the MSB) is at index 0.
type queryResult struct {
bits [2]uint64
}
// Has reports whether the given functions are present.
func (q *queryResult) Has(fns ...function) bool {
if len(fns) == 0 {
panic("no function codes provided")
}
for _, f := range fns {
if !bitIsSet(q.bits[:], uint(f)) {
return false
}
}
return true
}
func doinit() {
initS390Xbase()
// We need implementations of stfle, km and so on
// to detect cryptographic features.
if !haveAsmFunctions() {
return
}
// optional cryptographic functions
if S390X.HasMSA {
aes := []function{aes128, aes192, aes256}
// cipher message
km, kmc := kmQuery(), kmcQuery()
S390X.HasAES = km.Has(aes...)
S390X.HasAESCBC = kmc.Has(aes...)
if S390X.HasSTFLE {
facilities := stfle()
if facilities.Has(msa4) {
kmctr := kmctrQuery()
S390X.HasAESCTR = kmctr.Has(aes...)
}
if facilities.Has(msa8) {
kma := kmaQuery()
S390X.HasAESGCM = kma.Has(aes...)
}
}
// compute message digest
kimd := kimdQuery() // intermediate (no padding)
klmd := klmdQuery() // last (padding)
S390X.HasSHA1 = kimd.Has(sha1) && klmd.Has(sha1)
S390X.HasSHA256 = kimd.Has(sha256) && klmd.Has(sha256)
S390X.HasSHA512 = kimd.Has(sha512) && klmd.Has(sha512)
S390X.HasGHASH = kimd.Has(ghash) // KLMD-GHASH does not exist
sha3 := []function{
sha3_224, sha3_256, sha3_384, sha3_512,
shake128, shake256,
}
S390X.HasSHA3 = kimd.Has(sha3...) && klmd.Has(sha3...)
}
}
================================================
FILE: gossh/sys/cpu/cpu_s390x.s
================================================
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gc
#include "textflag.h"
// func stfle() facilityList
TEXT ·stfle(SB), NOSPLIT|NOFRAME, $0-32
MOVD $ret+0(FP), R1
MOVD $3, R0 // last doubleword index to store
XC $32, (R1), (R1) // clear 4 doublewords (32 bytes)
WORD $0xb2b01000 // store facility list extended (STFLE)
RET
// func kmQuery() queryResult
TEXT ·kmQuery(SB), NOSPLIT|NOFRAME, $0-16
MOVD $0, R0 // set function code to 0 (KM-Query)
MOVD $ret+0(FP), R1 // address of 16-byte return value
WORD $0xB92E0024 // cipher message (KM)
RET
// func kmcQuery() queryResult
TEXT ·kmcQuery(SB), NOSPLIT|NOFRAME, $0-16
MOVD $0, R0 // set function code to 0 (KMC-Query)
MOVD $ret+0(FP), R1 // address of 16-byte return value
WORD $0xB92F0024 // cipher message with chaining (KMC)
RET
// func kmctrQuery() queryResult
TEXT ·kmctrQuery(SB), NOSPLIT|NOFRAME, $0-16
MOVD $0, R0 // set function code to 0 (KMCTR-Query)
MOVD $ret+0(FP), R1 // address of 16-byte return value
WORD $0xB92D4024 // cipher message with counter (KMCTR)
RET
// func kmaQuery() queryResult
TEXT ·kmaQuery(SB), NOSPLIT|NOFRAME, $0-16
MOVD $0, R0 // set function code to 0 (KMA-Query)
MOVD $ret+0(FP), R1 // address of 16-byte return value
WORD $0xb9296024 // cipher message with authentication (KMA)
RET
// func kimdQuery() queryResult
TEXT ·kimdQuery(SB), NOSPLIT|NOFRAME, $0-16
MOVD $0, R0 // set function code to 0 (KIMD-Query)
MOVD $ret+0(FP), R1 // address of 16-byte return value
WORD $0xB93E0024 // compute intermediate message digest (KIMD)
RET
// func klmdQuery() queryResult
TEXT ·klmdQuery(SB), NOSPLIT|NOFRAME, $0-16
MOVD $0, R0 // set function code to 0 (KLMD-Query)
MOVD $ret+0(FP), R1 // address of 16-byte return value
WORD $0xB93F0024 // compute last message digest (KLMD)
RET
================================================
FILE: gossh/sys/cpu/cpu_wasm.go
================================================
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build wasm
package cpu
// We're compiling the cpu package for an unknown (software-abstracted) CPU.
// Make CacheLinePad an empty struct and hope that the usual struct alignment
// rules are good enough.
const cacheLineSize = 0
func initOptions() {}
func archInit() {}
================================================
FILE: gossh/sys/cpu/cpu_x86.go
================================================
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build 386 || amd64 || amd64p32
package cpu
import "runtime"
const cacheLineSize = 64
func initOptions() {
options = []option{
{Name: "adx", Feature: &X86.HasADX},
{Name: "aes", Feature: &X86.HasAES},
{Name: "avx", Feature: &X86.HasAVX},
{Name: "avx2", Feature: &X86.HasAVX2},
{Name: "avx512", Feature: &X86.HasAVX512},
{Name: "avx512f", Feature: &X86.HasAVX512F},
{Name: "avx512cd", Feature: &X86.HasAVX512CD},
{Name: "avx512er", Feature: &X86.HasAVX512ER},
{Name: "avx512pf", Feature: &X86.HasAVX512PF},
{Name: "avx512vl", Feature: &X86.HasAVX512VL},
{Name: "avx512bw", Feature: &X86.HasAVX512BW},
{Name: "avx512dq", Feature: &X86.HasAVX512DQ},
{Name: "avx512ifma", Feature: &X86.HasAVX512IFMA},
{Name: "avx512vbmi", Feature: &X86.HasAVX512VBMI},
{Name: "avx512vnniw", Feature: &X86.HasAVX5124VNNIW},
{Name: "avx5124fmaps", Feature: &X86.HasAVX5124FMAPS},
{Name: "avx512vpopcntdq", Feature: &X86.HasAVX512VPOPCNTDQ},
{Name: "avx512vpclmulqdq", Feature: &X86.HasAVX512VPCLMULQDQ},
{Name: "avx512vnni", Feature: &X86.HasAVX512VNNI},
{Name: "avx512gfni", Feature: &X86.HasAVX512GFNI},
{Name: "avx512vaes", Feature: &X86.HasAVX512VAES},
{Name: "avx512vbmi2", Feature: &X86.HasAVX512VBMI2},
{Name: "avx512bitalg", Feature: &X86.HasAVX512BITALG},
{Name: "avx512bf16", Feature: &X86.HasAVX512BF16},
{Name: "amxtile", Feature: &X86.HasAMXTile},
{Name: "amxint8", Feature: &X86.HasAMXInt8},
{Name: "amxbf16", Feature: &X86.HasAMXBF16},
{Name: "bmi1", Feature: &X86.HasBMI1},
{Name: "bmi2", Feature: &X86.HasBMI2},
{Name: "cx16", Feature: &X86.HasCX16},
{Name: "erms", Feature: &X86.HasERMS},
{Name: "fma", Feature: &X86.HasFMA},
{Name: "osxsave", Feature: &X86.HasOSXSAVE},
{Name: "pclmulqdq", Feature: &X86.HasPCLMULQDQ},
{Name: "popcnt", Feature: &X86.HasPOPCNT},
{Name: "rdrand", Feature: &X86.HasRDRAND},
{Name: "rdseed", Feature: &X86.HasRDSEED},
{Name: "sse3", Feature: &X86.HasSSE3},
{Name: "sse41", Feature: &X86.HasSSE41},
{Name: "sse42", Feature: &X86.HasSSE42},
{Name: "ssse3", Feature: &X86.HasSSSE3},
// These capabilities should always be enabled on amd64:
{Name: "sse2", Feature: &X86.HasSSE2, Required: runtime.GOARCH == "amd64"},
}
}
func archInit() {
Initialized = true
maxID, _, _, _ := cpuid(0, 0)
if maxID < 1 {
return
}
_, _, ecx1, edx1 := cpuid(1, 0)
X86.HasSSE2 = isSet(26, edx1)
X86.HasSSE3 = isSet(0, ecx1)
X86.HasPCLMULQDQ = isSet(1, ecx1)
X86.HasSSSE3 = isSet(9, ecx1)
X86.HasFMA = isSet(12, ecx1)
X86.HasCX16 = isSet(13, ecx1)
X86.HasSSE41 = isSet(19, ecx1)
X86.HasSSE42 = isSet(20, ecx1)
X86.HasPOPCNT = isSet(23, ecx1)
X86.HasAES = isSet(25, ecx1)
X86.HasOSXSAVE = isSet(27, ecx1)
X86.HasRDRAND = isSet(30, ecx1)
var osSupportsAVX, osSupportsAVX512 bool
// For XGETBV, OSXSAVE bit is required and sufficient.
if X86.HasOSXSAVE {
eax, _ := xgetbv()
// Check if XMM and YMM registers have OS support.
osSupportsAVX = isSet(1, eax) && isSet(2, eax)
if runtime.GOOS == "darwin" {
// Darwin doesn't save/restore AVX-512 mask registers correctly across signal handlers.
// Since users can't rely on mask register contents, let's not advertise AVX-512 support.
// See issue 49233.
osSupportsAVX512 = false
} else {
// Check if OPMASK and ZMM registers have OS support.
osSupportsAVX512 = osSupportsAVX && isSet(5, eax) && isSet(6, eax) && isSet(7, eax)
}
}
X86.HasAVX = isSet(28, ecx1) && osSupportsAVX
if maxID < 7 {
return
}
_, ebx7, ecx7, edx7 := cpuid(7, 0)
X86.HasBMI1 = isSet(3, ebx7)
X86.HasAVX2 = isSet(5, ebx7) && osSupportsAVX
X86.HasBMI2 = isSet(8, ebx7)
X86.HasERMS = isSet(9, ebx7)
X86.HasRDSEED = isSet(18, ebx7)
X86.HasADX = isSet(19, ebx7)
X86.HasAVX512 = isSet(16, ebx7) && osSupportsAVX512 // Because avx-512 foundation is the core required extension
if X86.HasAVX512 {
X86.HasAVX512F = true
X86.HasAVX512CD = isSet(28, ebx7)
X86.HasAVX512ER = isSet(27, ebx7)
X86.HasAVX512PF = isSet(26, ebx7)
X86.HasAVX512VL = isSet(31, ebx7)
X86.HasAVX512BW = isSet(30, ebx7)
X86.HasAVX512DQ = isSet(17, ebx7)
X86.HasAVX512IFMA = isSet(21, ebx7)
X86.HasAVX512VBMI = isSet(1, ecx7)
X86.HasAVX5124VNNIW = isSet(2, edx7)
X86.HasAVX5124FMAPS = isSet(3, edx7)
X86.HasAVX512VPOPCNTDQ = isSet(14, ecx7)
X86.HasAVX512VPCLMULQDQ = isSet(10, ecx7)
X86.HasAVX512VNNI = isSet(11, ecx7)
X86.HasAVX512GFNI = isSet(8, ecx7)
X86.HasAVX512VAES = isSet(9, ecx7)
X86.HasAVX512VBMI2 = isSet(6, ecx7)
X86.HasAVX512BITALG = isSet(12, ecx7)
eax71, _, _, _ := cpuid(7, 1)
X86.HasAVX512BF16 = isSet(5, eax71)
}
X86.HasAMXTile = isSet(24, edx7)
X86.HasAMXInt8 = isSet(25, edx7)
X86.HasAMXBF16 = isSet(22, edx7)
}
func isSet(bitpos uint, value uint32) bool {
return value&(1<> 63))
)
// For those platforms don't have a 'cpuid' equivalent we use HWCAP/HWCAP2
// These are initialized in cpu_$GOARCH.go
// and should not be changed after they are initialized.
var hwCap uint
var hwCap2 uint
func readHWCAP() error {
// For Go 1.21+, get auxv from the Go runtime.
if a := getAuxv(); len(a) > 0 {
for len(a) >= 2 {
tag, val := a[0], uint(a[1])
a = a[2:]
switch tag {
case _AT_HWCAP:
hwCap = val
case _AT_HWCAP2:
hwCap2 = val
}
}
return nil
}
buf, err := os.ReadFile(procAuxv)
if err != nil {
// e.g. on android /proc/self/auxv is not accessible, so silently
// ignore the error and leave Initialized = false. On some
// architectures (e.g. arm64) doinit() implements a fallback
// readout and will set Initialized = true again.
return err
}
bo := hostByteOrder()
for len(buf) >= 2*(uintSize/8) {
var tag, val uint
switch uintSize {
case 32:
tag = uint(bo.Uint32(buf[0:]))
val = uint(bo.Uint32(buf[4:]))
buf = buf[8:]
case 64:
tag = uint(bo.Uint64(buf[0:]))
val = uint(bo.Uint64(buf[8:]))
buf = buf[16:]
}
switch tag {
case _AT_HWCAP:
hwCap = val
case _AT_HWCAP2:
hwCap2 = val
}
}
return nil
}
================================================
FILE: gossh/sys/cpu/parse.go
================================================
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cpu
import "strconv"
// parseRelease parses a dot-separated version number. It follows the semver
// syntax, but allows the minor and patch versions to be elided.
//
// This is a copy of the Go runtime's parseRelease from
// https://golang.org/cl/209597.
func parseRelease(rel string) (major, minor, patch int, ok bool) {
// Strip anything after a dash or plus.
for i := 0; i < len(rel); i++ {
if rel[i] == '-' || rel[i] == '+' {
rel = rel[:i]
break
}
}
next := func() (int, bool) {
for i := 0; i < len(rel); i++ {
if rel[i] == '.' {
ver, err := strconv.Atoi(rel[:i])
rel = rel[i+1:]
return ver, err == nil
}
}
ver, err := strconv.Atoi(rel)
rel = ""
return ver, err == nil
}
if major, ok = next(); !ok || rel == "" {
return
}
if minor, ok = next(); !ok || rel == "" {
return
}
patch, ok = next()
return
}
================================================
FILE: gossh/sys/cpu/proc_cpuinfo_linux.go
================================================
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build linux && arm64
package cpu
import (
"errors"
"io"
"os"
"strings"
)
func readLinuxProcCPUInfo() error {
f, err := os.Open("/proc/cpuinfo")
if err != nil {
return err
}
defer f.Close()
var buf [1 << 10]byte // enough for first CPU
n, err := io.ReadFull(f, buf[:])
if err != nil && err != io.ErrUnexpectedEOF {
return err
}
in := string(buf[:n])
const features = "\nFeatures : "
i := strings.Index(in, features)
if i == -1 {
return errors.New("no CPU features found")
}
in = in[i+len(features):]
if i := strings.Index(in, "\n"); i != -1 {
in = in[:i]
}
m := map[string]*bool{}
initOptions() // need it early here; it's harmless to call twice
for _, o := range options {
m[o.Name] = o.Feature
}
// The EVTSTRM field has alias "evstrm" in Go, but Linux calls it "evtstrm".
m["evtstrm"] = &ARM64.HasEVTSTRM
for _, f := range strings.Fields(in) {
if p, ok := m[f]; ok {
*p = true
}
}
return nil
}
================================================
FILE: gossh/sys/cpu/runtime_auxv.go
================================================
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cpu
// getAuxvFn is non-nil on Go 1.21+ (via runtime_auxv_go121.go init)
// on platforms that use auxv.
var getAuxvFn func() []uintptr
func getAuxv() []uintptr {
if getAuxvFn == nil {
return nil
}
return getAuxvFn()
}
================================================
FILE: gossh/sys/cpu/runtime_auxv_go121.go
================================================
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.21
package cpu
import (
_ "unsafe" // for linkname
)
//go:linkname runtime_getAuxv runtime.getAuxv
func runtime_getAuxv() []uintptr
func init() {
getAuxvFn = runtime_getAuxv
}
================================================
FILE: gossh/sys/cpu/syscall_aix_gccgo.go
================================================
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Recreate a getsystemcfg syscall handler instead of
// using the one provided by x/sys/unix to avoid having
// the dependency between them. (See golang.org/issue/32102)
// Moreover, this file will be used during the building of
// gccgo's libgo and thus must not used a CGo method.
//go:build aix && gccgo
package cpu
import (
"syscall"
)
//extern getsystemcfg
func gccgoGetsystemcfg(label uint32) (r uint64)
func callgetsystemcfg(label int) (r1 uintptr, e1 syscall.Errno) {
r1 = uintptr(gccgoGetsystemcfg(uint32(label)))
e1 = syscall.GetErrno()
return
}
================================================
FILE: gossh/sys/cpu/syscall_aix_ppc64_gc.go
================================================
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Minimal copy of x/sys/unix so the cpu package can make a
// system call on AIX without depending on x/sys/unix.
// (See golang.org/issue/32102)
//go:build aix && ppc64 && gc
package cpu
import (
"syscall"
"unsafe"
)
//go:cgo_import_dynamic libc_getsystemcfg getsystemcfg "libc.a/shr_64.o"
//go:linkname libc_getsystemcfg libc_getsystemcfg
type syscallFunc uintptr
var libc_getsystemcfg syscallFunc
type errno = syscall.Errno
// Implemented in runtime/syscall_aix.go.
func rawSyscall6(trap, nargs, a1, a2, a3, a4, a5, a6 uintptr) (r1, r2 uintptr, err errno)
func syscall6(trap, nargs, a1, a2, a3, a4, a5, a6 uintptr) (r1, r2 uintptr, err errno)
func callgetsystemcfg(label int) (r1 uintptr, e1 errno) {
r1, _, e1 = syscall6(uintptr(unsafe.Pointer(&libc_getsystemcfg)), 1, uintptr(label), 0, 0, 0, 0, 0)
return
}
================================================
FILE: gossh/sys/execabs/execabs.go
================================================
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package execabs is a drop-in replacement for os/exec
// that requires PATH lookups to find absolute paths.
// That is, execabs.Command("cmd") runs the same PATH lookup
// as exec.Command("cmd"), but if the result is a path
// which is relative, the Run and Start methods will report
// an error instead of running the executable.
//
// See https://blog.golang.org/path-security for more information
// about when it may be necessary or appropriate to use this package.
package execabs
import (
"context"
"fmt"
"os/exec"
"path/filepath"
"reflect"
"unsafe"
)
// ErrNotFound is the error resulting if a path search failed to find an executable file.
// It is an alias for exec.ErrNotFound.
var ErrNotFound = exec.ErrNotFound
// Cmd represents an external command being prepared or run.
// It is an alias for exec.Cmd.
type Cmd = exec.Cmd
// Error is returned by LookPath when it fails to classify a file as an executable.
// It is an alias for exec.Error.
type Error = exec.Error
// An ExitError reports an unsuccessful exit by a command.
// It is an alias for exec.ExitError.
type ExitError = exec.ExitError
func relError(file, path string) error {
return fmt.Errorf("%s resolves to executable in current directory (.%c%s)", file, filepath.Separator, path)
}
// LookPath searches for an executable named file in the directories
// named by the PATH environment variable. If file contains a slash,
// it is tried directly and the PATH is not consulted. The result will be
// an absolute path.
//
// LookPath differs from exec.LookPath in its handling of PATH lookups,
// which are used for file names without slashes. If exec.LookPath's
// PATH lookup would have returned an executable from the current directory,
// LookPath instead returns an error.
func LookPath(file string) (string, error) {
path, err := exec.LookPath(file)
if err != nil && !isGo119ErrDot(err) {
return "", err
}
if filepath.Base(file) == file && !filepath.IsAbs(path) {
return "", relError(file, path)
}
return path, nil
}
func fixCmd(name string, cmd *exec.Cmd) {
if filepath.Base(name) == name && !filepath.IsAbs(cmd.Path) && !isGo119ErrFieldSet(cmd) {
// exec.Command was called with a bare binary name and
// exec.LookPath returned a path which is not absolute.
// Set cmd.lookPathErr and clear cmd.Path so that it
// cannot be run.
lookPathErr := (*error)(unsafe.Pointer(reflect.ValueOf(cmd).Elem().FieldByName("lookPathErr").Addr().Pointer()))
if *lookPathErr == nil {
*lookPathErr = relError(name, cmd.Path)
}
cmd.Path = ""
}
}
// CommandContext is like Command but includes a context.
//
// The provided context is used to kill the process (by calling os.Process.Kill)
// if the context becomes done before the command completes on its own.
func CommandContext(ctx context.Context, name string, arg ...string) *exec.Cmd {
cmd := exec.CommandContext(ctx, name, arg...)
fixCmd(name, cmd)
return cmd
}
// Command returns the Cmd struct to execute the named program with the given arguments.
// See exec.Command for most details.
//
// Command differs from exec.Command in its handling of PATH lookups,
// which are used when the program name contains no slashes.
// If exec.Command would have returned an exec.Cmd configured to run an
// executable from the current directory, Command instead
// returns an exec.Cmd that will return an error from Start or Run.
func Command(name string, arg ...string) *exec.Cmd {
cmd := exec.Command(name, arg...)
fixCmd(name, cmd)
return cmd
}
================================================
FILE: gossh/sys/execabs/execabs_go118.go
================================================
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !go1.19
package execabs
import "os/exec"
func isGo119ErrDot(err error) bool {
return false
}
func isGo119ErrFieldSet(cmd *exec.Cmd) bool {
return false
}
================================================
FILE: gossh/sys/execabs/execabs_go119.go
================================================
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.19
package execabs
import (
"errors"
"os/exec"
)
func isGo119ErrDot(err error) bool {
return errors.Is(err, exec.ErrDot)
}
func isGo119ErrFieldSet(cmd *exec.Cmd) bool {
return cmd.Err != nil
}
================================================
FILE: gossh/sys/plan9/asm.s
================================================
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
#include "textflag.h"
TEXT ·use(SB),NOSPLIT,$0
RET
================================================
FILE: gossh/sys/plan9/asm_plan9_386.s
================================================
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
#include "textflag.h"
//
// System call support for 386, Plan 9
//
// Just jump to package syscall's implementation for all these functions.
// The runtime may know about them.
TEXT ·Syscall(SB),NOSPLIT,$0-32
JMP syscall·Syscall(SB)
TEXT ·Syscall6(SB),NOSPLIT,$0-44
JMP syscall·Syscall6(SB)
TEXT ·RawSyscall(SB),NOSPLIT,$0-28
JMP syscall·RawSyscall(SB)
TEXT ·RawSyscall6(SB),NOSPLIT,$0-40
JMP syscall·RawSyscall6(SB)
TEXT ·seek(SB),NOSPLIT,$0-36
JMP syscall·seek(SB)
TEXT ·exit(SB),NOSPLIT,$4-4
JMP syscall·exit(SB)
================================================
FILE: gossh/sys/plan9/asm_plan9_amd64.s
================================================
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
#include "textflag.h"
//
// System call support for amd64, Plan 9
//
// Just jump to package syscall's implementation for all these functions.
// The runtime may know about them.
TEXT ·Syscall(SB),NOSPLIT,$0-64
JMP syscall·Syscall(SB)
TEXT ·Syscall6(SB),NOSPLIT,$0-88
JMP syscall·Syscall6(SB)
TEXT ·RawSyscall(SB),NOSPLIT,$0-56
JMP syscall·RawSyscall(SB)
TEXT ·RawSyscall6(SB),NOSPLIT,$0-80
JMP syscall·RawSyscall6(SB)
TEXT ·seek(SB),NOSPLIT,$0-56
JMP syscall·seek(SB)
TEXT ·exit(SB),NOSPLIT,$8-8
JMP syscall·exit(SB)
================================================
FILE: gossh/sys/plan9/asm_plan9_arm.s
================================================
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
#include "textflag.h"
// System call support for plan9 on arm
// Just jump to package syscall's implementation for all these functions.
// The runtime may know about them.
TEXT ·Syscall(SB),NOSPLIT,$0-32
JMP syscall·Syscall(SB)
TEXT ·Syscall6(SB),NOSPLIT,$0-44
JMP syscall·Syscall6(SB)
TEXT ·RawSyscall(SB),NOSPLIT,$0-28
JMP syscall·RawSyscall(SB)
TEXT ·RawSyscall6(SB),NOSPLIT,$0-40
JMP syscall·RawSyscall6(SB)
TEXT ·seek(SB),NOSPLIT,$0-36
JMP syscall·exit(SB)
================================================
FILE: gossh/sys/plan9/const_plan9.go
================================================
package plan9
// Plan 9 Constants
// Open modes
const (
O_RDONLY = 0
O_WRONLY = 1
O_RDWR = 2
O_TRUNC = 16
O_CLOEXEC = 32
O_EXCL = 0x1000
)
// Rfork flags
const (
RFNAMEG = 1 << 0
RFENVG = 1 << 1
RFFDG = 1 << 2
RFNOTEG = 1 << 3
RFPROC = 1 << 4
RFMEM = 1 << 5
RFNOWAIT = 1 << 6
RFCNAMEG = 1 << 10
RFCENVG = 1 << 11
RFCFDG = 1 << 12
RFREND = 1 << 13
RFNOMNT = 1 << 14
)
// Qid.Type bits
const (
QTDIR = 0x80
QTAPPEND = 0x40
QTEXCL = 0x20
QTMOUNT = 0x10
QTAUTH = 0x08
QTTMP = 0x04
QTFILE = 0x00
)
// Dir.Mode bits
const (
DMDIR = 0x80000000
DMAPPEND = 0x40000000
DMEXCL = 0x20000000
DMMOUNT = 0x10000000
DMAUTH = 0x08000000
DMTMP = 0x04000000
DMREAD = 0x4
DMWRITE = 0x2
DMEXEC = 0x1
)
const (
STATMAX = 65535
ERRMAX = 128
STATFIXLEN = 49
)
// Mount and bind flags
const (
MREPL = 0x0000
MBEFORE = 0x0001
MAFTER = 0x0002
MORDER = 0x0003
MCREATE = 0x0004
MCACHE = 0x0010
MMASK = 0x0017
)
================================================
FILE: gossh/sys/plan9/dir_plan9.go
================================================
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Plan 9 directory marshalling. See intro(5).
package plan9
import "errors"
var (
ErrShortStat = errors.New("stat buffer too short")
ErrBadStat = errors.New("malformed stat buffer")
ErrBadName = errors.New("bad character in file name")
)
// A Qid represents a 9P server's unique identification for a file.
type Qid struct {
Path uint64 // the file server's unique identification for the file
Vers uint32 // version number for given Path
Type uint8 // the type of the file (plan9.QTDIR for example)
}
// A Dir contains the metadata for a file.
type Dir struct {
// system-modified data
Type uint16 // server type
Dev uint32 // server subtype
// file data
Qid Qid // unique id from server
Mode uint32 // permissions
Atime uint32 // last read time
Mtime uint32 // last write time
Length int64 // file length
Name string // last element of path
Uid string // owner name
Gid string // group name
Muid string // last modifier name
}
var nullDir = Dir{
Type: ^uint16(0),
Dev: ^uint32(0),
Qid: Qid{
Path: ^uint64(0),
Vers: ^uint32(0),
Type: ^uint8(0),
},
Mode: ^uint32(0),
Atime: ^uint32(0),
Mtime: ^uint32(0),
Length: ^int64(0),
}
// Null assigns special "don't touch" values to members of d to
// avoid modifying them during plan9.Wstat.
func (d *Dir) Null() { *d = nullDir }
// Marshal encodes a 9P stat message corresponding to d into b
//
// If there isn't enough space in b for a stat message, ErrShortStat is returned.
func (d *Dir) Marshal(b []byte) (n int, err error) {
n = STATFIXLEN + len(d.Name) + len(d.Uid) + len(d.Gid) + len(d.Muid)
if n > len(b) {
return n, ErrShortStat
}
for _, c := range d.Name {
if c == '/' {
return n, ErrBadName
}
}
b = pbit16(b, uint16(n)-2)
b = pbit16(b, d.Type)
b = pbit32(b, d.Dev)
b = pbit8(b, d.Qid.Type)
b = pbit32(b, d.Qid.Vers)
b = pbit64(b, d.Qid.Path)
b = pbit32(b, d.Mode)
b = pbit32(b, d.Atime)
b = pbit32(b, d.Mtime)
b = pbit64(b, uint64(d.Length))
b = pstring(b, d.Name)
b = pstring(b, d.Uid)
b = pstring(b, d.Gid)
b = pstring(b, d.Muid)
return n, nil
}
// UnmarshalDir decodes a single 9P stat message from b and returns the resulting Dir.
//
// If b is too small to hold a valid stat message, ErrShortStat is returned.
//
// If the stat message itself is invalid, ErrBadStat is returned.
func UnmarshalDir(b []byte) (*Dir, error) {
if len(b) < STATFIXLEN {
return nil, ErrShortStat
}
size, buf := gbit16(b)
if len(b) != int(size)+2 {
return nil, ErrBadStat
}
b = buf
var d Dir
d.Type, b = gbit16(b)
d.Dev, b = gbit32(b)
d.Qid.Type, b = gbit8(b)
d.Qid.Vers, b = gbit32(b)
d.Qid.Path, b = gbit64(b)
d.Mode, b = gbit32(b)
d.Atime, b = gbit32(b)
d.Mtime, b = gbit32(b)
n, b := gbit64(b)
d.Length = int64(n)
var ok bool
if d.Name, b, ok = gstring(b); !ok {
return nil, ErrBadStat
}
if d.Uid, b, ok = gstring(b); !ok {
return nil, ErrBadStat
}
if d.Gid, b, ok = gstring(b); !ok {
return nil, ErrBadStat
}
if d.Muid, b, ok = gstring(b); !ok {
return nil, ErrBadStat
}
return &d, nil
}
// pbit8 copies the 8-bit number v to b and returns the remaining slice of b.
func pbit8(b []byte, v uint8) []byte {
b[0] = byte(v)
return b[1:]
}
// pbit16 copies the 16-bit number v to b in little-endian order and returns the remaining slice of b.
func pbit16(b []byte, v uint16) []byte {
b[0] = byte(v)
b[1] = byte(v >> 8)
return b[2:]
}
// pbit32 copies the 32-bit number v to b in little-endian order and returns the remaining slice of b.
func pbit32(b []byte, v uint32) []byte {
b[0] = byte(v)
b[1] = byte(v >> 8)
b[2] = byte(v >> 16)
b[3] = byte(v >> 24)
return b[4:]
}
// pbit64 copies the 64-bit number v to b in little-endian order and returns the remaining slice of b.
func pbit64(b []byte, v uint64) []byte {
b[0] = byte(v)
b[1] = byte(v >> 8)
b[2] = byte(v >> 16)
b[3] = byte(v >> 24)
b[4] = byte(v >> 32)
b[5] = byte(v >> 40)
b[6] = byte(v >> 48)
b[7] = byte(v >> 56)
return b[8:]
}
// pstring copies the string s to b, prepending it with a 16-bit length in little-endian order, and
// returning the remaining slice of b..
func pstring(b []byte, s string) []byte {
b = pbit16(b, uint16(len(s)))
n := copy(b, s)
return b[n:]
}
// gbit8 reads an 8-bit number from b and returns it with the remaining slice of b.
func gbit8(b []byte) (uint8, []byte) {
return uint8(b[0]), b[1:]
}
// gbit16 reads a 16-bit number in little-endian order from b and returns it with the remaining slice of b.
func gbit16(b []byte) (uint16, []byte) {
return uint16(b[0]) | uint16(b[1])<<8, b[2:]
}
// gbit32 reads a 32-bit number in little-endian order from b and returns it with the remaining slice of b.
func gbit32(b []byte) (uint32, []byte) {
return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24, b[4:]
}
// gbit64 reads a 64-bit number in little-endian order from b and returns it with the remaining slice of b.
func gbit64(b []byte) (uint64, []byte) {
lo := uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
hi := uint32(b[4]) | uint32(b[5])<<8 | uint32(b[6])<<16 | uint32(b[7])<<24
return uint64(lo) | uint64(hi)<<32, b[8:]
}
// gstring reads a string from b, prefixed with a 16-bit length in little-endian order.
// It returns the string with the remaining slice of b and a boolean. If the length is
// greater than the number of bytes in b, the boolean will be false.
func gstring(b []byte) (string, []byte, bool) {
n, b := gbit16(b)
if int(n) > len(b) {
return "", b, false
}
return string(b[:n]), b[n:], true
}
================================================
FILE: gossh/sys/plan9/env_plan9.go
================================================
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Plan 9 environment variables.
package plan9
import (
"syscall"
)
func Getenv(key string) (value string, found bool) {
return syscall.Getenv(key)
}
func Setenv(key, value string) error {
return syscall.Setenv(key, value)
}
func Clearenv() {
syscall.Clearenv()
}
func Environ() []string {
return syscall.Environ()
}
func Unsetenv(key string) error {
return syscall.Unsetenv(key)
}
================================================
FILE: gossh/sys/plan9/errors_plan9.go
================================================
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package plan9
import "syscall"
// Constants
const (
// Invented values to support what package os expects.
O_CREAT = 0x02000
O_APPEND = 0x00400
O_NOCTTY = 0x00000
O_NONBLOCK = 0x00000
O_SYNC = 0x00000
O_ASYNC = 0x00000
S_IFMT = 0x1f000
S_IFIFO = 0x1000
S_IFCHR = 0x2000
S_IFDIR = 0x4000
S_IFBLK = 0x6000
S_IFREG = 0x8000
S_IFLNK = 0xa000
S_IFSOCK = 0xc000
)
// Errors
var (
EINVAL = syscall.NewError("bad arg in system call")
ENOTDIR = syscall.NewError("not a directory")
EISDIR = syscall.NewError("file is a directory")
ENOENT = syscall.NewError("file does not exist")
EEXIST = syscall.NewError("file already exists")
EMFILE = syscall.NewError("no free file descriptors")
EIO = syscall.NewError("i/o error")
ENAMETOOLONG = syscall.NewError("file name too long")
EINTR = syscall.NewError("interrupted")
EPERM = syscall.NewError("permission denied")
EBUSY = syscall.NewError("no free devices")
ETIMEDOUT = syscall.NewError("connection timed out")
EPLAN9 = syscall.NewError("not supported by plan 9")
// The following errors do not correspond to any
// Plan 9 system messages. Invented to support
// what package os and others expect.
EACCES = syscall.NewError("access permission denied")
EAFNOSUPPORT = syscall.NewError("address family not supported by protocol")
)
================================================
FILE: gossh/sys/plan9/mkall.sh
================================================
#!/usr/bin/env bash
# Copyright 2009 The Go Authors. All rights reserved.
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file.
# The plan9 package provides access to the raw system call
# interface of the underlying operating system. Porting Go to
# a new architecture/operating system combination requires
# some manual effort, though there are tools that automate
# much of the process. The auto-generated files have names
# beginning with z.
#
# This script runs or (given -n) prints suggested commands to generate z files
# for the current system. Running those commands is not automatic.
# This script is documentation more than anything else.
#
# * asm_${GOOS}_${GOARCH}.s
#
# This hand-written assembly file implements system call dispatch.
# There are three entry points:
#
# func Syscall(trap, a1, a2, a3 uintptr) (r1, r2, err uintptr);
# func Syscall6(trap, a1, a2, a3, a4, a5, a6 uintptr) (r1, r2, err uintptr);
# func RawSyscall(trap, a1, a2, a3 uintptr) (r1, r2, err uintptr);
#
# The first and second are the standard ones; they differ only in
# how many arguments can be passed to the kernel.
# The third is for low-level use by the ForkExec wrapper;
# unlike the first two, it does not call into the scheduler to
# let it know that a system call is running.
#
# * syscall_${GOOS}.go
#
# This hand-written Go file implements system calls that need
# special handling and lists "//sys" comments giving prototypes
# for ones that can be auto-generated. Mksyscall reads those
# comments to generate the stubs.
#
# * syscall_${GOOS}_${GOARCH}.go
#
# Same as syscall_${GOOS}.go except that it contains code specific
# to ${GOOS} on one particular architecture.
#
# * types_${GOOS}.c
#
# This hand-written C file includes standard C headers and then
# creates typedef or enum names beginning with a dollar sign
# (use of $ in variable names is a gcc extension). The hardest
# part about preparing this file is figuring out which headers to
# include and which symbols need to be #defined to get the
# actual data structures that pass through to the kernel system calls.
# Some C libraries present alternate versions for binary compatibility
# and translate them on the way in and out of system calls, but
# there is almost always a #define that can get the real ones.
# See types_darwin.c and types_linux.c for examples.
#
# * zerror_${GOOS}_${GOARCH}.go
#
# This machine-generated file defines the system's error numbers,
# error strings, and signal numbers. The generator is "mkerrors.sh".
# Usually no arguments are needed, but mkerrors.sh will pass its
# arguments on to godefs.
#
# * zsyscall_${GOOS}_${GOARCH}.go
#
# Generated by mksyscall.pl; see syscall_${GOOS}.go above.
#
# * zsysnum_${GOOS}_${GOARCH}.go
#
# Generated by mksysnum_${GOOS}.
#
# * ztypes_${GOOS}_${GOARCH}.go
#
# Generated by godefs; see types_${GOOS}.c above.
GOOSARCH="${GOOS}_${GOARCH}"
# defaults
mksyscall="go run mksyscall.go"
mkerrors="./mkerrors.sh"
zerrors="zerrors_$GOOSARCH.go"
mksysctl=""
zsysctl="zsysctl_$GOOSARCH.go"
mksysnum=
mktypes=
run="sh"
case "$1" in
-syscalls)
for i in zsyscall*go
do
sed 1q $i | sed 's;^// ;;' | sh > _$i && gofmt < _$i > $i
rm _$i
done
exit 0
;;
-n)
run="cat"
shift
esac
case "$#" in
0)
;;
*)
echo 'usage: mkall.sh [-n]' 1>&2
exit 2
esac
case "$GOOSARCH" in
_* | *_ | _)
echo 'undefined $GOOS_$GOARCH:' "$GOOSARCH" 1>&2
exit 1
;;
plan9_386)
mkerrors=
mksyscall="go run mksyscall.go -l32 -plan9 -tags plan9,386"
mksysnum="./mksysnum_plan9.sh /n/sources/plan9/sys/src/libc/9syscall/sys.h"
mktypes="XXX"
;;
plan9_amd64)
mkerrors=
mksyscall="go run mksyscall.go -l32 -plan9 -tags plan9,amd64"
mksysnum="./mksysnum_plan9.sh /n/sources/plan9/sys/src/libc/9syscall/sys.h"
mktypes="XXX"
;;
plan9_arm)
mkerrors=
mksyscall="go run mksyscall.go -l32 -plan9 -tags plan9,arm"
mksysnum="./mksysnum_plan9.sh /n/sources/plan9/sys/src/libc/9syscall/sys.h"
mktypes="XXX"
;;
*)
echo 'unrecognized $GOOS_$GOARCH: ' "$GOOSARCH" 1>&2
exit 1
;;
esac
(
if [ -n "$mkerrors" ]; then echo "$mkerrors |gofmt >$zerrors"; fi
case "$GOOS" in
plan9)
syscall_goos="syscall_$GOOS.go"
if [ -n "$mksyscall" ]; then echo "$mksyscall $syscall_goos |gofmt >zsyscall_$GOOSARCH.go"; fi
;;
esac
if [ -n "$mksysctl" ]; then echo "$mksysctl |gofmt >$zsysctl"; fi
if [ -n "$mksysnum" ]; then echo "$mksysnum |gofmt >zsysnum_$GOOSARCH.go"; fi
if [ -n "$mktypes" ]; then echo "$mktypes types_$GOOS.go |gofmt >ztypes_$GOOSARCH.go"; fi
) | $run
================================================
FILE: gossh/sys/plan9/mkerrors.sh
================================================
#!/usr/bin/env bash
# Copyright 2009 The Go Authors. All rights reserved.
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file.
# Generate Go code listing errors and other #defined constant
# values (ENAMETOOLONG etc.), by asking the preprocessor
# about the definitions.
unset LANG
export LC_ALL=C
export LC_CTYPE=C
CC=${CC:-gcc}
uname=$(uname)
includes='
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
'
ccflags="$@"
# Write go tool cgo -godefs input.
(
echo package plan9
echo
echo '/*'
indirect="includes_$(uname)"
echo "${!indirect} $includes"
echo '*/'
echo 'import "C"'
echo
echo 'const ('
# The gcc command line prints all the #defines
# it encounters while processing the input
echo "${!indirect} $includes" | $CC -x c - -E -dM $ccflags |
awk '
$1 != "#define" || $2 ~ /\(/ || $3 == "" {next}
$2 ~ /^E([ABCD]X|[BIS]P|[SD]I|S|FL)$/ {next} # 386 registers
$2 ~ /^(SIGEV_|SIGSTKSZ|SIGRT(MIN|MAX))/ {next}
$2 ~ /^(SCM_SRCRT)$/ {next}
$2 ~ /^(MAP_FAILED)$/ {next}
$2 !~ /^ETH_/ &&
$2 !~ /^EPROC_/ &&
$2 !~ /^EQUIV_/ &&
$2 !~ /^EXPR_/ &&
$2 ~ /^E[A-Z0-9_]+$/ ||
$2 ~ /^B[0-9_]+$/ ||
$2 ~ /^V[A-Z0-9]+$/ ||
$2 ~ /^CS[A-Z0-9]/ ||
$2 ~ /^I(SIG|CANON|CRNL|EXTEN|MAXBEL|STRIP|UTF8)$/ ||
$2 ~ /^IGN/ ||
$2 ~ /^IX(ON|ANY|OFF)$/ ||
$2 ~ /^IN(LCR|PCK)$/ ||
$2 ~ /(^FLU?SH)|(FLU?SH$)/ ||
$2 ~ /^C(LOCAL|READ)$/ ||
$2 == "BRKINT" ||
$2 == "HUPCL" ||
$2 == "PENDIN" ||
$2 == "TOSTOP" ||
$2 ~ /^PAR/ ||
$2 ~ /^SIG[^_]/ ||
$2 ~ /^O[CNPFP][A-Z]+[^_][A-Z]+$/ ||
$2 ~ /^IN_/ ||
$2 ~ /^LOCK_(SH|EX|NB|UN)$/ ||
$2 ~ /^(AF|SOCK|SO|SOL|IPPROTO|IP|IPV6|ICMP6|TCP|EVFILT|NOTE|EV|SHUT|PROT|MAP|PACKET|MSG|SCM|MCL|DT|MADV|PR)_/ ||
$2 == "ICMPV6_FILTER" ||
$2 == "SOMAXCONN" ||
$2 == "NAME_MAX" ||
$2 == "IFNAMSIZ" ||
$2 ~ /^CTL_(MAXNAME|NET|QUERY)$/ ||
$2 ~ /^SYSCTL_VERS/ ||
$2 ~ /^(MS|MNT)_/ ||
$2 ~ /^TUN(SET|GET|ATTACH|DETACH)/ ||
$2 ~ /^(O|F|FD|NAME|S|PTRACE|PT)_/ ||
$2 ~ /^LINUX_REBOOT_CMD_/ ||
$2 ~ /^LINUX_REBOOT_MAGIC[12]$/ ||
$2 !~ "NLA_TYPE_MASK" &&
$2 ~ /^(NETLINK|NLM|NLMSG|NLA|IFA|IFAN|RT|RTCF|RTN|RTPROT|RTNH|ARPHRD|ETH_P)_/ ||
$2 ~ /^SIOC/ ||
$2 ~ /^TIOC/ ||
$2 !~ "RTF_BITS" &&
$2 ~ /^(IFF|IFT|NET_RT|RTM|RTF|RTV|RTA|RTAX)_/ ||
$2 ~ /^BIOC/ ||
$2 ~ /^RUSAGE_(SELF|CHILDREN|THREAD)/ ||
$2 ~ /^RLIMIT_(AS|CORE|CPU|DATA|FSIZE|NOFILE|STACK)|RLIM_INFINITY/ ||
$2 ~ /^PRIO_(PROCESS|PGRP|USER)/ ||
$2 ~ /^CLONE_[A-Z_]+/ ||
$2 !~ /^(BPF_TIMEVAL)$/ &&
$2 ~ /^(BPF|DLT)_/ ||
$2 !~ "WMESGLEN" &&
$2 ~ /^W[A-Z0-9]+$/ {printf("\t%s = C.%s\n", $2, $2)}
$2 ~ /^__WCOREFLAG$/ {next}
$2 ~ /^__W[A-Z0-9]+$/ {printf("\t%s = C.%s\n", substr($2,3), $2)}
{next}
' | sort
echo ')'
) >_const.go
# Pull out the error names for later.
errors=$(
echo '#include ' | $CC -x c - -E -dM $ccflags |
awk '$1=="#define" && $2 ~ /^E[A-Z0-9_]+$/ { print $2 }' |
sort
)
# Pull out the signal names for later.
signals=$(
echo '#include ' | $CC -x c - -E -dM $ccflags |
awk '$1=="#define" && $2 ~ /^SIG[A-Z0-9]+$/ { print $2 }' |
grep -v 'SIGSTKSIZE\|SIGSTKSZ\|SIGRT' |
sort
)
# Again, writing regexps to a file.
echo '#include ' | $CC -x c - -E -dM $ccflags |
awk '$1=="#define" && $2 ~ /^E[A-Z0-9_]+$/ { print "^\t" $2 "[ \t]*=" }' |
sort >_error.grep
echo '#include ' | $CC -x c - -E -dM $ccflags |
awk '$1=="#define" && $2 ~ /^SIG[A-Z0-9]+$/ { print "^\t" $2 "[ \t]*=" }' |
grep -v 'SIGSTKSIZE\|SIGSTKSZ\|SIGRT' |
sort >_signal.grep
echo '// mkerrors.sh' "$@"
echo '// Code generated by the command above; DO NOT EDIT.'
echo
go tool cgo -godefs -- "$@" _const.go >_error.out
cat _error.out | grep -vf _error.grep | grep -vf _signal.grep
echo
echo '// Errors'
echo 'const ('
cat _error.out | grep -f _error.grep | sed 's/=\(.*\)/= Errno(\1)/'
echo ')'
echo
echo '// Signals'
echo 'const ('
cat _error.out | grep -f _signal.grep | sed 's/=\(.*\)/= Signal(\1)/'
echo ')'
# Run C program to print error and syscall strings.
(
echo -E "
#include
#include
#include
#include
#include
#include
#define nelem(x) (sizeof(x)/sizeof((x)[0]))
enum { A = 'A', Z = 'Z', a = 'a', z = 'z' }; // avoid need for single quotes below
int errors[] = {
"
for i in $errors
do
echo -E ' '$i,
done
echo -E "
};
int signals[] = {
"
for i in $signals
do
echo -E ' '$i,
done
# Use -E because on some systems bash builtin interprets \n itself.
echo -E '
};
static int
intcmp(const void *a, const void *b)
{
return *(int*)a - *(int*)b;
}
int
main(void)
{
int i, j, e;
char buf[1024], *p;
printf("\n\n// Error table\n");
printf("var errors = [...]string {\n");
qsort(errors, nelem(errors), sizeof errors[0], intcmp);
for(i=0; i 0 && errors[i-1] == e)
continue;
strcpy(buf, strerror(e));
// lowercase first letter: Bad -> bad, but STREAM -> STREAM.
if(A <= buf[0] && buf[0] <= Z && a <= buf[1] && buf[1] <= z)
buf[0] += a - A;
printf("\t%d: \"%s\",\n", e, buf);
}
printf("}\n\n");
printf("\n\n// Signal table\n");
printf("var signals = [...]string {\n");
qsort(signals, nelem(signals), sizeof signals[0], intcmp);
for(i=0; i 0 && signals[i-1] == e)
continue;
strcpy(buf, strsignal(e));
// lowercase first letter: Bad -> bad, but STREAM -> STREAM.
if(A <= buf[0] && buf[0] <= Z && a <= buf[1] && buf[1] <= z)
buf[0] += a - A;
// cut trailing : number.
p = strrchr(buf, ":"[0]);
if(p)
*p = '\0';
printf("\t%d: \"%s\",\n", e, buf);
}
printf("}\n\n");
return 0;
}
'
) >_errors.c
$CC $ccflags -o _errors _errors.c && $GORUN ./_errors && rm -f _errors.c _errors _const.go _error.grep _signal.grep _error.out
================================================
FILE: gossh/sys/plan9/mksyscall.go
================================================
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build ignore
/*
This program reads a file containing function prototypes
(like syscall_plan9.go) and generates system call bodies.
The prototypes are marked by lines beginning with "//sys"
and read like func declarations if //sys is replaced by func, but:
- The parameter lists must give a name for each argument.
This includes return parameters.
- The parameter lists must give a type for each argument:
the (x, y, z int) shorthand is not allowed.
- If the return parameter is an error number, it must be named errno.
A line beginning with //sysnb is like //sys, except that the
goroutine will not be suspended during the execution of the system
call. This must only be used for system calls which can never
block, as otherwise the system call could cause all goroutines to
hang.
*/
package main
import (
"bufio"
"flag"
"fmt"
"os"
"regexp"
"strings"
)
var (
b32 = flag.Bool("b32", false, "32bit big-endian")
l32 = flag.Bool("l32", false, "32bit little-endian")
plan9 = flag.Bool("plan9", false, "plan9")
openbsd = flag.Bool("openbsd", false, "openbsd")
netbsd = flag.Bool("netbsd", false, "netbsd")
dragonfly = flag.Bool("dragonfly", false, "dragonfly")
arm = flag.Bool("arm", false, "arm") // 64-bit value should use (even, odd)-pair
tags = flag.String("tags", "", "build tags")
filename = flag.String("output", "", "output file name (standard output if omitted)")
)
// cmdLine returns this programs's commandline arguments
func cmdLine() string {
return "go run mksyscall.go " + strings.Join(os.Args[1:], " ")
}
// goBuildTags returns build tags in the go:build format.
func goBuildTags() string {
return strings.ReplaceAll(*tags, ",", " && ")
}
// Param is function parameter
type Param struct {
Name string
Type string
}
// usage prints the program usage
func usage() {
fmt.Fprintf(os.Stderr, "usage: go run mksyscall.go [-b32 | -l32] [-tags x,y] [file ...]\n")
os.Exit(1)
}
// parseParamList parses parameter list and returns a slice of parameters
func parseParamList(list string) []string {
list = strings.TrimSpace(list)
if list == "" {
return []string{}
}
return regexp.MustCompile(`\s*,\s*`).Split(list, -1)
}
// parseParam splits a parameter into name and type
func parseParam(p string) Param {
ps := regexp.MustCompile(`^(\S*) (\S*)$`).FindStringSubmatch(p)
if ps == nil {
fmt.Fprintf(os.Stderr, "malformed parameter: %s\n", p)
os.Exit(1)
}
return Param{ps[1], ps[2]}
}
func main() {
// Get the OS and architecture (using GOARCH_TARGET if it exists)
goos := os.Getenv("GOOS")
goarch := os.Getenv("GOARCH_TARGET")
if goarch == "" {
goarch = os.Getenv("GOARCH")
}
// Check that we are using the Docker-based build system if we should
if goos == "linux" {
if os.Getenv("GOLANG_SYS_BUILD") != "docker" {
fmt.Fprintf(os.Stderr, "In the Docker-based build system, mksyscall should not be called directly.\n")
fmt.Fprintf(os.Stderr, "See README.md\n")
os.Exit(1)
}
}
flag.Usage = usage
flag.Parse()
if len(flag.Args()) <= 0 {
fmt.Fprintf(os.Stderr, "no files to parse provided\n")
usage()
}
endianness := ""
if *b32 {
endianness = "big-endian"
} else if *l32 {
endianness = "little-endian"
}
libc := false
if goos == "darwin" && strings.Contains(goBuildTags(), " && go1.12") {
libc = true
}
trampolines := map[string]bool{}
text := ""
for _, path := range flag.Args() {
file, err := os.Open(path)
if err != nil {
fmt.Fprintf(os.Stderr, err.Error())
os.Exit(1)
}
s := bufio.NewScanner(file)
for s.Scan() {
t := s.Text()
t = strings.TrimSpace(t)
t = regexp.MustCompile(`\s+`).ReplaceAllString(t, ` `)
nonblock := regexp.MustCompile(`^\/\/sysnb `).FindStringSubmatch(t)
if regexp.MustCompile(`^\/\/sys `).FindStringSubmatch(t) == nil && nonblock == nil {
continue
}
// Line must be of the form
// func Open(path string, mode int, perm int) (fd int, errno error)
// Split into name, in params, out params.
f := regexp.MustCompile(`^\/\/sys(nb)? (\w+)\(([^()]*)\)\s*(?:\(([^()]+)\))?\s*(?:=\s*((?i)SYS_[A-Z0-9_]+))?$`).FindStringSubmatch(t)
if f == nil {
fmt.Fprintf(os.Stderr, "%s:%s\nmalformed //sys declaration\n", path, t)
os.Exit(1)
}
funct, inps, outps, sysname := f[2], f[3], f[4], f[5]
// Split argument lists on comma.
in := parseParamList(inps)
out := parseParamList(outps)
// Try in vain to keep people from editing this file.
// The theory is that they jump into the middle of the file
// without reading the header.
text += "// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT\n\n"
// Go function header.
outDecl := ""
if len(out) > 0 {
outDecl = fmt.Sprintf(" (%s)", strings.Join(out, ", "))
}
text += fmt.Sprintf("func %s(%s)%s {\n", funct, strings.Join(in, ", "), outDecl)
// Check if err return available
errvar := ""
for _, param := range out {
p := parseParam(param)
if p.Type == "error" {
errvar = p.Name
break
}
}
// Prepare arguments to Syscall.
var args []string
n := 0
for _, param := range in {
p := parseParam(param)
if regexp.MustCompile(`^\*`).FindStringSubmatch(p.Type) != nil {
args = append(args, "uintptr(unsafe.Pointer("+p.Name+"))")
} else if p.Type == "string" && errvar != "" {
text += fmt.Sprintf("\tvar _p%d *byte\n", n)
text += fmt.Sprintf("\t_p%d, %s = BytePtrFromString(%s)\n", n, errvar, p.Name)
text += fmt.Sprintf("\tif %s != nil {\n\t\treturn\n\t}\n", errvar)
args = append(args, fmt.Sprintf("uintptr(unsafe.Pointer(_p%d))", n))
n++
} else if p.Type == "string" {
fmt.Fprintf(os.Stderr, path+":"+funct+" uses string arguments, but has no error return\n")
text += fmt.Sprintf("\tvar _p%d *byte\n", n)
text += fmt.Sprintf("\t_p%d, _ = BytePtrFromString(%s)\n", n, p.Name)
args = append(args, fmt.Sprintf("uintptr(unsafe.Pointer(_p%d))", n))
n++
} else if regexp.MustCompile(`^\[\](.*)`).FindStringSubmatch(p.Type) != nil {
// Convert slice into pointer, length.
// Have to be careful not to take address of &a[0] if len == 0:
// pass dummy pointer in that case.
// Used to pass nil, but some OSes or simulators reject write(fd, nil, 0).
text += fmt.Sprintf("\tvar _p%d unsafe.Pointer\n", n)
text += fmt.Sprintf("\tif len(%s) > 0 {\n\t\t_p%d = unsafe.Pointer(&%s[0])\n\t}", p.Name, n, p.Name)
text += fmt.Sprintf(" else {\n\t\t_p%d = unsafe.Pointer(&_zero)\n\t}\n", n)
args = append(args, fmt.Sprintf("uintptr(_p%d)", n), fmt.Sprintf("uintptr(len(%s))", p.Name))
n++
} else if p.Type == "int64" && (*openbsd || *netbsd) {
args = append(args, "0")
if endianness == "big-endian" {
args = append(args, fmt.Sprintf("uintptr(%s>>32)", p.Name), fmt.Sprintf("uintptr(%s)", p.Name))
} else if endianness == "little-endian" {
args = append(args, fmt.Sprintf("uintptr(%s)", p.Name), fmt.Sprintf("uintptr(%s>>32)", p.Name))
} else {
args = append(args, fmt.Sprintf("uintptr(%s)", p.Name))
}
} else if p.Type == "int64" && *dragonfly {
if regexp.MustCompile(`^(?i)extp(read|write)`).FindStringSubmatch(funct) == nil {
args = append(args, "0")
}
if endianness == "big-endian" {
args = append(args, fmt.Sprintf("uintptr(%s>>32)", p.Name), fmt.Sprintf("uintptr(%s)", p.Name))
} else if endianness == "little-endian" {
args = append(args, fmt.Sprintf("uintptr(%s)", p.Name), fmt.Sprintf("uintptr(%s>>32)", p.Name))
} else {
args = append(args, fmt.Sprintf("uintptr(%s)", p.Name))
}
} else if p.Type == "int64" && endianness != "" {
if len(args)%2 == 1 && *arm {
// arm abi specifies 64-bit argument uses
// (even, odd) pair
args = append(args, "0")
}
if endianness == "big-endian" {
args = append(args, fmt.Sprintf("uintptr(%s>>32)", p.Name), fmt.Sprintf("uintptr(%s)", p.Name))
} else {
args = append(args, fmt.Sprintf("uintptr(%s)", p.Name), fmt.Sprintf("uintptr(%s>>32)", p.Name))
}
} else {
args = append(args, fmt.Sprintf("uintptr(%s)", p.Name))
}
}
// Determine which form to use; pad args with zeros.
asm := "Syscall"
if nonblock != nil {
if errvar == "" && goos == "linux" {
asm = "RawSyscallNoError"
} else {
asm = "RawSyscall"
}
} else {
if errvar == "" && goos == "linux" {
asm = "SyscallNoError"
}
}
if len(args) <= 3 {
for len(args) < 3 {
args = append(args, "0")
}
} else if len(args) <= 6 {
asm += "6"
for len(args) < 6 {
args = append(args, "0")
}
} else if len(args) <= 9 {
asm += "9"
for len(args) < 9 {
args = append(args, "0")
}
} else {
fmt.Fprintf(os.Stderr, "%s:%s too many arguments to system call\n", path, funct)
}
// System call number.
if sysname == "" {
sysname = "SYS_" + funct
sysname = regexp.MustCompile(`([a-z])([A-Z])`).ReplaceAllString(sysname, `${1}_$2`)
sysname = strings.ToUpper(sysname)
}
var libcFn string
if libc {
asm = "syscall_" + strings.ToLower(asm[:1]) + asm[1:] // internal syscall call
sysname = strings.TrimPrefix(sysname, "SYS_") // remove SYS_
sysname = strings.ToLower(sysname) // lowercase
if sysname == "getdirentries64" {
// Special case - libSystem name and
// raw syscall name don't match.
sysname = "__getdirentries64"
}
libcFn = sysname
sysname = "funcPC(libc_" + sysname + "_trampoline)"
}
// Actual call.
arglist := strings.Join(args, ", ")
call := fmt.Sprintf("%s(%s, %s)", asm, sysname, arglist)
// Assign return values.
body := ""
ret := []string{"_", "_", "_"}
doErrno := false
for i := 0; i < len(out); i++ {
p := parseParam(out[i])
reg := ""
if p.Name == "err" && !*plan9 {
reg = "e1"
ret[2] = reg
doErrno = true
} else if p.Name == "err" && *plan9 {
ret[0] = "r0"
ret[2] = "e1"
break
} else {
reg = fmt.Sprintf("r%d", i)
ret[i] = reg
}
if p.Type == "bool" {
reg = fmt.Sprintf("%s != 0", reg)
}
if p.Type == "int64" && endianness != "" {
// 64-bit number in r1:r0 or r0:r1.
if i+2 > len(out) {
fmt.Fprintf(os.Stderr, "%s:%s not enough registers for int64 return\n", path, funct)
}
if endianness == "big-endian" {
reg = fmt.Sprintf("int64(r%d)<<32 | int64(r%d)", i, i+1)
} else {
reg = fmt.Sprintf("int64(r%d)<<32 | int64(r%d)", i+1, i)
}
ret[i] = fmt.Sprintf("r%d", i)
ret[i+1] = fmt.Sprintf("r%d", i+1)
}
if reg != "e1" || *plan9 {
body += fmt.Sprintf("\t%s = %s(%s)\n", p.Name, p.Type, reg)
}
}
if ret[0] == "_" && ret[1] == "_" && ret[2] == "_" {
text += fmt.Sprintf("\t%s\n", call)
} else {
if errvar == "" && goos == "linux" {
// raw syscall without error on Linux, see golang.org/issue/22924
text += fmt.Sprintf("\t%s, %s := %s\n", ret[0], ret[1], call)
} else {
text += fmt.Sprintf("\t%s, %s, %s := %s\n", ret[0], ret[1], ret[2], call)
}
}
text += body
if *plan9 && ret[2] == "e1" {
text += "\tif int32(r0) == -1 {\n"
text += "\t\terr = e1\n"
text += "\t}\n"
} else if doErrno {
text += "\tif e1 != 0 {\n"
text += "\t\terr = errnoErr(e1)\n"
text += "\t}\n"
}
text += "\treturn\n"
text += "}\n\n"
if libc && !trampolines[libcFn] {
// some system calls share a trampoline, like read and readlen.
trampolines[libcFn] = true
// Declare assembly trampoline.
text += fmt.Sprintf("func libc_%s_trampoline()\n", libcFn)
// Assembly trampoline calls the libc_* function, which this magic
// redirects to use the function from libSystem.
text += fmt.Sprintf("//go:linkname libc_%s libc_%s\n", libcFn, libcFn)
text += fmt.Sprintf("//go:cgo_import_dynamic libc_%s %s \"/usr/lib/libSystem.B.dylib\"\n", libcFn, libcFn)
text += "\n"
}
}
if err := s.Err(); err != nil {
fmt.Fprintf(os.Stderr, err.Error())
os.Exit(1)
}
file.Close()
}
fmt.Printf(srcTemplate, cmdLine(), goBuildTags(), text)
}
const srcTemplate = `// %s
// Code generated by the command above; see README.md. DO NOT EDIT.
//go:build %s
package plan9
import "unsafe"
%s
`
================================================
FILE: gossh/sys/plan9/mksysnum_plan9.sh
================================================
#!/bin/sh
# Copyright 2009 The Go Authors. All rights reserved.
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file.
COMMAND="mksysnum_plan9.sh $@"
cat <= 10 {
buf[i] = byte(val%10 + '0')
i--
val /= 10
}
buf[i] = byte(val + '0')
return string(buf[i:])
}
================================================
FILE: gossh/sys/plan9/syscall.go
================================================
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build plan9
// Package plan9 contains an interface to the low-level operating system
// primitives. OS details vary depending on the underlying system, and
// by default, godoc will display the OS-specific documentation for the current
// system. If you want godoc to display documentation for another
// system, set $GOOS and $GOARCH to the desired system. For example, if
// you want to view documentation for freebsd/arm on linux/amd64, set $GOOS
// to freebsd and $GOARCH to arm.
//
// The primary use of this package is inside other packages that provide a more
// portable interface to the system, such as "os", "time" and "net". Use
// those packages rather than this one if you can.
//
// For details of the functions and data types in this package consult
// the manuals for the appropriate operating system.
//
// These calls return err == nil to indicate success; otherwise
// err represents an operating system error describing the failure and
// holds a value of type syscall.ErrorString.
package plan9 // import "gossh/sys/plan9"
import (
"bytes"
"strings"
"unsafe"
)
// ByteSliceFromString returns a NUL-terminated slice of bytes
// containing the text of s. If s contains a NUL byte at any
// location, it returns (nil, EINVAL).
func ByteSliceFromString(s string) ([]byte, error) {
if strings.IndexByte(s, 0) != -1 {
return nil, EINVAL
}
a := make([]byte, len(s)+1)
copy(a, s)
return a, nil
}
// BytePtrFromString returns a pointer to a NUL-terminated array of
// bytes containing the text of s. If s contains a NUL byte at any
// location, it returns (nil, EINVAL).
func BytePtrFromString(s string) (*byte, error) {
a, err := ByteSliceFromString(s)
if err != nil {
return nil, err
}
return &a[0], nil
}
// ByteSliceToString returns a string form of the text represented by the slice s, with a terminating NUL and any
// bytes after the NUL removed.
func ByteSliceToString(s []byte) string {
if i := bytes.IndexByte(s, 0); i != -1 {
s = s[:i]
}
return string(s)
}
// BytePtrToString takes a pointer to a sequence of text and returns the corresponding string.
// If the pointer is nil, it returns the empty string. It assumes that the text sequence is terminated
// at a zero byte; if the zero byte is not present, the program may crash.
func BytePtrToString(p *byte) string {
if p == nil {
return ""
}
if *p == 0 {
return ""
}
// Find NUL terminator.
n := 0
for ptr := unsafe.Pointer(p); *(*byte)(ptr) != 0; n++ {
ptr = unsafe.Pointer(uintptr(ptr) + 1)
}
return string(unsafe.Slice(p, n))
}
// Single-word zero for use when we need a valid pointer to 0 bytes.
// See mksyscall.pl.
var _zero uintptr
func (ts *Timespec) Unix() (sec int64, nsec int64) {
return int64(ts.Sec), int64(ts.Nsec)
}
func (tv *Timeval) Unix() (sec int64, nsec int64) {
return int64(tv.Sec), int64(tv.Usec) * 1000
}
func (ts *Timespec) Nano() int64 {
return int64(ts.Sec)*1e9 + int64(ts.Nsec)
}
func (tv *Timeval) Nano() int64 {
return int64(tv.Sec)*1e9 + int64(tv.Usec)*1000
}
// use is a no-op, but the compiler cannot see that it is.
// Calling use(p) ensures that p is kept live until that point.
//
//go:noescape
func use(p unsafe.Pointer)
================================================
FILE: gossh/sys/plan9/syscall_plan9.go
================================================
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Plan 9 system calls.
// This file is compiled as ordinary Go code,
// but it is also input to mksyscall,
// which parses the //sys lines and generates system call stubs.
// Note that sometimes we use a lowercase //sys name and
// wrap it in our own nicer implementation.
package plan9
import (
"bytes"
"syscall"
"unsafe"
)
// A Note is a string describing a process note.
// It implements the os.Signal interface.
type Note string
func (n Note) Signal() {}
func (n Note) String() string {
return string(n)
}
var (
Stdin = 0
Stdout = 1
Stderr = 2
)
// For testing: clients can set this flag to force
// creation of IPv6 sockets to return EAFNOSUPPORT.
var SocketDisableIPv6 bool
func Syscall(trap, a1, a2, a3 uintptr) (r1, r2 uintptr, err syscall.ErrorString)
func Syscall6(trap, a1, a2, a3, a4, a5, a6 uintptr) (r1, r2 uintptr, err syscall.ErrorString)
func RawSyscall(trap, a1, a2, a3 uintptr) (r1, r2, err uintptr)
func RawSyscall6(trap, a1, a2, a3, a4, a5, a6 uintptr) (r1, r2, err uintptr)
func atoi(b []byte) (n uint) {
n = 0
for i := 0; i < len(b); i++ {
n = n*10 + uint(b[i]-'0')
}
return
}
func cstring(s []byte) string {
i := bytes.IndexByte(s, 0)
if i == -1 {
i = len(s)
}
return string(s[:i])
}
func errstr() string {
var buf [ERRMAX]byte
RawSyscall(SYS_ERRSTR, uintptr(unsafe.Pointer(&buf[0])), uintptr(len(buf)), 0)
buf[len(buf)-1] = 0
return cstring(buf[:])
}
// Implemented in assembly to import from runtime.
func exit(code int)
func Exit(code int) { exit(code) }
func readnum(path string) (uint, error) {
var b [12]byte
fd, e := Open(path, O_RDONLY)
if e != nil {
return 0, e
}
defer Close(fd)
n, e := Pread(fd, b[:], 0)
if e != nil {
return 0, e
}
m := 0
for ; m < n && b[m] == ' '; m++ {
}
return atoi(b[m : n-1]), nil
}
func Getpid() (pid int) {
n, _ := readnum("#c/pid")
return int(n)
}
func Getppid() (ppid int) {
n, _ := readnum("#c/ppid")
return int(n)
}
func Read(fd int, p []byte) (n int, err error) {
return Pread(fd, p, -1)
}
func Write(fd int, p []byte) (n int, err error) {
return Pwrite(fd, p, -1)
}
var ioSync int64
//sys fd2path(fd int, buf []byte) (err error)
func Fd2path(fd int) (path string, err error) {
var buf [512]byte
e := fd2path(fd, buf[:])
if e != nil {
return "", e
}
return cstring(buf[:]), nil
}
//sys pipe(p *[2]int32) (err error)
func Pipe(p []int) (err error) {
if len(p) != 2 {
return syscall.ErrorString("bad arg in system call")
}
var pp [2]int32
err = pipe(&pp)
if err == nil {
p[0] = int(pp[0])
p[1] = int(pp[1])
}
return
}
// Underlying system call writes to newoffset via pointer.
// Implemented in assembly to avoid allocation.
func seek(placeholder uintptr, fd int, offset int64, whence int) (newoffset int64, err string)
func Seek(fd int, offset int64, whence int) (newoffset int64, err error) {
newoffset, e := seek(0, fd, offset, whence)
if newoffset == -1 {
err = syscall.ErrorString(e)
}
return
}
func Mkdir(path string, mode uint32) (err error) {
fd, err := Create(path, O_RDONLY, DMDIR|mode)
if fd != -1 {
Close(fd)
}
return
}
type Waitmsg struct {
Pid int
Time [3]uint32
Msg string
}
func (w Waitmsg) Exited() bool { return true }
func (w Waitmsg) Signaled() bool { return false }
func (w Waitmsg) ExitStatus() int {
if len(w.Msg) == 0 {
// a normal exit returns no message
return 0
}
return 1
}
//sys await(s []byte) (n int, err error)
func Await(w *Waitmsg) (err error) {
var buf [512]byte
var f [5][]byte
n, err := await(buf[:])
if err != nil || w == nil {
return
}
nf := 0
p := 0
for i := 0; i < n && nf < len(f)-1; i++ {
if buf[i] == ' ' {
f[nf] = buf[p:i]
p = i + 1
nf++
}
}
f[nf] = buf[p:]
nf++
if nf != len(f) {
return syscall.ErrorString("invalid wait message")
}
w.Pid = int(atoi(f[0]))
w.Time[0] = uint32(atoi(f[1]))
w.Time[1] = uint32(atoi(f[2]))
w.Time[2] = uint32(atoi(f[3]))
w.Msg = cstring(f[4])
if w.Msg == "''" {
// await() returns '' for no error
w.Msg = ""
}
return
}
func Unmount(name, old string) (err error) {
fixwd()
oldp, err := BytePtrFromString(old)
if err != nil {
return err
}
oldptr := uintptr(unsafe.Pointer(oldp))
var r0 uintptr
var e syscall.ErrorString
// bind(2) man page: If name is zero, everything bound or mounted upon old is unbound or unmounted.
if name == "" {
r0, _, e = Syscall(SYS_UNMOUNT, _zero, oldptr, 0)
} else {
namep, err := BytePtrFromString(name)
if err != nil {
return err
}
r0, _, e = Syscall(SYS_UNMOUNT, uintptr(unsafe.Pointer(namep)), oldptr, 0)
}
if int32(r0) == -1 {
err = e
}
return
}
func Fchdir(fd int) (err error) {
path, err := Fd2path(fd)
if err != nil {
return
}
return Chdir(path)
}
type Timespec struct {
Sec int32
Nsec int32
}
type Timeval struct {
Sec int32
Usec int32
}
func NsecToTimeval(nsec int64) (tv Timeval) {
nsec += 999 // round up to microsecond
tv.Usec = int32(nsec % 1e9 / 1e3)
tv.Sec = int32(nsec / 1e9)
return
}
func nsec() int64 {
var scratch int64
r0, _, _ := Syscall(SYS_NSEC, uintptr(unsafe.Pointer(&scratch)), 0, 0)
// TODO(aram): remove hack after I fix _nsec in the pc64 kernel.
if r0 == 0 {
return scratch
}
return int64(r0)
}
func Gettimeofday(tv *Timeval) error {
nsec := nsec()
*tv = NsecToTimeval(nsec)
return nil
}
func Getpagesize() int { return 0x1000 }
func Getegid() (egid int) { return -1 }
func Geteuid() (euid int) { return -1 }
func Getgid() (gid int) { return -1 }
func Getuid() (uid int) { return -1 }
func Getgroups() (gids []int, err error) {
return make([]int, 0), nil
}
//sys open(path string, mode int) (fd int, err error)
func Open(path string, mode int) (fd int, err error) {
fixwd()
return open(path, mode)
}
//sys create(path string, mode int, perm uint32) (fd int, err error)
func Create(path string, mode int, perm uint32) (fd int, err error) {
fixwd()
return create(path, mode, perm)
}
//sys remove(path string) (err error)
func Remove(path string) error {
fixwd()
return remove(path)
}
//sys stat(path string, edir []byte) (n int, err error)
func Stat(path string, edir []byte) (n int, err error) {
fixwd()
return stat(path, edir)
}
//sys bind(name string, old string, flag int) (err error)
func Bind(name string, old string, flag int) (err error) {
fixwd()
return bind(name, old, flag)
}
//sys mount(fd int, afd int, old string, flag int, aname string) (err error)
func Mount(fd int, afd int, old string, flag int, aname string) (err error) {
fixwd()
return mount(fd, afd, old, flag, aname)
}
//sys wstat(path string, edir []byte) (err error)
func Wstat(path string, edir []byte) (err error) {
fixwd()
return wstat(path, edir)
}
//sys chdir(path string) (err error)
//sys Dup(oldfd int, newfd int) (fd int, err error)
//sys Pread(fd int, p []byte, offset int64) (n int, err error)
//sys Pwrite(fd int, p []byte, offset int64) (n int, err error)
//sys Close(fd int) (err error)
//sys Fstat(fd int, edir []byte) (n int, err error)
//sys Fwstat(fd int, edir []byte) (err error)
================================================
FILE: gossh/sys/plan9/zsyscall_plan9_386.go
================================================
// go run mksyscall.go -l32 -plan9 -tags plan9,386 syscall_plan9.go
// Code generated by the command above; see README.md. DO NOT EDIT.
//go:build plan9 && 386
package plan9
import "unsafe"
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func fd2path(fd int, buf []byte) (err error) {
var _p0 unsafe.Pointer
if len(buf) > 0 {
_p0 = unsafe.Pointer(&buf[0])
} else {
_p0 = unsafe.Pointer(&_zero)
}
r0, _, e1 := Syscall(SYS_FD2PATH, uintptr(fd), uintptr(_p0), uintptr(len(buf)))
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func pipe(p *[2]int32) (err error) {
r0, _, e1 := Syscall(SYS_PIPE, uintptr(unsafe.Pointer(p)), 0, 0)
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func await(s []byte) (n int, err error) {
var _p0 unsafe.Pointer
if len(s) > 0 {
_p0 = unsafe.Pointer(&s[0])
} else {
_p0 = unsafe.Pointer(&_zero)
}
r0, _, e1 := Syscall(SYS_AWAIT, uintptr(_p0), uintptr(len(s)), 0)
n = int(r0)
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func open(path string, mode int) (fd int, err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
r0, _, e1 := Syscall(SYS_OPEN, uintptr(unsafe.Pointer(_p0)), uintptr(mode), 0)
fd = int(r0)
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func create(path string, mode int, perm uint32) (fd int, err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
r0, _, e1 := Syscall(SYS_CREATE, uintptr(unsafe.Pointer(_p0)), uintptr(mode), uintptr(perm))
fd = int(r0)
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func remove(path string) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
r0, _, e1 := Syscall(SYS_REMOVE, uintptr(unsafe.Pointer(_p0)), 0, 0)
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func stat(path string, edir []byte) (n int, err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
var _p1 unsafe.Pointer
if len(edir) > 0 {
_p1 = unsafe.Pointer(&edir[0])
} else {
_p1 = unsafe.Pointer(&_zero)
}
r0, _, e1 := Syscall(SYS_STAT, uintptr(unsafe.Pointer(_p0)), uintptr(_p1), uintptr(len(edir)))
n = int(r0)
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func bind(name string, old string, flag int) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(name)
if err != nil {
return
}
var _p1 *byte
_p1, err = BytePtrFromString(old)
if err != nil {
return
}
r0, _, e1 := Syscall(SYS_BIND, uintptr(unsafe.Pointer(_p0)), uintptr(unsafe.Pointer(_p1)), uintptr(flag))
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func mount(fd int, afd int, old string, flag int, aname string) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(old)
if err != nil {
return
}
var _p1 *byte
_p1, err = BytePtrFromString(aname)
if err != nil {
return
}
r0, _, e1 := Syscall6(SYS_MOUNT, uintptr(fd), uintptr(afd), uintptr(unsafe.Pointer(_p0)), uintptr(flag), uintptr(unsafe.Pointer(_p1)), 0)
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func wstat(path string, edir []byte) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
var _p1 unsafe.Pointer
if len(edir) > 0 {
_p1 = unsafe.Pointer(&edir[0])
} else {
_p1 = unsafe.Pointer(&_zero)
}
r0, _, e1 := Syscall(SYS_WSTAT, uintptr(unsafe.Pointer(_p0)), uintptr(_p1), uintptr(len(edir)))
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func chdir(path string) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
r0, _, e1 := Syscall(SYS_CHDIR, uintptr(unsafe.Pointer(_p0)), 0, 0)
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Dup(oldfd int, newfd int) (fd int, err error) {
r0, _, e1 := Syscall(SYS_DUP, uintptr(oldfd), uintptr(newfd), 0)
fd = int(r0)
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Pread(fd int, p []byte, offset int64) (n int, err error) {
var _p0 unsafe.Pointer
if len(p) > 0 {
_p0 = unsafe.Pointer(&p[0])
} else {
_p0 = unsafe.Pointer(&_zero)
}
r0, _, e1 := Syscall6(SYS_PREAD, uintptr(fd), uintptr(_p0), uintptr(len(p)), uintptr(offset), uintptr(offset>>32), 0)
n = int(r0)
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Pwrite(fd int, p []byte, offset int64) (n int, err error) {
var _p0 unsafe.Pointer
if len(p) > 0 {
_p0 = unsafe.Pointer(&p[0])
} else {
_p0 = unsafe.Pointer(&_zero)
}
r0, _, e1 := Syscall6(SYS_PWRITE, uintptr(fd), uintptr(_p0), uintptr(len(p)), uintptr(offset), uintptr(offset>>32), 0)
n = int(r0)
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Close(fd int) (err error) {
r0, _, e1 := Syscall(SYS_CLOSE, uintptr(fd), 0, 0)
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Fstat(fd int, edir []byte) (n int, err error) {
var _p0 unsafe.Pointer
if len(edir) > 0 {
_p0 = unsafe.Pointer(&edir[0])
} else {
_p0 = unsafe.Pointer(&_zero)
}
r0, _, e1 := Syscall(SYS_FSTAT, uintptr(fd), uintptr(_p0), uintptr(len(edir)))
n = int(r0)
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Fwstat(fd int, edir []byte) (err error) {
var _p0 unsafe.Pointer
if len(edir) > 0 {
_p0 = unsafe.Pointer(&edir[0])
} else {
_p0 = unsafe.Pointer(&_zero)
}
r0, _, e1 := Syscall(SYS_FWSTAT, uintptr(fd), uintptr(_p0), uintptr(len(edir)))
if int32(r0) == -1 {
err = e1
}
return
}
================================================
FILE: gossh/sys/plan9/zsyscall_plan9_amd64.go
================================================
// go run mksyscall.go -l32 -plan9 -tags plan9,amd64 syscall_plan9.go
// Code generated by the command above; see README.md. DO NOT EDIT.
//go:build plan9 && amd64
package plan9
import "unsafe"
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func fd2path(fd int, buf []byte) (err error) {
var _p0 unsafe.Pointer
if len(buf) > 0 {
_p0 = unsafe.Pointer(&buf[0])
} else {
_p0 = unsafe.Pointer(&_zero)
}
r0, _, e1 := Syscall(SYS_FD2PATH, uintptr(fd), uintptr(_p0), uintptr(len(buf)))
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func pipe(p *[2]int32) (err error) {
r0, _, e1 := Syscall(SYS_PIPE, uintptr(unsafe.Pointer(p)), 0, 0)
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func await(s []byte) (n int, err error) {
var _p0 unsafe.Pointer
if len(s) > 0 {
_p0 = unsafe.Pointer(&s[0])
} else {
_p0 = unsafe.Pointer(&_zero)
}
r0, _, e1 := Syscall(SYS_AWAIT, uintptr(_p0), uintptr(len(s)), 0)
n = int(r0)
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func open(path string, mode int) (fd int, err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
r0, _, e1 := Syscall(SYS_OPEN, uintptr(unsafe.Pointer(_p0)), uintptr(mode), 0)
fd = int(r0)
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func create(path string, mode int, perm uint32) (fd int, err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
r0, _, e1 := Syscall(SYS_CREATE, uintptr(unsafe.Pointer(_p0)), uintptr(mode), uintptr(perm))
fd = int(r0)
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func remove(path string) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
r0, _, e1 := Syscall(SYS_REMOVE, uintptr(unsafe.Pointer(_p0)), 0, 0)
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func stat(path string, edir []byte) (n int, err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
var _p1 unsafe.Pointer
if len(edir) > 0 {
_p1 = unsafe.Pointer(&edir[0])
} else {
_p1 = unsafe.Pointer(&_zero)
}
r0, _, e1 := Syscall(SYS_STAT, uintptr(unsafe.Pointer(_p0)), uintptr(_p1), uintptr(len(edir)))
n = int(r0)
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func bind(name string, old string, flag int) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(name)
if err != nil {
return
}
var _p1 *byte
_p1, err = BytePtrFromString(old)
if err != nil {
return
}
r0, _, e1 := Syscall(SYS_BIND, uintptr(unsafe.Pointer(_p0)), uintptr(unsafe.Pointer(_p1)), uintptr(flag))
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func mount(fd int, afd int, old string, flag int, aname string) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(old)
if err != nil {
return
}
var _p1 *byte
_p1, err = BytePtrFromString(aname)
if err != nil {
return
}
r0, _, e1 := Syscall6(SYS_MOUNT, uintptr(fd), uintptr(afd), uintptr(unsafe.Pointer(_p0)), uintptr(flag), uintptr(unsafe.Pointer(_p1)), 0)
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func wstat(path string, edir []byte) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
var _p1 unsafe.Pointer
if len(edir) > 0 {
_p1 = unsafe.Pointer(&edir[0])
} else {
_p1 = unsafe.Pointer(&_zero)
}
r0, _, e1 := Syscall(SYS_WSTAT, uintptr(unsafe.Pointer(_p0)), uintptr(_p1), uintptr(len(edir)))
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func chdir(path string) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
r0, _, e1 := Syscall(SYS_CHDIR, uintptr(unsafe.Pointer(_p0)), 0, 0)
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Dup(oldfd int, newfd int) (fd int, err error) {
r0, _, e1 := Syscall(SYS_DUP, uintptr(oldfd), uintptr(newfd), 0)
fd = int(r0)
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Pread(fd int, p []byte, offset int64) (n int, err error) {
var _p0 unsafe.Pointer
if len(p) > 0 {
_p0 = unsafe.Pointer(&p[0])
} else {
_p0 = unsafe.Pointer(&_zero)
}
r0, _, e1 := Syscall6(SYS_PREAD, uintptr(fd), uintptr(_p0), uintptr(len(p)), uintptr(offset), uintptr(offset>>32), 0)
n = int(r0)
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Pwrite(fd int, p []byte, offset int64) (n int, err error) {
var _p0 unsafe.Pointer
if len(p) > 0 {
_p0 = unsafe.Pointer(&p[0])
} else {
_p0 = unsafe.Pointer(&_zero)
}
r0, _, e1 := Syscall6(SYS_PWRITE, uintptr(fd), uintptr(_p0), uintptr(len(p)), uintptr(offset), uintptr(offset>>32), 0)
n = int(r0)
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Close(fd int) (err error) {
r0, _, e1 := Syscall(SYS_CLOSE, uintptr(fd), 0, 0)
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Fstat(fd int, edir []byte) (n int, err error) {
var _p0 unsafe.Pointer
if len(edir) > 0 {
_p0 = unsafe.Pointer(&edir[0])
} else {
_p0 = unsafe.Pointer(&_zero)
}
r0, _, e1 := Syscall(SYS_FSTAT, uintptr(fd), uintptr(_p0), uintptr(len(edir)))
n = int(r0)
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Fwstat(fd int, edir []byte) (err error) {
var _p0 unsafe.Pointer
if len(edir) > 0 {
_p0 = unsafe.Pointer(&edir[0])
} else {
_p0 = unsafe.Pointer(&_zero)
}
r0, _, e1 := Syscall(SYS_FWSTAT, uintptr(fd), uintptr(_p0), uintptr(len(edir)))
if int32(r0) == -1 {
err = e1
}
return
}
================================================
FILE: gossh/sys/plan9/zsyscall_plan9_arm.go
================================================
// go run mksyscall.go -l32 -plan9 -tags plan9,arm syscall_plan9.go
// Code generated by the command above; see README.md. DO NOT EDIT.
//go:build plan9 && arm
package plan9
import "unsafe"
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func fd2path(fd int, buf []byte) (err error) {
var _p0 unsafe.Pointer
if len(buf) > 0 {
_p0 = unsafe.Pointer(&buf[0])
} else {
_p0 = unsafe.Pointer(&_zero)
}
r0, _, e1 := Syscall(SYS_FD2PATH, uintptr(fd), uintptr(_p0), uintptr(len(buf)))
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func pipe(p *[2]int32) (err error) {
r0, _, e1 := Syscall(SYS_PIPE, uintptr(unsafe.Pointer(p)), 0, 0)
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func await(s []byte) (n int, err error) {
var _p0 unsafe.Pointer
if len(s) > 0 {
_p0 = unsafe.Pointer(&s[0])
} else {
_p0 = unsafe.Pointer(&_zero)
}
r0, _, e1 := Syscall(SYS_AWAIT, uintptr(_p0), uintptr(len(s)), 0)
n = int(r0)
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func open(path string, mode int) (fd int, err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
r0, _, e1 := Syscall(SYS_OPEN, uintptr(unsafe.Pointer(_p0)), uintptr(mode), 0)
fd = int(r0)
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func create(path string, mode int, perm uint32) (fd int, err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
r0, _, e1 := Syscall(SYS_CREATE, uintptr(unsafe.Pointer(_p0)), uintptr(mode), uintptr(perm))
fd = int(r0)
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func remove(path string) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
r0, _, e1 := Syscall(SYS_REMOVE, uintptr(unsafe.Pointer(_p0)), 0, 0)
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func stat(path string, edir []byte) (n int, err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
var _p1 unsafe.Pointer
if len(edir) > 0 {
_p1 = unsafe.Pointer(&edir[0])
} else {
_p1 = unsafe.Pointer(&_zero)
}
r0, _, e1 := Syscall(SYS_STAT, uintptr(unsafe.Pointer(_p0)), uintptr(_p1), uintptr(len(edir)))
n = int(r0)
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func bind(name string, old string, flag int) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(name)
if err != nil {
return
}
var _p1 *byte
_p1, err = BytePtrFromString(old)
if err != nil {
return
}
r0, _, e1 := Syscall(SYS_BIND, uintptr(unsafe.Pointer(_p0)), uintptr(unsafe.Pointer(_p1)), uintptr(flag))
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func mount(fd int, afd int, old string, flag int, aname string) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(old)
if err != nil {
return
}
var _p1 *byte
_p1, err = BytePtrFromString(aname)
if err != nil {
return
}
r0, _, e1 := Syscall6(SYS_MOUNT, uintptr(fd), uintptr(afd), uintptr(unsafe.Pointer(_p0)), uintptr(flag), uintptr(unsafe.Pointer(_p1)), 0)
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func wstat(path string, edir []byte) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
var _p1 unsafe.Pointer
if len(edir) > 0 {
_p1 = unsafe.Pointer(&edir[0])
} else {
_p1 = unsafe.Pointer(&_zero)
}
r0, _, e1 := Syscall(SYS_WSTAT, uintptr(unsafe.Pointer(_p0)), uintptr(_p1), uintptr(len(edir)))
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func chdir(path string) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
r0, _, e1 := Syscall(SYS_CHDIR, uintptr(unsafe.Pointer(_p0)), 0, 0)
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Dup(oldfd int, newfd int) (fd int, err error) {
r0, _, e1 := Syscall(SYS_DUP, uintptr(oldfd), uintptr(newfd), 0)
fd = int(r0)
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Pread(fd int, p []byte, offset int64) (n int, err error) {
var _p0 unsafe.Pointer
if len(p) > 0 {
_p0 = unsafe.Pointer(&p[0])
} else {
_p0 = unsafe.Pointer(&_zero)
}
r0, _, e1 := Syscall6(SYS_PREAD, uintptr(fd), uintptr(_p0), uintptr(len(p)), uintptr(offset), uintptr(offset>>32), 0)
n = int(r0)
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Pwrite(fd int, p []byte, offset int64) (n int, err error) {
var _p0 unsafe.Pointer
if len(p) > 0 {
_p0 = unsafe.Pointer(&p[0])
} else {
_p0 = unsafe.Pointer(&_zero)
}
r0, _, e1 := Syscall6(SYS_PWRITE, uintptr(fd), uintptr(_p0), uintptr(len(p)), uintptr(offset), uintptr(offset>>32), 0)
n = int(r0)
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Close(fd int) (err error) {
r0, _, e1 := Syscall(SYS_CLOSE, uintptr(fd), 0, 0)
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Fstat(fd int, edir []byte) (n int, err error) {
var _p0 unsafe.Pointer
if len(edir) > 0 {
_p0 = unsafe.Pointer(&edir[0])
} else {
_p0 = unsafe.Pointer(&_zero)
}
r0, _, e1 := Syscall(SYS_FSTAT, uintptr(fd), uintptr(_p0), uintptr(len(edir)))
n = int(r0)
if int32(r0) == -1 {
err = e1
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Fwstat(fd int, edir []byte) (err error) {
var _p0 unsafe.Pointer
if len(edir) > 0 {
_p0 = unsafe.Pointer(&edir[0])
} else {
_p0 = unsafe.Pointer(&_zero)
}
r0, _, e1 := Syscall(SYS_FWSTAT, uintptr(fd), uintptr(_p0), uintptr(len(edir)))
if int32(r0) == -1 {
err = e1
}
return
}
================================================
FILE: gossh/sys/plan9/zsysnum_plan9.go
================================================
// mksysnum_plan9.sh /opt/plan9/sys/src/libc/9syscall/sys.h
// MACHINE GENERATED BY THE ABOVE COMMAND; DO NOT EDIT
package plan9
const (
SYS_SYSR1 = 0
SYS_BIND = 2
SYS_CHDIR = 3
SYS_CLOSE = 4
SYS_DUP = 5
SYS_ALARM = 6
SYS_EXEC = 7
SYS_EXITS = 8
SYS_FAUTH = 10
SYS_SEGBRK = 12
SYS_OPEN = 14
SYS_OSEEK = 16
SYS_SLEEP = 17
SYS_RFORK = 19
SYS_PIPE = 21
SYS_CREATE = 22
SYS_FD2PATH = 23
SYS_BRK_ = 24
SYS_REMOVE = 25
SYS_NOTIFY = 28
SYS_NOTED = 29
SYS_SEGATTACH = 30
SYS_SEGDETACH = 31
SYS_SEGFREE = 32
SYS_SEGFLUSH = 33
SYS_RENDEZVOUS = 34
SYS_UNMOUNT = 35
SYS_SEMACQUIRE = 37
SYS_SEMRELEASE = 38
SYS_SEEK = 39
SYS_FVERSION = 40
SYS_ERRSTR = 41
SYS_STAT = 42
SYS_FSTAT = 43
SYS_WSTAT = 44
SYS_FWSTAT = 45
SYS_MOUNT = 46
SYS_AWAIT = 47
SYS_PREAD = 50
SYS_PWRITE = 51
SYS_TSEMACQUIRE = 52
SYS_NSEC = 53
)
================================================
FILE: gossh/sys/unix/README.md
================================================
# Building `sys/unix`
The sys/unix package provides access to the raw system call interface of the
underlying operating system. See: https://godoc.org/gossh/sys/unix
Porting Go to a new architecture/OS combination or adding syscalls, types, or
constants to an existing architecture/OS pair requires some manual effort;
however, there are tools that automate much of the process.
## Build Systems
There are currently two ways we generate the necessary files. We are currently
migrating the build system to use containers so the builds are reproducible.
This is being done on an OS-by-OS basis. Please update this documentation as
components of the build system change.
### Old Build System (currently for `GOOS != "linux"`)
The old build system generates the Go files based on the C header files
present on your system. This means that files
for a given GOOS/GOARCH pair must be generated on a system with that OS and
architecture. This also means that the generated code can differ from system
to system, based on differences in the header files.
To avoid this, if you are using the old build system, only generate the Go
files on an installation with unmodified header files. It is also important to
keep track of which version of the OS the files were generated from (ex.
Darwin 14 vs Darwin 15). This makes it easier to track the progress of changes
and have each OS upgrade correspond to a single change.
To build the files for your current OS and architecture, make sure GOOS and
GOARCH are set correctly and run `mkall.sh`. This will generate the files for
your specific system. Running `mkall.sh -n` shows the commands that will be run.
Requirements: bash, go
### New Build System (currently for `GOOS == "linux"`)
The new build system uses a Docker container to generate the go files directly
from source checkouts of the kernel and various system libraries. This means
that on any platform that supports Docker, all the files using the new build
system can be generated at once, and generated files will not change based on
what the person running the scripts has installed on their computer.
The OS specific files for the new build system are located in the `${GOOS}`
directory, and the build is coordinated by the `${GOOS}/mkall.go` program. When
the kernel or system library updates, modify the Dockerfile at
`${GOOS}/Dockerfile` to checkout the new release of the source.
To build all the files under the new build system, you must be on an amd64/Linux
system and have your GOOS and GOARCH set accordingly. Running `mkall.sh` will
then generate all of the files for all of the GOOS/GOARCH pairs in the new build
system. Running `mkall.sh -n` shows the commands that will be run.
Requirements: bash, go, docker
## Component files
This section describes the various files used in the code generation process.
It also contains instructions on how to modify these files to add a new
architecture/OS or to add additional syscalls, types, or constants. Note that
if you are using the new build system, the scripts/programs cannot be called normally.
They must be called from within the docker container.
### asm files
The hand-written assembly file at `asm_${GOOS}_${GOARCH}.s` implements system
call dispatch. There are three entry points:
```
func Syscall(trap, a1, a2, a3 uintptr) (r1, r2, err uintptr)
func Syscall6(trap, a1, a2, a3, a4, a5, a6 uintptr) (r1, r2, err uintptr)
func RawSyscall(trap, a1, a2, a3 uintptr) (r1, r2, err uintptr)
```
The first and second are the standard ones; they differ only in how many
arguments can be passed to the kernel. The third is for low-level use by the
ForkExec wrapper. Unlike the first two, it does not call into the scheduler to
let it know that a system call is running.
When porting Go to a new architecture/OS, this file must be implemented for
each GOOS/GOARCH pair.
### mksysnum
Mksysnum is a Go program located at `${GOOS}/mksysnum.go` (or `mksysnum_${GOOS}.go`
for the old system). This program takes in a list of header files containing the
syscall number declarations and parses them to produce the corresponding list of
Go numeric constants. See `zsysnum_${GOOS}_${GOARCH}.go` for the generated
constants.
Adding new syscall numbers is mostly done by running the build on a sufficiently
new installation of the target OS (or updating the source checkouts for the
new build system). However, depending on the OS, you may need to update the
parsing in mksysnum.
### mksyscall.go
The `syscall.go`, `syscall_${GOOS}.go`, `syscall_${GOOS}_${GOARCH}.go` are
hand-written Go files which implement system calls (for unix, the specific OS,
or the specific OS/Architecture pair respectively) that need special handling
and list `//sys` comments giving prototypes for ones that can be generated.
The mksyscall.go program takes the `//sys` and `//sysnb` comments and converts
them into syscalls. This requires the name of the prototype in the comment to
match a syscall number in the `zsysnum_${GOOS}_${GOARCH}.go` file. The function
prototype can be exported (capitalized) or not.
Adding a new syscall often just requires adding a new `//sys` function prototype
with the desired arguments and a capitalized name so it is exported. However, if
you want the interface to the syscall to be different, often one will make an
unexported `//sys` prototype, and then write a custom wrapper in
`syscall_${GOOS}.go`.
### types files
For each OS, there is a hand-written Go file at `${GOOS}/types.go` (or
`types_${GOOS}.go` on the old system). This file includes standard C headers and
creates Go type aliases to the corresponding C types. The file is then fed
through godef to get the Go compatible definitions. Finally, the generated code
is fed though mkpost.go to format the code correctly and remove any hidden or
private identifiers. This cleaned-up code is written to
`ztypes_${GOOS}_${GOARCH}.go`.
The hardest part about preparing this file is figuring out which headers to
include and which symbols need to be `#define`d to get the actual data
structures that pass through to the kernel system calls. Some C libraries
preset alternate versions for binary compatibility and translate them on the
way in and out of system calls, but there is almost always a `#define` that can
get the real ones.
See `types_darwin.go` and `linux/types.go` for examples.
To add a new type, add in the necessary include statement at the top of the
file (if it is not already there) and add in a type alias line. Note that if
your type is significantly different on different architectures, you may need
some `#if/#elif` macros in your include statements.
### mkerrors.sh
This script is used to generate the system's various constants. This doesn't
just include the error numbers and error strings, but also the signal numbers
and a wide variety of miscellaneous constants. The constants come from the list
of include files in the `includes_${uname}` variable. A regex then picks out
the desired `#define` statements, and generates the corresponding Go constants.
The error numbers and strings are generated from `#include `, and the
signal numbers and strings are generated from `#include `. All of
these constants are written to `zerrors_${GOOS}_${GOARCH}.go` via a C program,
`_errors.c`, which prints out all the constants.
To add a constant, add the header that includes it to the appropriate variable.
Then, edit the regex (if necessary) to match the desired constant. Avoid making
the regex too broad to avoid matching unintended constants.
### internal/mkmerge
This program is used to extract duplicate const, func, and type declarations
from the generated architecture-specific files listed below, and merge these
into a common file for each OS.
The merge is performed in the following steps:
1. Construct the set of common code that is idential in all architecture-specific files.
2. Write this common code to the merged file.
3. Remove the common code from all architecture-specific files.
## Generated files
### `zerrors_${GOOS}_${GOARCH}.go`
A file containing all of the system's generated error numbers, error strings,
signal numbers, and constants. Generated by `mkerrors.sh` (see above).
### `zsyscall_${GOOS}_${GOARCH}.go`
A file containing all the generated syscalls for a specific GOOS and GOARCH.
Generated by `mksyscall.go` (see above).
### `zsysnum_${GOOS}_${GOARCH}.go`
A list of numeric constants for all the syscall number of the specific GOOS
and GOARCH. Generated by mksysnum (see above).
### `ztypes_${GOOS}_${GOARCH}.go`
A file containing Go types for passing into (or returning from) syscalls.
Generated by godefs and the types file (see above).
================================================
FILE: gossh/sys/unix/affinity_linux.go
================================================
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// CPU affinity functions
package unix
import (
"math/bits"
"unsafe"
)
const cpuSetSize = _CPU_SETSIZE / _NCPUBITS
// CPUSet represents a CPU affinity mask.
type CPUSet [cpuSetSize]cpuMask
func schedAffinity(trap uintptr, pid int, set *CPUSet) error {
_, _, e := RawSyscall(trap, uintptr(pid), uintptr(unsafe.Sizeof(*set)), uintptr(unsafe.Pointer(set)))
if e != 0 {
return errnoErr(e)
}
return nil
}
// SchedGetaffinity gets the CPU affinity mask of the thread specified by pid.
// If pid is 0 the calling thread is used.
func SchedGetaffinity(pid int, set *CPUSet) error {
return schedAffinity(SYS_SCHED_GETAFFINITY, pid, set)
}
// SchedSetaffinity sets the CPU affinity mask of the thread specified by pid.
// If pid is 0 the calling thread is used.
func SchedSetaffinity(pid int, set *CPUSet) error {
return schedAffinity(SYS_SCHED_SETAFFINITY, pid, set)
}
// Zero clears the set s, so that it contains no CPUs.
func (s *CPUSet) Zero() {
for i := range s {
s[i] = 0
}
}
func cpuBitsIndex(cpu int) int {
return cpu / _NCPUBITS
}
func cpuBitsMask(cpu int) cpuMask {
return cpuMask(1 << (uint(cpu) % _NCPUBITS))
}
// Set adds cpu to the set s.
func (s *CPUSet) Set(cpu int) {
i := cpuBitsIndex(cpu)
if i < len(s) {
s[i] |= cpuBitsMask(cpu)
}
}
// Clear removes cpu from the set s.
func (s *CPUSet) Clear(cpu int) {
i := cpuBitsIndex(cpu)
if i < len(s) {
s[i] &^= cpuBitsMask(cpu)
}
}
// IsSet reports whether cpu is in the set s.
func (s *CPUSet) IsSet(cpu int) bool {
i := cpuBitsIndex(cpu)
if i < len(s) {
return s[i]&cpuBitsMask(cpu) != 0
}
return false
}
// Count returns the number of CPUs in the set s.
func (s *CPUSet) Count() int {
c := 0
for _, b := range s {
c += bits.OnesCount64(uint64(b))
}
return c
}
================================================
FILE: gossh/sys/unix/aliases.go
================================================
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || zos
package unix
import "syscall"
type Signal = syscall.Signal
type Errno = syscall.Errno
type SysProcAttr = syscall.SysProcAttr
================================================
FILE: gossh/sys/unix/asm_aix_ppc64.s
================================================
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gc
#include "textflag.h"
//
// System calls for ppc64, AIX are implemented in runtime/syscall_aix.go
//
TEXT ·syscall6(SB),NOSPLIT,$0-88
JMP syscall·syscall6(SB)
TEXT ·rawSyscall6(SB),NOSPLIT,$0-88
JMP syscall·rawSyscall6(SB)
================================================
FILE: gossh/sys/unix/asm_bsd_386.s
================================================
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build (freebsd || netbsd || openbsd) && gc
#include "textflag.h"
// System call support for 386 BSD
// Just jump to package syscall's implementation for all these functions.
// The runtime may know about them.
TEXT ·Syscall(SB),NOSPLIT,$0-28
JMP syscall·Syscall(SB)
TEXT ·Syscall6(SB),NOSPLIT,$0-40
JMP syscall·Syscall6(SB)
TEXT ·Syscall9(SB),NOSPLIT,$0-52
JMP syscall·Syscall9(SB)
TEXT ·RawSyscall(SB),NOSPLIT,$0-28
JMP syscall·RawSyscall(SB)
TEXT ·RawSyscall6(SB),NOSPLIT,$0-40
JMP syscall·RawSyscall6(SB)
================================================
FILE: gossh/sys/unix/asm_bsd_amd64.s
================================================
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build (darwin || dragonfly || freebsd || netbsd || openbsd) && gc
#include "textflag.h"
// System call support for AMD64 BSD
// Just jump to package syscall's implementation for all these functions.
// The runtime may know about them.
TEXT ·Syscall(SB),NOSPLIT,$0-56
JMP syscall·Syscall(SB)
TEXT ·Syscall6(SB),NOSPLIT,$0-80
JMP syscall·Syscall6(SB)
TEXT ·Syscall9(SB),NOSPLIT,$0-104
JMP syscall·Syscall9(SB)
TEXT ·RawSyscall(SB),NOSPLIT,$0-56
JMP syscall·RawSyscall(SB)
TEXT ·RawSyscall6(SB),NOSPLIT,$0-80
JMP syscall·RawSyscall6(SB)
================================================
FILE: gossh/sys/unix/asm_bsd_arm.s
================================================
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build (freebsd || netbsd || openbsd) && gc
#include "textflag.h"
// System call support for ARM BSD
// Just jump to package syscall's implementation for all these functions.
// The runtime may know about them.
TEXT ·Syscall(SB),NOSPLIT,$0-28
B syscall·Syscall(SB)
TEXT ·Syscall6(SB),NOSPLIT,$0-40
B syscall·Syscall6(SB)
TEXT ·Syscall9(SB),NOSPLIT,$0-52
B syscall·Syscall9(SB)
TEXT ·RawSyscall(SB),NOSPLIT,$0-28
B syscall·RawSyscall(SB)
TEXT ·RawSyscall6(SB),NOSPLIT,$0-40
B syscall·RawSyscall6(SB)
================================================
FILE: gossh/sys/unix/asm_bsd_arm64.s
================================================
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build (darwin || freebsd || netbsd || openbsd) && gc
#include "textflag.h"
// System call support for ARM64 BSD
// Just jump to package syscall's implementation for all these functions.
// The runtime may know about them.
TEXT ·Syscall(SB),NOSPLIT,$0-56
JMP syscall·Syscall(SB)
TEXT ·Syscall6(SB),NOSPLIT,$0-80
JMP syscall·Syscall6(SB)
TEXT ·Syscall9(SB),NOSPLIT,$0-104
JMP syscall·Syscall9(SB)
TEXT ·RawSyscall(SB),NOSPLIT,$0-56
JMP syscall·RawSyscall(SB)
TEXT ·RawSyscall6(SB),NOSPLIT,$0-80
JMP syscall·RawSyscall6(SB)
================================================
FILE: gossh/sys/unix/asm_bsd_ppc64.s
================================================
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build (darwin || freebsd || netbsd || openbsd) && gc
#include "textflag.h"
//
// System call support for ppc64, BSD
//
// Just jump to package syscall's implementation for all these functions.
// The runtime may know about them.
TEXT ·Syscall(SB),NOSPLIT,$0-56
JMP syscall·Syscall(SB)
TEXT ·Syscall6(SB),NOSPLIT,$0-80
JMP syscall·Syscall6(SB)
TEXT ·Syscall9(SB),NOSPLIT,$0-104
JMP syscall·Syscall9(SB)
TEXT ·RawSyscall(SB),NOSPLIT,$0-56
JMP syscall·RawSyscall(SB)
TEXT ·RawSyscall6(SB),NOSPLIT,$0-80
JMP syscall·RawSyscall6(SB)
================================================
FILE: gossh/sys/unix/asm_bsd_riscv64.s
================================================
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build (darwin || freebsd || netbsd || openbsd) && gc
#include "textflag.h"
// System call support for RISCV64 BSD
// Just jump to package syscall's implementation for all these functions.
// The runtime may know about them.
TEXT ·Syscall(SB),NOSPLIT,$0-56
JMP syscall·Syscall(SB)
TEXT ·Syscall6(SB),NOSPLIT,$0-80
JMP syscall·Syscall6(SB)
TEXT ·Syscall9(SB),NOSPLIT,$0-104
JMP syscall·Syscall9(SB)
TEXT ·RawSyscall(SB),NOSPLIT,$0-56
JMP syscall·RawSyscall(SB)
TEXT ·RawSyscall6(SB),NOSPLIT,$0-80
JMP syscall·RawSyscall6(SB)
================================================
FILE: gossh/sys/unix/asm_linux_386.s
================================================
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gc
#include "textflag.h"
//
// System calls for 386, Linux
//
// See ../runtime/sys_linux_386.s for the reason why we always use int 0x80
// instead of the glibc-specific "CALL 0x10(GS)".
#define INVOKE_SYSCALL INT $0x80
// Just jump to package syscall's implementation for all these functions.
// The runtime may know about them.
TEXT ·Syscall(SB),NOSPLIT,$0-28
JMP syscall·Syscall(SB)
TEXT ·Syscall6(SB),NOSPLIT,$0-40
JMP syscall·Syscall6(SB)
TEXT ·SyscallNoError(SB),NOSPLIT,$0-24
CALL runtime·entersyscall(SB)
MOVL trap+0(FP), AX // syscall entry
MOVL a1+4(FP), BX
MOVL a2+8(FP), CX
MOVL a3+12(FP), DX
MOVL $0, SI
MOVL $0, DI
INVOKE_SYSCALL
MOVL AX, r1+16(FP)
MOVL DX, r2+20(FP)
CALL runtime·exitsyscall(SB)
RET
TEXT ·RawSyscall(SB),NOSPLIT,$0-28
JMP syscall·RawSyscall(SB)
TEXT ·RawSyscall6(SB),NOSPLIT,$0-40
JMP syscall·RawSyscall6(SB)
TEXT ·RawSyscallNoError(SB),NOSPLIT,$0-24
MOVL trap+0(FP), AX // syscall entry
MOVL a1+4(FP), BX
MOVL a2+8(FP), CX
MOVL a3+12(FP), DX
MOVL $0, SI
MOVL $0, DI
INVOKE_SYSCALL
MOVL AX, r1+16(FP)
MOVL DX, r2+20(FP)
RET
TEXT ·socketcall(SB),NOSPLIT,$0-36
JMP syscall·socketcall(SB)
TEXT ·rawsocketcall(SB),NOSPLIT,$0-36
JMP syscall·rawsocketcall(SB)
TEXT ·seek(SB),NOSPLIT,$0-28
JMP syscall·seek(SB)
================================================
FILE: gossh/sys/unix/asm_linux_amd64.s
================================================
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gc
#include "textflag.h"
//
// System calls for AMD64, Linux
//
// Just jump to package syscall's implementation for all these functions.
// The runtime may know about them.
TEXT ·Syscall(SB),NOSPLIT,$0-56
JMP syscall·Syscall(SB)
TEXT ·Syscall6(SB),NOSPLIT,$0-80
JMP syscall·Syscall6(SB)
TEXT ·SyscallNoError(SB),NOSPLIT,$0-48
CALL runtime·entersyscall(SB)
MOVQ a1+8(FP), DI
MOVQ a2+16(FP), SI
MOVQ a3+24(FP), DX
MOVQ $0, R10
MOVQ $0, R8
MOVQ $0, R9
MOVQ trap+0(FP), AX // syscall entry
SYSCALL
MOVQ AX, r1+32(FP)
MOVQ DX, r2+40(FP)
CALL runtime·exitsyscall(SB)
RET
TEXT ·RawSyscall(SB),NOSPLIT,$0-56
JMP syscall·RawSyscall(SB)
TEXT ·RawSyscall6(SB),NOSPLIT,$0-80
JMP syscall·RawSyscall6(SB)
TEXT ·RawSyscallNoError(SB),NOSPLIT,$0-48
MOVQ a1+8(FP), DI
MOVQ a2+16(FP), SI
MOVQ a3+24(FP), DX
MOVQ $0, R10
MOVQ $0, R8
MOVQ $0, R9
MOVQ trap+0(FP), AX // syscall entry
SYSCALL
MOVQ AX, r1+32(FP)
MOVQ DX, r2+40(FP)
RET
TEXT ·gettimeofday(SB),NOSPLIT,$0-16
JMP syscall·gettimeofday(SB)
================================================
FILE: gossh/sys/unix/asm_linux_arm.s
================================================
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gc
#include "textflag.h"
//
// System calls for arm, Linux
//
// Just jump to package syscall's implementation for all these functions.
// The runtime may know about them.
TEXT ·Syscall(SB),NOSPLIT,$0-28
B syscall·Syscall(SB)
TEXT ·Syscall6(SB),NOSPLIT,$0-40
B syscall·Syscall6(SB)
TEXT ·SyscallNoError(SB),NOSPLIT,$0-24
BL runtime·entersyscall(SB)
MOVW trap+0(FP), R7
MOVW a1+4(FP), R0
MOVW a2+8(FP), R1
MOVW a3+12(FP), R2
MOVW $0, R3
MOVW $0, R4
MOVW $0, R5
SWI $0
MOVW R0, r1+16(FP)
MOVW $0, R0
MOVW R0, r2+20(FP)
BL runtime·exitsyscall(SB)
RET
TEXT ·RawSyscall(SB),NOSPLIT,$0-28
B syscall·RawSyscall(SB)
TEXT ·RawSyscall6(SB),NOSPLIT,$0-40
B syscall·RawSyscall6(SB)
TEXT ·RawSyscallNoError(SB),NOSPLIT,$0-24
MOVW trap+0(FP), R7 // syscall entry
MOVW a1+4(FP), R0
MOVW a2+8(FP), R1
MOVW a3+12(FP), R2
SWI $0
MOVW R0, r1+16(FP)
MOVW $0, R0
MOVW R0, r2+20(FP)
RET
TEXT ·seek(SB),NOSPLIT,$0-28
B syscall·seek(SB)
================================================
FILE: gossh/sys/unix/asm_linux_arm64.s
================================================
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build linux && arm64 && gc
#include "textflag.h"
// Just jump to package syscall's implementation for all these functions.
// The runtime may know about them.
TEXT ·Syscall(SB),NOSPLIT,$0-56
B syscall·Syscall(SB)
TEXT ·Syscall6(SB),NOSPLIT,$0-80
B syscall·Syscall6(SB)
TEXT ·SyscallNoError(SB),NOSPLIT,$0-48
BL runtime·entersyscall(SB)
MOVD a1+8(FP), R0
MOVD a2+16(FP), R1
MOVD a3+24(FP), R2
MOVD $0, R3
MOVD $0, R4
MOVD $0, R5
MOVD trap+0(FP), R8 // syscall entry
SVC
MOVD R0, r1+32(FP) // r1
MOVD R1, r2+40(FP) // r2
BL runtime·exitsyscall(SB)
RET
TEXT ·RawSyscall(SB),NOSPLIT,$0-56
B syscall·RawSyscall(SB)
TEXT ·RawSyscall6(SB),NOSPLIT,$0-80
B syscall·RawSyscall6(SB)
TEXT ·RawSyscallNoError(SB),NOSPLIT,$0-48
MOVD a1+8(FP), R0
MOVD a2+16(FP), R1
MOVD a3+24(FP), R2
MOVD $0, R3
MOVD $0, R4
MOVD $0, R5
MOVD trap+0(FP), R8 // syscall entry
SVC
MOVD R0, r1+32(FP)
MOVD R1, r2+40(FP)
RET
================================================
FILE: gossh/sys/unix/asm_linux_loong64.s
================================================
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build linux && loong64 && gc
#include "textflag.h"
// Just jump to package syscall's implementation for all these functions.
// The runtime may know about them.
TEXT ·Syscall(SB),NOSPLIT,$0-56
JMP syscall·Syscall(SB)
TEXT ·Syscall6(SB),NOSPLIT,$0-80
JMP syscall·Syscall6(SB)
TEXT ·SyscallNoError(SB),NOSPLIT,$0-48
JAL runtime·entersyscall(SB)
MOVV a1+8(FP), R4
MOVV a2+16(FP), R5
MOVV a3+24(FP), R6
MOVV R0, R7
MOVV R0, R8
MOVV R0, R9
MOVV trap+0(FP), R11 // syscall entry
SYSCALL
MOVV R4, r1+32(FP)
MOVV R0, r2+40(FP) // r2 is not used. Always set to 0
JAL runtime·exitsyscall(SB)
RET
TEXT ·RawSyscall(SB),NOSPLIT,$0-56
JMP syscall·RawSyscall(SB)
TEXT ·RawSyscall6(SB),NOSPLIT,$0-80
JMP syscall·RawSyscall6(SB)
TEXT ·RawSyscallNoError(SB),NOSPLIT,$0-48
MOVV a1+8(FP), R4
MOVV a2+16(FP), R5
MOVV a3+24(FP), R6
MOVV R0, R7
MOVV R0, R8
MOVV R0, R9
MOVV trap+0(FP), R11 // syscall entry
SYSCALL
MOVV R4, r1+32(FP)
MOVV R0, r2+40(FP) // r2 is not used. Always set to 0
RET
================================================
FILE: gossh/sys/unix/asm_linux_mips64x.s
================================================
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build linux && (mips64 || mips64le) && gc
#include "textflag.h"
//
// System calls for mips64, Linux
//
// Just jump to package syscall's implementation for all these functions.
// The runtime may know about them.
TEXT ·Syscall(SB),NOSPLIT,$0-56
JMP syscall·Syscall(SB)
TEXT ·Syscall6(SB),NOSPLIT,$0-80
JMP syscall·Syscall6(SB)
TEXT ·SyscallNoError(SB),NOSPLIT,$0-48
JAL runtime·entersyscall(SB)
MOVV a1+8(FP), R4
MOVV a2+16(FP), R5
MOVV a3+24(FP), R6
MOVV R0, R7
MOVV R0, R8
MOVV R0, R9
MOVV trap+0(FP), R2 // syscall entry
SYSCALL
MOVV R2, r1+32(FP)
MOVV R3, r2+40(FP)
JAL runtime·exitsyscall(SB)
RET
TEXT ·RawSyscall(SB),NOSPLIT,$0-56
JMP syscall·RawSyscall(SB)
TEXT ·RawSyscall6(SB),NOSPLIT,$0-80
JMP syscall·RawSyscall6(SB)
TEXT ·RawSyscallNoError(SB),NOSPLIT,$0-48
MOVV a1+8(FP), R4
MOVV a2+16(FP), R5
MOVV a3+24(FP), R6
MOVV R0, R7
MOVV R0, R8
MOVV R0, R9
MOVV trap+0(FP), R2 // syscall entry
SYSCALL
MOVV R2, r1+32(FP)
MOVV R3, r2+40(FP)
RET
================================================
FILE: gossh/sys/unix/asm_linux_mipsx.s
================================================
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build linux && (mips || mipsle) && gc
#include "textflag.h"
//
// System calls for mips, Linux
//
// Just jump to package syscall's implementation for all these functions.
// The runtime may know about them.
TEXT ·Syscall(SB),NOSPLIT,$0-28
JMP syscall·Syscall(SB)
TEXT ·Syscall6(SB),NOSPLIT,$0-40
JMP syscall·Syscall6(SB)
TEXT ·Syscall9(SB),NOSPLIT,$0-52
JMP syscall·Syscall9(SB)
TEXT ·SyscallNoError(SB),NOSPLIT,$0-24
JAL runtime·entersyscall(SB)
MOVW a1+4(FP), R4
MOVW a2+8(FP), R5
MOVW a3+12(FP), R6
MOVW R0, R7
MOVW trap+0(FP), R2 // syscall entry
SYSCALL
MOVW R2, r1+16(FP) // r1
MOVW R3, r2+20(FP) // r2
JAL runtime·exitsyscall(SB)
RET
TEXT ·RawSyscall(SB),NOSPLIT,$0-28
JMP syscall·RawSyscall(SB)
TEXT ·RawSyscall6(SB),NOSPLIT,$0-40
JMP syscall·RawSyscall6(SB)
TEXT ·RawSyscallNoError(SB),NOSPLIT,$0-24
MOVW a1+4(FP), R4
MOVW a2+8(FP), R5
MOVW a3+12(FP), R6
MOVW trap+0(FP), R2 // syscall entry
SYSCALL
MOVW R2, r1+16(FP)
MOVW R3, r2+20(FP)
RET
================================================
FILE: gossh/sys/unix/asm_linux_ppc64x.s
================================================
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build linux && (ppc64 || ppc64le) && gc
#include "textflag.h"
//
// System calls for ppc64, Linux
//
// Just jump to package syscall's implementation for all these functions.
// The runtime may know about them.
TEXT ·SyscallNoError(SB),NOSPLIT,$0-48
BL runtime·entersyscall(SB)
MOVD a1+8(FP), R3
MOVD a2+16(FP), R4
MOVD a3+24(FP), R5
MOVD R0, R6
MOVD R0, R7
MOVD R0, R8
MOVD trap+0(FP), R9 // syscall entry
SYSCALL R9
MOVD R3, r1+32(FP)
MOVD R4, r2+40(FP)
BL runtime·exitsyscall(SB)
RET
TEXT ·RawSyscallNoError(SB),NOSPLIT,$0-48
MOVD a1+8(FP), R3
MOVD a2+16(FP), R4
MOVD a3+24(FP), R5
MOVD R0, R6
MOVD R0, R7
MOVD R0, R8
MOVD trap+0(FP), R9 // syscall entry
SYSCALL R9
MOVD R3, r1+32(FP)
MOVD R4, r2+40(FP)
RET
================================================
FILE: gossh/sys/unix/asm_linux_riscv64.s
================================================
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build riscv64 && gc
#include "textflag.h"
//
// System calls for linux/riscv64.
//
// Where available, just jump to package syscall's implementation of
// these functions.
TEXT ·Syscall(SB),NOSPLIT,$0-56
JMP syscall·Syscall(SB)
TEXT ·Syscall6(SB),NOSPLIT,$0-80
JMP syscall·Syscall6(SB)
TEXT ·SyscallNoError(SB),NOSPLIT,$0-48
CALL runtime·entersyscall(SB)
MOV a1+8(FP), A0
MOV a2+16(FP), A1
MOV a3+24(FP), A2
MOV trap+0(FP), A7 // syscall entry
ECALL
MOV A0, r1+32(FP) // r1
MOV A1, r2+40(FP) // r2
CALL runtime·exitsyscall(SB)
RET
TEXT ·RawSyscall(SB),NOSPLIT,$0-56
JMP syscall·RawSyscall(SB)
TEXT ·RawSyscall6(SB),NOSPLIT,$0-80
JMP syscall·RawSyscall6(SB)
TEXT ·RawSyscallNoError(SB),NOSPLIT,$0-48
MOV a1+8(FP), A0
MOV a2+16(FP), A1
MOV a3+24(FP), A2
MOV trap+0(FP), A7 // syscall entry
ECALL
MOV A0, r1+32(FP)
MOV A1, r2+40(FP)
RET
================================================
FILE: gossh/sys/unix/asm_linux_s390x.s
================================================
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build linux && s390x && gc
#include "textflag.h"
//
// System calls for s390x, Linux
//
// Just jump to package syscall's implementation for all these functions.
// The runtime may know about them.
TEXT ·Syscall(SB),NOSPLIT,$0-56
BR syscall·Syscall(SB)
TEXT ·Syscall6(SB),NOSPLIT,$0-80
BR syscall·Syscall6(SB)
TEXT ·SyscallNoError(SB),NOSPLIT,$0-48
BL runtime·entersyscall(SB)
MOVD a1+8(FP), R2
MOVD a2+16(FP), R3
MOVD a3+24(FP), R4
MOVD $0, R5
MOVD $0, R6
MOVD $0, R7
MOVD trap+0(FP), R1 // syscall entry
SYSCALL
MOVD R2, r1+32(FP)
MOVD R3, r2+40(FP)
BL runtime·exitsyscall(SB)
RET
TEXT ·RawSyscall(SB),NOSPLIT,$0-56
BR syscall·RawSyscall(SB)
TEXT ·RawSyscall6(SB),NOSPLIT,$0-80
BR syscall·RawSyscall6(SB)
TEXT ·RawSyscallNoError(SB),NOSPLIT,$0-48
MOVD a1+8(FP), R2
MOVD a2+16(FP), R3
MOVD a3+24(FP), R4
MOVD $0, R5
MOVD $0, R6
MOVD $0, R7
MOVD trap+0(FP), R1 // syscall entry
SYSCALL
MOVD R2, r1+32(FP)
MOVD R3, r2+40(FP)
RET
================================================
FILE: gossh/sys/unix/asm_openbsd_mips64.s
================================================
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gc
#include "textflag.h"
//
// System call support for mips64, OpenBSD
//
// Just jump to package syscall's implementation for all these functions.
// The runtime may know about them.
TEXT ·Syscall(SB),NOSPLIT,$0-56
JMP syscall·Syscall(SB)
TEXT ·Syscall6(SB),NOSPLIT,$0-80
JMP syscall·Syscall6(SB)
TEXT ·Syscall9(SB),NOSPLIT,$0-104
JMP syscall·Syscall9(SB)
TEXT ·RawSyscall(SB),NOSPLIT,$0-56
JMP syscall·RawSyscall(SB)
TEXT ·RawSyscall6(SB),NOSPLIT,$0-80
JMP syscall·RawSyscall6(SB)
================================================
FILE: gossh/sys/unix/asm_solaris_amd64.s
================================================
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gc
#include "textflag.h"
//
// System calls for amd64, Solaris are implemented in runtime/syscall_solaris.go
//
TEXT ·sysvicall6(SB),NOSPLIT,$0-88
JMP syscall·sysvicall6(SB)
TEXT ·rawSysvicall6(SB),NOSPLIT,$0-88
JMP syscall·rawSysvicall6(SB)
================================================
FILE: gossh/sys/unix/asm_zos_s390x.s
================================================
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build zos && s390x && gc
#include "textflag.h"
#define PSALAA 1208(R0)
#define GTAB64(x) 80(x)
#define LCA64(x) 88(x)
#define CAA(x) 8(x)
#define EDCHPXV(x) 1016(x) // in the CAA
#define SAVSTACK_ASYNC(x) 336(x) // in the LCA
// SS_*, where x=SAVSTACK_ASYNC
#define SS_LE(x) 0(x)
#define SS_GO(x) 8(x)
#define SS_ERRNO(x) 16(x)
#define SS_ERRNOJR(x) 20(x)
#define LE_CALL BYTE $0x0D; BYTE $0x76; // BL R7, R6
TEXT ·clearErrno(SB),NOSPLIT,$0-0
BL addrerrno<>(SB)
MOVD $0, 0(R3)
RET
// Returns the address of errno in R3.
TEXT addrerrno<>(SB),NOSPLIT|NOFRAME,$0-0
// Get library control area (LCA).
MOVW PSALAA, R8
MOVD LCA64(R8), R8
// Get __errno FuncDesc.
MOVD CAA(R8), R9
MOVD EDCHPXV(R9), R9
ADD $(0x156*16), R9
LMG 0(R9), R5, R6
// Switch to saved LE stack.
MOVD SAVSTACK_ASYNC(R8), R9
MOVD 0(R9), R4
MOVD $0, 0(R9)
// Call __errno function.
LE_CALL
NOPH
// Switch back to Go stack.
XOR R0, R0 // Restore R0 to $0.
MOVD R4, 0(R9) // Save stack pointer.
RET
TEXT ·syscall_syscall(SB),NOSPLIT,$0-56
BL runtime·entersyscall(SB)
MOVD a1+8(FP), R1
MOVD a2+16(FP), R2
MOVD a3+24(FP), R3
// Get library control area (LCA).
MOVW PSALAA, R8
MOVD LCA64(R8), R8
// Get function.
MOVD CAA(R8), R9
MOVD EDCHPXV(R9), R9
MOVD trap+0(FP), R5
SLD $4, R5
ADD R5, R9
LMG 0(R9), R5, R6
// Restore LE stack.
MOVD SAVSTACK_ASYNC(R8), R9
MOVD 0(R9), R4
MOVD $0, 0(R9)
// Call function.
LE_CALL
NOPH
XOR R0, R0 // Restore R0 to $0.
MOVD R4, 0(R9) // Save stack pointer.
MOVD R3, r1+32(FP)
MOVD R0, r2+40(FP)
MOVD R0, err+48(FP)
MOVW R3, R4
CMP R4, $-1
BNE done
BL addrerrno<>(SB)
MOVWZ 0(R3), R3
MOVD R3, err+48(FP)
done:
BL runtime·exitsyscall(SB)
RET
TEXT ·syscall_rawsyscall(SB),NOSPLIT,$0-56
MOVD a1+8(FP), R1
MOVD a2+16(FP), R2
MOVD a3+24(FP), R3
// Get library control area (LCA).
MOVW PSALAA, R8
MOVD LCA64(R8), R8
// Get function.
MOVD CAA(R8), R9
MOVD EDCHPXV(R9), R9
MOVD trap+0(FP), R5
SLD $4, R5
ADD R5, R9
LMG 0(R9), R5, R6
// Restore LE stack.
MOVD SAVSTACK_ASYNC(R8), R9
MOVD 0(R9), R4
MOVD $0, 0(R9)
// Call function.
LE_CALL
NOPH
XOR R0, R0 // Restore R0 to $0.
MOVD R4, 0(R9) // Save stack pointer.
MOVD R3, r1+32(FP)
MOVD R0, r2+40(FP)
MOVD R0, err+48(FP)
MOVW R3, R4
CMP R4, $-1
BNE done
BL addrerrno<>(SB)
MOVWZ 0(R3), R3
MOVD R3, err+48(FP)
done:
RET
TEXT ·syscall_syscall6(SB),NOSPLIT,$0-80
BL runtime·entersyscall(SB)
MOVD a1+8(FP), R1
MOVD a2+16(FP), R2
MOVD a3+24(FP), R3
// Get library control area (LCA).
MOVW PSALAA, R8
MOVD LCA64(R8), R8
// Get function.
MOVD CAA(R8), R9
MOVD EDCHPXV(R9), R9
MOVD trap+0(FP), R5
SLD $4, R5
ADD R5, R9
LMG 0(R9), R5, R6
// Restore LE stack.
MOVD SAVSTACK_ASYNC(R8), R9
MOVD 0(R9), R4
MOVD $0, 0(R9)
// Fill in parameter list.
MOVD a4+32(FP), R12
MOVD R12, (2176+24)(R4)
MOVD a5+40(FP), R12
MOVD R12, (2176+32)(R4)
MOVD a6+48(FP), R12
MOVD R12, (2176+40)(R4)
// Call function.
LE_CALL
NOPH
XOR R0, R0 // Restore R0 to $0.
MOVD R4, 0(R9) // Save stack pointer.
MOVD R3, r1+56(FP)
MOVD R0, r2+64(FP)
MOVD R0, err+72(FP)
MOVW R3, R4
CMP R4, $-1
BNE done
BL addrerrno<>(SB)
MOVWZ 0(R3), R3
MOVD R3, err+72(FP)
done:
BL runtime·exitsyscall(SB)
RET
TEXT ·syscall_rawsyscall6(SB),NOSPLIT,$0-80
MOVD a1+8(FP), R1
MOVD a2+16(FP), R2
MOVD a3+24(FP), R3
// Get library control area (LCA).
MOVW PSALAA, R8
MOVD LCA64(R8), R8
// Get function.
MOVD CAA(R8), R9
MOVD EDCHPXV(R9), R9
MOVD trap+0(FP), R5
SLD $4, R5
ADD R5, R9
LMG 0(R9), R5, R6
// Restore LE stack.
MOVD SAVSTACK_ASYNC(R8), R9
MOVD 0(R9), R4
MOVD $0, 0(R9)
// Fill in parameter list.
MOVD a4+32(FP), R12
MOVD R12, (2176+24)(R4)
MOVD a5+40(FP), R12
MOVD R12, (2176+32)(R4)
MOVD a6+48(FP), R12
MOVD R12, (2176+40)(R4)
// Call function.
LE_CALL
NOPH
XOR R0, R0 // Restore R0 to $0.
MOVD R4, 0(R9) // Save stack pointer.
MOVD R3, r1+56(FP)
MOVD R0, r2+64(FP)
MOVD R0, err+72(FP)
MOVW R3, R4
CMP R4, $-1
BNE done
BL ·rrno<>(SB)
MOVWZ 0(R3), R3
MOVD R3, err+72(FP)
done:
RET
TEXT ·syscall_syscall9(SB),NOSPLIT,$0
BL runtime·entersyscall(SB)
MOVD a1+8(FP), R1
MOVD a2+16(FP), R2
MOVD a3+24(FP), R3
// Get library control area (LCA).
MOVW PSALAA, R8
MOVD LCA64(R8), R8
// Get function.
MOVD CAA(R8), R9
MOVD EDCHPXV(R9), R9
MOVD trap+0(FP), R5
SLD $4, R5
ADD R5, R9
LMG 0(R9), R5, R6
// Restore LE stack.
MOVD SAVSTACK_ASYNC(R8), R9
MOVD 0(R9), R4
MOVD $0, 0(R9)
// Fill in parameter list.
MOVD a4+32(FP), R12
MOVD R12, (2176+24)(R4)
MOVD a5+40(FP), R12
MOVD R12, (2176+32)(R4)
MOVD a6+48(FP), R12
MOVD R12, (2176+40)(R4)
MOVD a7+56(FP), R12
MOVD R12, (2176+48)(R4)
MOVD a8+64(FP), R12
MOVD R12, (2176+56)(R4)
MOVD a9+72(FP), R12
MOVD R12, (2176+64)(R4)
// Call function.
LE_CALL
NOPH
XOR R0, R0 // Restore R0 to $0.
MOVD R4, 0(R9) // Save stack pointer.
MOVD R3, r1+80(FP)
MOVD R0, r2+88(FP)
MOVD R0, err+96(FP)
MOVW R3, R4
CMP R4, $-1
BNE done
BL addrerrno<>(SB)
MOVWZ 0(R3), R3
MOVD R3, err+96(FP)
done:
BL runtime·exitsyscall(SB)
RET
TEXT ·syscall_rawsyscall9(SB),NOSPLIT,$0
MOVD a1+8(FP), R1
MOVD a2+16(FP), R2
MOVD a3+24(FP), R3
// Get library control area (LCA).
MOVW PSALAA, R8
MOVD LCA64(R8), R8
// Get function.
MOVD CAA(R8), R9
MOVD EDCHPXV(R9), R9
MOVD trap+0(FP), R5
SLD $4, R5
ADD R5, R9
LMG 0(R9), R5, R6
// Restore LE stack.
MOVD SAVSTACK_ASYNC(R8), R9
MOVD 0(R9), R4
MOVD $0, 0(R9)
// Fill in parameter list.
MOVD a4+32(FP), R12
MOVD R12, (2176+24)(R4)
MOVD a5+40(FP), R12
MOVD R12, (2176+32)(R4)
MOVD a6+48(FP), R12
MOVD R12, (2176+40)(R4)
MOVD a7+56(FP), R12
MOVD R12, (2176+48)(R4)
MOVD a8+64(FP), R12
MOVD R12, (2176+56)(R4)
MOVD a9+72(FP), R12
MOVD R12, (2176+64)(R4)
// Call function.
LE_CALL
NOPH
XOR R0, R0 // Restore R0 to $0.
MOVD R4, 0(R9) // Save stack pointer.
MOVD R3, r1+80(FP)
MOVD R0, r2+88(FP)
MOVD R0, err+96(FP)
MOVW R3, R4
CMP R4, $-1
BNE done
BL addrerrno<>(SB)
MOVWZ 0(R3), R3
MOVD R3, err+96(FP)
done:
RET
// func svcCall(fnptr unsafe.Pointer, argv *unsafe.Pointer, dsa *uint64)
TEXT ·svcCall(SB),NOSPLIT,$0
BL runtime·save_g(SB) // Save g and stack pointer
MOVW PSALAA, R8
MOVD LCA64(R8), R8
MOVD SAVSTACK_ASYNC(R8), R9
MOVD R15, 0(R9)
MOVD argv+8(FP), R1 // Move function arguments into registers
MOVD dsa+16(FP), g
MOVD fnptr+0(FP), R15
BYTE $0x0D // Branch to function
BYTE $0xEF
BL runtime·load_g(SB) // Restore g and stack pointer
MOVW PSALAA, R8
MOVD LCA64(R8), R8
MOVD SAVSTACK_ASYNC(R8), R9
MOVD 0(R9), R15
RET
// func svcLoad(name *byte) unsafe.Pointer
TEXT ·svcLoad(SB),NOSPLIT,$0
MOVD R15, R2 // Save go stack pointer
MOVD name+0(FP), R0 // Move SVC args into registers
MOVD $0x80000000, R1
MOVD $0, R15
BYTE $0x0A // SVC 08 LOAD
BYTE $0x08
MOVW R15, R3 // Save return code from SVC
MOVD R2, R15 // Restore go stack pointer
CMP R3, $0 // Check SVC return code
BNE error
MOVD $-2, R3 // Reset last bit of entry point to zero
AND R0, R3
MOVD R3, addr+8(FP) // Return entry point returned by SVC
CMP R0, R3 // Check if last bit of entry point was set
BNE done
MOVD R15, R2 // Save go stack pointer
MOVD $0, R15 // Move SVC args into registers (entry point still in r0 from SVC 08)
BYTE $0x0A // SVC 09 DELETE
BYTE $0x09
MOVD R2, R15 // Restore go stack pointer
error:
MOVD $0, addr+8(FP) // Return 0 on failure
done:
XOR R0, R0 // Reset r0 to 0
RET
// func svcUnload(name *byte, fnptr unsafe.Pointer) int64
TEXT ·svcUnload(SB),NOSPLIT,$0
MOVD R15, R2 // Save go stack pointer
MOVD name+0(FP), R0 // Move SVC args into registers
MOVD addr+8(FP), R15
BYTE $0x0A // SVC 09
BYTE $0x09
XOR R0, R0 // Reset r0 to 0
MOVD R15, R1 // Save SVC return code
MOVD R2, R15 // Restore go stack pointer
MOVD R1, rc+0(FP) // Return SVC return code
RET
// func gettid() uint64
TEXT ·gettid(SB), NOSPLIT, $0
// Get library control area (LCA).
MOVW PSALAA, R8
MOVD LCA64(R8), R8
// Get CEECAATHDID
MOVD CAA(R8), R9
MOVD 0x3D0(R9), R9
MOVD R9, ret+0(FP)
RET
================================================
FILE: gossh/sys/unix/bluetooth_linux.go
================================================
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Bluetooth sockets and messages
package unix
// Bluetooth Protocols
const (
BTPROTO_L2CAP = 0
BTPROTO_HCI = 1
BTPROTO_SCO = 2
BTPROTO_RFCOMM = 3
BTPROTO_BNEP = 4
BTPROTO_CMTP = 5
BTPROTO_HIDP = 6
BTPROTO_AVDTP = 7
)
const (
HCI_CHANNEL_RAW = 0
HCI_CHANNEL_USER = 1
HCI_CHANNEL_MONITOR = 2
HCI_CHANNEL_CONTROL = 3
HCI_CHANNEL_LOGGING = 4
)
// Socketoption Level
const (
SOL_BLUETOOTH = 0x112
SOL_HCI = 0x0
SOL_L2CAP = 0x6
SOL_RFCOMM = 0x12
SOL_SCO = 0x11
)
================================================
FILE: gossh/sys/unix/cap_freebsd.go
================================================
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build freebsd
package unix
import (
"errors"
"fmt"
)
// Go implementation of C mostly found in /usr/src/sys/kern/subr_capability.c
const (
// This is the version of CapRights this package understands. See C implementation for parallels.
capRightsGoVersion = CAP_RIGHTS_VERSION_00
capArSizeMin = CAP_RIGHTS_VERSION_00 + 2
capArSizeMax = capRightsGoVersion + 2
)
var (
bit2idx = []int{
-1, 0, 1, -1, 2, -1, -1, -1, 3, -1, -1, -1, -1, -1, -1, -1,
4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
}
)
func capidxbit(right uint64) int {
return int((right >> 57) & 0x1f)
}
func rightToIndex(right uint64) (int, error) {
idx := capidxbit(right)
if idx < 0 || idx >= len(bit2idx) {
return -2, fmt.Errorf("index for right 0x%x out of range", right)
}
return bit2idx[idx], nil
}
func caprver(right uint64) int {
return int(right >> 62)
}
func capver(rights *CapRights) int {
return caprver(rights.Rights[0])
}
func caparsize(rights *CapRights) int {
return capver(rights) + 2
}
// CapRightsSet sets the permissions in setrights in rights.
func CapRightsSet(rights *CapRights, setrights []uint64) error {
// This is essentially a copy of cap_rights_vset()
if capver(rights) != CAP_RIGHTS_VERSION_00 {
return fmt.Errorf("bad rights version %d", capver(rights))
}
n := caparsize(rights)
if n < capArSizeMin || n > capArSizeMax {
return errors.New("bad rights size")
}
for _, right := range setrights {
if caprver(right) != CAP_RIGHTS_VERSION_00 {
return errors.New("bad right version")
}
i, err := rightToIndex(right)
if err != nil {
return err
}
if i >= n {
return errors.New("index overflow")
}
if capidxbit(rights.Rights[i]) != capidxbit(right) {
return errors.New("index mismatch")
}
rights.Rights[i] |= right
if capidxbit(rights.Rights[i]) != capidxbit(right) {
return errors.New("index mismatch (after assign)")
}
}
return nil
}
// CapRightsClear clears the permissions in clearrights from rights.
func CapRightsClear(rights *CapRights, clearrights []uint64) error {
// This is essentially a copy of cap_rights_vclear()
if capver(rights) != CAP_RIGHTS_VERSION_00 {
return fmt.Errorf("bad rights version %d", capver(rights))
}
n := caparsize(rights)
if n < capArSizeMin || n > capArSizeMax {
return errors.New("bad rights size")
}
for _, right := range clearrights {
if caprver(right) != CAP_RIGHTS_VERSION_00 {
return errors.New("bad right version")
}
i, err := rightToIndex(right)
if err != nil {
return err
}
if i >= n {
return errors.New("index overflow")
}
if capidxbit(rights.Rights[i]) != capidxbit(right) {
return errors.New("index mismatch")
}
rights.Rights[i] &= ^(right & 0x01FFFFFFFFFFFFFF)
if capidxbit(rights.Rights[i]) != capidxbit(right) {
return errors.New("index mismatch (after assign)")
}
}
return nil
}
// CapRightsIsSet checks whether all the permissions in setrights are present in rights.
func CapRightsIsSet(rights *CapRights, setrights []uint64) (bool, error) {
// This is essentially a copy of cap_rights_is_vset()
if capver(rights) != CAP_RIGHTS_VERSION_00 {
return false, fmt.Errorf("bad rights version %d", capver(rights))
}
n := caparsize(rights)
if n < capArSizeMin || n > capArSizeMax {
return false, errors.New("bad rights size")
}
for _, right := range setrights {
if caprver(right) != CAP_RIGHTS_VERSION_00 {
return false, errors.New("bad right version")
}
i, err := rightToIndex(right)
if err != nil {
return false, err
}
if i >= n {
return false, errors.New("index overflow")
}
if capidxbit(rights.Rights[i]) != capidxbit(right) {
return false, errors.New("index mismatch")
}
if (rights.Rights[i] & right) != right {
return false, nil
}
}
return true, nil
}
func capright(idx uint64, bit uint64) uint64 {
return ((1 << (57 + idx)) | bit)
}
// CapRightsInit returns a pointer to an initialised CapRights structure filled with rights.
// See man cap_rights_init(3) and rights(4).
func CapRightsInit(rights []uint64) (*CapRights, error) {
var r CapRights
r.Rights[0] = (capRightsGoVersion << 62) | capright(0, 0)
r.Rights[1] = capright(1, 0)
err := CapRightsSet(&r, rights)
if err != nil {
return nil, err
}
return &r, nil
}
// CapRightsLimit reduces the operations permitted on fd to at most those contained in rights.
// The capability rights on fd can never be increased by CapRightsLimit.
// See man cap_rights_limit(2) and rights(4).
func CapRightsLimit(fd uintptr, rights *CapRights) error {
return capRightsLimit(int(fd), rights)
}
// CapRightsGet returns a CapRights structure containing the operations permitted on fd.
// See man cap_rights_get(3) and rights(4).
func CapRightsGet(fd uintptr) (*CapRights, error) {
r, err := CapRightsInit(nil)
if err != nil {
return nil, err
}
err = capRightsGet(capRightsGoVersion, int(fd), r)
if err != nil {
return nil, err
}
return r, nil
}
================================================
FILE: gossh/sys/unix/constants.go
================================================
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || zos
package unix
const (
R_OK = 0x4
W_OK = 0x2
X_OK = 0x1
)
================================================
FILE: gossh/sys/unix/dev_aix_ppc.go
================================================
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build aix && ppc
// Functions to access/create device major and minor numbers matching the
// encoding used by AIX.
package unix
// Major returns the major component of a Linux device number.
func Major(dev uint64) uint32 {
return uint32((dev >> 16) & 0xffff)
}
// Minor returns the minor component of a Linux device number.
func Minor(dev uint64) uint32 {
return uint32(dev & 0xffff)
}
// Mkdev returns a Linux device number generated from the given major and minor
// components.
func Mkdev(major, minor uint32) uint64 {
return uint64(((major) << 16) | (minor))
}
================================================
FILE: gossh/sys/unix/dev_aix_ppc64.go
================================================
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build aix && ppc64
// Functions to access/create device major and minor numbers matching the
// encoding used AIX.
package unix
// Major returns the major component of a Linux device number.
func Major(dev uint64) uint32 {
return uint32((dev & 0x3fffffff00000000) >> 32)
}
// Minor returns the minor component of a Linux device number.
func Minor(dev uint64) uint32 {
return uint32((dev & 0x00000000ffffffff) >> 0)
}
// Mkdev returns a Linux device number generated from the given major and minor
// components.
func Mkdev(major, minor uint32) uint64 {
var DEVNO64 uint64
DEVNO64 = 0x8000000000000000
return ((uint64(major) << 32) | (uint64(minor) & 0x00000000FFFFFFFF) | DEVNO64)
}
================================================
FILE: gossh/sys/unix/dev_darwin.go
================================================
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Functions to access/create device major and minor numbers matching the
// encoding used in Darwin's sys/types.h header.
package unix
// Major returns the major component of a Darwin device number.
func Major(dev uint64) uint32 {
return uint32((dev >> 24) & 0xff)
}
// Minor returns the minor component of a Darwin device number.
func Minor(dev uint64) uint32 {
return uint32(dev & 0xffffff)
}
// Mkdev returns a Darwin device number generated from the given major and minor
// components.
func Mkdev(major, minor uint32) uint64 {
return (uint64(major) << 24) | uint64(minor)
}
================================================
FILE: gossh/sys/unix/dev_dragonfly.go
================================================
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Functions to access/create device major and minor numbers matching the
// encoding used in Dragonfly's sys/types.h header.
//
// The information below is extracted and adapted from sys/types.h:
//
// Minor gives a cookie instead of an index since in order to avoid changing the
// meanings of bits 0-15 or wasting time and space shifting bits 16-31 for
// devices that don't use them.
package unix
// Major returns the major component of a DragonFlyBSD device number.
func Major(dev uint64) uint32 {
return uint32((dev >> 8) & 0xff)
}
// Minor returns the minor component of a DragonFlyBSD device number.
func Minor(dev uint64) uint32 {
return uint32(dev & 0xffff00ff)
}
// Mkdev returns a DragonFlyBSD device number generated from the given major and
// minor components.
func Mkdev(major, minor uint32) uint64 {
return (uint64(major) << 8) | uint64(minor)
}
================================================
FILE: gossh/sys/unix/dev_freebsd.go
================================================
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Functions to access/create device major and minor numbers matching the
// encoding used in FreeBSD's sys/types.h header.
//
// The information below is extracted and adapted from sys/types.h:
//
// Minor gives a cookie instead of an index since in order to avoid changing the
// meanings of bits 0-15 or wasting time and space shifting bits 16-31 for
// devices that don't use them.
package unix
// Major returns the major component of a FreeBSD device number.
func Major(dev uint64) uint32 {
return uint32((dev >> 8) & 0xff)
}
// Minor returns the minor component of a FreeBSD device number.
func Minor(dev uint64) uint32 {
return uint32(dev & 0xffff00ff)
}
// Mkdev returns a FreeBSD device number generated from the given major and
// minor components.
func Mkdev(major, minor uint32) uint64 {
return (uint64(major) << 8) | uint64(minor)
}
================================================
FILE: gossh/sys/unix/dev_linux.go
================================================
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Functions to access/create device major and minor numbers matching the
// encoding used by the Linux kernel and glibc.
//
// The information below is extracted and adapted from bits/sysmacros.h in the
// glibc sources:
//
// dev_t in glibc is 64-bit, with 32-bit major and minor numbers. glibc's
// default encoding is MMMM Mmmm mmmM MMmm, where M is a hex digit of the major
// number and m is a hex digit of the minor number. This is backward compatible
// with legacy systems where dev_t is 16 bits wide, encoded as MMmm. It is also
// backward compatible with the Linux kernel, which for some architectures uses
// 32-bit dev_t, encoded as mmmM MMmm.
package unix
// Major returns the major component of a Linux device number.
func Major(dev uint64) uint32 {
major := uint32((dev & 0x00000000000fff00) >> 8)
major |= uint32((dev & 0xfffff00000000000) >> 32)
return major
}
// Minor returns the minor component of a Linux device number.
func Minor(dev uint64) uint32 {
minor := uint32((dev & 0x00000000000000ff) >> 0)
minor |= uint32((dev & 0x00000ffffff00000) >> 12)
return minor
}
// Mkdev returns a Linux device number generated from the given major and minor
// components.
func Mkdev(major, minor uint32) uint64 {
dev := (uint64(major) & 0x00000fff) << 8
dev |= (uint64(major) & 0xfffff000) << 32
dev |= (uint64(minor) & 0x000000ff) << 0
dev |= (uint64(minor) & 0xffffff00) << 12
return dev
}
================================================
FILE: gossh/sys/unix/dev_netbsd.go
================================================
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Functions to access/create device major and minor numbers matching the
// encoding used in NetBSD's sys/types.h header.
package unix
// Major returns the major component of a NetBSD device number.
func Major(dev uint64) uint32 {
return uint32((dev & 0x000fff00) >> 8)
}
// Minor returns the minor component of a NetBSD device number.
func Minor(dev uint64) uint32 {
minor := uint32((dev & 0x000000ff) >> 0)
minor |= uint32((dev & 0xfff00000) >> 12)
return minor
}
// Mkdev returns a NetBSD device number generated from the given major and minor
// components.
func Mkdev(major, minor uint32) uint64 {
dev := (uint64(major) << 8) & 0x000fff00
dev |= (uint64(minor) << 12) & 0xfff00000
dev |= (uint64(minor) << 0) & 0x000000ff
return dev
}
================================================
FILE: gossh/sys/unix/dev_openbsd.go
================================================
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Functions to access/create device major and minor numbers matching the
// encoding used in OpenBSD's sys/types.h header.
package unix
// Major returns the major component of an OpenBSD device number.
func Major(dev uint64) uint32 {
return uint32((dev & 0x0000ff00) >> 8)
}
// Minor returns the minor component of an OpenBSD device number.
func Minor(dev uint64) uint32 {
minor := uint32((dev & 0x000000ff) >> 0)
minor |= uint32((dev & 0xffff0000) >> 8)
return minor
}
// Mkdev returns an OpenBSD device number generated from the given major and minor
// components.
func Mkdev(major, minor uint32) uint64 {
dev := (uint64(major) << 8) & 0x0000ff00
dev |= (uint64(minor) << 8) & 0xffff0000
dev |= (uint64(minor) << 0) & 0x000000ff
return dev
}
================================================
FILE: gossh/sys/unix/dev_zos.go
================================================
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build zos && s390x
// Functions to access/create device major and minor numbers matching the
// encoding used by z/OS.
//
// The information below is extracted and adapted from macros.
package unix
// Major returns the major component of a z/OS device number.
func Major(dev uint64) uint32 {
return uint32((dev >> 16) & 0x0000FFFF)
}
// Minor returns the minor component of a z/OS device number.
func Minor(dev uint64) uint32 {
return uint32(dev & 0x0000FFFF)
}
// Mkdev returns a z/OS device number generated from the given major and minor
// components.
func Mkdev(major, minor uint32) uint64 {
return (uint64(major) << 16) | uint64(minor)
}
================================================
FILE: gossh/sys/unix/dirent.go
================================================
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || zos
package unix
import "unsafe"
// readInt returns the size-bytes unsigned integer in native byte order at offset off.
func readInt(b []byte, off, size uintptr) (u uint64, ok bool) {
if len(b) < int(off+size) {
return 0, false
}
if isBigEndian {
return readIntBE(b[off:], size), true
}
return readIntLE(b[off:], size), true
}
func readIntBE(b []byte, size uintptr) uint64 {
switch size {
case 1:
return uint64(b[0])
case 2:
_ = b[1] // bounds check hint to compiler; see golang.org/issue/14808
return uint64(b[1]) | uint64(b[0])<<8
case 4:
_ = b[3] // bounds check hint to compiler; see golang.org/issue/14808
return uint64(b[3]) | uint64(b[2])<<8 | uint64(b[1])<<16 | uint64(b[0])<<24
case 8:
_ = b[7] // bounds check hint to compiler; see golang.org/issue/14808
return uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 |
uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56
default:
panic("syscall: readInt with unsupported size")
}
}
func readIntLE(b []byte, size uintptr) uint64 {
switch size {
case 1:
return uint64(b[0])
case 2:
_ = b[1] // bounds check hint to compiler; see golang.org/issue/14808
return uint64(b[0]) | uint64(b[1])<<8
case 4:
_ = b[3] // bounds check hint to compiler; see golang.org/issue/14808
return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24
case 8:
_ = b[7] // bounds check hint to compiler; see golang.org/issue/14808
return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 |
uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56
default:
panic("syscall: readInt with unsupported size")
}
}
// ParseDirent parses up to max directory entries in buf,
// appending the names to names. It returns the number of
// bytes consumed from buf, the number of entries added
// to names, and the new names slice.
func ParseDirent(buf []byte, max int, names []string) (consumed int, count int, newnames []string) {
origlen := len(buf)
count = 0
for max != 0 && len(buf) > 0 {
reclen, ok := direntReclen(buf)
if !ok || reclen > uint64(len(buf)) {
return origlen, count, names
}
rec := buf[:reclen]
buf = buf[reclen:]
ino, ok := direntIno(rec)
if !ok {
break
}
if ino == 0 { // File absent in directory.
continue
}
const namoff = uint64(unsafe.Offsetof(Dirent{}.Name))
namlen, ok := direntNamlen(rec)
if !ok || namoff+namlen > uint64(len(rec)) {
break
}
name := rec[namoff : namoff+namlen]
for i, c := range name {
if c == 0 {
name = name[:i]
break
}
}
// Check for useless names before allocating a string.
if string(name) == "." || string(name) == ".." {
continue
}
max--
count++
names = append(names, string(name))
}
return origlen - len(buf), count, names
}
================================================
FILE: gossh/sys/unix/endian_big.go
================================================
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//
//go:build armbe || arm64be || m68k || mips || mips64 || mips64p32 || ppc || ppc64 || s390 || s390x || shbe || sparc || sparc64
package unix
const isBigEndian = true
================================================
FILE: gossh/sys/unix/endian_little.go
================================================
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//
//go:build 386 || amd64 || amd64p32 || alpha || arm || arm64 || loong64 || mipsle || mips64le || mips64p32le || nios2 || ppc64le || riscv || riscv64 || sh
package unix
const isBigEndian = false
================================================
FILE: gossh/sys/unix/env_unix.go
================================================
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || zos
// Unix environment variables.
package unix
import "syscall"
func Getenv(key string) (value string, found bool) {
return syscall.Getenv(key)
}
func Setenv(key, value string) error {
return syscall.Setenv(key, value)
}
func Clearenv() {
syscall.Clearenv()
}
func Environ() []string {
return syscall.Environ()
}
func Unsetenv(key string) error {
return syscall.Unsetenv(key)
}
================================================
FILE: gossh/sys/unix/epoll_zos.go
================================================
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build zos && s390x
package unix
import (
"sync"
)
// This file simulates epoll on z/OS using poll.
// Analogous to epoll_event on Linux.
// TODO(neeilan): Pad is because the Linux kernel expects a 96-bit struct. We never pass this to the kernel; remove?
type EpollEvent struct {
Events uint32
Fd int32
Pad int32
}
const (
EPOLLERR = 0x8
EPOLLHUP = 0x10
EPOLLIN = 0x1
EPOLLMSG = 0x400
EPOLLOUT = 0x4
EPOLLPRI = 0x2
EPOLLRDBAND = 0x80
EPOLLRDNORM = 0x40
EPOLLWRBAND = 0x200
EPOLLWRNORM = 0x100
EPOLL_CTL_ADD = 0x1
EPOLL_CTL_DEL = 0x2
EPOLL_CTL_MOD = 0x3
// The following constants are part of the epoll API, but represent
// currently unsupported functionality on z/OS.
// EPOLL_CLOEXEC = 0x80000
// EPOLLET = 0x80000000
// EPOLLONESHOT = 0x40000000
// EPOLLRDHUP = 0x2000 // Typically used with edge-triggered notis
// EPOLLEXCLUSIVE = 0x10000000 // Exclusive wake-up mode
// EPOLLWAKEUP = 0x20000000 // Relies on Linux's BLOCK_SUSPEND capability
)
// TODO(neeilan): We can eliminate these epToPoll / pToEpoll calls by using identical mask values for POLL/EPOLL
// constants where possible The lower 16 bits of epoll events (uint32) can fit any system poll event (int16).
// epToPollEvt converts epoll event field to poll equivalent.
// In epoll, Events is a 32-bit field, while poll uses 16 bits.
func epToPollEvt(events uint32) int16 {
var ep2p = map[uint32]int16{
EPOLLIN: POLLIN,
EPOLLOUT: POLLOUT,
EPOLLHUP: POLLHUP,
EPOLLPRI: POLLPRI,
EPOLLERR: POLLERR,
}
var pollEvts int16 = 0
for epEvt, pEvt := range ep2p {
if (events & epEvt) != 0 {
pollEvts |= pEvt
}
}
return pollEvts
}
// pToEpollEvt converts 16 bit poll event bitfields to 32-bit epoll event fields.
func pToEpollEvt(revents int16) uint32 {
var p2ep = map[int16]uint32{
POLLIN: EPOLLIN,
POLLOUT: EPOLLOUT,
POLLHUP: EPOLLHUP,
POLLPRI: EPOLLPRI,
POLLERR: EPOLLERR,
}
var epollEvts uint32 = 0
for pEvt, epEvt := range p2ep {
if (revents & pEvt) != 0 {
epollEvts |= epEvt
}
}
return epollEvts
}
// Per-process epoll implementation.
type epollImpl struct {
mu sync.Mutex
epfd2ep map[int]*eventPoll
nextEpfd int
}
// eventPoll holds a set of file descriptors being watched by the process. A process can have multiple epoll instances.
// On Linux, this is an in-kernel data structure accessed through a fd.
type eventPoll struct {
mu sync.Mutex
fds map[int]*EpollEvent
}
// epoll impl for this process.
var impl epollImpl = epollImpl{
epfd2ep: make(map[int]*eventPoll),
nextEpfd: 0,
}
func (e *epollImpl) epollcreate(size int) (epfd int, err error) {
e.mu.Lock()
defer e.mu.Unlock()
epfd = e.nextEpfd
e.nextEpfd++
e.epfd2ep[epfd] = &eventPoll{
fds: make(map[int]*EpollEvent),
}
return epfd, nil
}
func (e *epollImpl) epollcreate1(flag int) (fd int, err error) {
return e.epollcreate(4)
}
func (e *epollImpl) epollctl(epfd int, op int, fd int, event *EpollEvent) (err error) {
e.mu.Lock()
defer e.mu.Unlock()
ep, ok := e.epfd2ep[epfd]
if !ok {
return EBADF
}
switch op {
case EPOLL_CTL_ADD:
// TODO(neeilan): When we make epfds and fds disjoint, detect epoll
// loops here (instances watching each other) and return ELOOP.
if _, ok := ep.fds[fd]; ok {
return EEXIST
}
ep.fds[fd] = event
case EPOLL_CTL_MOD:
if _, ok := ep.fds[fd]; !ok {
return ENOENT
}
ep.fds[fd] = event
case EPOLL_CTL_DEL:
if _, ok := ep.fds[fd]; !ok {
return ENOENT
}
delete(ep.fds, fd)
}
return nil
}
// Must be called while holding ep.mu
func (ep *eventPoll) getFds() []int {
fds := make([]int, len(ep.fds))
for fd := range ep.fds {
fds = append(fds, fd)
}
return fds
}
func (e *epollImpl) epollwait(epfd int, events []EpollEvent, msec int) (n int, err error) {
e.mu.Lock() // in [rare] case of concurrent epollcreate + epollwait
ep, ok := e.epfd2ep[epfd]
if !ok {
e.mu.Unlock()
return 0, EBADF
}
pollfds := make([]PollFd, 4)
for fd, epollevt := range ep.fds {
pollfds = append(pollfds, PollFd{Fd: int32(fd), Events: epToPollEvt(epollevt.Events)})
}
e.mu.Unlock()
n, err = Poll(pollfds, msec)
if err != nil {
return n, err
}
i := 0
for _, pFd := range pollfds {
if pFd.Revents != 0 {
events[i] = EpollEvent{Fd: pFd.Fd, Events: pToEpollEvt(pFd.Revents)}
i++
}
if i == n {
break
}
}
return n, nil
}
func EpollCreate(size int) (fd int, err error) {
return impl.epollcreate(size)
}
func EpollCreate1(flag int) (fd int, err error) {
return impl.epollcreate1(flag)
}
func EpollCtl(epfd int, op int, fd int, event *EpollEvent) (err error) {
return impl.epollctl(epfd, op, fd, event)
}
// Because EpollWait mutates events, the caller is expected to coordinate
// concurrent access if calling with the same epfd from multiple goroutines.
func EpollWait(epfd int, events []EpollEvent, msec int) (n int, err error) {
return impl.epollwait(epfd, events, msec)
}
================================================
FILE: gossh/sys/unix/fcntl.go
================================================
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build dragonfly || freebsd || linux || netbsd
package unix
import "unsafe"
// fcntl64Syscall is usually SYS_FCNTL, but is overridden on 32-bit Linux
// systems by fcntl_linux_32bit.go to be SYS_FCNTL64.
var fcntl64Syscall uintptr = SYS_FCNTL
func fcntl(fd int, cmd, arg int) (int, error) {
valptr, _, errno := Syscall(fcntl64Syscall, uintptr(fd), uintptr(cmd), uintptr(arg))
var err error
if errno != 0 {
err = errno
}
return int(valptr), err
}
// FcntlInt performs a fcntl syscall on fd with the provided command and argument.
func FcntlInt(fd uintptr, cmd, arg int) (int, error) {
return fcntl(int(fd), cmd, arg)
}
// FcntlFlock performs a fcntl syscall for the F_GETLK, F_SETLK or F_SETLKW command.
func FcntlFlock(fd uintptr, cmd int, lk *Flock_t) error {
_, _, errno := Syscall(fcntl64Syscall, fd, uintptr(cmd), uintptr(unsafe.Pointer(lk)))
if errno == 0 {
return nil
}
return errno
}
================================================
FILE: gossh/sys/unix/fcntl_darwin.go
================================================
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package unix
import "unsafe"
// FcntlInt performs a fcntl syscall on fd with the provided command and argument.
func FcntlInt(fd uintptr, cmd, arg int) (int, error) {
return fcntl(int(fd), cmd, arg)
}
// FcntlFlock performs a fcntl syscall for the F_GETLK, F_SETLK or F_SETLKW command.
func FcntlFlock(fd uintptr, cmd int, lk *Flock_t) error {
_, err := fcntl(int(fd), cmd, int(uintptr(unsafe.Pointer(lk))))
return err
}
// FcntlFstore performs a fcntl syscall for the F_PREALLOCATE command.
func FcntlFstore(fd uintptr, cmd int, fstore *Fstore_t) error {
_, err := fcntl(int(fd), cmd, int(uintptr(unsafe.Pointer(fstore))))
return err
}
================================================
FILE: gossh/sys/unix/fcntl_linux_32bit.go
================================================
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build (linux && 386) || (linux && arm) || (linux && mips) || (linux && mipsle) || (linux && ppc)
package unix
func init() {
// On 32-bit Linux systems, the fcntl syscall that matches Go's
// Flock_t type is SYS_FCNTL64, not SYS_FCNTL.
fcntl64Syscall = SYS_FCNTL64
}
================================================
FILE: gossh/sys/unix/fdset.go
================================================
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || zos
package unix
// Set adds fd to the set fds.
func (fds *FdSet) Set(fd int) {
fds.Bits[fd/NFDBITS] |= (1 << (uintptr(fd) % NFDBITS))
}
// Clear removes fd from the set fds.
func (fds *FdSet) Clear(fd int) {
fds.Bits[fd/NFDBITS] &^= (1 << (uintptr(fd) % NFDBITS))
}
// IsSet returns whether fd is in the set fds.
func (fds *FdSet) IsSet(fd int) bool {
return fds.Bits[fd/NFDBITS]&(1<<(uintptr(fd)%NFDBITS)) != 0
}
// Zero clears the set fds.
func (fds *FdSet) Zero() {
for i := range fds.Bits {
fds.Bits[i] = 0
}
}
================================================
FILE: gossh/sys/unix/fstatfs_zos.go
================================================
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build zos && s390x
package unix
import (
"unsafe"
)
// This file simulates fstatfs on z/OS using fstatvfs and w_getmntent.
func Fstatfs(fd int, stat *Statfs_t) (err error) {
var stat_v Statvfs_t
err = Fstatvfs(fd, &stat_v)
if err == nil {
// populate stat
stat.Type = 0
stat.Bsize = stat_v.Bsize
stat.Blocks = stat_v.Blocks
stat.Bfree = stat_v.Bfree
stat.Bavail = stat_v.Bavail
stat.Files = stat_v.Files
stat.Ffree = stat_v.Ffree
stat.Fsid = stat_v.Fsid
stat.Namelen = stat_v.Namemax
stat.Frsize = stat_v.Frsize
stat.Flags = stat_v.Flag
for passn := 0; passn < 5; passn++ {
switch passn {
case 0:
err = tryGetmntent64(stat)
break
case 1:
err = tryGetmntent128(stat)
break
case 2:
err = tryGetmntent256(stat)
break
case 3:
err = tryGetmntent512(stat)
break
case 4:
err = tryGetmntent1024(stat)
break
default:
break
}
//proceed to return if: err is nil (found), err is nonnil but not ERANGE (another error occurred)
if err == nil || err != nil && err != ERANGE {
break
}
}
}
return err
}
func tryGetmntent64(stat *Statfs_t) (err error) {
var mnt_ent_buffer struct {
header W_Mnth
filesys_info [64]W_Mntent
}
var buffer_size int = int(unsafe.Sizeof(mnt_ent_buffer))
fs_count, err := W_Getmntent((*byte)(unsafe.Pointer(&mnt_ent_buffer)), buffer_size)
if err != nil {
return err
}
err = ERANGE //return ERANGE if no match is found in this batch
for i := 0; i < fs_count; i++ {
if stat.Fsid == uint64(mnt_ent_buffer.filesys_info[i].Dev) {
stat.Type = uint32(mnt_ent_buffer.filesys_info[i].Fstname[0])
err = nil
break
}
}
return err
}
func tryGetmntent128(stat *Statfs_t) (err error) {
var mnt_ent_buffer struct {
header W_Mnth
filesys_info [128]W_Mntent
}
var buffer_size int = int(unsafe.Sizeof(mnt_ent_buffer))
fs_count, err := W_Getmntent((*byte)(unsafe.Pointer(&mnt_ent_buffer)), buffer_size)
if err != nil {
return err
}
err = ERANGE //return ERANGE if no match is found in this batch
for i := 0; i < fs_count; i++ {
if stat.Fsid == uint64(mnt_ent_buffer.filesys_info[i].Dev) {
stat.Type = uint32(mnt_ent_buffer.filesys_info[i].Fstname[0])
err = nil
break
}
}
return err
}
func tryGetmntent256(stat *Statfs_t) (err error) {
var mnt_ent_buffer struct {
header W_Mnth
filesys_info [256]W_Mntent
}
var buffer_size int = int(unsafe.Sizeof(mnt_ent_buffer))
fs_count, err := W_Getmntent((*byte)(unsafe.Pointer(&mnt_ent_buffer)), buffer_size)
if err != nil {
return err
}
err = ERANGE //return ERANGE if no match is found in this batch
for i := 0; i < fs_count; i++ {
if stat.Fsid == uint64(mnt_ent_buffer.filesys_info[i].Dev) {
stat.Type = uint32(mnt_ent_buffer.filesys_info[i].Fstname[0])
err = nil
break
}
}
return err
}
func tryGetmntent512(stat *Statfs_t) (err error) {
var mnt_ent_buffer struct {
header W_Mnth
filesys_info [512]W_Mntent
}
var buffer_size int = int(unsafe.Sizeof(mnt_ent_buffer))
fs_count, err := W_Getmntent((*byte)(unsafe.Pointer(&mnt_ent_buffer)), buffer_size)
if err != nil {
return err
}
err = ERANGE //return ERANGE if no match is found in this batch
for i := 0; i < fs_count; i++ {
if stat.Fsid == uint64(mnt_ent_buffer.filesys_info[i].Dev) {
stat.Type = uint32(mnt_ent_buffer.filesys_info[i].Fstname[0])
err = nil
break
}
}
return err
}
func tryGetmntent1024(stat *Statfs_t) (err error) {
var mnt_ent_buffer struct {
header W_Mnth
filesys_info [1024]W_Mntent
}
var buffer_size int = int(unsafe.Sizeof(mnt_ent_buffer))
fs_count, err := W_Getmntent((*byte)(unsafe.Pointer(&mnt_ent_buffer)), buffer_size)
if err != nil {
return err
}
err = ERANGE //return ERANGE if no match is found in this batch
for i := 0; i < fs_count; i++ {
if stat.Fsid == uint64(mnt_ent_buffer.filesys_info[i].Dev) {
stat.Type = uint32(mnt_ent_buffer.filesys_info[i].Fstname[0])
err = nil
break
}
}
return err
}
================================================
FILE: gossh/sys/unix/gccgo.go
================================================
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gccgo && !aix && !hurd
package unix
import "syscall"
// We can't use the gc-syntax .s files for gccgo. On the plus side
// much of the functionality can be written directly in Go.
func realSyscallNoError(trap, a1, a2, a3, a4, a5, a6, a7, a8, a9 uintptr) (r uintptr)
func realSyscall(trap, a1, a2, a3, a4, a5, a6, a7, a8, a9 uintptr) (r, errno uintptr)
func SyscallNoError(trap, a1, a2, a3 uintptr) (r1, r2 uintptr) {
syscall.Entersyscall()
r := realSyscallNoError(trap, a1, a2, a3, 0, 0, 0, 0, 0, 0)
syscall.Exitsyscall()
return r, 0
}
func Syscall(trap, a1, a2, a3 uintptr) (r1, r2 uintptr, err syscall.Errno) {
syscall.Entersyscall()
r, errno := realSyscall(trap, a1, a2, a3, 0, 0, 0, 0, 0, 0)
syscall.Exitsyscall()
return r, 0, syscall.Errno(errno)
}
func Syscall6(trap, a1, a2, a3, a4, a5, a6 uintptr) (r1, r2 uintptr, err syscall.Errno) {
syscall.Entersyscall()
r, errno := realSyscall(trap, a1, a2, a3, a4, a5, a6, 0, 0, 0)
syscall.Exitsyscall()
return r, 0, syscall.Errno(errno)
}
func Syscall9(trap, a1, a2, a3, a4, a5, a6, a7, a8, a9 uintptr) (r1, r2 uintptr, err syscall.Errno) {
syscall.Entersyscall()
r, errno := realSyscall(trap, a1, a2, a3, a4, a5, a6, a7, a8, a9)
syscall.Exitsyscall()
return r, 0, syscall.Errno(errno)
}
func RawSyscallNoError(trap, a1, a2, a3 uintptr) (r1, r2 uintptr) {
r := realSyscallNoError(trap, a1, a2, a3, 0, 0, 0, 0, 0, 0)
return r, 0
}
func RawSyscall(trap, a1, a2, a3 uintptr) (r1, r2 uintptr, err syscall.Errno) {
r, errno := realSyscall(trap, a1, a2, a3, 0, 0, 0, 0, 0, 0)
return r, 0, syscall.Errno(errno)
}
func RawSyscall6(trap, a1, a2, a3, a4, a5, a6 uintptr) (r1, r2 uintptr, err syscall.Errno) {
r, errno := realSyscall(trap, a1, a2, a3, a4, a5, a6, 0, 0, 0)
return r, 0, syscall.Errno(errno)
}
================================================
FILE: gossh/sys/unix/gccgo_c.c
================================================
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gccgo && !aix && !hurd
#include
#include
#include
#define _STRINGIFY2_(x) #x
#define _STRINGIFY_(x) _STRINGIFY2_(x)
#define GOSYM_PREFIX _STRINGIFY_(__USER_LABEL_PREFIX__)
// Call syscall from C code because the gccgo support for calling from
// Go to C does not support varargs functions.
struct ret {
uintptr_t r;
uintptr_t err;
};
struct ret gccgoRealSyscall(uintptr_t trap, uintptr_t a1, uintptr_t a2, uintptr_t a3, uintptr_t a4, uintptr_t a5, uintptr_t a6, uintptr_t a7, uintptr_t a8, uintptr_t a9)
__asm__(GOSYM_PREFIX GOPKGPATH ".realSyscall");
struct ret
gccgoRealSyscall(uintptr_t trap, uintptr_t a1, uintptr_t a2, uintptr_t a3, uintptr_t a4, uintptr_t a5, uintptr_t a6, uintptr_t a7, uintptr_t a8, uintptr_t a9)
{
struct ret r;
errno = 0;
r.r = syscall(trap, a1, a2, a3, a4, a5, a6, a7, a8, a9);
r.err = errno;
return r;
}
uintptr_t gccgoRealSyscallNoError(uintptr_t trap, uintptr_t a1, uintptr_t a2, uintptr_t a3, uintptr_t a4, uintptr_t a5, uintptr_t a6, uintptr_t a7, uintptr_t a8, uintptr_t a9)
__asm__(GOSYM_PREFIX GOPKGPATH ".realSyscallNoError");
uintptr_t
gccgoRealSyscallNoError(uintptr_t trap, uintptr_t a1, uintptr_t a2, uintptr_t a3, uintptr_t a4, uintptr_t a5, uintptr_t a6, uintptr_t a7, uintptr_t a8, uintptr_t a9)
{
return syscall(trap, a1, a2, a3, a4, a5, a6, a7, a8, a9);
}
================================================
FILE: gossh/sys/unix/gccgo_linux_amd64.go
================================================
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gccgo && linux && amd64
package unix
import "syscall"
//extern gettimeofday
func realGettimeofday(*Timeval, *byte) int32
func gettimeofday(tv *Timeval) (err syscall.Errno) {
r := realGettimeofday(tv, nil)
if r < 0 {
return syscall.GetErrno()
}
return 0
}
================================================
FILE: gossh/sys/unix/ifreq_linux.go
================================================
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build linux
package unix
import (
"unsafe"
)
// Helpers for dealing with ifreq since it contains a union and thus requires a
// lot of unsafe.Pointer casts to use properly.
// An Ifreq is a type-safe wrapper around the raw ifreq struct. An Ifreq
// contains an interface name and a union of arbitrary data which can be
// accessed using the Ifreq's methods. To create an Ifreq, use the NewIfreq
// function.
//
// Use the Name method to access the stored interface name. The union data
// fields can be get and set using the following methods:
// - Uint16/SetUint16: flags
// - Uint32/SetUint32: ifindex, metric, mtu
type Ifreq struct{ raw ifreq }
// NewIfreq creates an Ifreq with the input network interface name after
// validating the name does not exceed IFNAMSIZ-1 (trailing NULL required)
// bytes.
func NewIfreq(name string) (*Ifreq, error) {
// Leave room for terminating NULL byte.
if len(name) >= IFNAMSIZ {
return nil, EINVAL
}
var ifr ifreq
copy(ifr.Ifrn[:], name)
return &Ifreq{raw: ifr}, nil
}
// TODO(mdlayher): get/set methods for hardware address sockaddr, char array, etc.
// Name returns the interface name associated with the Ifreq.
func (ifr *Ifreq) Name() string {
return ByteSliceToString(ifr.raw.Ifrn[:])
}
// According to netdevice(7), only AF_INET addresses are returned for numerous
// sockaddr ioctls. For convenience, we expose these as Inet4Addr since the Port
// field and other data is always empty.
// Inet4Addr returns the Ifreq union data from an embedded sockaddr as a C
// in_addr/Go []byte (4-byte IPv4 address) value. If the sockaddr family is not
// AF_INET, an error is returned.
func (ifr *Ifreq) Inet4Addr() ([]byte, error) {
raw := *(*RawSockaddrInet4)(unsafe.Pointer(&ifr.raw.Ifru[:SizeofSockaddrInet4][0]))
if raw.Family != AF_INET {
// Cannot safely interpret raw.Addr bytes as an IPv4 address.
return nil, EINVAL
}
return raw.Addr[:], nil
}
// SetInet4Addr sets a C in_addr/Go []byte (4-byte IPv4 address) value in an
// embedded sockaddr within the Ifreq's union data. v must be 4 bytes in length
// or an error will be returned.
func (ifr *Ifreq) SetInet4Addr(v []byte) error {
if len(v) != 4 {
return EINVAL
}
var addr [4]byte
copy(addr[:], v)
ifr.clear()
*(*RawSockaddrInet4)(
unsafe.Pointer(&ifr.raw.Ifru[:SizeofSockaddrInet4][0]),
) = RawSockaddrInet4{
// Always set IP family as ioctls would require it anyway.
Family: AF_INET,
Addr: addr,
}
return nil
}
// Uint16 returns the Ifreq union data as a C short/Go uint16 value.
func (ifr *Ifreq) Uint16() uint16 {
return *(*uint16)(unsafe.Pointer(&ifr.raw.Ifru[:2][0]))
}
// SetUint16 sets a C short/Go uint16 value as the Ifreq's union data.
func (ifr *Ifreq) SetUint16(v uint16) {
ifr.clear()
*(*uint16)(unsafe.Pointer(&ifr.raw.Ifru[:2][0])) = v
}
// Uint32 returns the Ifreq union data as a C int/Go uint32 value.
func (ifr *Ifreq) Uint32() uint32 {
return *(*uint32)(unsafe.Pointer(&ifr.raw.Ifru[:4][0]))
}
// SetUint32 sets a C int/Go uint32 value as the Ifreq's union data.
func (ifr *Ifreq) SetUint32(v uint32) {
ifr.clear()
*(*uint32)(unsafe.Pointer(&ifr.raw.Ifru[:4][0])) = v
}
// clear zeroes the ifreq's union field to prevent trailing garbage data from
// being sent to the kernel if an ifreq is reused.
func (ifr *Ifreq) clear() {
for i := range ifr.raw.Ifru {
ifr.raw.Ifru[i] = 0
}
}
// TODO(mdlayher): export as IfreqData? For now we can provide helpers such as
// IoctlGetEthtoolDrvinfo which use these APIs under the hood.
// An ifreqData is an Ifreq which carries pointer data. To produce an ifreqData,
// use the Ifreq.withData method.
type ifreqData struct {
name [IFNAMSIZ]byte
// A type separate from ifreq is required in order to comply with the
// unsafe.Pointer rules since the "pointer-ness" of data would not be
// preserved if it were cast into the byte array of a raw ifreq.
data unsafe.Pointer
// Pad to the same size as ifreq.
_ [len(ifreq{}.Ifru) - SizeofPtr]byte
}
// withData produces an ifreqData with the pointer p set for ioctls which require
// arbitrary pointer data.
func (ifr Ifreq) withData(p unsafe.Pointer) ifreqData {
return ifreqData{
name: ifr.raw.Ifrn,
data: p,
}
}
================================================
FILE: gossh/sys/unix/internal/mkmerge/mkmerge.go
================================================
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// The mkmerge command parses generated source files and merges common
// consts, funcs, and types into a common source file, per GOOS.
//
// Usage:
//
// $ mkmerge -out MERGED FILE [FILE ...]
//
// Example:
//
// # Remove all common consts, funcs, and types from zerrors_linux_*.go
// # and write the common code into zerrors_linux.go
// $ mkmerge -out zerrors_linux.go zerrors_linux_*.go
//
// mkmerge performs the merge in the following steps:
// 1. Construct the set of common code that is identical in all
// architecture-specific files.
// 2. Write this common code to the merged file.
// 3. Remove the common code from all architecture-specific files.
package main
import (
"bufio"
"bytes"
"flag"
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/token"
"io"
"log"
"os"
"path"
"path/filepath"
"regexp"
"strconv"
"strings"
)
const validGOOS = "aix|darwin|dragonfly|freebsd|linux|netbsd|openbsd|solaris"
// getValidGOOS returns GOOS, true if filename ends with a valid "_GOOS.go"
func getValidGOOS(filename string) (string, bool) {
matches := regexp.MustCompile(`_(` + validGOOS + `)\.go$`).FindStringSubmatch(filename)
if len(matches) != 2 {
return "", false
}
return matches[1], true
}
// codeElem represents an ast.Decl in a comparable way.
type codeElem struct {
tok token.Token // e.g. token.CONST, token.TYPE, or token.FUNC
src string // the declaration formatted as source code
}
// newCodeElem returns a codeElem based on tok and node, or an error is returned.
func newCodeElem(tok token.Token, node ast.Node) (codeElem, error) {
var b strings.Builder
err := format.Node(&b, token.NewFileSet(), node)
if err != nil {
return codeElem{}, err
}
return codeElem{tok, b.String()}, nil
}
// codeSet is a set of codeElems
type codeSet struct {
set map[codeElem]bool // true for all codeElems in the set
}
// newCodeSet returns a new codeSet
func newCodeSet() *codeSet { return &codeSet{make(map[codeElem]bool)} }
// add adds elem to c
func (c *codeSet) add(elem codeElem) { c.set[elem] = true }
// has returns true if elem is in c
func (c *codeSet) has(elem codeElem) bool { return c.set[elem] }
// isEmpty returns true if the set is empty
func (c *codeSet) isEmpty() bool { return len(c.set) == 0 }
// intersection returns a new set which is the intersection of c and a
func (c *codeSet) intersection(a *codeSet) *codeSet {
res := newCodeSet()
for elem := range c.set {
if a.has(elem) {
res.add(elem)
}
}
return res
}
// keepCommon is a filterFn for filtering the merged file with common declarations.
func (c *codeSet) keepCommon(elem codeElem) bool {
switch elem.tok {
case token.VAR:
// Remove all vars from the merged file
return false
case token.CONST, token.TYPE, token.FUNC, token.COMMENT:
// Remove arch-specific consts, types, functions, and file-level comments from the merged file
return c.has(elem)
case token.IMPORT:
// Keep imports, they are handled by filterImports
return true
}
log.Fatalf("keepCommon: invalid elem %v", elem)
return true
}
// keepArchSpecific is a filterFn for filtering the GOARC-specific files.
func (c *codeSet) keepArchSpecific(elem codeElem) bool {
switch elem.tok {
case token.CONST, token.TYPE, token.FUNC:
// Remove common consts, types, or functions from the arch-specific file
return !c.has(elem)
}
return true
}
// srcFile represents a source file
type srcFile struct {
name string
src []byte
}
// filterFn is a helper for filter
type filterFn func(codeElem) bool
// filter parses and filters Go source code from src, removing top
// level declarations using keep as predicate.
// For src parameter, please see docs for parser.ParseFile.
func filter(src any, keep filterFn) ([]byte, error) {
// Parse the src into an ast
fset := token.NewFileSet()
f, err := parser.ParseFile(fset, "", src, parser.ParseComments)
if err != nil {
return nil, err
}
cmap := ast.NewCommentMap(fset, f, f.Comments)
// Group const/type specs on adjacent lines
var groups specGroups = make(map[string]int)
var groupID int
decls := f.Decls
f.Decls = f.Decls[:0]
for _, decl := range decls {
switch decl := decl.(type) {
case *ast.GenDecl:
// Filter imports, consts, types, vars
specs := decl.Specs
decl.Specs = decl.Specs[:0]
for i, spec := range specs {
elem, err := newCodeElem(decl.Tok, spec)
if err != nil {
return nil, err
}
// Create new group if there are empty lines between this and the previous spec
if i > 0 && fset.Position(specs[i-1].End()).Line < fset.Position(spec.Pos()).Line-1 {
groupID++
}
// Check if we should keep this spec
if keep(elem) {
decl.Specs = append(decl.Specs, spec)
groups.add(elem.src, groupID)
}
}
// Check if we should keep this decl
if len(decl.Specs) > 0 {
f.Decls = append(f.Decls, decl)
}
case *ast.FuncDecl:
// Filter funcs
elem, err := newCodeElem(token.FUNC, decl)
if err != nil {
return nil, err
}
if keep(elem) {
f.Decls = append(f.Decls, decl)
}
}
}
// Filter file level comments
if cmap[f] != nil {
commentGroups := cmap[f]
cmap[f] = cmap[f][:0]
for _, cGrp := range commentGroups {
if keep(codeElem{token.COMMENT, cGrp.Text()}) {
cmap[f] = append(cmap[f], cGrp)
}
}
}
f.Comments = cmap.Filter(f).Comments()
// Generate code for the filtered ast
var buf bytes.Buffer
if err = format.Node(&buf, fset, f); err != nil {
return nil, err
}
groupedSrc, err := groups.filterEmptyLines(&buf)
if err != nil {
return nil, err
}
return filterImports(groupedSrc)
}
// getCommonSet returns the set of consts, types, and funcs that are present in every file.
func getCommonSet(files []srcFile) (*codeSet, error) {
if len(files) == 0 {
return nil, fmt.Errorf("no files provided")
}
// Use the first architecture file as the baseline
baseSet, err := getCodeSet(files[0].src)
if err != nil {
return nil, err
}
// Compare baseline set with other architecture files: discard any element,
// that doesn't exist in other architecture files.
for _, f := range files[1:] {
set, err := getCodeSet(f.src)
if err != nil {
return nil, err
}
baseSet = baseSet.intersection(set)
}
return baseSet, nil
}
// getCodeSet returns the set of all top-level consts, types, and funcs from src.
// src must be string, []byte, or io.Reader (see go/parser.ParseFile docs)
func getCodeSet(src any) (*codeSet, error) {
set := newCodeSet()
fset := token.NewFileSet()
f, err := parser.ParseFile(fset, "", src, parser.ParseComments)
if err != nil {
return nil, err
}
for _, decl := range f.Decls {
switch decl := decl.(type) {
case *ast.GenDecl:
// Add const, and type declarations
if !(decl.Tok == token.CONST || decl.Tok == token.TYPE) {
break
}
for _, spec := range decl.Specs {
elem, err := newCodeElem(decl.Tok, spec)
if err != nil {
return nil, err
}
set.add(elem)
}
case *ast.FuncDecl:
// Add func declarations
elem, err := newCodeElem(token.FUNC, decl)
if err != nil {
return nil, err
}
set.add(elem)
}
}
// Add file level comments
cmap := ast.NewCommentMap(fset, f, f.Comments)
for _, cGrp := range cmap[f] {
text := cGrp.Text()
if text == "" && len(cGrp.List) == 1 && strings.HasPrefix(cGrp.List[0].Text, "//go:build ") {
// ast.CommentGroup.Text doesn't include comment directives like "//go:build"
// in the text. So if a comment group has empty text and a single //go:build
// constraint line, make a custom codeElem. This is enough for mkmerge needs.
set.add(codeElem{token.COMMENT, cGrp.List[0].Text[len("//"):] + "\n"})
continue
}
set.add(codeElem{token.COMMENT, text})
}
return set, nil
}
// importName returns the identifier (PackageName) for an imported package
func importName(iSpec *ast.ImportSpec) (string, error) {
if iSpec.Name == nil {
name, err := strconv.Unquote(iSpec.Path.Value)
if err != nil {
return "", err
}
return path.Base(name), nil
}
return iSpec.Name.Name, nil
}
// specGroups tracks grouped const/type specs with a map of line: groupID pairs
type specGroups map[string]int
// add spec source to group
func (s specGroups) add(src string, groupID int) error {
srcBytes, err := format.Source(bytes.TrimSpace([]byte(src)))
if err != nil {
return err
}
s[string(srcBytes)] = groupID
return nil
}
// filterEmptyLines removes empty lines within groups of const/type specs.
// Returns the filtered source.
func (s specGroups) filterEmptyLines(src io.Reader) ([]byte, error) {
scanner := bufio.NewScanner(src)
var out bytes.Buffer
var emptyLines bytes.Buffer
prevGroupID := -1 // Initialize to invalid group
for scanner.Scan() {
line := bytes.TrimSpace(scanner.Bytes())
if len(line) == 0 {
fmt.Fprintf(&emptyLines, "%s\n", scanner.Bytes())
continue
}
// Discard emptyLines if previous non-empty line belonged to the same
// group as this line
if src, err := format.Source(line); err == nil {
groupID, ok := s[string(src)]
if ok && groupID == prevGroupID {
emptyLines.Reset()
}
prevGroupID = groupID
}
emptyLines.WriteTo(&out)
fmt.Fprintf(&out, "%s\n", scanner.Bytes())
}
if err := scanner.Err(); err != nil {
return nil, err
}
return out.Bytes(), nil
}
// filterImports removes unused imports from fileSrc, and returns a formatted src.
func filterImports(fileSrc []byte) ([]byte, error) {
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "", fileSrc, parser.ParseComments)
if err != nil {
return nil, err
}
cmap := ast.NewCommentMap(fset, file, file.Comments)
// create set of references to imported identifiers
keepImport := make(map[string]bool)
for _, u := range file.Unresolved {
keepImport[u.Name] = true
}
// filter import declarations
decls := file.Decls
file.Decls = file.Decls[:0]
for _, decl := range decls {
importDecl, ok := decl.(*ast.GenDecl)
// Keep non-import declarations
if !ok || importDecl.Tok != token.IMPORT {
file.Decls = append(file.Decls, decl)
continue
}
// Filter the import specs
specs := importDecl.Specs
importDecl.Specs = importDecl.Specs[:0]
for _, spec := range specs {
iSpec := spec.(*ast.ImportSpec)
name, err := importName(iSpec)
if err != nil {
return nil, err
}
if keepImport[name] {
importDecl.Specs = append(importDecl.Specs, iSpec)
}
}
if len(importDecl.Specs) > 0 {
file.Decls = append(file.Decls, importDecl)
}
}
// filter file.Imports
imports := file.Imports
file.Imports = file.Imports[:0]
for _, spec := range imports {
name, err := importName(spec)
if err != nil {
return nil, err
}
if keepImport[name] {
file.Imports = append(file.Imports, spec)
}
}
file.Comments = cmap.Filter(file).Comments()
var buf bytes.Buffer
err = format.Node(&buf, fset, file)
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// merge extracts duplicate code from archFiles and merges it to mergeFile.
// 1. Construct commonSet: the set of code that is idential in all archFiles.
// 2. Write the code in commonSet to mergedFile.
// 3. Remove the commonSet code from all archFiles.
func merge(mergedFile string, archFiles ...string) error {
// extract and validate the GOOS part of the merged filename
goos, ok := getValidGOOS(mergedFile)
if !ok {
return fmt.Errorf("invalid GOOS in merged file name %s", mergedFile)
}
// Read architecture files
var inSrc []srcFile
for _, file := range archFiles {
src, err := os.ReadFile(file)
if err != nil {
return fmt.Errorf("cannot read archfile %s: %w", file, err)
}
inSrc = append(inSrc, srcFile{file, src})
}
// 1. Construct the set of top-level declarations common for all files
commonSet, err := getCommonSet(inSrc)
if err != nil {
return err
}
if commonSet.isEmpty() {
// No common code => do not modify any files
return nil
}
// 2. Write the merged file
mergedSrc, err := filter(inSrc[0].src, commonSet.keepCommon)
if err != nil {
return err
}
f, err := os.Create(mergedFile)
if err != nil {
return err
}
buf := bufio.NewWriter(f)
fmt.Fprintln(buf, "// Code generated by mkmerge; DO NOT EDIT.")
fmt.Fprintln(buf)
fmt.Fprintf(buf, "//go:build %s\n", goos)
fmt.Fprintln(buf)
buf.Write(mergedSrc)
err = buf.Flush()
if err != nil {
return err
}
err = f.Close()
if err != nil {
return err
}
// 3. Remove duplicate declarations from the architecture files
for _, inFile := range inSrc {
src, err := filter(inFile.src, commonSet.keepArchSpecific)
if err != nil {
return err
}
err = os.WriteFile(inFile.name, src, 0644)
if err != nil {
return err
}
}
return nil
}
func main() {
var mergedFile string
flag.StringVar(&mergedFile, "out", "", "Write merged code to `FILE`")
flag.Parse()
// Expand wildcards
var filenames []string
for _, arg := range flag.Args() {
matches, err := filepath.Glob(arg)
if err != nil {
fmt.Fprintf(os.Stderr, "Invalid command line argument %q: %v\n", arg, err)
os.Exit(1)
}
filenames = append(filenames, matches...)
}
if len(filenames) < 2 {
// No need to merge
return
}
err := merge(mergedFile, filenames...)
if err != nil {
fmt.Fprintf(os.Stderr, "Merge failed with error: %v\n", err)
os.Exit(1)
}
}
================================================
FILE: gossh/sys/unix/ioctl_linux.go
================================================
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package unix
import "unsafe"
// IoctlRetInt performs an ioctl operation specified by req on a device
// associated with opened file descriptor fd, and returns a non-negative
// integer that is returned by the ioctl syscall.
func IoctlRetInt(fd int, req uint) (int, error) {
ret, _, err := Syscall(SYS_IOCTL, uintptr(fd), uintptr(req), 0)
if err != 0 {
return 0, err
}
return int(ret), nil
}
func IoctlGetUint32(fd int, req uint) (uint32, error) {
var value uint32
err := ioctlPtr(fd, req, unsafe.Pointer(&value))
return value, err
}
func IoctlGetRTCTime(fd int) (*RTCTime, error) {
var value RTCTime
err := ioctlPtr(fd, RTC_RD_TIME, unsafe.Pointer(&value))
return &value, err
}
func IoctlSetRTCTime(fd int, value *RTCTime) error {
return ioctlPtr(fd, RTC_SET_TIME, unsafe.Pointer(value))
}
func IoctlGetRTCWkAlrm(fd int) (*RTCWkAlrm, error) {
var value RTCWkAlrm
err := ioctlPtr(fd, RTC_WKALM_RD, unsafe.Pointer(&value))
return &value, err
}
func IoctlSetRTCWkAlrm(fd int, value *RTCWkAlrm) error {
return ioctlPtr(fd, RTC_WKALM_SET, unsafe.Pointer(value))
}
// IoctlGetEthtoolDrvinfo fetches ethtool driver information for the network
// device specified by ifname.
func IoctlGetEthtoolDrvinfo(fd int, ifname string) (*EthtoolDrvinfo, error) {
ifr, err := NewIfreq(ifname)
if err != nil {
return nil, err
}
value := EthtoolDrvinfo{Cmd: ETHTOOL_GDRVINFO}
ifrd := ifr.withData(unsafe.Pointer(&value))
err = ioctlIfreqData(fd, SIOCETHTOOL, &ifrd)
return &value, err
}
// IoctlGetWatchdogInfo fetches information about a watchdog device from the
// Linux watchdog API. For more information, see:
// https://www.kernel.org/doc/html/latest/watchdog/watchdog-api.html.
func IoctlGetWatchdogInfo(fd int) (*WatchdogInfo, error) {
var value WatchdogInfo
err := ioctlPtr(fd, WDIOC_GETSUPPORT, unsafe.Pointer(&value))
return &value, err
}
// IoctlWatchdogKeepalive issues a keepalive ioctl to a watchdog device. For
// more information, see:
// https://www.kernel.org/doc/html/latest/watchdog/watchdog-api.html.
func IoctlWatchdogKeepalive(fd int) error {
// arg is ignored and not a pointer, so ioctl is fine instead of ioctlPtr.
return ioctl(fd, WDIOC_KEEPALIVE, 0)
}
// IoctlFileCloneRange performs an FICLONERANGE ioctl operation to clone the
// range of data conveyed in value to the file associated with the file
// descriptor destFd. See the ioctl_ficlonerange(2) man page for details.
func IoctlFileCloneRange(destFd int, value *FileCloneRange) error {
return ioctlPtr(destFd, FICLONERANGE, unsafe.Pointer(value))
}
// IoctlFileClone performs an FICLONE ioctl operation to clone the entire file
// associated with the file description srcFd to the file associated with the
// file descriptor destFd. See the ioctl_ficlone(2) man page for details.
func IoctlFileClone(destFd, srcFd int) error {
return ioctl(destFd, FICLONE, uintptr(srcFd))
}
type FileDedupeRange struct {
Src_offset uint64
Src_length uint64
Reserved1 uint16
Reserved2 uint32
Info []FileDedupeRangeInfo
}
type FileDedupeRangeInfo struct {
Dest_fd int64
Dest_offset uint64
Bytes_deduped uint64
Status int32
Reserved uint32
}
// IoctlFileDedupeRange performs an FIDEDUPERANGE ioctl operation to share the
// range of data conveyed in value from the file associated with the file
// descriptor srcFd to the value.Info destinations. See the
// ioctl_fideduperange(2) man page for details.
func IoctlFileDedupeRange(srcFd int, value *FileDedupeRange) error {
buf := make([]byte, SizeofRawFileDedupeRange+
len(value.Info)*SizeofRawFileDedupeRangeInfo)
rawrange := (*RawFileDedupeRange)(unsafe.Pointer(&buf[0]))
rawrange.Src_offset = value.Src_offset
rawrange.Src_length = value.Src_length
rawrange.Dest_count = uint16(len(value.Info))
rawrange.Reserved1 = value.Reserved1
rawrange.Reserved2 = value.Reserved2
for i := range value.Info {
rawinfo := (*RawFileDedupeRangeInfo)(unsafe.Pointer(
uintptr(unsafe.Pointer(&buf[0])) + uintptr(SizeofRawFileDedupeRange) +
uintptr(i*SizeofRawFileDedupeRangeInfo)))
rawinfo.Dest_fd = value.Info[i].Dest_fd
rawinfo.Dest_offset = value.Info[i].Dest_offset
rawinfo.Bytes_deduped = value.Info[i].Bytes_deduped
rawinfo.Status = value.Info[i].Status
rawinfo.Reserved = value.Info[i].Reserved
}
err := ioctlPtr(srcFd, FIDEDUPERANGE, unsafe.Pointer(&buf[0]))
// Output
for i := range value.Info {
rawinfo := (*RawFileDedupeRangeInfo)(unsafe.Pointer(
uintptr(unsafe.Pointer(&buf[0])) + uintptr(SizeofRawFileDedupeRange) +
uintptr(i*SizeofRawFileDedupeRangeInfo)))
value.Info[i].Dest_fd = rawinfo.Dest_fd
value.Info[i].Dest_offset = rawinfo.Dest_offset
value.Info[i].Bytes_deduped = rawinfo.Bytes_deduped
value.Info[i].Status = rawinfo.Status
value.Info[i].Reserved = rawinfo.Reserved
}
return err
}
func IoctlHIDGetDesc(fd int, value *HIDRawReportDescriptor) error {
return ioctlPtr(fd, HIDIOCGRDESC, unsafe.Pointer(value))
}
func IoctlHIDGetRawInfo(fd int) (*HIDRawDevInfo, error) {
var value HIDRawDevInfo
err := ioctlPtr(fd, HIDIOCGRAWINFO, unsafe.Pointer(&value))
return &value, err
}
func IoctlHIDGetRawName(fd int) (string, error) {
var value [_HIDIOCGRAWNAME_LEN]byte
err := ioctlPtr(fd, _HIDIOCGRAWNAME, unsafe.Pointer(&value[0]))
return ByteSliceToString(value[:]), err
}
func IoctlHIDGetRawPhys(fd int) (string, error) {
var value [_HIDIOCGRAWPHYS_LEN]byte
err := ioctlPtr(fd, _HIDIOCGRAWPHYS, unsafe.Pointer(&value[0]))
return ByteSliceToString(value[:]), err
}
func IoctlHIDGetRawUniq(fd int) (string, error) {
var value [_HIDIOCGRAWUNIQ_LEN]byte
err := ioctlPtr(fd, _HIDIOCGRAWUNIQ, unsafe.Pointer(&value[0]))
return ByteSliceToString(value[:]), err
}
// IoctlIfreq performs an ioctl using an Ifreq structure for input and/or
// output. See the netdevice(7) man page for details.
func IoctlIfreq(fd int, req uint, value *Ifreq) error {
// It is possible we will add more fields to *Ifreq itself later to prevent
// misuse, so pass the raw *ifreq directly.
return ioctlPtr(fd, req, unsafe.Pointer(&value.raw))
}
// TODO(mdlayher): export if and when IfreqData is exported.
// ioctlIfreqData performs an ioctl using an ifreqData structure for input
// and/or output. See the netdevice(7) man page for details.
func ioctlIfreqData(fd int, req uint, value *ifreqData) error {
// The memory layout of IfreqData (type-safe) and ifreq (not type-safe) are
// identical so pass *IfreqData directly.
return ioctlPtr(fd, req, unsafe.Pointer(value))
}
// IoctlKCMClone attaches a new file descriptor to a multiplexor by cloning an
// existing KCM socket, returning a structure containing the file descriptor of
// the new socket.
func IoctlKCMClone(fd int) (*KCMClone, error) {
var info KCMClone
if err := ioctlPtr(fd, SIOCKCMCLONE, unsafe.Pointer(&info)); err != nil {
return nil, err
}
return &info, nil
}
// IoctlKCMAttach attaches a TCP socket and associated BPF program file
// descriptor to a multiplexor.
func IoctlKCMAttach(fd int, info KCMAttach) error {
return ioctlPtr(fd, SIOCKCMATTACH, unsafe.Pointer(&info))
}
// IoctlKCMUnattach unattaches a TCP socket file descriptor from a multiplexor.
func IoctlKCMUnattach(fd int, info KCMUnattach) error {
return ioctlPtr(fd, SIOCKCMUNATTACH, unsafe.Pointer(&info))
}
// IoctlLoopGetStatus64 gets the status of the loop device associated with the
// file descriptor fd using the LOOP_GET_STATUS64 operation.
func IoctlLoopGetStatus64(fd int) (*LoopInfo64, error) {
var value LoopInfo64
if err := ioctlPtr(fd, LOOP_GET_STATUS64, unsafe.Pointer(&value)); err != nil {
return nil, err
}
return &value, nil
}
// IoctlLoopSetStatus64 sets the status of the loop device associated with the
// file descriptor fd using the LOOP_SET_STATUS64 operation.
func IoctlLoopSetStatus64(fd int, value *LoopInfo64) error {
return ioctlPtr(fd, LOOP_SET_STATUS64, unsafe.Pointer(value))
}
// IoctlLoopConfigure configures all loop device parameters in a single step
func IoctlLoopConfigure(fd int, value *LoopConfig) error {
return ioctlPtr(fd, LOOP_CONFIGURE, unsafe.Pointer(value))
}
================================================
FILE: gossh/sys/unix/ioctl_signed.go
================================================
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build aix || solaris
package unix
import (
"unsafe"
)
// ioctl itself should not be exposed directly, but additional get/set
// functions for specific types are permissible.
// IoctlSetInt performs an ioctl operation which sets an integer value
// on fd, using the specified request number.
func IoctlSetInt(fd int, req int, value int) error {
return ioctl(fd, req, uintptr(value))
}
// IoctlSetPointerInt performs an ioctl operation which sets an
// integer value on fd, using the specified request number. The ioctl
// argument is called with a pointer to the integer value, rather than
// passing the integer value directly.
func IoctlSetPointerInt(fd int, req int, value int) error {
v := int32(value)
return ioctlPtr(fd, req, unsafe.Pointer(&v))
}
// IoctlSetWinsize performs an ioctl on fd with a *Winsize argument.
//
// To change fd's window size, the req argument should be TIOCSWINSZ.
func IoctlSetWinsize(fd int, req int, value *Winsize) error {
// TODO: if we get the chance, remove the req parameter and
// hardcode TIOCSWINSZ.
return ioctlPtr(fd, req, unsafe.Pointer(value))
}
// IoctlSetTermios performs an ioctl on fd with a *Termios.
//
// The req value will usually be TCSETA or TIOCSETA.
func IoctlSetTermios(fd int, req int, value *Termios) error {
// TODO: if we get the chance, remove the req parameter.
return ioctlPtr(fd, req, unsafe.Pointer(value))
}
// IoctlGetInt performs an ioctl operation which gets an integer value
// from fd, using the specified request number.
//
// A few ioctl requests use the return value as an output parameter;
// for those, IoctlRetInt should be used instead of this function.
func IoctlGetInt(fd int, req int) (int, error) {
var value int
err := ioctlPtr(fd, req, unsafe.Pointer(&value))
return value, err
}
func IoctlGetWinsize(fd int, req int) (*Winsize, error) {
var value Winsize
err := ioctlPtr(fd, req, unsafe.Pointer(&value))
return &value, err
}
func IoctlGetTermios(fd int, req int) (*Termios, error) {
var value Termios
err := ioctlPtr(fd, req, unsafe.Pointer(&value))
return &value, err
}
================================================
FILE: gossh/sys/unix/ioctl_unsigned.go
================================================
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd
package unix
import (
"unsafe"
)
// ioctl itself should not be exposed directly, but additional get/set
// functions for specific types are permissible.
// IoctlSetInt performs an ioctl operation which sets an integer value
// on fd, using the specified request number.
func IoctlSetInt(fd int, req uint, value int) error {
return ioctl(fd, req, uintptr(value))
}
// IoctlSetPointerInt performs an ioctl operation which sets an
// integer value on fd, using the specified request number. The ioctl
// argument is called with a pointer to the integer value, rather than
// passing the integer value directly.
func IoctlSetPointerInt(fd int, req uint, value int) error {
v := int32(value)
return ioctlPtr(fd, req, unsafe.Pointer(&v))
}
// IoctlSetWinsize performs an ioctl on fd with a *Winsize argument.
//
// To change fd's window size, the req argument should be TIOCSWINSZ.
func IoctlSetWinsize(fd int, req uint, value *Winsize) error {
// TODO: if we get the chance, remove the req parameter and
// hardcode TIOCSWINSZ.
return ioctlPtr(fd, req, unsafe.Pointer(value))
}
// IoctlSetTermios performs an ioctl on fd with a *Termios.
//
// The req value will usually be TCSETA or TIOCSETA.
func IoctlSetTermios(fd int, req uint, value *Termios) error {
// TODO: if we get the chance, remove the req parameter.
return ioctlPtr(fd, req, unsafe.Pointer(value))
}
// IoctlGetInt performs an ioctl operation which gets an integer value
// from fd, using the specified request number.
//
// A few ioctl requests use the return value as an output parameter;
// for those, IoctlRetInt should be used instead of this function.
func IoctlGetInt(fd int, req uint) (int, error) {
var value int
err := ioctlPtr(fd, req, unsafe.Pointer(&value))
return value, err
}
func IoctlGetWinsize(fd int, req uint) (*Winsize, error) {
var value Winsize
err := ioctlPtr(fd, req, unsafe.Pointer(&value))
return &value, err
}
func IoctlGetTermios(fd int, req uint) (*Termios, error) {
var value Termios
err := ioctlPtr(fd, req, unsafe.Pointer(&value))
return &value, err
}
================================================
FILE: gossh/sys/unix/ioctl_zos.go
================================================
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build zos && s390x
package unix
import (
"runtime"
"unsafe"
)
// ioctl itself should not be exposed directly, but additional get/set
// functions for specific types are permissible.
// IoctlSetInt performs an ioctl operation which sets an integer value
// on fd, using the specified request number.
func IoctlSetInt(fd int, req int, value int) error {
return ioctl(fd, req, uintptr(value))
}
// IoctlSetWinsize performs an ioctl on fd with a *Winsize argument.
//
// To change fd's window size, the req argument should be TIOCSWINSZ.
func IoctlSetWinsize(fd int, req int, value *Winsize) error {
// TODO: if we get the chance, remove the req parameter and
// hardcode TIOCSWINSZ.
return ioctlPtr(fd, req, unsafe.Pointer(value))
}
// IoctlSetTermios performs an ioctl on fd with a *Termios.
//
// The req value is expected to be TCSETS, TCSETSW, or TCSETSF
func IoctlSetTermios(fd int, req int, value *Termios) error {
if (req != TCSETS) && (req != TCSETSW) && (req != TCSETSF) {
return ENOSYS
}
err := Tcsetattr(fd, int(req), value)
runtime.KeepAlive(value)
return err
}
// IoctlGetInt performs an ioctl operation which gets an integer value
// from fd, using the specified request number.
//
// A few ioctl requests use the return value as an output parameter;
// for those, IoctlRetInt should be used instead of this function.
func IoctlGetInt(fd int, req int) (int, error) {
var value int
err := ioctlPtr(fd, req, unsafe.Pointer(&value))
return value, err
}
func IoctlGetWinsize(fd int, req int) (*Winsize, error) {
var value Winsize
err := ioctlPtr(fd, req, unsafe.Pointer(&value))
return &value, err
}
// IoctlGetTermios performs an ioctl on fd with a *Termios.
//
// The req value is expected to be TCGETS
func IoctlGetTermios(fd int, req int) (*Termios, error) {
var value Termios
if req != TCGETS {
return &value, ENOSYS
}
err := Tcgetattr(fd, &value)
return &value, err
}
================================================
FILE: gossh/sys/unix/linux/Dockerfile
================================================
FROM ubuntu:20.04
# Disable interactive prompts on package installation
ENV DEBIAN_FRONTEND noninteractive
# Dependencies to get the git sources and go binaries
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates \
curl \
git \
rsync \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
# Get the git sources. If not cached, this takes O(5 minutes).
WORKDIR /git
RUN git config --global advice.detachedHead false
# Linux Kernel: Released 07 January 2024
RUN git clone --branch v6.7 --depth 1 https://kernel.googlesource.com/pub/scm/linux/kernel/git/torvalds/linux
# GNU C library: Released 1 Feb 2023
RUN git clone --branch release/2.37/master --depth 1 https://sourceware.org/git/glibc.git
# Get Go
ENV GOLANG_VERSION 1.21.0
ENV GOLANG_DOWNLOAD_URL https://golang.org/dl/go$GOLANG_VERSION.linux-amd64.tar.gz
ENV GOLANG_DOWNLOAD_SHA256 d0398903a16ba2232b389fb31032ddf57cac34efda306a0eebac34f0965a0742
RUN curl -fsSL "$GOLANG_DOWNLOAD_URL" -o golang.tar.gz \
&& echo "$GOLANG_DOWNLOAD_SHA256 golang.tar.gz" | sha256sum -c - \
&& tar -C /usr/local -xzf golang.tar.gz \
&& rm golang.tar.gz
ENV PATH /usr/local/go/bin:$PATH
# Linux and Glibc build dependencies and emulator
RUN apt-get update && apt-get install -y --no-install-recommends \
bison gawk make python3 \
gcc gcc-multilib \
gettext texinfo \
qemu-user \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
# Cross compilers (install recommended packages to get cross libc-dev)
RUN apt-get update && apt-get install -y \
gcc-aarch64-linux-gnu gcc-arm-linux-gnueabi \
gcc-mips-linux-gnu gcc-mips64-linux-gnuabi64 \
gcc-mips64el-linux-gnuabi64 gcc-mipsel-linux-gnu \
gcc-powerpc-linux-gnu gcc-powerpc64-linux-gnu \
gcc-powerpc64le-linux-gnu gcc-riscv64-linux-gnu \
gcc-s390x-linux-gnu gcc-sparc64-linux-gnu \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
# Only for loong64, getting tools of qemu-user and gcc-cross-compiler
ENV LOONG64_BASE_URL https://github.com/loongson/build-tools/releases/download/2023.08.08
ENV LOONG64_GCC CLFS-loongarch64-8.1-x86_64-cross-tools-gcc-glibc.tar.xz
ENV LOONG64_QEMU qemu-loongarch64
ENV LOONG64_GCC_DOWNLOAD_URL $LOONG64_BASE_URL/$LOONG64_GCC
ENV LOONG64_QEMU_DOWNLOAD_URL $LOONG64_BASE_URL/$LOONG64_QEMU
RUN apt-get update && apt-get install xz-utils -y && mkdir /loong64 && cd /loong64 \
&& curl -fsSL "$LOONG64_QEMU_DOWNLOAD_URL" -o /usr/bin/"$LOONG64_QEMU" \
&& chmod +x /usr/bin/"$LOONG64_QEMU" \
&& curl -fsSL "$LOONG64_GCC_DOWNLOAD_URL" -o "$LOONG64_GCC" \
&& tar xf "$LOONG64_GCC" -C /usr/local/ \
&& ln -s /usr/local/cross-tools/bin/loongarch64-unknown-linux-gnu-gcc /usr/bin/loongarch64-linux-gnu-gcc \
&& rm -rf /loong64
# Let the scripts know they are in the docker environment
ENV GOLANG_SYS_BUILD docker
WORKDIR /build/unix
ENTRYPOINT ["go", "run", "linux/mkall.go", "/git/linux", "/git/glibc"]
================================================
FILE: gossh/sys/unix/linux/mkall.go
================================================
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// linux/mkall.go - Generates all Linux zsysnum, zsyscall, zerror, and ztype
// files for all Linux architectures supported by the go compiler. See
// README.md for more information about the build system.
// To run it you must have a git checkout of the Linux kernel and glibc. Once
// the appropriate sources are ready, the program is run as:
// go run linux/mkall.go
//go:build ignore
package main
import (
"bufio"
"bytes"
"debug/elf"
"encoding/binary"
"errors"
"fmt"
"go/build/constraint"
"io"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"sync"
"unicode"
)
// These will be paths to the appropriate source directories.
var LinuxDir string
var GlibcDir string
const TempDir = "/tmp"
const GOOS = "linux" // Only for Linux targets
const BuildArch = "amd64" // Must be built on this architecture
const MinKernel = "2.6.23" // https://golang.org/doc/install#requirements
type target struct {
GoArch string // Architecture name according to Go
LinuxArch string // Architecture name according to the Linux Kernel
GNUArch string // Architecture name according to GNU tools (https://wiki.debian.org/Multiarch/Tuples)
BigEndian bool // Default Little Endian
SignedChar bool // Is -fsigned-char needed (default no)
Bits int
env []string
stderrBuf bytes.Buffer
compiler string
}
// List of all Linux targets supported by the go compiler. Currently, sparc64 is
// not fully supported, but there is enough support already to generate Go type
// and error definitions.
var targets = []target{
{
GoArch: "386",
LinuxArch: "x86",
GNUArch: "i686-linux-gnu", // Note "i686" not "i386"
Bits: 32,
},
{
GoArch: "amd64",
LinuxArch: "x86",
GNUArch: "x86_64-linux-gnu",
Bits: 64,
},
{
GoArch: "arm64",
LinuxArch: "arm64",
GNUArch: "aarch64-linux-gnu",
SignedChar: true,
Bits: 64,
},
{
GoArch: "arm",
LinuxArch: "arm",
GNUArch: "arm-linux-gnueabi",
Bits: 32,
},
{
GoArch: "loong64",
LinuxArch: "loongarch",
GNUArch: "loongarch64-linux-gnu",
Bits: 64,
},
{
GoArch: "mips",
LinuxArch: "mips",
GNUArch: "mips-linux-gnu",
BigEndian: true,
Bits: 32,
},
{
GoArch: "mipsle",
LinuxArch: "mips",
GNUArch: "mipsel-linux-gnu",
Bits: 32,
},
{
GoArch: "mips64",
LinuxArch: "mips",
GNUArch: "mips64-linux-gnuabi64",
BigEndian: true,
Bits: 64,
},
{
GoArch: "mips64le",
LinuxArch: "mips",
GNUArch: "mips64el-linux-gnuabi64",
Bits: 64,
},
{
GoArch: "ppc",
LinuxArch: "powerpc",
GNUArch: "powerpc-linux-gnu",
BigEndian: true,
Bits: 32,
},
{
GoArch: "ppc64",
LinuxArch: "powerpc",
GNUArch: "powerpc64-linux-gnu",
BigEndian: true,
Bits: 64,
},
{
GoArch: "ppc64le",
LinuxArch: "powerpc",
GNUArch: "powerpc64le-linux-gnu",
Bits: 64,
},
{
GoArch: "riscv64",
LinuxArch: "riscv",
GNUArch: "riscv64-linux-gnu",
Bits: 64,
},
{
GoArch: "s390x",
LinuxArch: "s390",
GNUArch: "s390x-linux-gnu",
BigEndian: true,
SignedChar: true,
Bits: 64,
},
{
GoArch: "sparc64",
LinuxArch: "sparc",
GNUArch: "sparc64-linux-gnu",
BigEndian: true,
Bits: 64,
},
}
// ptracePairs is a list of pairs of targets that can, in some cases,
// run each other's binaries. 'archName' is the combined name of 'a1'
// and 'a2', which is used in the file name. Generally we use an 'x'
// suffix in the file name to indicate that the file works for both
// big-endian and little-endian, here we use 'nn' to indicate that this
// file is suitable for 32-bit and 64-bit.
var ptracePairs = []struct{ a1, a2, archName string }{
{"386", "amd64", "x86"},
{"arm", "arm64", "armnn"},
{"mips", "mips64", "mipsnn"},
{"mipsle", "mips64le", "mipsnnle"},
}
func main() {
if runtime.GOOS != GOOS || runtime.GOARCH != BuildArch {
fmt.Printf("Build system has GOOS_GOARCH = %s_%s, need %s_%s\n",
runtime.GOOS, runtime.GOARCH, GOOS, BuildArch)
return
}
// Check that we are using the new build system if we should
if os.Getenv("GOLANG_SYS_BUILD") != "docker" {
fmt.Println("In the new build system, mkall.go should not be called directly.")
fmt.Println("See README.md")
return
}
// Parse the command line options
if len(os.Args) != 3 {
fmt.Println("USAGE: go run linux/mkall.go ")
return
}
LinuxDir = os.Args[1]
GlibcDir = os.Args[2]
wg := sync.WaitGroup{}
for _, t := range targets {
fmt.Printf("arch %s: GENERATING\n", t.GoArch)
if err := t.setupEnvironment(); err != nil {
fmt.Printf("arch %s: could not setup environment: %v\n", t.GoArch, err)
break
}
includeDir := filepath.Join(TempDir, t.GoArch, "include")
// Make the include directory and fill it with headers
if err := os.MkdirAll(includeDir, os.ModePerm); err != nil {
fmt.Printf("arch %s: could not make directory: %v\n", t.GoArch, err)
break
}
// During header generation "/git/linux/scripts/basic/fixdep" is created by "basic/Makefile" for each
// instance of "make headers_install". This leads to a "text file is busy" error from any running
// "make headers_install" after the first one's target. Workaround is to serialize header generation
if err := t.makeHeaders(); err != nil {
fmt.Printf("arch %s: could not make header files: %v\n", t.GoArch, err)
break
}
wg.Add(1)
go func(t target) {
defer wg.Done()
fmt.Printf("arch %s: header files generated\n", t.GoArch)
if err := t.generateFiles(); err != nil {
fmt.Printf("%v\n***** FAILURE: %s *****\n\n", err, t.GoArch)
} else {
fmt.Printf("arch %s: SUCCESS\n", t.GoArch)
}
}(t)
}
wg.Wait()
fmt.Printf("----- GENERATING: merging generated files -----\n")
if err := mergeFiles(); err != nil {
fmt.Printf("%v\n***** FAILURE: merging generated files *****\n\n", err)
} else {
fmt.Printf("----- SUCCESS: merging generated files -----\n\n")
}
fmt.Printf("----- GENERATING ptrace pairs -----\n")
ok := true
for _, p := range ptracePairs {
if err := generatePtracePair(p.a1, p.a2, p.archName); err != nil {
fmt.Printf("%v\n***** FAILURE: %s/%s *****\n\n", err, p.a1, p.a2)
ok = false
}
}
// generate functions PtraceGetRegSetArm64 and PtraceSetRegSetArm64.
if err := generatePtraceRegSet("arm64"); err != nil {
fmt.Printf("%v\n***** FAILURE: generatePtraceRegSet(%q) *****\n\n", err, "arm64")
ok = false
}
if ok {
fmt.Printf("----- SUCCESS ptrace pairs -----\n\n")
}
}
func (t *target) printAndResetBuilder() {
if t.stderrBuf.Len() > 0 {
for _, l := range bytes.Split(t.stderrBuf.Bytes(), []byte{'\n'}) {
fmt.Printf("arch %s: stderr: %s\n", t.GoArch, l)
}
t.stderrBuf.Reset()
}
}
// Makes an exec.Cmd with Stderr attached to the target string Builder, and target environment
func (t *target) makeCommand(name string, args ...string) *exec.Cmd {
cmd := exec.Command(name, args...)
cmd.Env = t.env
cmd.Stderr = &t.stderrBuf
return cmd
}
// Set GOARCH for target and build environments.
func (t *target) setTargetBuildArch(cmd *exec.Cmd) {
// Set GOARCH_TARGET so command knows what GOARCH is..
var env []string
env = append(env, t.env...)
cmd.Env = append(env, "GOARCH_TARGET="+t.GoArch)
// Set GOARCH to host arch for command, so it can run natively.
for i, s := range cmd.Env {
if strings.HasPrefix(s, "GOARCH=") {
cmd.Env[i] = "GOARCH=" + BuildArch
}
}
}
// Runs the command, pipes output to a formatter, pipes that to an output file.
func (t *target) commandFormatOutput(formatter string, outputFile string,
name string, args ...string) (err error) {
mainCmd := t.makeCommand(name, args...)
if name == "mksyscall" {
args = append([]string{"run", "mksyscall.go"}, args...)
mainCmd = t.makeCommand("go", args...)
t.setTargetBuildArch(mainCmd)
} else if name == "mksysnum" {
args = append([]string{"run", "linux/mksysnum.go"}, args...)
mainCmd = t.makeCommand("go", args...)
t.setTargetBuildArch(mainCmd)
}
fmtCmd := t.makeCommand(formatter)
if formatter == "mkpost" {
fmtCmd = t.makeCommand("go", "run", "mkpost.go")
t.setTargetBuildArch(fmtCmd)
} else if formatter == "gofmt2" {
fmtCmd = t.makeCommand("gofmt")
mainCmd.Dir = filepath.Join(TempDir, t.GoArch, "mkerrors")
if err = os.MkdirAll(mainCmd.Dir, os.ModePerm); err != nil {
return err
}
}
defer t.printAndResetBuilder()
// mainCmd | fmtCmd > outputFile
if fmtCmd.Stdin, err = mainCmd.StdoutPipe(); err != nil {
return
}
if fmtCmd.Stdout, err = os.Create(outputFile); err != nil {
return
}
// Make sure the formatter eventually closes
if err = fmtCmd.Start(); err != nil {
return
}
defer func() {
fmtErr := fmtCmd.Wait()
if err == nil {
err = fmtErr
}
}()
return mainCmd.Run()
}
func (t *target) setupEnvironment() error {
// Setup environment variables
t.env = append(os.Environ(), fmt.Sprintf("%s=%s", "GOOS", GOOS))
t.env = append(t.env, fmt.Sprintf("%s=%s", "GOARCH", t.GoArch))
// Get appropriate compiler and emulator (unless on x86)
if t.LinuxArch != "x86" {
// Check/Setup cross compiler
t.compiler = t.GNUArch + "-gcc"
if _, err := exec.LookPath(t.compiler); err != nil {
return err
}
t.env = append(t.env, fmt.Sprintf("%s=%s", "CC", t.compiler))
// Check/Setup emulator (usually first component of GNUArch)
qemuArchName := t.GNUArch[:strings.Index(t.GNUArch, "-")]
if t.LinuxArch == "powerpc" {
qemuArchName = t.GoArch
}
// Fake uname for QEMU to allow running on Host kernel version < 4.15
if t.LinuxArch == "riscv" {
t.env = append(t.env, fmt.Sprintf("%s=%s", "QEMU_UNAME", "4.15"))
}
t.env = append(t.env, fmt.Sprintf("%s=%s", "GORUN", "qemu-"+qemuArchName))
} else {
t.compiler = "gcc"
t.env = append(t.env, fmt.Sprintf("%s=%s", "CC", "gcc"))
}
return nil
}
// Generates all the files for a Linux target
func (t *target) generateFiles() error {
// Make each of the four files
if err := t.makeZSysnumFile(); err != nil {
return fmt.Errorf("could not make zsysnum file: %v", err)
}
fmt.Printf("arch %s: zsysnum file generated\n", t.GoArch)
if err := t.makeZSyscallFile(); err != nil {
return fmt.Errorf("could not make zsyscall file: %v", err)
}
fmt.Printf("arch %s: zsyscall file generated\n", t.GoArch)
if err := t.makeZTypesFile(); err != nil {
return fmt.Errorf("could not make ztypes file: %v", err)
}
fmt.Printf("arch %s: ztypes file generated\n", t.GoArch)
if err := t.makeZErrorsFile(); err != nil {
return fmt.Errorf("could not make zerrors file: %v", err)
}
fmt.Printf("arch %s: zerrors file generated\n", t.GoArch)
return nil
}
// Create the Linux, glibc and ABI (C compiler convention) headers in the include directory.
func (t *target) makeHeaders() error {
defer t.printAndResetBuilder()
// Make the Linux headers we need for this architecture
linuxMake := t.makeCommand("make", "headers_install", "ARCH="+t.LinuxArch, "INSTALL_HDR_PATH="+filepath.Join(TempDir, t.GoArch))
linuxMake.Dir = LinuxDir
if err := linuxMake.Run(); err != nil {
return err
}
buildDir := filepath.Join(TempDir, t.GoArch, "build")
// A Temporary build directory for glibc
if err := os.MkdirAll(buildDir, os.ModePerm); err != nil {
return err
}
defer os.RemoveAll(buildDir)
// Make the glibc headers we need for this architecture
confScript := filepath.Join(GlibcDir, "configure")
glibcArgs := []string{"--prefix=" + filepath.Join(TempDir, t.GoArch), "--host=" + t.GNUArch}
if t.LinuxArch == "loongarch" {
// The minimum version requirement of the Loongarch for the kernel in glibc
// is 5.19, if --enable-kernel is less than 5.19, glibc handles errors
glibcArgs = append(glibcArgs, "--enable-kernel=5.19.0")
} else {
glibcArgs = append(glibcArgs, "--enable-kernel="+MinKernel)
}
glibcConf := t.makeCommand(confScript, glibcArgs...)
glibcConf.Dir = buildDir
if err := glibcConf.Run(); err != nil {
return err
}
glibcMake := t.makeCommand("make", "install-headers")
glibcMake.Dir = buildDir
if err := glibcMake.Run(); err != nil {
return err
}
// We only need an empty stubs file
stubsFile := filepath.Join(TempDir, t.GoArch, "include", "gnu", "stubs.h")
if file, err := os.Create(stubsFile); err != nil {
return err
} else {
file.Close()
}
// ABI headers will specify C compiler behavior for the target platform.
return t.makeABIHeaders()
}
// makeABIHeaders generates C header files based on the platform's calling convention.
// While many platforms have formal Application Binary Interfaces, in practice, whatever the
// dominant C compilers generate is the de-facto calling convention.
//
// We generate C headers instead of a Go file, so as to enable references to the ABI from Cgo.
func (t *target) makeABIHeaders() (err error) {
abiDir := filepath.Join(TempDir, t.GoArch, "include", "abi")
if err = os.Mkdir(abiDir, os.ModePerm); err != nil {
return err
}
if t.compiler == "" {
return errors.New("CC (compiler) env var not set")
}
// Build a sacrificial ELF file, to mine for C compiler behavior.
binPath := filepath.Join(TempDir, t.GoArch, "tmp_abi.o")
bin, err := t.buildELF(t.compiler, cCode, binPath)
if err != nil {
return fmt.Errorf("cannot build ELF to analyze: %v", err)
}
defer bin.Close()
defer os.Remove(binPath)
// Right now, we put everything in abi.h, but we may change this later.
abiFile, err := os.Create(filepath.Join(abiDir, "abi.h"))
if err != nil {
return err
}
defer func() {
if cerr := abiFile.Close(); cerr != nil && err == nil {
err = cerr
}
}()
if err = t.writeBitFieldMasks(bin, abiFile); err != nil {
return fmt.Errorf("cannot write bitfield masks: %v", err)
}
return nil
}
func (t *target) buildELF(cc, src, path string) (*elf.File, error) {
// Compile the cCode source using the set compiler - we will need its .data section.
// Do not link the binary, so that we can find .data section offsets from the symbol values.
ccCmd := t.makeCommand(cc, "-o", path, "-gdwarf", "-x", "c", "-c", "-")
ccCmd.Stdin = strings.NewReader(src)
ccCmd.Stdout = os.Stdout
defer t.printAndResetBuilder()
if err := ccCmd.Run(); err != nil {
return nil, fmt.Errorf("compiler error: %v", err)
}
bin, err := elf.Open(path)
if err != nil {
return nil, fmt.Errorf("cannot read ELF file %s: %v", path, err)
}
return bin, nil
}
func (t *target) writeBitFieldMasks(bin *elf.File, out io.Writer) error {
symbols, err := bin.Symbols()
if err != nil {
return fmt.Errorf("getting ELF symbols: %v", err)
}
var masksSym *elf.Symbol
for _, sym := range symbols {
if sym.Name == "masks" {
masksSym = &sym
}
}
if masksSym == nil {
return errors.New("could not find the 'masks' symbol in ELF symtab")
}
dataSection := bin.Section(".data")
if dataSection == nil {
return errors.New("ELF file has no .data section")
}
data, err := dataSection.Data()
if err != nil {
return fmt.Errorf("could not read .data section: %v\n", err)
}
var bo binary.ByteOrder
if t.BigEndian {
bo = binary.BigEndian
} else {
bo = binary.LittleEndian
}
// 64 bit masks of type uint64 are stored in the data section starting at masks.Value.
// Here we are running on AMD64, but these values may be big endian or little endian,
// depending on target architecture.
for i := uint64(0); i < 64; i++ {
off := masksSym.Value + i*8
// Define each mask in native by order, so as to match target endian.
fmt.Fprintf(out, "#define BITFIELD_MASK_%d %dULL\n", i, bo.Uint64(data[off:off+8]))
}
return nil
}
// makes the zsysnum_linux_$GOARCH.go file
func (t *target) makeZSysnumFile() error {
zsysnumFile := fmt.Sprintf("zsysnum_linux_%s.go", t.GoArch)
unistdFile := filepath.Join(TempDir, t.GoArch, "include", "asm", "unistd.h")
args := append(t.cFlags(), unistdFile)
return t.commandFormatOutput("gofmt", zsysnumFile, "mksysnum", args...)
}
// makes the zsyscall_linux_$GOARCH.go file
func (t *target) makeZSyscallFile() error {
zsyscallFile := fmt.Sprintf("zsyscall_linux_%s.go", t.GoArch)
// Find the correct architecture syscall file (might end with x.go)
archSyscallFile := fmt.Sprintf("syscall_linux_%s.go", t.GoArch)
if _, err := os.Stat(archSyscallFile); os.IsNotExist(err) {
shortArch := strings.TrimSuffix(t.GoArch, "le")
archSyscallFile = fmt.Sprintf("syscall_linux_%sx.go", shortArch)
}
args := append(t.mksyscallFlags(), "-tags", "linux,"+t.GoArch,
"syscall_linux.go",
archSyscallFile,
)
files, err := t.archMksyscallFiles()
if err != nil {
return fmt.Errorf("failed to check GOARCH-specific mksyscall files: %v", err)
}
args = append(args, files...)
return t.commandFormatOutput("gofmt", zsyscallFile, "mksyscall", args...)
}
// archMksyscallFiles produces additional file arguments to mksyscall if the
// build constraints in those files match those defined for target.
func (t *target) archMksyscallFiles() ([]string, error) {
// These input files don't fit the typical GOOS/GOARCH file name conventions
// but are included conditionally in the arguments to mksyscall based on
// whether or not the target matches the build constraints defined in each
// file.
//
// TODO(mdlayher): it should be possible to generalize this approach to work
// over all of syscall_linux_* rather than hard-coding a few special files.
// Investigate this.
inputs := []string{
// GOARCH: all except arm* and riscv.
"syscall_linux_alarm.go",
}
var outputs []string
for _, in := range inputs {
ok, err := t.matchesMksyscallFile(in)
if err != nil {
return nil, fmt.Errorf("failed to parse file %q: %v", in, err)
}
if ok {
// Constraints match, use for this target's code generation.
outputs = append(outputs, in)
}
}
return outputs, nil
}
// matchesMksyscallFile reports whether the input file contains constraints
// which match those defined for target.
func (t *target) matchesMksyscallFile(file string) (bool, error) {
f, err := os.Open(file)
if err != nil {
return false, err
}
defer f.Close()
var (
expr constraint.Expr
found bool
)
s := bufio.NewScanner(f)
for s.Scan() {
// Keep scanning until a valid constraint is found or we hit EOF.
// This is sufficient for the single-line //go:build constraints.
if expr, err = constraint.Parse(s.Text()); err == nil {
found = true
break
}
}
if err := s.Err(); err != nil {
return false, err
}
if !found {
return false, errors.New("no build constraints found")
}
// Do the defined constraints match target's GOOS/GOARCH?
ok := expr.Eval(func(tag string) bool {
return tag == GOOS || tag == t.GoArch
})
return ok, nil
}
// makes the zerrors_linux_$GOARCH.go file
func (t *target) makeZErrorsFile() error {
zerrorsFile := fmt.Sprintf("zerrors_linux_%s.go", t.GoArch)
return t.commandFormatOutput("gofmt2", zerrorsFile, "/"+filepath.Join("build", "unix", "mkerrors.sh"), t.cFlags()...)
}
// makes the ztypes_linux_$GOARCH.go file
func (t *target) makeZTypesFile() error {
ztypesFile := fmt.Sprintf("ztypes_linux_%s.go", t.GoArch)
cgoDir := filepath.Join(TempDir, t.GoArch, "cgo")
if err := os.MkdirAll(cgoDir, os.ModePerm); err != nil {
return err
}
args := []string{"tool", "cgo", "-godefs", "-objdir=" + cgoDir, "--"}
args = append(args, t.cFlags()...)
args = append(args, "linux/types.go")
return t.commandFormatOutput("mkpost", ztypesFile, "go", args...)
}
// Flags that should be given to gcc and cgo for this target
func (t *target) cFlags() []string {
// Compile statically to avoid cross-architecture dynamic linking.
flags := []string{"-Wall", "-Werror", "-static", "-I" + filepath.Join(TempDir, t.GoArch, "include")}
// Architecture-specific flags
if t.SignedChar {
flags = append(flags, "-fsigned-char")
}
if t.LinuxArch == "x86" {
flags = append(flags, fmt.Sprintf("-m%d", t.Bits))
}
return flags
}
// Flags that should be given to mksyscall for this target
func (t *target) mksyscallFlags() (flags []string) {
if t.Bits == 32 {
if t.BigEndian {
flags = append(flags, "-b32")
} else {
flags = append(flags, "-l32")
}
}
// This flag means a 64-bit value should use (even, odd)-pair.
if t.GoArch == "arm" || (t.LinuxArch == "mips" && t.Bits == 32) {
flags = append(flags, "-arm")
}
return
}
// Merge all the generated files for Linux targets
func mergeFiles() error {
// Setup environment variables
os.Setenv("GOOS", runtime.GOOS)
os.Setenv("GOARCH", runtime.GOARCH)
// Merge each of the four type of files
for _, ztyp := range []string{"zerrors", "zsyscall", "zsysnum", "ztypes"} {
cmd := exec.Command("go", "run", "./internal/mkmerge", "-out", fmt.Sprintf("%s_%s.go", ztyp, GOOS), fmt.Sprintf("%s_%s_*.go", ztyp, GOOS))
cmd.Stderr = os.Stderr
err := cmd.Run()
if err != nil {
return fmt.Errorf("could not merge %s files: %w", ztyp, err)
}
fmt.Printf("%s files merged\n", ztyp)
}
return nil
}
// generatePtracePair takes a pair of GOARCH values that can run each
// other's binaries, such as 386 and amd64. It extracts the PtraceRegs
// type for each one. It writes a new file defining the types
// PtraceRegsArch1 and PtraceRegsArch2 and the corresponding functions
// Ptrace{Get,Set}Regs{arch1,arch2}. This permits debugging the other
// binary on a native system. 'archName' is the combined name of 'arch1'
// and 'arch2', which is used in the file name.
func generatePtracePair(arch1, arch2, archName string) error {
def1, err := ptraceDef(arch1)
if err != nil {
return err
}
def2, err := ptraceDef(arch2)
if err != nil {
return err
}
f, err := os.Create(fmt.Sprintf("zptrace_%s_linux.go", archName))
if err != nil {
return err
}
buf := bufio.NewWriter(f)
fmt.Fprintf(buf, "// Code generated by linux/mkall.go generatePtracePair(%q, %q). DO NOT EDIT.\n", arch1, arch2)
fmt.Fprintf(buf, "\n")
fmt.Fprintf(buf, "//go:build linux && (%s || %s)\n", arch1, arch2)
fmt.Fprintf(buf, "\n")
fmt.Fprintf(buf, "package unix\n")
fmt.Fprintf(buf, "\n")
fmt.Fprintf(buf, "%s\n", `import "unsafe"`)
fmt.Fprintf(buf, "\n")
writeOnePtrace(buf, arch1, def1)
fmt.Fprintf(buf, "\n")
writeOnePtrace(buf, arch2, def2)
if err := buf.Flush(); err != nil {
return err
}
if err := f.Close(); err != nil {
return err
}
return nil
}
// generatePtraceRegSet takes a GOARCH value to generate a file zptrace_linux_{arch}.go
// containing functions PtraceGetRegSet{arch} and PtraceSetRegSet{arch}.
func generatePtraceRegSet(arch string) error {
f, err := os.Create(fmt.Sprintf("zptrace_linux_%s.go", arch))
if err != nil {
return err
}
buf := bufio.NewWriter(f)
fmt.Fprintf(buf, "// Code generated by linux/mkall.go generatePtraceRegSet(%q). DO NOT EDIT.\n", arch)
fmt.Fprintf(buf, "\n")
fmt.Fprintf(buf, "package unix\n")
fmt.Fprintf(buf, "\n")
fmt.Fprintf(buf, "%s\n", `import "unsafe"`)
fmt.Fprintf(buf, "\n")
uarch := string(unicode.ToUpper(rune(arch[0]))) + arch[1:]
fmt.Fprintf(buf, "// PtraceGetRegSet%s fetches the registers used by %s binaries.\n", uarch, arch)
fmt.Fprintf(buf, "func PtraceGetRegSet%s(pid, addr int, regsout *PtraceRegs%s) error {\n", uarch, uarch)
fmt.Fprintf(buf, "\tiovec := Iovec{(*byte)(unsafe.Pointer(regsout)), uint64(unsafe.Sizeof(*regsout))}\n")
fmt.Fprintf(buf, "\treturn ptracePtr(PTRACE_GETREGSET, pid, uintptr(addr), unsafe.Pointer(&iovec))\n")
fmt.Fprintf(buf, "}\n")
fmt.Fprintf(buf, "\n")
fmt.Fprintf(buf, "// PtraceSetRegSet%s sets the registers used by %s binaries.\n", uarch, arch)
fmt.Fprintf(buf, "func PtraceSetRegSet%s(pid, addr int, regs *PtraceRegs%s) error {\n", uarch, uarch)
fmt.Fprintf(buf, "\tiovec := Iovec{(*byte)(unsafe.Pointer(regs)), uint64(unsafe.Sizeof(*regs))}\n")
fmt.Fprintf(buf, "\treturn ptracePtr(PTRACE_SETREGSET, pid, uintptr(addr), unsafe.Pointer(&iovec))\n")
fmt.Fprintf(buf, "}\n")
if err := buf.Flush(); err != nil {
return err
}
if err := f.Close(); err != nil {
return err
}
return nil
}
// ptraceDef returns the definition of PtraceRegs for arch.
func ptraceDef(arch string) (string, error) {
filename := fmt.Sprintf("ztypes_linux_%s.go", arch)
data, err := os.ReadFile(filename)
if err != nil {
return "", fmt.Errorf("reading %s: %v", filename, err)
}
start := bytes.Index(data, []byte("type PtraceRegs struct"))
if start < 0 {
return "", fmt.Errorf("%s: no definition of PtraceRegs", filename)
}
data = data[start:]
end := bytes.Index(data, []byte("\n}\n"))
if end < 0 {
return "", fmt.Errorf("%s: can't find end of PtraceRegs definition", filename)
}
return string(data[:end+2]), nil
}
// writeOnePtrace writes out the ptrace definitions for arch.
func writeOnePtrace(w io.Writer, arch, def string) {
uarch := string(unicode.ToUpper(rune(arch[0]))) + arch[1:]
fmt.Fprintf(w, "// PtraceRegs%s is the registers used by %s binaries.\n", uarch, arch)
fmt.Fprintf(w, "%s\n", strings.Replace(def, "PtraceRegs", "PtraceRegs"+uarch, 1))
fmt.Fprintf(w, "\n")
fmt.Fprintf(w, "// PtraceGetRegs%s fetches the registers used by %s binaries.\n", uarch, arch)
fmt.Fprintf(w, "func PtraceGetRegs%s(pid int, regsout *PtraceRegs%s) error {\n", uarch, uarch)
fmt.Fprintf(w, "\treturn ptracePtr(PTRACE_GETREGS, pid, 0, unsafe.Pointer(regsout))\n")
fmt.Fprintf(w, "}\n")
fmt.Fprintf(w, "\n")
fmt.Fprintf(w, "// PtraceSetRegs%s sets the registers used by %s binaries.\n", uarch, arch)
fmt.Fprintf(w, "func PtraceSetRegs%s(pid int, regs *PtraceRegs%s) error {\n", uarch, uarch)
fmt.Fprintf(w, "\treturn ptracePtr(PTRACE_SETREGS, pid, 0, unsafe.Pointer(regs))\n")
fmt.Fprintf(w, "}\n")
}
// cCode is compiled for the target architecture, and the resulting data section is carved for
// the statically initialized bit masks.
const cCode = `
// Bit fields are used in some system calls and other ABIs, but their memory layout is
// implementation-defined [1]. Even with formal ABIs, bit fields are a source of subtle bugs [2].
// Here we generate the offsets for all 64 bits in an uint64.
// 1: http://en.cppreference.com/w/c/language/bit_field
// 2: https://lwn.net/Articles/478657/
#include
struct bitfield {
union {
uint64_t val;
struct {
uint64_t u64_bit_0 : 1;
uint64_t u64_bit_1 : 1;
uint64_t u64_bit_2 : 1;
uint64_t u64_bit_3 : 1;
uint64_t u64_bit_4 : 1;
uint64_t u64_bit_5 : 1;
uint64_t u64_bit_6 : 1;
uint64_t u64_bit_7 : 1;
uint64_t u64_bit_8 : 1;
uint64_t u64_bit_9 : 1;
uint64_t u64_bit_10 : 1;
uint64_t u64_bit_11 : 1;
uint64_t u64_bit_12 : 1;
uint64_t u64_bit_13 : 1;
uint64_t u64_bit_14 : 1;
uint64_t u64_bit_15 : 1;
uint64_t u64_bit_16 : 1;
uint64_t u64_bit_17 : 1;
uint64_t u64_bit_18 : 1;
uint64_t u64_bit_19 : 1;
uint64_t u64_bit_20 : 1;
uint64_t u64_bit_21 : 1;
uint64_t u64_bit_22 : 1;
uint64_t u64_bit_23 : 1;
uint64_t u64_bit_24 : 1;
uint64_t u64_bit_25 : 1;
uint64_t u64_bit_26 : 1;
uint64_t u64_bit_27 : 1;
uint64_t u64_bit_28 : 1;
uint64_t u64_bit_29 : 1;
uint64_t u64_bit_30 : 1;
uint64_t u64_bit_31 : 1;
uint64_t u64_bit_32 : 1;
uint64_t u64_bit_33 : 1;
uint64_t u64_bit_34 : 1;
uint64_t u64_bit_35 : 1;
uint64_t u64_bit_36 : 1;
uint64_t u64_bit_37 : 1;
uint64_t u64_bit_38 : 1;
uint64_t u64_bit_39 : 1;
uint64_t u64_bit_40 : 1;
uint64_t u64_bit_41 : 1;
uint64_t u64_bit_42 : 1;
uint64_t u64_bit_43 : 1;
uint64_t u64_bit_44 : 1;
uint64_t u64_bit_45 : 1;
uint64_t u64_bit_46 : 1;
uint64_t u64_bit_47 : 1;
uint64_t u64_bit_48 : 1;
uint64_t u64_bit_49 : 1;
uint64_t u64_bit_50 : 1;
uint64_t u64_bit_51 : 1;
uint64_t u64_bit_52 : 1;
uint64_t u64_bit_53 : 1;
uint64_t u64_bit_54 : 1;
uint64_t u64_bit_55 : 1;
uint64_t u64_bit_56 : 1;
uint64_t u64_bit_57 : 1;
uint64_t u64_bit_58 : 1;
uint64_t u64_bit_59 : 1;
uint64_t u64_bit_60 : 1;
uint64_t u64_bit_61 : 1;
uint64_t u64_bit_62 : 1;
uint64_t u64_bit_63 : 1;
};
};
};
struct bitfield masks[] = {
{.u64_bit_0 = 1},
{.u64_bit_1 = 1},
{.u64_bit_2 = 1},
{.u64_bit_3 = 1},
{.u64_bit_4 = 1},
{.u64_bit_5 = 1},
{.u64_bit_6 = 1},
{.u64_bit_7 = 1},
{.u64_bit_8 = 1},
{.u64_bit_9 = 1},
{.u64_bit_10 = 1},
{.u64_bit_11 = 1},
{.u64_bit_12 = 1},
{.u64_bit_13 = 1},
{.u64_bit_14 = 1},
{.u64_bit_15 = 1},
{.u64_bit_16 = 1},
{.u64_bit_17 = 1},
{.u64_bit_18 = 1},
{.u64_bit_19 = 1},
{.u64_bit_20 = 1},
{.u64_bit_21 = 1},
{.u64_bit_22 = 1},
{.u64_bit_23 = 1},
{.u64_bit_24 = 1},
{.u64_bit_25 = 1},
{.u64_bit_26 = 1},
{.u64_bit_27 = 1},
{.u64_bit_28 = 1},
{.u64_bit_29 = 1},
{.u64_bit_30 = 1},
{.u64_bit_31 = 1},
{.u64_bit_32 = 1},
{.u64_bit_33 = 1},
{.u64_bit_34 = 1},
{.u64_bit_35 = 1},
{.u64_bit_36 = 1},
{.u64_bit_37 = 1},
{.u64_bit_38 = 1},
{.u64_bit_39 = 1},
{.u64_bit_40 = 1},
{.u64_bit_41 = 1},
{.u64_bit_42 = 1},
{.u64_bit_43 = 1},
{.u64_bit_44 = 1},
{.u64_bit_45 = 1},
{.u64_bit_46 = 1},
{.u64_bit_47 = 1},
{.u64_bit_48 = 1},
{.u64_bit_49 = 1},
{.u64_bit_50 = 1},
{.u64_bit_51 = 1},
{.u64_bit_52 = 1},
{.u64_bit_53 = 1},
{.u64_bit_54 = 1},
{.u64_bit_55 = 1},
{.u64_bit_56 = 1},
{.u64_bit_57 = 1},
{.u64_bit_58 = 1},
{.u64_bit_59 = 1},
{.u64_bit_60 = 1},
{.u64_bit_61 = 1},
{.u64_bit_62 = 1},
{.u64_bit_63 = 1}
};
int main(int argc, char **argv) {
struct bitfield *mask_ptr = &masks[0];
return mask_ptr->val;
}
`
================================================
FILE: gossh/sys/unix/linux/mksysnum.go
================================================
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build ignore
package main
import (
"bufio"
"fmt"
"os"
"os/exec"
"regexp"
"sort"
"strconv"
"strings"
)
var (
goos, goarch string
)
// cmdLine returns this programs's commandline arguments
func cmdLine() string {
return "go run linux/mksysnum.go " + strings.Join(os.Args[1:], " ")
}
// goBuildTags returns build tags in the go:build format.
func goBuildTags() string {
return fmt.Sprintf("%s && %s", goarch, goos)
}
func format(name string, num int, offset int) (int, string) {
if num > 999 {
// ignore deprecated syscalls that are no longer implemented
// https://git.kernel.org/cgit/linux/kernel/git/torvalds/linux.git/tree/include/uapi/asm-generic/unistd.h?id=refs/heads/master#n716
return 0, ""
}
name = strings.ToUpper(name)
num = num + offset
return num, fmt.Sprintf(" SYS_%s = %d;\n", name, num)
}
func checkErr(err error) {
if err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
os.Exit(1)
}
}
// source string and substring slice for regexp
type re struct {
str string // source string
sub []string // matched sub-string
}
// Match performs regular expression match
func (r *re) Match(exp string) bool {
r.sub = regexp.MustCompile(exp).FindStringSubmatch(r.str)
if r.sub != nil {
return true
}
return false
}
// syscallNum holds the syscall number and the string
// we will write to the generated file.
type syscallNum struct {
num int
declaration string
}
// syscallNums is a slice of syscallNum sorted by the syscall number in ascending order.
type syscallNums []syscallNum
// addSyscallNum adds the syscall declaration to syscallNums.
func (nums *syscallNums) addSyscallNum(num int, declaration string) {
if declaration == "" {
return
}
if len(*nums) == 0 || (*nums)[len(*nums)-1].num <= num {
// This is the most common case as the syscall declarations output by the preprocessor
// are almost always sorted.
*nums = append(*nums, syscallNum{num, declaration})
return
}
i := sort.Search(len(*nums), func(i int) bool { return (*nums)[i].num >= num })
// Maintain the ordering in the preprocessor output when we have multiple definitions with
// the same value. i cannot be > len(nums) - 1 as nums[len(nums)-1].num > num.
for ; (*nums)[i].num == num; i++ {
}
*nums = append((*nums)[:i], append([]syscallNum{{num, declaration}}, (*nums)[i:]...)...)
}
func main() {
// Get the OS and architecture (using GOARCH_TARGET if it exists)
goos = os.Getenv("GOOS")
goarch = os.Getenv("GOARCH_TARGET")
if goarch == "" {
goarch = os.Getenv("GOARCH")
}
// Check if GOOS and GOARCH environment variables are defined
if goarch == "" || goos == "" {
fmt.Fprintf(os.Stderr, "GOARCH or GOOS not defined in environment\n")
os.Exit(1)
}
// Check that we are using the new build system if we should
if os.Getenv("GOLANG_SYS_BUILD") != "docker" {
fmt.Fprintf(os.Stderr, "In the new build system, mksysnum should not be called directly.\n")
fmt.Fprintf(os.Stderr, "See README.md\n")
os.Exit(1)
}
cc := os.Getenv("CC")
if cc == "" {
fmt.Fprintf(os.Stderr, "CC is not defined in environment\n")
os.Exit(1)
}
args := os.Args[1:]
args = append([]string{"-E", "-dD"}, args...)
cmd, err := exec.Command(cc, args...).Output() // execute command and capture output
if err != nil {
fmt.Fprintf(os.Stderr, "can't run %s", cc)
os.Exit(1)
}
s := bufio.NewScanner(strings.NewReader(string(cmd)))
var offset, prev, asOffset int
var nums syscallNums
for s.Scan() {
t := re{str: s.Text()}
// The generated zsysnum_linux_*.go files for some platforms (arm64, loong64, riscv64)
// treat SYS_ARCH_SPECIFIC_SYSCALL as if it's a syscall which it isn't. It's an offset.
// However, as this constant is already part of the public API we leave it in place.
// Lines of type SYS_ARCH_SPECIFIC_SYSCALL = 244 are thus processed twice, once to extract
// the offset and once to add the constant.
if t.Match(`^#define __NR_arch_specific_syscall\s+([0-9]+)`) {
// riscv: extract arch specific offset
asOffset, _ = strconv.Atoi(t.sub[1]) // Make asOffset=0 if empty or non-numeric
}
if t.Match(`^#define __NR_Linux\s+([0-9]+)`) {
// mips/mips64: extract offset
offset, _ = strconv.Atoi(t.sub[1]) // Make offset=0 if empty or non-numeric
} else if t.Match(`^#define __NR(\w*)_SYSCALL_BASE\s+([0-9]+)`) {
// arm: extract offset
offset, _ = strconv.Atoi(t.sub[1]) // Make offset=0 if empty or non-numeric
} else if t.Match(`^#define __NR_syscalls\s+`) {
// ignore redefinitions of __NR_syscalls
} else if t.Match(`^#define __NR_(\w*)Linux_syscalls\s+`) {
// mips/mips64: ignore definitions about the number of syscalls
} else if t.Match(`^#define __NR_(\w+)\s+([0-9]+)`) {
prev, err = strconv.Atoi(t.sub[2])
checkErr(err)
nums.addSyscallNum(format(t.sub[1], prev, offset))
} else if t.Match(`^#define __NR3264_(\w+)\s+([0-9]+)`) {
prev, err = strconv.Atoi(t.sub[2])
checkErr(err)
nums.addSyscallNum(format(t.sub[1], prev, offset))
} else if t.Match(`^#define __NR_(\w+)\s+\(\w+\+\s*([0-9]+)\)`) {
r2, err := strconv.Atoi(t.sub[2])
checkErr(err)
nums.addSyscallNum(format(t.sub[1], prev+r2, offset))
} else if t.Match(`^#define __NR_(\w+)\s+\(__NR_(?:SYSCALL_BASE|Linux) \+ ([0-9]+)`) {
r2, err := strconv.Atoi(t.sub[2])
checkErr(err)
nums.addSyscallNum(format(t.sub[1], r2, offset))
} else if asOffset != 0 && t.Match(`^#define __NR_(\w+)\s+\(__NR_arch_specific_syscall \+ ([0-9]+)`) {
r2, err := strconv.Atoi(t.sub[2])
checkErr(err)
nums.addSyscallNum(format(t.sub[1], r2, asOffset))
}
}
err = s.Err()
checkErr(err)
var text strings.Builder
for _, num := range nums {
text.WriteString(num.declaration)
}
fmt.Printf(template, cmdLine(), goBuildTags(), text.String())
}
const template = `// %s
// Code generated by the command above; see README.md. DO NOT EDIT.
//go:build %s
package unix
const(
%s)`
================================================
FILE: gossh/sys/unix/linux/types.go
================================================
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build ignore
/*
Input to cgo -godefs. See README.md
*/
// +godefs map struct_in_addr [4]byte /* in_addr */
// +godefs map struct_in6_addr [16]byte /* in6_addr */
// +godefs map struct___kernel_sockaddr_storage SockaddrStorage
package unix
/*
#define _LARGEFILE_SOURCE
#define _LARGEFILE64_SOURCE
#define _FILE_OFFSET_BITS 64
#define _GNU_SOURCE
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#if defined(__sparc__)
// On sparc{,64}, the kernel defines struct termios2 itself which clashes with the
// definition in glibc. Duplicate the kernel version here.
#if defined(__arch64__)
typedef unsigned int tcflag_t;
#else
typedef unsigned long tcflag_t;
#endif
struct termios2 {
tcflag_t c_iflag;
tcflag_t c_oflag;
tcflag_t c_cflag;
tcflag_t c_lflag;
unsigned char c_line;
unsigned char c_cc[19];
unsigned int c_ispeed;
unsigned int c_ospeed;
};
#else
#include
#endif
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
// This is to avoid a conflict of struct sched_param being defined by
// both the kernel and the glibc (sched.h) headers.
#define sched_param kernel_sched_param
#include
#undef kernel_sched_param
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
// abi/abi.h generated by mkall.go.
#include "abi/abi.h"
// On mips64, the glibc stat and kernel stat do not agree
#if (defined(__mips__) && _MIPS_SIM == _MIPS_SIM_ABI64)
// Use the stat defined by the kernel with a few modifications. These are:
// * The time fields (like st_atime and st_atimensec) use the timespec
// struct (like st_atim) for consistency with the glibc fields.
// * The padding fields get different names to not break compatibility.
// * st_blocks is signed, again for compatibility.
typedef struct {
unsigned int st_dev;
unsigned int st_pad1[3]; // Reserved for st_dev expansion
unsigned long st_ino;
mode_t st_mode;
__u32 st_nlink;
uid_t st_uid;
gid_t st_gid;
unsigned int st_rdev;
unsigned int st_pad2[3]; // Reserved for st_rdev expansion
off_t st_size;
// These are declared as separate fields in the kernel. Here we use
// the timespec struct for consistency with the other stat structs.
struct timespec st_atim;
struct timespec st_mtim;
struct timespec st_ctim;
unsigned int st_blksize;
unsigned int st_pad4;
long st_blocks;
} my_stat;
#else
typedef struct stat my_stat;
#endif
#ifdef TCSETS2
// On systems that have "struct termios2" use this as type Termios.
typedef struct termios2 termios_t;
#else
typedef struct termios termios_t;
#endif
enum {
sizeofPtr = sizeof(void*),
};
union sockaddr_all {
struct sockaddr s1; // this one gets used for fields
struct sockaddr_in s2; // these pad it out
struct sockaddr_in6 s3;
struct sockaddr_un s4;
struct sockaddr_ll s5;
struct sockaddr_nl s6;
struct sockaddr_pppox s7;
struct sockaddr_l2tpip s8;
struct sockaddr_l2tpip6 s9;
struct sockaddr_nfc s10;
struct sockaddr_nfc_llcp s11;
};
struct sockaddr_any {
struct sockaddr addr;
char pad[sizeof(union sockaddr_all) - sizeof(struct sockaddr)];
};
// copied from /usr/include/bluetooth/hci.h
struct sockaddr_hci {
sa_family_t hci_family;
unsigned short hci_dev;
unsigned short hci_channel;
};
// copied from /usr/include/bluetooth/bluetooth.h
#define BDADDR_BREDR 0x00
#define BDADDR_LE_PUBLIC 0x01
#define BDADDR_LE_RANDOM 0x02
// copied from /usr/include/bluetooth/l2cap.h
struct sockaddr_l2 {
sa_family_t l2_family;
unsigned short l2_psm;
uint8_t l2_bdaddr[6];
unsigned short l2_cid;
uint8_t l2_bdaddr_type;
};
// copied from /usr/include/net/bluetooth/rfcomm.h
struct sockaddr_rc {
sa_family_t rc_family;
uint8_t rc_bdaddr[6];
uint8_t rc_channel;
};
// copied from /usr/include/linux/un.h
struct my_sockaddr_un {
sa_family_t sun_family;
#if defined(__ARM_EABI__) || defined(__powerpc__) || defined(__powerpc64__) || defined(__riscv)
// on some platforms char is unsigned by default
signed char sun_path[108];
#else
char sun_path[108];
#endif
};
// copied from /usr/include/netiucv/iucv.h modified with explicit signed chars.
struct sockaddr_iucv {
sa_family_t siucv_family;
unsigned short siucv_port;
unsigned int siucv_addr;
signed char siucv_nodeid[8];
signed char siucv_user_id[8];
signed char siucv_name[8];
};
// copied from /usr/include/linux/nfc.h modified with explicit unsigned chars.
struct my_sockaddr_nfc_llcp {
sa_family_t sa_family;
uint32_t dev_idx;
uint32_t target_idx;
uint32_t nfc_protocol;
uint8_t dsap;
uint8_t ssap;
uint8_t service_name[NFC_LLCP_MAX_SERVICE_NAME];
size_t service_name_len;
};
#ifdef __ARM_EABI__
typedef struct user_regs PtraceRegs;
#elif defined(__aarch64__) || defined(__loongarch64)
typedef struct user_pt_regs PtraceRegs;
#elif defined(__mips__) || defined(__powerpc__) || defined(__powerpc64__)
typedef struct pt_regs PtraceRegs;
#elif defined(__s390x__)
typedef struct _user_regs_struct PtraceRegs;
#elif defined(__sparc__)
#include
typedef struct pt_regs PtraceRegs;
#else
typedef struct user_regs_struct PtraceRegs;
#endif
#if defined(__s390x__)
typedef struct _user_psw_struct ptracePsw;
typedef struct _user_fpregs_struct ptraceFpregs;
typedef struct _user_per_struct ptracePer;
#else
typedef struct {} ptracePsw;
typedef struct {} ptraceFpregs;
typedef struct {} ptracePer;
#endif
// The real epoll_event is a union, and godefs doesn't handle it well.
struct my_epoll_event {
uint32_t events;
#if defined(__ARM_EABI__) || defined(__aarch64__) || (defined(__mips__) && _MIPS_SIM == _ABIO32)
// padding is not specified in linux/eventpoll.h but added to conform to the
// alignment requirements of EABI
int32_t padFd;
#elif defined(__powerpc__) || defined(__powerpc64__) || defined(__s390x__) || defined(__sparc__) \
|| defined(__riscv) || (defined(__mips__) && _MIPS_SIM == _MIPS_SIM_ABI64) \
|| defined(__loongarch64)
int32_t _padFd;
#endif
int32_t fd;
int32_t pad;
};
// Copied from with the following modifications:
// 1) bit field after read_format redeclared as '__u64 bits' to make it
// accessible from Go
// 2) collapsed the unions, to avoid confusing godoc for the generated output
// (e.g. having to use BpAddr as an extension of Config)
struct perf_event_attr_go {
__u32 type;
__u32 size;
__u64 config;
// union {
// __u64 sample_period;
// __u64 sample_freq;
// };
__u64 sample;
__u64 sample_type;
__u64 read_format;
// Replaces the bit field. Flags are defined as constants.
__u64 bits;
// union {
// __u32 wakeup_events;
// __u32 wakeup_watermark;
// };
__u32 wakeup;
__u32 bp_type;
// union {
// __u64 bp_addr;
// __u64 config1;
// };
__u64 ext1;
// union {
// __u64 bp_len;
// __u64 config2;
// };
__u64 ext2;
__u64 branch_sample_type;
__u64 sample_regs_user;
__u32 sample_stack_user;
__s32 clockid;
__u64 sample_regs_intr;
__u32 aux_watermark;
__u16 sample_max_stack;
__u16 __reserved_2;
__u32 aux_sample_size;
__u32 __reserved_3;
__u64 sig_data;
};
// ustat is deprecated and glibc 2.28 removed ustat.h. Provide the type here for
// backwards compatibility. Copied from /usr/include/bits/ustat.h
struct ustat {
__daddr_t f_tfree;
__ino_t f_tinode;
char f_fname[6];
char f_fpack[6];
};
// my_blkpg_partition is blkpg_partition with unsigned devname & volname.
struct my_blkpg_partition {
long long start;
long long length;
int pno;
unsigned char devname[BLKPG_DEVNAMELTH];
unsigned char volname[BLKPG_VOLNAMELTH];
};
// tipc_service_name is a copied and slightly modified form of the "name"
// variant in sockaddr_tipc's union in tipc.h, so it can be exported as part of
// SockaddrTIPC's API.
struct tipc_service_name {
// From tipc_service_addr.
__u32 type;
__u32 instance;
// From the union.
__u32 domain;
};
// copied from /usr/include/linux/can/netlink.h modified with explicit unsigned chars.
struct my_can_bittiming_const {
uint8_t name[16];
__u32 tseg1_min;
__u32 tseg1_max;
__u32 tseg2_min;
__u32 tseg2_max;
__u32 sjw_max;
__u32 brp_min;
__u32 brp_max;
__u32 brp_inc;
};
#if defined(__riscv)
#include
#else
// copied from /usr/include/asm/hwprobe.h
// values are not used but they need to be defined.
#define RISCV_HWPROBE_KEY_MVENDORID 0
#define RISCV_HWPROBE_KEY_MARCHID 1
#define RISCV_HWPROBE_KEY_MIMPID 2
#define RISCV_HWPROBE_KEY_BASE_BEHAVIOR 3
#define RISCV_HWPROBE_BASE_BEHAVIOR_IMA (1 << 0)
#define RISCV_HWPROBE_KEY_IMA_EXT_0 4
#define RISCV_HWPROBE_IMA_FD (1 << 0)
#define RISCV_HWPROBE_IMA_C (1 << 1)
#define RISCV_HWPROBE_IMA_V (1 << 2)
#define RISCV_HWPROBE_EXT_ZBA (1 << 3)
#define RISCV_HWPROBE_EXT_ZBB (1 << 4)
#define RISCV_HWPROBE_EXT_ZBS (1 << 5)
#define RISCV_HWPROBE_KEY_CPUPERF_0 5
#define RISCV_HWPROBE_MISALIGNED_UNKNOWN (0 << 0)
#define RISCV_HWPROBE_MISALIGNED_EMULATED (1 << 0)
#define RISCV_HWPROBE_MISALIGNED_SLOW (2 << 0)
#define RISCV_HWPROBE_MISALIGNED_FAST (3 << 0)
#define RISCV_HWPROBE_MISALIGNED_UNSUPPORTED (4 << 0)
#define RISCV_HWPROBE_MISALIGNED_MASK (7 << 0)
struct riscv_hwprobe {};
#endif
// copied from /usr/include/uapi/linux/mman.h
struct cachestat_range {
__u64 off;
__u64 len;
};
struct cachestat {
__u64 nr_cache;
__u64 nr_dirty;
__u64 nr_writeback;
__u64 nr_evicted;
__u64 nr_recently_evicted;
};
*/
import "C"
// Machine characteristics
const (
SizeofPtr = C.sizeofPtr
SizeofShort = C.sizeof_short
SizeofInt = C.sizeof_int
SizeofLong = C.sizeof_long
SizeofLongLong = C.sizeof_longlong
PathMax = C.PATH_MAX
)
// Basic types
type (
_C_short C.short
_C_int C.int
_C_long C.long
_C_long_long C.longlong
)
// Time
type Timespec C.struct_timespec
type Timeval C.struct_timeval
type Timex C.struct_timex
type ItimerSpec C.struct_itimerspec
type Itimerval C.struct_itimerval
const (
ADJ_OFFSET = C.ADJ_OFFSET
ADJ_FREQUENCY = C.ADJ_FREQUENCY
ADJ_MAXERROR = C.ADJ_MAXERROR
ADJ_ESTERROR = C.ADJ_ESTERROR
ADJ_STATUS = C.ADJ_STATUS
ADJ_TIMECONST = C.ADJ_TIMECONST
ADJ_TAI = C.ADJ_TAI
ADJ_SETOFFSET = C.ADJ_SETOFFSET
ADJ_MICRO = C.ADJ_MICRO
ADJ_NANO = C.ADJ_NANO
ADJ_TICK = C.ADJ_TICK
ADJ_OFFSET_SINGLESHOT = C.ADJ_OFFSET_SINGLESHOT
ADJ_OFFSET_SS_READ = C.ADJ_OFFSET_SS_READ
)
const (
STA_PLL = C.STA_PLL
STA_PPSFREQ = C.STA_PPSFREQ
STA_PPSTIME = C.STA_PPSTIME
STA_FLL = C.STA_FLL
STA_INS = C.STA_INS
STA_DEL = C.STA_DEL
STA_UNSYNC = C.STA_UNSYNC
STA_FREQHOLD = C.STA_FREQHOLD
STA_PPSSIGNAL = C.STA_PPSSIGNAL
STA_PPSJITTER = C.STA_PPSJITTER
STA_PPSWANDER = C.STA_PPSWANDER
STA_PPSERROR = C.STA_PPSERROR
STA_CLOCKERR = C.STA_CLOCKERR
STA_NANO = C.STA_NANO
STA_MODE = C.STA_MODE
STA_CLK = C.STA_CLK
)
const (
TIME_OK = C.TIME_OK
TIME_INS = C.TIME_INS
TIME_DEL = C.TIME_DEL
TIME_OOP = C.TIME_OOP
TIME_WAIT = C.TIME_WAIT
TIME_ERROR = C.TIME_ERROR
TIME_BAD = C.TIME_BAD
)
type Time_t C.time_t
type Tms C.struct_tms
type Utimbuf C.struct_utimbuf
// Processes
type Rusage C.struct_rusage
type Rlimit C.struct_rlimit
type _Gid_t C.gid_t
// Files
type Stat_t C.my_stat
type StatxTimestamp C.struct_statx_timestamp
type Statx_t C.struct_statx
type Dirent C.struct_dirent
type Fsid C.fsid_t
type Flock_t C.struct_flock
type FileCloneRange C.struct_file_clone_range
type RawFileDedupeRange C.struct_file_dedupe_range
type RawFileDedupeRangeInfo C.struct_file_dedupe_range_info
const (
SizeofRawFileDedupeRange = C.sizeof_struct_file_dedupe_range
SizeofRawFileDedupeRangeInfo = C.sizeof_struct_file_dedupe_range_info
FILE_DEDUPE_RANGE_SAME = C.FILE_DEDUPE_RANGE_SAME
FILE_DEDUPE_RANGE_DIFFERS = C.FILE_DEDUPE_RANGE_DIFFERS
)
// Filesystem Encryption
type FscryptPolicy C.struct_fscrypt_policy
type FscryptKey C.struct_fscrypt_key
type FscryptPolicyV1 C.struct_fscrypt_policy_v1
type FscryptPolicyV2 C.struct_fscrypt_policy_v2
type FscryptGetPolicyExArg C.struct_fscrypt_get_policy_ex_arg
type FscryptKeySpecifier C.struct_fscrypt_key_specifier
type FscryptAddKeyArg C.struct_fscrypt_add_key_arg
type FscryptRemoveKeyArg C.struct_fscrypt_remove_key_arg
type FscryptGetKeyStatusArg C.struct_fscrypt_get_key_status_arg
// Device Mapper
type DmIoctl C.struct_dm_ioctl
type DmTargetSpec C.struct_dm_target_spec
type DmTargetDeps C.struct_dm_target_deps
type DmNameList C.struct_dm_name_list
type DmTargetVersions C.struct_dm_target_versions
type DmTargetMsg C.struct_dm_target_msg
const (
SizeofDmIoctl = C.sizeof_struct_dm_ioctl
SizeofDmTargetSpec = C.sizeof_struct_dm_target_spec
)
// Structure for Keyctl
type KeyctlDHParams C.struct_keyctl_dh_params
// Advice to Fadvise
const (
FADV_NORMAL = C.POSIX_FADV_NORMAL
FADV_RANDOM = C.POSIX_FADV_RANDOM
FADV_SEQUENTIAL = C.POSIX_FADV_SEQUENTIAL
FADV_WILLNEED = C.POSIX_FADV_WILLNEED
FADV_DONTNEED = C.POSIX_FADV_DONTNEED
FADV_NOREUSE = C.POSIX_FADV_NOREUSE
)
// Sockets
type RawSockaddrInet4 C.struct_sockaddr_in
type RawSockaddrInet6 C.struct_sockaddr_in6
type RawSockaddrUnix C.struct_my_sockaddr_un
type RawSockaddrLinklayer C.struct_sockaddr_ll
type RawSockaddrNetlink C.struct_sockaddr_nl
type RawSockaddrHCI C.struct_sockaddr_hci
type RawSockaddrL2 C.struct_sockaddr_l2
type RawSockaddrRFCOMM C.struct_sockaddr_rc
type RawSockaddrCAN C.struct_sockaddr_can
type RawSockaddrALG C.struct_sockaddr_alg
type RawSockaddrVM C.struct_sockaddr_vm
type RawSockaddrXDP C.struct_sockaddr_xdp
type RawSockaddrPPPoX [C.sizeof_struct_sockaddr_pppox]byte
type RawSockaddrTIPC C.struct_sockaddr_tipc
type RawSockaddrL2TPIP C.struct_sockaddr_l2tpip
type RawSockaddrL2TPIP6 C.struct_sockaddr_l2tpip6
type RawSockaddrIUCV C.struct_sockaddr_iucv
type RawSockaddrNFC C.struct_sockaddr_nfc
type RawSockaddrNFCLLCP C.struct_my_sockaddr_nfc_llcp
type RawSockaddr C.struct_sockaddr
type RawSockaddrAny C.struct_sockaddr_any
type _Socklen C.socklen_t
type Linger C.struct_linger
type Iovec C.struct_iovec
type IPMreq C.struct_ip_mreq
type IPMreqn C.struct_ip_mreqn
type IPv6Mreq C.struct_ipv6_mreq
type PacketMreq C.struct_packet_mreq
type Msghdr C.struct_msghdr
type Cmsghdr C.struct_cmsghdr
type Inet4Pktinfo C.struct_in_pktinfo
type Inet6Pktinfo C.struct_in6_pktinfo
type IPv6MTUInfo C.struct_ip6_mtuinfo
type ICMPv6Filter C.struct_icmp6_filter
type Ucred C.struct_ucred
type TCPInfo C.struct_tcp_info
type CanFilter C.struct_can_filter
type ifreq C.struct_ifreq
type TCPRepairOpt C.struct_tcp_repair_opt
const (
SizeofSockaddrInet4 = C.sizeof_struct_sockaddr_in
SizeofSockaddrInet6 = C.sizeof_struct_sockaddr_in6
SizeofSockaddrAny = C.sizeof_struct_sockaddr_any
SizeofSockaddrUnix = C.sizeof_struct_sockaddr_un
SizeofSockaddrLinklayer = C.sizeof_struct_sockaddr_ll
SizeofSockaddrNetlink = C.sizeof_struct_sockaddr_nl
SizeofSockaddrHCI = C.sizeof_struct_sockaddr_hci
SizeofSockaddrL2 = C.sizeof_struct_sockaddr_l2
SizeofSockaddrRFCOMM = C.sizeof_struct_sockaddr_rc
SizeofSockaddrCAN = C.sizeof_struct_sockaddr_can
SizeofSockaddrALG = C.sizeof_struct_sockaddr_alg
SizeofSockaddrVM = C.sizeof_struct_sockaddr_vm
SizeofSockaddrXDP = C.sizeof_struct_sockaddr_xdp
SizeofSockaddrPPPoX = C.sizeof_struct_sockaddr_pppox
SizeofSockaddrTIPC = C.sizeof_struct_sockaddr_tipc
SizeofSockaddrL2TPIP = C.sizeof_struct_sockaddr_l2tpip
SizeofSockaddrL2TPIP6 = C.sizeof_struct_sockaddr_l2tpip6
SizeofSockaddrIUCV = C.sizeof_struct_sockaddr_iucv
SizeofSockaddrNFC = C.sizeof_struct_sockaddr_nfc
SizeofSockaddrNFCLLCP = C.sizeof_struct_sockaddr_nfc_llcp
SizeofLinger = C.sizeof_struct_linger
SizeofIovec = C.sizeof_struct_iovec
SizeofIPMreq = C.sizeof_struct_ip_mreq
SizeofIPMreqn = C.sizeof_struct_ip_mreqn
SizeofIPv6Mreq = C.sizeof_struct_ipv6_mreq
SizeofPacketMreq = C.sizeof_struct_packet_mreq
SizeofMsghdr = C.sizeof_struct_msghdr
SizeofCmsghdr = C.sizeof_struct_cmsghdr
SizeofInet4Pktinfo = C.sizeof_struct_in_pktinfo
SizeofInet6Pktinfo = C.sizeof_struct_in6_pktinfo
SizeofIPv6MTUInfo = C.sizeof_struct_ip6_mtuinfo
SizeofICMPv6Filter = C.sizeof_struct_icmp6_filter
SizeofUcred = C.sizeof_struct_ucred
SizeofTCPInfo = C.sizeof_struct_tcp_info
SizeofCanFilter = C.sizeof_struct_can_filter
SizeofTCPRepairOpt = C.sizeof_struct_tcp_repair_opt
)
// Netlink routing and interface messages
const (
NDA_UNSPEC = C.NDA_UNSPEC
NDA_DST = C.NDA_DST
NDA_LLADDR = C.NDA_LLADDR
NDA_CACHEINFO = C.NDA_CACHEINFO
NDA_PROBES = C.NDA_PROBES
NDA_VLAN = C.NDA_VLAN
NDA_PORT = C.NDA_PORT
NDA_VNI = C.NDA_VNI
NDA_IFINDEX = C.NDA_IFINDEX
NDA_MASTER = C.NDA_MASTER
NDA_LINK_NETNSID = C.NDA_LINK_NETNSID
NDA_SRC_VNI = C.NDA_SRC_VNI
NTF_USE = C.NTF_USE
NTF_SELF = C.NTF_SELF
NTF_MASTER = C.NTF_MASTER
NTF_PROXY = C.NTF_PROXY
NTF_EXT_LEARNED = C.NTF_EXT_LEARNED
NTF_OFFLOADED = C.NTF_OFFLOADED
NTF_ROUTER = C.NTF_ROUTER
NUD_INCOMPLETE = C.NUD_INCOMPLETE
NUD_REACHABLE = C.NUD_REACHABLE
NUD_STALE = C.NUD_STALE
NUD_DELAY = C.NUD_DELAY
NUD_PROBE = C.NUD_PROBE
NUD_FAILED = C.NUD_FAILED
NUD_NOARP = C.NUD_NOARP
NUD_PERMANENT = C.NUD_PERMANENT
NUD_NONE = C.NUD_NONE
IFA_UNSPEC = C.IFA_UNSPEC
IFA_ADDRESS = C.IFA_ADDRESS
IFA_LOCAL = C.IFA_LOCAL
IFA_LABEL = C.IFA_LABEL
IFA_BROADCAST = C.IFA_BROADCAST
IFA_ANYCAST = C.IFA_ANYCAST
IFA_CACHEINFO = C.IFA_CACHEINFO
IFA_MULTICAST = C.IFA_MULTICAST
IFA_FLAGS = C.IFA_FLAGS
IFA_RT_PRIORITY = C.IFA_RT_PRIORITY
IFA_TARGET_NETNSID = C.IFA_TARGET_NETNSID
RT_SCOPE_UNIVERSE = C.RT_SCOPE_UNIVERSE
RT_SCOPE_SITE = C.RT_SCOPE_SITE
RT_SCOPE_LINK = C.RT_SCOPE_LINK
RT_SCOPE_HOST = C.RT_SCOPE_HOST
RT_SCOPE_NOWHERE = C.RT_SCOPE_NOWHERE
RT_TABLE_UNSPEC = C.RT_TABLE_UNSPEC
RT_TABLE_COMPAT = C.RT_TABLE_COMPAT
RT_TABLE_DEFAULT = C.RT_TABLE_DEFAULT
RT_TABLE_MAIN = C.RT_TABLE_MAIN
RT_TABLE_LOCAL = C.RT_TABLE_LOCAL
RT_TABLE_MAX = C.RT_TABLE_MAX
RTA_UNSPEC = C.RTA_UNSPEC
RTA_DST = C.RTA_DST
RTA_SRC = C.RTA_SRC
RTA_IIF = C.RTA_IIF
RTA_OIF = C.RTA_OIF
RTA_GATEWAY = C.RTA_GATEWAY
RTA_PRIORITY = C.RTA_PRIORITY
RTA_PREFSRC = C.RTA_PREFSRC
RTA_METRICS = C.RTA_METRICS
RTA_MULTIPATH = C.RTA_MULTIPATH
RTA_FLOW = C.RTA_FLOW
RTA_CACHEINFO = C.RTA_CACHEINFO
RTA_TABLE = C.RTA_TABLE
RTA_MARK = C.RTA_MARK
RTA_MFC_STATS = C.RTA_MFC_STATS
RTA_VIA = C.RTA_VIA
RTA_NEWDST = C.RTA_NEWDST
RTA_PREF = C.RTA_PREF
RTA_ENCAP_TYPE = C.RTA_ENCAP_TYPE
RTA_ENCAP = C.RTA_ENCAP
RTA_EXPIRES = C.RTA_EXPIRES
RTA_PAD = C.RTA_PAD
RTA_UID = C.RTA_UID
RTA_TTL_PROPAGATE = C.RTA_TTL_PROPAGATE
RTA_IP_PROTO = C.RTA_IP_PROTO
RTA_SPORT = C.RTA_SPORT
RTA_DPORT = C.RTA_DPORT
RTN_UNSPEC = C.RTN_UNSPEC
RTN_UNICAST = C.RTN_UNICAST
RTN_LOCAL = C.RTN_LOCAL
RTN_BROADCAST = C.RTN_BROADCAST
RTN_ANYCAST = C.RTN_ANYCAST
RTN_MULTICAST = C.RTN_MULTICAST
RTN_BLACKHOLE = C.RTN_BLACKHOLE
RTN_UNREACHABLE = C.RTN_UNREACHABLE
RTN_PROHIBIT = C.RTN_PROHIBIT
RTN_THROW = C.RTN_THROW
RTN_NAT = C.RTN_NAT
RTN_XRESOLVE = C.RTN_XRESOLVE
SizeofNlMsghdr = C.sizeof_struct_nlmsghdr
SizeofNlMsgerr = C.sizeof_struct_nlmsgerr
SizeofRtGenmsg = C.sizeof_struct_rtgenmsg
SizeofNlAttr = C.sizeof_struct_nlattr
SizeofRtAttr = C.sizeof_struct_rtattr
SizeofIfInfomsg = C.sizeof_struct_ifinfomsg
SizeofIfAddrmsg = C.sizeof_struct_ifaddrmsg
SizeofIfaCacheinfo = C.sizeof_struct_ifa_cacheinfo
SizeofRtMsg = C.sizeof_struct_rtmsg
SizeofRtNexthop = C.sizeof_struct_rtnexthop
SizeofNdUseroptmsg = C.sizeof_struct_nduseroptmsg
SizeofNdMsg = C.sizeof_struct_ndmsg
)
type NlMsghdr C.struct_nlmsghdr
type NlMsgerr C.struct_nlmsgerr
type RtGenmsg C.struct_rtgenmsg
type NlAttr C.struct_nlattr
type RtAttr C.struct_rtattr
type IfInfomsg C.struct_ifinfomsg
type IfAddrmsg C.struct_ifaddrmsg
type IfaCacheinfo C.struct_ifa_cacheinfo
type RtMsg C.struct_rtmsg
type RtNexthop C.struct_rtnexthop
type NdUseroptmsg C.struct_nduseroptmsg
type NdMsg C.struct_ndmsg
// ICMP socket options
const (
ICMP_FILTER = C.ICMP_FILTER
ICMPV6_FILTER = C.ICMPV6_FILTER
ICMPV6_FILTER_BLOCK = C.ICMPV6_FILTER_BLOCK
ICMPV6_FILTER_BLOCKOTHERS = C.ICMPV6_FILTER_BLOCKOTHERS
ICMPV6_FILTER_PASS = C.ICMPV6_FILTER_PASS
ICMPV6_FILTER_PASSONLY = C.ICMPV6_FILTER_PASSONLY
)
// Linux socket filter
const (
SizeofSockFilter = C.sizeof_struct_sock_filter
SizeofSockFprog = C.sizeof_struct_sock_fprog
)
type SockFilter C.struct_sock_filter
type SockFprog C.struct_sock_fprog
// Inotify
type InotifyEvent C.struct_inotify_event
const SizeofInotifyEvent = C.sizeof_struct_inotify_event
// Ptrace
// Register structures
type PtraceRegs C.PtraceRegs
// Structures contained in PtraceRegs on s390x (exported by mkpost.go)
type PtracePsw C.ptracePsw
type PtraceFpregs C.ptraceFpregs
type PtracePer C.ptracePer
// Misc
type FdSet C.fd_set
type Sysinfo_t C.struct_sysinfo
const SI_LOAD_SHIFT = C.SI_LOAD_SHIFT
type Utsname C.struct_utsname
type Ustat_t C.struct_ustat
type EpollEvent C.struct_my_epoll_event
const (
AT_EMPTY_PATH = C.AT_EMPTY_PATH
AT_FDCWD = C.AT_FDCWD
AT_NO_AUTOMOUNT = C.AT_NO_AUTOMOUNT
AT_REMOVEDIR = C.AT_REMOVEDIR
AT_STATX_SYNC_AS_STAT = C.AT_STATX_SYNC_AS_STAT
AT_STATX_FORCE_SYNC = C.AT_STATX_FORCE_SYNC
AT_STATX_DONT_SYNC = C.AT_STATX_DONT_SYNC
AT_RECURSIVE = C.AT_RECURSIVE
AT_SYMLINK_FOLLOW = C.AT_SYMLINK_FOLLOW
AT_SYMLINK_NOFOLLOW = C.AT_SYMLINK_NOFOLLOW
AT_EACCESS = C.AT_EACCESS
OPEN_TREE_CLONE = C.OPEN_TREE_CLONE
OPEN_TREE_CLOEXEC = C.OPEN_TREE_CLOEXEC
MOVE_MOUNT_F_SYMLINKS = C.MOVE_MOUNT_F_SYMLINKS
MOVE_MOUNT_F_AUTOMOUNTS = C.MOVE_MOUNT_F_AUTOMOUNTS
MOVE_MOUNT_F_EMPTY_PATH = C.MOVE_MOUNT_F_EMPTY_PATH
MOVE_MOUNT_T_SYMLINKS = C.MOVE_MOUNT_T_SYMLINKS
MOVE_MOUNT_T_AUTOMOUNTS = C.MOVE_MOUNT_T_AUTOMOUNTS
MOVE_MOUNT_T_EMPTY_PATH = C.MOVE_MOUNT_T_EMPTY_PATH
MOVE_MOUNT_SET_GROUP = C.MOVE_MOUNT_SET_GROUP
FSOPEN_CLOEXEC = C.FSOPEN_CLOEXEC
FSPICK_CLOEXEC = C.FSPICK_CLOEXEC
FSPICK_SYMLINK_NOFOLLOW = C.FSPICK_SYMLINK_NOFOLLOW
FSPICK_NO_AUTOMOUNT = C.FSPICK_NO_AUTOMOUNT
FSPICK_EMPTY_PATH = C.FSPICK_EMPTY_PATH
FSMOUNT_CLOEXEC = C.FSMOUNT_CLOEXEC
FSCONFIG_SET_FLAG = C.FSCONFIG_SET_FLAG
FSCONFIG_SET_STRING = C.FSCONFIG_SET_STRING
FSCONFIG_SET_BINARY = C.FSCONFIG_SET_BINARY
FSCONFIG_SET_PATH = C.FSCONFIG_SET_PATH
FSCONFIG_SET_PATH_EMPTY = C.FSCONFIG_SET_PATH_EMPTY
FSCONFIG_SET_FD = C.FSCONFIG_SET_FD
FSCONFIG_CMD_CREATE = C.FSCONFIG_CMD_CREATE
FSCONFIG_CMD_RECONFIGURE = C.FSCONFIG_CMD_RECONFIGURE
)
type OpenHow C.struct_open_how
const SizeofOpenHow = C.sizeof_struct_open_how
const (
RESOLVE_BENEATH = C.RESOLVE_BENEATH
RESOLVE_IN_ROOT = C.RESOLVE_IN_ROOT
RESOLVE_NO_MAGICLINKS = C.RESOLVE_NO_MAGICLINKS
RESOLVE_NO_SYMLINKS = C.RESOLVE_NO_SYMLINKS
RESOLVE_NO_XDEV = C.RESOLVE_NO_XDEV
)
type PollFd C.struct_pollfd
const (
POLLIN = C.POLLIN
POLLPRI = C.POLLPRI
POLLOUT = C.POLLOUT
POLLRDHUP = C.POLLRDHUP
POLLERR = C.POLLERR
POLLHUP = C.POLLHUP
POLLNVAL = C.POLLNVAL
)
type Sigset_t C.sigset_t
type sigset_argpack struct {
ss *Sigset_t
ssLen uintptr // Size (in bytes) of object pointed to by ss.
}
const _C__NSIG = C._NSIG
const (
SIG_BLOCK = C.SIG_BLOCK
SIG_UNBLOCK = C.SIG_UNBLOCK
SIG_SETMASK = C.SIG_SETMASK
)
type SignalfdSiginfo C.struct_signalfd_siginfo
type Siginfo C.siginfo_t
// Terminal handling
type Termios C.termios_t
type Winsize C.struct_winsize
// Taskstats and cgroup stats.
type Taskstats C.struct_taskstats
const (
TASKSTATS_CMD_UNSPEC = C.TASKSTATS_CMD_UNSPEC
TASKSTATS_CMD_GET = C.TASKSTATS_CMD_GET
TASKSTATS_CMD_NEW = C.TASKSTATS_CMD_NEW
TASKSTATS_TYPE_UNSPEC = C.TASKSTATS_TYPE_UNSPEC
TASKSTATS_TYPE_PID = C.TASKSTATS_TYPE_PID
TASKSTATS_TYPE_TGID = C.TASKSTATS_TYPE_TGID
TASKSTATS_TYPE_STATS = C.TASKSTATS_TYPE_STATS
TASKSTATS_TYPE_AGGR_PID = C.TASKSTATS_TYPE_AGGR_PID
TASKSTATS_TYPE_AGGR_TGID = C.TASKSTATS_TYPE_AGGR_TGID
TASKSTATS_TYPE_NULL = C.TASKSTATS_TYPE_NULL
TASKSTATS_CMD_ATTR_UNSPEC = C.TASKSTATS_CMD_ATTR_UNSPEC
TASKSTATS_CMD_ATTR_PID = C.TASKSTATS_CMD_ATTR_PID
TASKSTATS_CMD_ATTR_TGID = C.TASKSTATS_CMD_ATTR_TGID
TASKSTATS_CMD_ATTR_REGISTER_CPUMASK = C.TASKSTATS_CMD_ATTR_REGISTER_CPUMASK
TASKSTATS_CMD_ATTR_DEREGISTER_CPUMASK = C.TASKSTATS_CMD_ATTR_DEREGISTER_CPUMASK
)
type CGroupStats C.struct_cgroupstats
const (
CGROUPSTATS_CMD_UNSPEC = C.__TASKSTATS_CMD_MAX
CGROUPSTATS_CMD_GET = C.CGROUPSTATS_CMD_GET
CGROUPSTATS_CMD_NEW = C.CGROUPSTATS_CMD_NEW
CGROUPSTATS_TYPE_UNSPEC = C.CGROUPSTATS_TYPE_UNSPEC
CGROUPSTATS_TYPE_CGROUP_STATS = C.CGROUPSTATS_TYPE_CGROUP_STATS
CGROUPSTATS_CMD_ATTR_UNSPEC = C.CGROUPSTATS_CMD_ATTR_UNSPEC
CGROUPSTATS_CMD_ATTR_FD = C.CGROUPSTATS_CMD_ATTR_FD
)
// Generic netlink
type Genlmsghdr C.struct_genlmsghdr
// Generated by:
// $ perl -nlE '/^\s*(CTRL_\w+)/ && say "$1 = C.$1"' /usr/include/linux/genetlink.h
const (
CTRL_CMD_UNSPEC = C.CTRL_CMD_UNSPEC
CTRL_CMD_NEWFAMILY = C.CTRL_CMD_NEWFAMILY
CTRL_CMD_DELFAMILY = C.CTRL_CMD_DELFAMILY
CTRL_CMD_GETFAMILY = C.CTRL_CMD_GETFAMILY
CTRL_CMD_NEWOPS = C.CTRL_CMD_NEWOPS
CTRL_CMD_DELOPS = C.CTRL_CMD_DELOPS
CTRL_CMD_GETOPS = C.CTRL_CMD_GETOPS
CTRL_CMD_NEWMCAST_GRP = C.CTRL_CMD_NEWMCAST_GRP
CTRL_CMD_DELMCAST_GRP = C.CTRL_CMD_DELMCAST_GRP
CTRL_CMD_GETMCAST_GRP = C.CTRL_CMD_GETMCAST_GRP
CTRL_CMD_GETPOLICY = C.CTRL_CMD_GETPOLICY
CTRL_ATTR_UNSPEC = C.CTRL_ATTR_UNSPEC
CTRL_ATTR_FAMILY_ID = C.CTRL_ATTR_FAMILY_ID
CTRL_ATTR_FAMILY_NAME = C.CTRL_ATTR_FAMILY_NAME
CTRL_ATTR_VERSION = C.CTRL_ATTR_VERSION
CTRL_ATTR_HDRSIZE = C.CTRL_ATTR_HDRSIZE
CTRL_ATTR_MAXATTR = C.CTRL_ATTR_MAXATTR
CTRL_ATTR_OPS = C.CTRL_ATTR_OPS
CTRL_ATTR_MCAST_GROUPS = C.CTRL_ATTR_MCAST_GROUPS
CTRL_ATTR_POLICY = C.CTRL_ATTR_POLICY
CTRL_ATTR_OP_POLICY = C.CTRL_ATTR_OP_POLICY
CTRL_ATTR_OP = C.CTRL_ATTR_OP
CTRL_ATTR_OP_UNSPEC = C.CTRL_ATTR_OP_UNSPEC
CTRL_ATTR_OP_ID = C.CTRL_ATTR_OP_ID
CTRL_ATTR_OP_FLAGS = C.CTRL_ATTR_OP_FLAGS
CTRL_ATTR_MCAST_GRP_UNSPEC = C.CTRL_ATTR_MCAST_GRP_UNSPEC
CTRL_ATTR_MCAST_GRP_NAME = C.CTRL_ATTR_MCAST_GRP_NAME
CTRL_ATTR_MCAST_GRP_ID = C.CTRL_ATTR_MCAST_GRP_ID
CTRL_ATTR_POLICY_UNSPEC = C.CTRL_ATTR_POLICY_UNSPEC
CTRL_ATTR_POLICY_DO = C.CTRL_ATTR_POLICY_DO
CTRL_ATTR_POLICY_DUMP = C.CTRL_ATTR_POLICY_DUMP
CTRL_ATTR_POLICY_DUMP_MAX = C.CTRL_ATTR_POLICY_DUMP_MAX
)
// CPU affinity
type cpuMask C.__cpu_mask
const (
_CPU_SETSIZE = C.__CPU_SETSIZE
_NCPUBITS = C.__NCPUBITS
)
// Bluetooth
const (
BDADDR_BREDR = C.BDADDR_BREDR
BDADDR_LE_PUBLIC = C.BDADDR_LE_PUBLIC
BDADDR_LE_RANDOM = C.BDADDR_LE_RANDOM
)
// Perf subsystem
type PerfEventAttr C.struct_perf_event_attr_go
type PerfEventMmapPage C.struct_perf_event_mmap_page
// Bit field in struct perf_event_attr expanded as flags.
// Set these on PerfEventAttr.Bits by ORing them together.
const (
PerfBitDisabled uint64 = CBitFieldMaskBit0
PerfBitInherit = CBitFieldMaskBit1
PerfBitPinned = CBitFieldMaskBit2
PerfBitExclusive = CBitFieldMaskBit3
PerfBitExcludeUser = CBitFieldMaskBit4
PerfBitExcludeKernel = CBitFieldMaskBit5
PerfBitExcludeHv = CBitFieldMaskBit6
PerfBitExcludeIdle = CBitFieldMaskBit7
PerfBitMmap = CBitFieldMaskBit8
PerfBitComm = CBitFieldMaskBit9
PerfBitFreq = CBitFieldMaskBit10
PerfBitInheritStat = CBitFieldMaskBit11
PerfBitEnableOnExec = CBitFieldMaskBit12
PerfBitTask = CBitFieldMaskBit13
PerfBitWatermark = CBitFieldMaskBit14
PerfBitPreciseIPBit1 = CBitFieldMaskBit15
PerfBitPreciseIPBit2 = CBitFieldMaskBit16
PerfBitMmapData = CBitFieldMaskBit17
PerfBitSampleIDAll = CBitFieldMaskBit18
PerfBitExcludeHost = CBitFieldMaskBit19
PerfBitExcludeGuest = CBitFieldMaskBit20
PerfBitExcludeCallchainKernel = CBitFieldMaskBit21
PerfBitExcludeCallchainUser = CBitFieldMaskBit22
PerfBitMmap2 = CBitFieldMaskBit23
PerfBitCommExec = CBitFieldMaskBit24
PerfBitUseClockID = CBitFieldMaskBit25
PerfBitContextSwitch = CBitFieldMaskBit26
PerfBitWriteBackward = CBitFieldMaskBit27
)
// generated by:
// perl -nlE '/^\s*(PERF_\w+)/ && say "$1 = C.$1"' /usr/include/linux/perf_event.h
const (
PERF_TYPE_HARDWARE = C.PERF_TYPE_HARDWARE
PERF_TYPE_SOFTWARE = C.PERF_TYPE_SOFTWARE
PERF_TYPE_TRACEPOINT = C.PERF_TYPE_TRACEPOINT
PERF_TYPE_HW_CACHE = C.PERF_TYPE_HW_CACHE
PERF_TYPE_RAW = C.PERF_TYPE_RAW
PERF_TYPE_BREAKPOINT = C.PERF_TYPE_BREAKPOINT
PERF_TYPE_MAX = C.PERF_TYPE_MAX
PERF_COUNT_HW_CPU_CYCLES = C.PERF_COUNT_HW_CPU_CYCLES
PERF_COUNT_HW_INSTRUCTIONS = C.PERF_COUNT_HW_INSTRUCTIONS
PERF_COUNT_HW_CACHE_REFERENCES = C.PERF_COUNT_HW_CACHE_REFERENCES
PERF_COUNT_HW_CACHE_MISSES = C.PERF_COUNT_HW_CACHE_MISSES
PERF_COUNT_HW_BRANCH_INSTRUCTIONS = C.PERF_COUNT_HW_BRANCH_INSTRUCTIONS
PERF_COUNT_HW_BRANCH_MISSES = C.PERF_COUNT_HW_BRANCH_MISSES
PERF_COUNT_HW_BUS_CYCLES = C.PERF_COUNT_HW_BUS_CYCLES
PERF_COUNT_HW_STALLED_CYCLES_FRONTEND = C.PERF_COUNT_HW_STALLED_CYCLES_FRONTEND
PERF_COUNT_HW_STALLED_CYCLES_BACKEND = C.PERF_COUNT_HW_STALLED_CYCLES_BACKEND
PERF_COUNT_HW_REF_CPU_CYCLES = C.PERF_COUNT_HW_REF_CPU_CYCLES
PERF_COUNT_HW_MAX = C.PERF_COUNT_HW_MAX
PERF_COUNT_HW_CACHE_L1D = C.PERF_COUNT_HW_CACHE_L1D
PERF_COUNT_HW_CACHE_L1I = C.PERF_COUNT_HW_CACHE_L1I
PERF_COUNT_HW_CACHE_LL = C.PERF_COUNT_HW_CACHE_LL
PERF_COUNT_HW_CACHE_DTLB = C.PERF_COUNT_HW_CACHE_DTLB
PERF_COUNT_HW_CACHE_ITLB = C.PERF_COUNT_HW_CACHE_ITLB
PERF_COUNT_HW_CACHE_BPU = C.PERF_COUNT_HW_CACHE_BPU
PERF_COUNT_HW_CACHE_NODE = C.PERF_COUNT_HW_CACHE_NODE
PERF_COUNT_HW_CACHE_MAX = C.PERF_COUNT_HW_CACHE_MAX
PERF_COUNT_HW_CACHE_OP_READ = C.PERF_COUNT_HW_CACHE_OP_READ
PERF_COUNT_HW_CACHE_OP_WRITE = C.PERF_COUNT_HW_CACHE_OP_WRITE
PERF_COUNT_HW_CACHE_OP_PREFETCH = C.PERF_COUNT_HW_CACHE_OP_PREFETCH
PERF_COUNT_HW_CACHE_OP_MAX = C.PERF_COUNT_HW_CACHE_OP_MAX
PERF_COUNT_HW_CACHE_RESULT_ACCESS = C.PERF_COUNT_HW_CACHE_RESULT_ACCESS
PERF_COUNT_HW_CACHE_RESULT_MISS = C.PERF_COUNT_HW_CACHE_RESULT_MISS
PERF_COUNT_HW_CACHE_RESULT_MAX = C.PERF_COUNT_HW_CACHE_RESULT_MAX
PERF_COUNT_SW_CPU_CLOCK = C.PERF_COUNT_SW_CPU_CLOCK
PERF_COUNT_SW_TASK_CLOCK = C.PERF_COUNT_SW_TASK_CLOCK
PERF_COUNT_SW_PAGE_FAULTS = C.PERF_COUNT_SW_PAGE_FAULTS
PERF_COUNT_SW_CONTEXT_SWITCHES = C.PERF_COUNT_SW_CONTEXT_SWITCHES
PERF_COUNT_SW_CPU_MIGRATIONS = C.PERF_COUNT_SW_CPU_MIGRATIONS
PERF_COUNT_SW_PAGE_FAULTS_MIN = C.PERF_COUNT_SW_PAGE_FAULTS_MIN
PERF_COUNT_SW_PAGE_FAULTS_MAJ = C.PERF_COUNT_SW_PAGE_FAULTS_MAJ
PERF_COUNT_SW_ALIGNMENT_FAULTS = C.PERF_COUNT_SW_ALIGNMENT_FAULTS
PERF_COUNT_SW_EMULATION_FAULTS = C.PERF_COUNT_SW_EMULATION_FAULTS
PERF_COUNT_SW_DUMMY = C.PERF_COUNT_SW_DUMMY
PERF_COUNT_SW_BPF_OUTPUT = C.PERF_COUNT_SW_BPF_OUTPUT
PERF_COUNT_SW_MAX = C.PERF_COUNT_SW_MAX
PERF_SAMPLE_IP = C.PERF_SAMPLE_IP
PERF_SAMPLE_TID = C.PERF_SAMPLE_TID
PERF_SAMPLE_TIME = C.PERF_SAMPLE_TIME
PERF_SAMPLE_ADDR = C.PERF_SAMPLE_ADDR
PERF_SAMPLE_READ = C.PERF_SAMPLE_READ
PERF_SAMPLE_CALLCHAIN = C.PERF_SAMPLE_CALLCHAIN
PERF_SAMPLE_ID = C.PERF_SAMPLE_ID
PERF_SAMPLE_CPU = C.PERF_SAMPLE_CPU
PERF_SAMPLE_PERIOD = C.PERF_SAMPLE_PERIOD
PERF_SAMPLE_STREAM_ID = C.PERF_SAMPLE_STREAM_ID
PERF_SAMPLE_RAW = C.PERF_SAMPLE_RAW
PERF_SAMPLE_BRANCH_STACK = C.PERF_SAMPLE_BRANCH_STACK
PERF_SAMPLE_REGS_USER = C.PERF_SAMPLE_REGS_USER
PERF_SAMPLE_STACK_USER = C.PERF_SAMPLE_STACK_USER
PERF_SAMPLE_WEIGHT = C.PERF_SAMPLE_WEIGHT
PERF_SAMPLE_DATA_SRC = C.PERF_SAMPLE_DATA_SRC
PERF_SAMPLE_IDENTIFIER = C.PERF_SAMPLE_IDENTIFIER
PERF_SAMPLE_TRANSACTION = C.PERF_SAMPLE_TRANSACTION
PERF_SAMPLE_REGS_INTR = C.PERF_SAMPLE_REGS_INTR
PERF_SAMPLE_PHYS_ADDR = C.PERF_SAMPLE_PHYS_ADDR
PERF_SAMPLE_AUX = C.PERF_SAMPLE_AUX
PERF_SAMPLE_CGROUP = C.PERF_SAMPLE_CGROUP
PERF_SAMPLE_DATA_PAGE_SIZE = C.PERF_SAMPLE_DATA_PAGE_SIZE
PERF_SAMPLE_CODE_PAGE_SIZE = C.PERF_SAMPLE_CODE_PAGE_SIZE
PERF_SAMPLE_WEIGHT_STRUCT = C.PERF_SAMPLE_WEIGHT_STRUCT
PERF_SAMPLE_MAX = C.PERF_SAMPLE_MAX
PERF_SAMPLE_BRANCH_USER_SHIFT = C.PERF_SAMPLE_BRANCH_USER_SHIFT
PERF_SAMPLE_BRANCH_KERNEL_SHIFT = C.PERF_SAMPLE_BRANCH_KERNEL_SHIFT
PERF_SAMPLE_BRANCH_HV_SHIFT = C.PERF_SAMPLE_BRANCH_HV_SHIFT
PERF_SAMPLE_BRANCH_ANY_SHIFT = C.PERF_SAMPLE_BRANCH_ANY_SHIFT
PERF_SAMPLE_BRANCH_ANY_CALL_SHIFT = C.PERF_SAMPLE_BRANCH_ANY_CALL_SHIFT
PERF_SAMPLE_BRANCH_ANY_RETURN_SHIFT = C.PERF_SAMPLE_BRANCH_ANY_RETURN_SHIFT
PERF_SAMPLE_BRANCH_IND_CALL_SHIFT = C.PERF_SAMPLE_BRANCH_IND_CALL_SHIFT
PERF_SAMPLE_BRANCH_ABORT_TX_SHIFT = C.PERF_SAMPLE_BRANCH_ABORT_TX_SHIFT
PERF_SAMPLE_BRANCH_IN_TX_SHIFT = C.PERF_SAMPLE_BRANCH_IN_TX_SHIFT
PERF_SAMPLE_BRANCH_NO_TX_SHIFT = C.PERF_SAMPLE_BRANCH_NO_TX_SHIFT
PERF_SAMPLE_BRANCH_COND_SHIFT = C.PERF_SAMPLE_BRANCH_COND_SHIFT
PERF_SAMPLE_BRANCH_CALL_STACK_SHIFT = C.PERF_SAMPLE_BRANCH_CALL_STACK_SHIFT
PERF_SAMPLE_BRANCH_IND_JUMP_SHIFT = C.PERF_SAMPLE_BRANCH_IND_JUMP_SHIFT
PERF_SAMPLE_BRANCH_CALL_SHIFT = C.PERF_SAMPLE_BRANCH_CALL_SHIFT
PERF_SAMPLE_BRANCH_NO_FLAGS_SHIFT = C.PERF_SAMPLE_BRANCH_NO_FLAGS_SHIFT
PERF_SAMPLE_BRANCH_NO_CYCLES_SHIFT = C.PERF_SAMPLE_BRANCH_NO_CYCLES_SHIFT
PERF_SAMPLE_BRANCH_TYPE_SAVE_SHIFT = C.PERF_SAMPLE_BRANCH_TYPE_SAVE_SHIFT
PERF_SAMPLE_BRANCH_HW_INDEX_SHIFT = C.PERF_SAMPLE_BRANCH_HW_INDEX_SHIFT
PERF_SAMPLE_BRANCH_PRIV_SAVE_SHIFT = C.PERF_SAMPLE_BRANCH_PRIV_SAVE_SHIFT
PERF_SAMPLE_BRANCH_MAX_SHIFT = C.PERF_SAMPLE_BRANCH_MAX_SHIFT
PERF_SAMPLE_BRANCH_USER = C.PERF_SAMPLE_BRANCH_USER
PERF_SAMPLE_BRANCH_KERNEL = C.PERF_SAMPLE_BRANCH_KERNEL
PERF_SAMPLE_BRANCH_HV = C.PERF_SAMPLE_BRANCH_HV
PERF_SAMPLE_BRANCH_ANY = C.PERF_SAMPLE_BRANCH_ANY
PERF_SAMPLE_BRANCH_ANY_CALL = C.PERF_SAMPLE_BRANCH_ANY_CALL
PERF_SAMPLE_BRANCH_ANY_RETURN = C.PERF_SAMPLE_BRANCH_ANY_RETURN
PERF_SAMPLE_BRANCH_IND_CALL = C.PERF_SAMPLE_BRANCH_IND_CALL
PERF_SAMPLE_BRANCH_ABORT_TX = C.PERF_SAMPLE_BRANCH_ABORT_TX
PERF_SAMPLE_BRANCH_IN_TX = C.PERF_SAMPLE_BRANCH_IN_TX
PERF_SAMPLE_BRANCH_NO_TX = C.PERF_SAMPLE_BRANCH_NO_TX
PERF_SAMPLE_BRANCH_COND = C.PERF_SAMPLE_BRANCH_COND
PERF_SAMPLE_BRANCH_CALL_STACK = C.PERF_SAMPLE_BRANCH_CALL_STACK
PERF_SAMPLE_BRANCH_IND_JUMP = C.PERF_SAMPLE_BRANCH_IND_JUMP
PERF_SAMPLE_BRANCH_CALL = C.PERF_SAMPLE_BRANCH_CALL
PERF_SAMPLE_BRANCH_NO_FLAGS = C.PERF_SAMPLE_BRANCH_NO_FLAGS
PERF_SAMPLE_BRANCH_NO_CYCLES = C.PERF_SAMPLE_BRANCH_NO_CYCLES
PERF_SAMPLE_BRANCH_TYPE_SAVE = C.PERF_SAMPLE_BRANCH_TYPE_SAVE
PERF_SAMPLE_BRANCH_HW_INDEX = C.PERF_SAMPLE_BRANCH_HW_INDEX
PERF_SAMPLE_BRANCH_PRIV_SAVE = C.PERF_SAMPLE_BRANCH_PRIV_SAVE
PERF_SAMPLE_BRANCH_MAX = C.PERF_SAMPLE_BRANCH_MAX
PERF_BR_UNKNOWN = C.PERF_BR_UNKNOWN
PERF_BR_COND = C.PERF_BR_COND
PERF_BR_UNCOND = C.PERF_BR_UNCOND
PERF_BR_IND = C.PERF_BR_IND
PERF_BR_CALL = C.PERF_BR_CALL
PERF_BR_IND_CALL = C.PERF_BR_IND_CALL
PERF_BR_RET = C.PERF_BR_RET
PERF_BR_SYSCALL = C.PERF_BR_SYSCALL
PERF_BR_SYSRET = C.PERF_BR_SYSRET
PERF_BR_COND_CALL = C.PERF_BR_COND_CALL
PERF_BR_COND_RET = C.PERF_BR_COND_RET
PERF_BR_ERET = C.PERF_BR_ERET
PERF_BR_IRQ = C.PERF_BR_IRQ
PERF_BR_SERROR = C.PERF_BR_SERROR
PERF_BR_NO_TX = C.PERF_BR_NO_TX
PERF_BR_EXTEND_ABI = C.PERF_BR_EXTEND_ABI
PERF_BR_MAX = C.PERF_BR_MAX
PERF_SAMPLE_REGS_ABI_NONE = C.PERF_SAMPLE_REGS_ABI_NONE
PERF_SAMPLE_REGS_ABI_32 = C.PERF_SAMPLE_REGS_ABI_32
PERF_SAMPLE_REGS_ABI_64 = C.PERF_SAMPLE_REGS_ABI_64
PERF_TXN_ELISION = C.PERF_TXN_ELISION
PERF_TXN_TRANSACTION = C.PERF_TXN_TRANSACTION
PERF_TXN_SYNC = C.PERF_TXN_SYNC
PERF_TXN_ASYNC = C.PERF_TXN_ASYNC
PERF_TXN_RETRY = C.PERF_TXN_RETRY
PERF_TXN_CONFLICT = C.PERF_TXN_CONFLICT
PERF_TXN_CAPACITY_WRITE = C.PERF_TXN_CAPACITY_WRITE
PERF_TXN_CAPACITY_READ = C.PERF_TXN_CAPACITY_READ
PERF_TXN_MAX = C.PERF_TXN_MAX
PERF_TXN_ABORT_MASK = C.PERF_TXN_ABORT_MASK
PERF_TXN_ABORT_SHIFT = C.PERF_TXN_ABORT_SHIFT
PERF_FORMAT_TOTAL_TIME_ENABLED = C.PERF_FORMAT_TOTAL_TIME_ENABLED
PERF_FORMAT_TOTAL_TIME_RUNNING = C.PERF_FORMAT_TOTAL_TIME_RUNNING
PERF_FORMAT_ID = C.PERF_FORMAT_ID
PERF_FORMAT_GROUP = C.PERF_FORMAT_GROUP
PERF_FORMAT_LOST = C.PERF_FORMAT_LOST
PERF_FORMAT_MAX = C.PERF_FORMAT_MAX
PERF_IOC_FLAG_GROUP = C.PERF_IOC_FLAG_GROUP
PERF_RECORD_MMAP = C.PERF_RECORD_MMAP
PERF_RECORD_LOST = C.PERF_RECORD_LOST
PERF_RECORD_COMM = C.PERF_RECORD_COMM
PERF_RECORD_EXIT = C.PERF_RECORD_EXIT
PERF_RECORD_THROTTLE = C.PERF_RECORD_THROTTLE
PERF_RECORD_UNTHROTTLE = C.PERF_RECORD_UNTHROTTLE
PERF_RECORD_FORK = C.PERF_RECORD_FORK
PERF_RECORD_READ = C.PERF_RECORD_READ
PERF_RECORD_SAMPLE = C.PERF_RECORD_SAMPLE
PERF_RECORD_MMAP2 = C.PERF_RECORD_MMAP2
PERF_RECORD_AUX = C.PERF_RECORD_AUX
PERF_RECORD_ITRACE_START = C.PERF_RECORD_ITRACE_START
PERF_RECORD_LOST_SAMPLES = C.PERF_RECORD_LOST_SAMPLES
PERF_RECORD_SWITCH = C.PERF_RECORD_SWITCH
PERF_RECORD_SWITCH_CPU_WIDE = C.PERF_RECORD_SWITCH_CPU_WIDE
PERF_RECORD_NAMESPACES = C.PERF_RECORD_NAMESPACES
PERF_RECORD_KSYMBOL = C.PERF_RECORD_KSYMBOL
PERF_RECORD_BPF_EVENT = C.PERF_RECORD_BPF_EVENT
PERF_RECORD_CGROUP = C.PERF_RECORD_CGROUP
PERF_RECORD_TEXT_POKE = C.PERF_RECORD_TEXT_POKE
PERF_RECORD_AUX_OUTPUT_HW_ID = C.PERF_RECORD_AUX_OUTPUT_HW_ID
PERF_RECORD_MAX = C.PERF_RECORD_MAX
PERF_RECORD_KSYMBOL_TYPE_UNKNOWN = C.PERF_RECORD_KSYMBOL_TYPE_UNKNOWN
PERF_RECORD_KSYMBOL_TYPE_BPF = C.PERF_RECORD_KSYMBOL_TYPE_BPF
PERF_RECORD_KSYMBOL_TYPE_OOL = C.PERF_RECORD_KSYMBOL_TYPE_OOL
PERF_RECORD_KSYMBOL_TYPE_MAX = C.PERF_RECORD_KSYMBOL_TYPE_MAX
PERF_BPF_EVENT_UNKNOWN = C.PERF_BPF_EVENT_UNKNOWN
PERF_BPF_EVENT_PROG_LOAD = C.PERF_BPF_EVENT_PROG_LOAD
PERF_BPF_EVENT_PROG_UNLOAD = C.PERF_BPF_EVENT_PROG_UNLOAD
PERF_BPF_EVENT_MAX = C.PERF_BPF_EVENT_MAX
PERF_CONTEXT_HV = C.PERF_CONTEXT_HV
PERF_CONTEXT_KERNEL = C.PERF_CONTEXT_KERNEL
PERF_CONTEXT_USER = C.PERF_CONTEXT_USER
PERF_CONTEXT_GUEST = C.PERF_CONTEXT_GUEST
PERF_CONTEXT_GUEST_KERNEL = C.PERF_CONTEXT_GUEST_KERNEL
PERF_CONTEXT_GUEST_USER = C.PERF_CONTEXT_GUEST_USER
PERF_CONTEXT_MAX = C.PERF_CONTEXT_MAX
)
// Platform ABI and calling convention
// Bit field masks for interoperability with C code that uses bit fields.
// Each mask corresponds to a single bit set - bit field behavior can be replicated by combining
// the masks with bitwise OR.
const (
CBitFieldMaskBit0 = C.BITFIELD_MASK_0
CBitFieldMaskBit1 = C.BITFIELD_MASK_1
CBitFieldMaskBit2 = C.BITFIELD_MASK_2
CBitFieldMaskBit3 = C.BITFIELD_MASK_3
CBitFieldMaskBit4 = C.BITFIELD_MASK_4
CBitFieldMaskBit5 = C.BITFIELD_MASK_5
CBitFieldMaskBit6 = C.BITFIELD_MASK_6
CBitFieldMaskBit7 = C.BITFIELD_MASK_7
CBitFieldMaskBit8 = C.BITFIELD_MASK_8
CBitFieldMaskBit9 = C.BITFIELD_MASK_9
CBitFieldMaskBit10 = C.BITFIELD_MASK_10
CBitFieldMaskBit11 = C.BITFIELD_MASK_11
CBitFieldMaskBit12 = C.BITFIELD_MASK_12
CBitFieldMaskBit13 = C.BITFIELD_MASK_13
CBitFieldMaskBit14 = C.BITFIELD_MASK_14
CBitFieldMaskBit15 = C.BITFIELD_MASK_15
CBitFieldMaskBit16 = C.BITFIELD_MASK_16
CBitFieldMaskBit17 = C.BITFIELD_MASK_17
CBitFieldMaskBit18 = C.BITFIELD_MASK_18
CBitFieldMaskBit19 = C.BITFIELD_MASK_19
CBitFieldMaskBit20 = C.BITFIELD_MASK_20
CBitFieldMaskBit21 = C.BITFIELD_MASK_21
CBitFieldMaskBit22 = C.BITFIELD_MASK_22
CBitFieldMaskBit23 = C.BITFIELD_MASK_23
CBitFieldMaskBit24 = C.BITFIELD_MASK_24
CBitFieldMaskBit25 = C.BITFIELD_MASK_25
CBitFieldMaskBit26 = C.BITFIELD_MASK_26
CBitFieldMaskBit27 = C.BITFIELD_MASK_27
CBitFieldMaskBit28 = C.BITFIELD_MASK_28
CBitFieldMaskBit29 = C.BITFIELD_MASK_29
CBitFieldMaskBit30 = C.BITFIELD_MASK_30
CBitFieldMaskBit31 = C.BITFIELD_MASK_31
CBitFieldMaskBit32 = C.BITFIELD_MASK_32
CBitFieldMaskBit33 = C.BITFIELD_MASK_33
CBitFieldMaskBit34 = C.BITFIELD_MASK_34
CBitFieldMaskBit35 = C.BITFIELD_MASK_35
CBitFieldMaskBit36 = C.BITFIELD_MASK_36
CBitFieldMaskBit37 = C.BITFIELD_MASK_37
CBitFieldMaskBit38 = C.BITFIELD_MASK_38
CBitFieldMaskBit39 = C.BITFIELD_MASK_39
CBitFieldMaskBit40 = C.BITFIELD_MASK_40
CBitFieldMaskBit41 = C.BITFIELD_MASK_41
CBitFieldMaskBit42 = C.BITFIELD_MASK_42
CBitFieldMaskBit43 = C.BITFIELD_MASK_43
CBitFieldMaskBit44 = C.BITFIELD_MASK_44
CBitFieldMaskBit45 = C.BITFIELD_MASK_45
CBitFieldMaskBit46 = C.BITFIELD_MASK_46
CBitFieldMaskBit47 = C.BITFIELD_MASK_47
CBitFieldMaskBit48 = C.BITFIELD_MASK_48
CBitFieldMaskBit49 = C.BITFIELD_MASK_49
CBitFieldMaskBit50 = C.BITFIELD_MASK_50
CBitFieldMaskBit51 = C.BITFIELD_MASK_51
CBitFieldMaskBit52 = C.BITFIELD_MASK_52
CBitFieldMaskBit53 = C.BITFIELD_MASK_53
CBitFieldMaskBit54 = C.BITFIELD_MASK_54
CBitFieldMaskBit55 = C.BITFIELD_MASK_55
CBitFieldMaskBit56 = C.BITFIELD_MASK_56
CBitFieldMaskBit57 = C.BITFIELD_MASK_57
CBitFieldMaskBit58 = C.BITFIELD_MASK_58
CBitFieldMaskBit59 = C.BITFIELD_MASK_59
CBitFieldMaskBit60 = C.BITFIELD_MASK_60
CBitFieldMaskBit61 = C.BITFIELD_MASK_61
CBitFieldMaskBit62 = C.BITFIELD_MASK_62
CBitFieldMaskBit63 = C.BITFIELD_MASK_63
)
// TCP-MD5 signature.
type SockaddrStorage C.struct_sockaddr_storage
type TCPMD5Sig C.struct_tcp_md5sig
// Disk drive operations.
type HDDriveCmdHdr C.struct_hd_drive_cmd_hdr
type HDGeometry C.struct_hd_geometry
type HDDriveID C.struct_hd_driveid
// Statfs
type Statfs_t C.struct_statfs
const (
ST_MANDLOCK = C.ST_MANDLOCK
ST_NOATIME = C.ST_NOATIME
ST_NODEV = C.ST_NODEV
ST_NODIRATIME = C.ST_NODIRATIME
ST_NOEXEC = C.ST_NOEXEC
ST_NOSUID = C.ST_NOSUID
ST_RDONLY = C.ST_RDONLY
ST_RELATIME = C.ST_RELATIME
ST_SYNCHRONOUS = C.ST_SYNCHRONOUS
)
// TPacket
type TpacketHdr C.struct_tpacket_hdr
type Tpacket2Hdr C.struct_tpacket2_hdr
type Tpacket3Hdr C.struct_tpacket3_hdr
type TpacketHdrVariant1 C.struct_tpacket_hdr_variant1
type TpacketBlockDesc C.struct_tpacket_block_desc
type TpacketBDTS C.struct_tpacket_bd_ts
type TpacketHdrV1 C.struct_tpacket_hdr_v1
type TpacketReq C.struct_tpacket_req
type TpacketReq3 C.struct_tpacket_req3
type TpacketStats C.struct_tpacket_stats
type TpacketStatsV3 C.struct_tpacket_stats_v3
type TpacketAuxdata C.struct_tpacket_auxdata
const (
TPACKET_V1 = C.TPACKET_V1
TPACKET_V2 = C.TPACKET_V2
TPACKET_V3 = C.TPACKET_V3
)
const (
SizeofTpacketHdr = C.sizeof_struct_tpacket_hdr
SizeofTpacket2Hdr = C.sizeof_struct_tpacket2_hdr
SizeofTpacket3Hdr = C.sizeof_struct_tpacket3_hdr
SizeofTpacketStats = C.sizeof_struct_tpacket_stats
SizeofTpacketStatsV3 = C.sizeof_struct_tpacket_stats_v3
)
// generated by:
// perl -nlE '/^\s*((IFLA|NETKIT)\w+)/ && say "$1 = C.$1"' /usr/include/linux/if_link.h
const (
IFLA_UNSPEC = C.IFLA_UNSPEC
IFLA_ADDRESS = C.IFLA_ADDRESS
IFLA_BROADCAST = C.IFLA_BROADCAST
IFLA_IFNAME = C.IFLA_IFNAME
IFLA_MTU = C.IFLA_MTU
IFLA_LINK = C.IFLA_LINK
IFLA_QDISC = C.IFLA_QDISC
IFLA_STATS = C.IFLA_STATS
IFLA_COST = C.IFLA_COST
IFLA_PRIORITY = C.IFLA_PRIORITY
IFLA_MASTER = C.IFLA_MASTER
IFLA_WIRELESS = C.IFLA_WIRELESS
IFLA_PROTINFO = C.IFLA_PROTINFO
IFLA_TXQLEN = C.IFLA_TXQLEN
IFLA_MAP = C.IFLA_MAP
IFLA_WEIGHT = C.IFLA_WEIGHT
IFLA_OPERSTATE = C.IFLA_OPERSTATE
IFLA_LINKMODE = C.IFLA_LINKMODE
IFLA_LINKINFO = C.IFLA_LINKINFO
IFLA_NET_NS_PID = C.IFLA_NET_NS_PID
IFLA_IFALIAS = C.IFLA_IFALIAS
IFLA_NUM_VF = C.IFLA_NUM_VF
IFLA_VFINFO_LIST = C.IFLA_VFINFO_LIST
IFLA_STATS64 = C.IFLA_STATS64
IFLA_VF_PORTS = C.IFLA_VF_PORTS
IFLA_PORT_SELF = C.IFLA_PORT_SELF
IFLA_AF_SPEC = C.IFLA_AF_SPEC
IFLA_GROUP = C.IFLA_GROUP
IFLA_NET_NS_FD = C.IFLA_NET_NS_FD
IFLA_EXT_MASK = C.IFLA_EXT_MASK
IFLA_PROMISCUITY = C.IFLA_PROMISCUITY
IFLA_NUM_TX_QUEUES = C.IFLA_NUM_TX_QUEUES
IFLA_NUM_RX_QUEUES = C.IFLA_NUM_RX_QUEUES
IFLA_CARRIER = C.IFLA_CARRIER
IFLA_PHYS_PORT_ID = C.IFLA_PHYS_PORT_ID
IFLA_CARRIER_CHANGES = C.IFLA_CARRIER_CHANGES
IFLA_PHYS_SWITCH_ID = C.IFLA_PHYS_SWITCH_ID
IFLA_LINK_NETNSID = C.IFLA_LINK_NETNSID
IFLA_PHYS_PORT_NAME = C.IFLA_PHYS_PORT_NAME
IFLA_PROTO_DOWN = C.IFLA_PROTO_DOWN
IFLA_GSO_MAX_SEGS = C.IFLA_GSO_MAX_SEGS
IFLA_GSO_MAX_SIZE = C.IFLA_GSO_MAX_SIZE
IFLA_PAD = C.IFLA_PAD
IFLA_XDP = C.IFLA_XDP
IFLA_EVENT = C.IFLA_EVENT
IFLA_NEW_NETNSID = C.IFLA_NEW_NETNSID
IFLA_IF_NETNSID = C.IFLA_IF_NETNSID
IFLA_TARGET_NETNSID = C.IFLA_TARGET_NETNSID
IFLA_CARRIER_UP_COUNT = C.IFLA_CARRIER_UP_COUNT
IFLA_CARRIER_DOWN_COUNT = C.IFLA_CARRIER_DOWN_COUNT
IFLA_NEW_IFINDEX = C.IFLA_NEW_IFINDEX
IFLA_MIN_MTU = C.IFLA_MIN_MTU
IFLA_MAX_MTU = C.IFLA_MAX_MTU
IFLA_PROP_LIST = C.IFLA_PROP_LIST
IFLA_ALT_IFNAME = C.IFLA_ALT_IFNAME
IFLA_PERM_ADDRESS = C.IFLA_PERM_ADDRESS
IFLA_PROTO_DOWN_REASON = C.IFLA_PROTO_DOWN_REASON
IFLA_PARENT_DEV_NAME = C.IFLA_PARENT_DEV_NAME
IFLA_PARENT_DEV_BUS_NAME = C.IFLA_PARENT_DEV_BUS_NAME
IFLA_GRO_MAX_SIZE = C.IFLA_GRO_MAX_SIZE
IFLA_TSO_MAX_SIZE = C.IFLA_TSO_MAX_SIZE
IFLA_TSO_MAX_SEGS = C.IFLA_TSO_MAX_SEGS
IFLA_ALLMULTI = C.IFLA_ALLMULTI
IFLA_DEVLINK_PORT = C.IFLA_DEVLINK_PORT
IFLA_GSO_IPV4_MAX_SIZE = C.IFLA_GSO_IPV4_MAX_SIZE
IFLA_GRO_IPV4_MAX_SIZE = C.IFLA_GRO_IPV4_MAX_SIZE
IFLA_DPLL_PIN = C.IFLA_DPLL_PIN
IFLA_PROTO_DOWN_REASON_UNSPEC = C.IFLA_PROTO_DOWN_REASON_UNSPEC
IFLA_PROTO_DOWN_REASON_MASK = C.IFLA_PROTO_DOWN_REASON_MASK
IFLA_PROTO_DOWN_REASON_VALUE = C.IFLA_PROTO_DOWN_REASON_VALUE
IFLA_PROTO_DOWN_REASON_MAX = C.IFLA_PROTO_DOWN_REASON_MAX
IFLA_INET_UNSPEC = C.IFLA_INET_UNSPEC
IFLA_INET_CONF = C.IFLA_INET_CONF
IFLA_INET6_UNSPEC = C.IFLA_INET6_UNSPEC
IFLA_INET6_FLAGS = C.IFLA_INET6_FLAGS
IFLA_INET6_CONF = C.IFLA_INET6_CONF
IFLA_INET6_STATS = C.IFLA_INET6_STATS
IFLA_INET6_MCAST = C.IFLA_INET6_MCAST
IFLA_INET6_CACHEINFO = C.IFLA_INET6_CACHEINFO
IFLA_INET6_ICMP6STATS = C.IFLA_INET6_ICMP6STATS
IFLA_INET6_TOKEN = C.IFLA_INET6_TOKEN
IFLA_INET6_ADDR_GEN_MODE = C.IFLA_INET6_ADDR_GEN_MODE
IFLA_INET6_RA_MTU = C.IFLA_INET6_RA_MTU
IFLA_BR_UNSPEC = C.IFLA_BR_UNSPEC
IFLA_BR_FORWARD_DELAY = C.IFLA_BR_FORWARD_DELAY
IFLA_BR_HELLO_TIME = C.IFLA_BR_HELLO_TIME
IFLA_BR_MAX_AGE = C.IFLA_BR_MAX_AGE
IFLA_BR_AGEING_TIME = C.IFLA_BR_AGEING_TIME
IFLA_BR_STP_STATE = C.IFLA_BR_STP_STATE
IFLA_BR_PRIORITY = C.IFLA_BR_PRIORITY
IFLA_BR_VLAN_FILTERING = C.IFLA_BR_VLAN_FILTERING
IFLA_BR_VLAN_PROTOCOL = C.IFLA_BR_VLAN_PROTOCOL
IFLA_BR_GROUP_FWD_MASK = C.IFLA_BR_GROUP_FWD_MASK
IFLA_BR_ROOT_ID = C.IFLA_BR_ROOT_ID
IFLA_BR_BRIDGE_ID = C.IFLA_BR_BRIDGE_ID
IFLA_BR_ROOT_PORT = C.IFLA_BR_ROOT_PORT
IFLA_BR_ROOT_PATH_COST = C.IFLA_BR_ROOT_PATH_COST
IFLA_BR_TOPOLOGY_CHANGE = C.IFLA_BR_TOPOLOGY_CHANGE
IFLA_BR_TOPOLOGY_CHANGE_DETECTED = C.IFLA_BR_TOPOLOGY_CHANGE_DETECTED
IFLA_BR_HELLO_TIMER = C.IFLA_BR_HELLO_TIMER
IFLA_BR_TCN_TIMER = C.IFLA_BR_TCN_TIMER
IFLA_BR_TOPOLOGY_CHANGE_TIMER = C.IFLA_BR_TOPOLOGY_CHANGE_TIMER
IFLA_BR_GC_TIMER = C.IFLA_BR_GC_TIMER
IFLA_BR_GROUP_ADDR = C.IFLA_BR_GROUP_ADDR
IFLA_BR_FDB_FLUSH = C.IFLA_BR_FDB_FLUSH
IFLA_BR_MCAST_ROUTER = C.IFLA_BR_MCAST_ROUTER
IFLA_BR_MCAST_SNOOPING = C.IFLA_BR_MCAST_SNOOPING
IFLA_BR_MCAST_QUERY_USE_IFADDR = C.IFLA_BR_MCAST_QUERY_USE_IFADDR
IFLA_BR_MCAST_QUERIER = C.IFLA_BR_MCAST_QUERIER
IFLA_BR_MCAST_HASH_ELASTICITY = C.IFLA_BR_MCAST_HASH_ELASTICITY
IFLA_BR_MCAST_HASH_MAX = C.IFLA_BR_MCAST_HASH_MAX
IFLA_BR_MCAST_LAST_MEMBER_CNT = C.IFLA_BR_MCAST_LAST_MEMBER_CNT
IFLA_BR_MCAST_STARTUP_QUERY_CNT = C.IFLA_BR_MCAST_STARTUP_QUERY_CNT
IFLA_BR_MCAST_LAST_MEMBER_INTVL = C.IFLA_BR_MCAST_LAST_MEMBER_INTVL
IFLA_BR_MCAST_MEMBERSHIP_INTVL = C.IFLA_BR_MCAST_MEMBERSHIP_INTVL
IFLA_BR_MCAST_QUERIER_INTVL = C.IFLA_BR_MCAST_QUERIER_INTVL
IFLA_BR_MCAST_QUERY_INTVL = C.IFLA_BR_MCAST_QUERY_INTVL
IFLA_BR_MCAST_QUERY_RESPONSE_INTVL = C.IFLA_BR_MCAST_QUERY_RESPONSE_INTVL
IFLA_BR_MCAST_STARTUP_QUERY_INTVL = C.IFLA_BR_MCAST_STARTUP_QUERY_INTVL
IFLA_BR_NF_CALL_IPTABLES = C.IFLA_BR_NF_CALL_IPTABLES
IFLA_BR_NF_CALL_IP6TABLES = C.IFLA_BR_NF_CALL_IP6TABLES
IFLA_BR_NF_CALL_ARPTABLES = C.IFLA_BR_NF_CALL_ARPTABLES
IFLA_BR_VLAN_DEFAULT_PVID = C.IFLA_BR_VLAN_DEFAULT_PVID
IFLA_BR_PAD = C.IFLA_BR_PAD
IFLA_BR_VLAN_STATS_ENABLED = C.IFLA_BR_VLAN_STATS_ENABLED
IFLA_BR_MCAST_STATS_ENABLED = C.IFLA_BR_MCAST_STATS_ENABLED
IFLA_BR_MCAST_IGMP_VERSION = C.IFLA_BR_MCAST_IGMP_VERSION
IFLA_BR_MCAST_MLD_VERSION = C.IFLA_BR_MCAST_MLD_VERSION
IFLA_BR_VLAN_STATS_PER_PORT = C.IFLA_BR_VLAN_STATS_PER_PORT
IFLA_BR_MULTI_BOOLOPT = C.IFLA_BR_MULTI_BOOLOPT
IFLA_BR_MCAST_QUERIER_STATE = C.IFLA_BR_MCAST_QUERIER_STATE
IFLA_BR_FDB_N_LEARNED = C.IFLA_BR_FDB_N_LEARNED
IFLA_BR_FDB_MAX_LEARNED = C.IFLA_BR_FDB_MAX_LEARNED
IFLA_BRPORT_UNSPEC = C.IFLA_BRPORT_UNSPEC
IFLA_BRPORT_STATE = C.IFLA_BRPORT_STATE
IFLA_BRPORT_PRIORITY = C.IFLA_BRPORT_PRIORITY
IFLA_BRPORT_COST = C.IFLA_BRPORT_COST
IFLA_BRPORT_MODE = C.IFLA_BRPORT_MODE
IFLA_BRPORT_GUARD = C.IFLA_BRPORT_GUARD
IFLA_BRPORT_PROTECT = C.IFLA_BRPORT_PROTECT
IFLA_BRPORT_FAST_LEAVE = C.IFLA_BRPORT_FAST_LEAVE
IFLA_BRPORT_LEARNING = C.IFLA_BRPORT_LEARNING
IFLA_BRPORT_UNICAST_FLOOD = C.IFLA_BRPORT_UNICAST_FLOOD
IFLA_BRPORT_PROXYARP = C.IFLA_BRPORT_PROXYARP
IFLA_BRPORT_LEARNING_SYNC = C.IFLA_BRPORT_LEARNING_SYNC
IFLA_BRPORT_PROXYARP_WIFI = C.IFLA_BRPORT_PROXYARP_WIFI
IFLA_BRPORT_ROOT_ID = C.IFLA_BRPORT_ROOT_ID
IFLA_BRPORT_BRIDGE_ID = C.IFLA_BRPORT_BRIDGE_ID
IFLA_BRPORT_DESIGNATED_PORT = C.IFLA_BRPORT_DESIGNATED_PORT
IFLA_BRPORT_DESIGNATED_COST = C.IFLA_BRPORT_DESIGNATED_COST
IFLA_BRPORT_ID = C.IFLA_BRPORT_ID
IFLA_BRPORT_NO = C.IFLA_BRPORT_NO
IFLA_BRPORT_TOPOLOGY_CHANGE_ACK = C.IFLA_BRPORT_TOPOLOGY_CHANGE_ACK
IFLA_BRPORT_CONFIG_PENDING = C.IFLA_BRPORT_CONFIG_PENDING
IFLA_BRPORT_MESSAGE_AGE_TIMER = C.IFLA_BRPORT_MESSAGE_AGE_TIMER
IFLA_BRPORT_FORWARD_DELAY_TIMER = C.IFLA_BRPORT_FORWARD_DELAY_TIMER
IFLA_BRPORT_HOLD_TIMER = C.IFLA_BRPORT_HOLD_TIMER
IFLA_BRPORT_FLUSH = C.IFLA_BRPORT_FLUSH
IFLA_BRPORT_MULTICAST_ROUTER = C.IFLA_BRPORT_MULTICAST_ROUTER
IFLA_BRPORT_PAD = C.IFLA_BRPORT_PAD
IFLA_BRPORT_MCAST_FLOOD = C.IFLA_BRPORT_MCAST_FLOOD
IFLA_BRPORT_MCAST_TO_UCAST = C.IFLA_BRPORT_MCAST_TO_UCAST
IFLA_BRPORT_VLAN_TUNNEL = C.IFLA_BRPORT_VLAN_TUNNEL
IFLA_BRPORT_BCAST_FLOOD = C.IFLA_BRPORT_BCAST_FLOOD
IFLA_BRPORT_GROUP_FWD_MASK = C.IFLA_BRPORT_GROUP_FWD_MASK
IFLA_BRPORT_NEIGH_SUPPRESS = C.IFLA_BRPORT_NEIGH_SUPPRESS
IFLA_BRPORT_ISOLATED = C.IFLA_BRPORT_ISOLATED
IFLA_BRPORT_BACKUP_PORT = C.IFLA_BRPORT_BACKUP_PORT
IFLA_BRPORT_MRP_RING_OPEN = C.IFLA_BRPORT_MRP_RING_OPEN
IFLA_BRPORT_MRP_IN_OPEN = C.IFLA_BRPORT_MRP_IN_OPEN
IFLA_BRPORT_MCAST_EHT_HOSTS_LIMIT = C.IFLA_BRPORT_MCAST_EHT_HOSTS_LIMIT
IFLA_BRPORT_MCAST_EHT_HOSTS_CNT = C.IFLA_BRPORT_MCAST_EHT_HOSTS_CNT
IFLA_BRPORT_LOCKED = C.IFLA_BRPORT_LOCKED
IFLA_BRPORT_MAB = C.IFLA_BRPORT_MAB
IFLA_BRPORT_MCAST_N_GROUPS = C.IFLA_BRPORT_MCAST_N_GROUPS
IFLA_BRPORT_MCAST_MAX_GROUPS = C.IFLA_BRPORT_MCAST_MAX_GROUPS
IFLA_BRPORT_NEIGH_VLAN_SUPPRESS = C.IFLA_BRPORT_NEIGH_VLAN_SUPPRESS
IFLA_BRPORT_BACKUP_NHID = C.IFLA_BRPORT_BACKUP_NHID
IFLA_INFO_UNSPEC = C.IFLA_INFO_UNSPEC
IFLA_INFO_KIND = C.IFLA_INFO_KIND
IFLA_INFO_DATA = C.IFLA_INFO_DATA
IFLA_INFO_XSTATS = C.IFLA_INFO_XSTATS
IFLA_INFO_SLAVE_KIND = C.IFLA_INFO_SLAVE_KIND
IFLA_INFO_SLAVE_DATA = C.IFLA_INFO_SLAVE_DATA
IFLA_VLAN_UNSPEC = C.IFLA_VLAN_UNSPEC
IFLA_VLAN_ID = C.IFLA_VLAN_ID
IFLA_VLAN_FLAGS = C.IFLA_VLAN_FLAGS
IFLA_VLAN_EGRESS_QOS = C.IFLA_VLAN_EGRESS_QOS
IFLA_VLAN_INGRESS_QOS = C.IFLA_VLAN_INGRESS_QOS
IFLA_VLAN_PROTOCOL = C.IFLA_VLAN_PROTOCOL
IFLA_VLAN_QOS_UNSPEC = C.IFLA_VLAN_QOS_UNSPEC
IFLA_VLAN_QOS_MAPPING = C.IFLA_VLAN_QOS_MAPPING
IFLA_MACVLAN_UNSPEC = C.IFLA_MACVLAN_UNSPEC
IFLA_MACVLAN_MODE = C.IFLA_MACVLAN_MODE
IFLA_MACVLAN_FLAGS = C.IFLA_MACVLAN_FLAGS
IFLA_MACVLAN_MACADDR_MODE = C.IFLA_MACVLAN_MACADDR_MODE
IFLA_MACVLAN_MACADDR = C.IFLA_MACVLAN_MACADDR
IFLA_MACVLAN_MACADDR_DATA = C.IFLA_MACVLAN_MACADDR_DATA
IFLA_MACVLAN_MACADDR_COUNT = C.IFLA_MACVLAN_MACADDR_COUNT
IFLA_MACVLAN_BC_QUEUE_LEN = C.IFLA_MACVLAN_BC_QUEUE_LEN
IFLA_MACVLAN_BC_QUEUE_LEN_USED = C.IFLA_MACVLAN_BC_QUEUE_LEN_USED
IFLA_MACVLAN_BC_CUTOFF = C.IFLA_MACVLAN_BC_CUTOFF
IFLA_VRF_UNSPEC = C.IFLA_VRF_UNSPEC
IFLA_VRF_TABLE = C.IFLA_VRF_TABLE
IFLA_VRF_PORT_UNSPEC = C.IFLA_VRF_PORT_UNSPEC
IFLA_VRF_PORT_TABLE = C.IFLA_VRF_PORT_TABLE
IFLA_MACSEC_UNSPEC = C.IFLA_MACSEC_UNSPEC
IFLA_MACSEC_SCI = C.IFLA_MACSEC_SCI
IFLA_MACSEC_PORT = C.IFLA_MACSEC_PORT
IFLA_MACSEC_ICV_LEN = C.IFLA_MACSEC_ICV_LEN
IFLA_MACSEC_CIPHER_SUITE = C.IFLA_MACSEC_CIPHER_SUITE
IFLA_MACSEC_WINDOW = C.IFLA_MACSEC_WINDOW
IFLA_MACSEC_ENCODING_SA = C.IFLA_MACSEC_ENCODING_SA
IFLA_MACSEC_ENCRYPT = C.IFLA_MACSEC_ENCRYPT
IFLA_MACSEC_PROTECT = C.IFLA_MACSEC_PROTECT
IFLA_MACSEC_INC_SCI = C.IFLA_MACSEC_INC_SCI
IFLA_MACSEC_ES = C.IFLA_MACSEC_ES
IFLA_MACSEC_SCB = C.IFLA_MACSEC_SCB
IFLA_MACSEC_REPLAY_PROTECT = C.IFLA_MACSEC_REPLAY_PROTECT
IFLA_MACSEC_VALIDATION = C.IFLA_MACSEC_VALIDATION
IFLA_MACSEC_PAD = C.IFLA_MACSEC_PAD
IFLA_MACSEC_OFFLOAD = C.IFLA_MACSEC_OFFLOAD
IFLA_XFRM_UNSPEC = C.IFLA_XFRM_UNSPEC
IFLA_XFRM_LINK = C.IFLA_XFRM_LINK
IFLA_XFRM_IF_ID = C.IFLA_XFRM_IF_ID
IFLA_XFRM_COLLECT_METADATA = C.IFLA_XFRM_COLLECT_METADATA
IFLA_IPVLAN_UNSPEC = C.IFLA_IPVLAN_UNSPEC
IFLA_IPVLAN_MODE = C.IFLA_IPVLAN_MODE
IFLA_IPVLAN_FLAGS = C.IFLA_IPVLAN_FLAGS
NETKIT_NEXT = C.NETKIT_NEXT
NETKIT_PASS = C.NETKIT_PASS
NETKIT_DROP = C.NETKIT_DROP
NETKIT_REDIRECT = C.NETKIT_REDIRECT
NETKIT_L2 = C.NETKIT_L2
NETKIT_L3 = C.NETKIT_L3
IFLA_NETKIT_UNSPEC = C.IFLA_NETKIT_UNSPEC
IFLA_NETKIT_PEER_INFO = C.IFLA_NETKIT_PEER_INFO
IFLA_NETKIT_PRIMARY = C.IFLA_NETKIT_PRIMARY
IFLA_NETKIT_POLICY = C.IFLA_NETKIT_POLICY
IFLA_NETKIT_PEER_POLICY = C.IFLA_NETKIT_PEER_POLICY
IFLA_NETKIT_MODE = C.IFLA_NETKIT_MODE
IFLA_VXLAN_UNSPEC = C.IFLA_VXLAN_UNSPEC
IFLA_VXLAN_ID = C.IFLA_VXLAN_ID
IFLA_VXLAN_GROUP = C.IFLA_VXLAN_GROUP
IFLA_VXLAN_LINK = C.IFLA_VXLAN_LINK
IFLA_VXLAN_LOCAL = C.IFLA_VXLAN_LOCAL
IFLA_VXLAN_TTL = C.IFLA_VXLAN_TTL
IFLA_VXLAN_TOS = C.IFLA_VXLAN_TOS
IFLA_VXLAN_LEARNING = C.IFLA_VXLAN_LEARNING
IFLA_VXLAN_AGEING = C.IFLA_VXLAN_AGEING
IFLA_VXLAN_LIMIT = C.IFLA_VXLAN_LIMIT
IFLA_VXLAN_PORT_RANGE = C.IFLA_VXLAN_PORT_RANGE
IFLA_VXLAN_PROXY = C.IFLA_VXLAN_PROXY
IFLA_VXLAN_RSC = C.IFLA_VXLAN_RSC
IFLA_VXLAN_L2MISS = C.IFLA_VXLAN_L2MISS
IFLA_VXLAN_L3MISS = C.IFLA_VXLAN_L3MISS
IFLA_VXLAN_PORT = C.IFLA_VXLAN_PORT
IFLA_VXLAN_GROUP6 = C.IFLA_VXLAN_GROUP6
IFLA_VXLAN_LOCAL6 = C.IFLA_VXLAN_LOCAL6
IFLA_VXLAN_UDP_CSUM = C.IFLA_VXLAN_UDP_CSUM
IFLA_VXLAN_UDP_ZERO_CSUM6_TX = C.IFLA_VXLAN_UDP_ZERO_CSUM6_TX
IFLA_VXLAN_UDP_ZERO_CSUM6_RX = C.IFLA_VXLAN_UDP_ZERO_CSUM6_RX
IFLA_VXLAN_REMCSUM_TX = C.IFLA_VXLAN_REMCSUM_TX
IFLA_VXLAN_REMCSUM_RX = C.IFLA_VXLAN_REMCSUM_RX
IFLA_VXLAN_GBP = C.IFLA_VXLAN_GBP
IFLA_VXLAN_REMCSUM_NOPARTIAL = C.IFLA_VXLAN_REMCSUM_NOPARTIAL
IFLA_VXLAN_COLLECT_METADATA = C.IFLA_VXLAN_COLLECT_METADATA
IFLA_VXLAN_LABEL = C.IFLA_VXLAN_LABEL
IFLA_VXLAN_GPE = C.IFLA_VXLAN_GPE
IFLA_VXLAN_TTL_INHERIT = C.IFLA_VXLAN_TTL_INHERIT
IFLA_VXLAN_DF = C.IFLA_VXLAN_DF
IFLA_VXLAN_VNIFILTER = C.IFLA_VXLAN_VNIFILTER
IFLA_VXLAN_LOCALBYPASS = C.IFLA_VXLAN_LOCALBYPASS
IFLA_GENEVE_UNSPEC = C.IFLA_GENEVE_UNSPEC
IFLA_GENEVE_ID = C.IFLA_GENEVE_ID
IFLA_GENEVE_REMOTE = C.IFLA_GENEVE_REMOTE
IFLA_GENEVE_TTL = C.IFLA_GENEVE_TTL
IFLA_GENEVE_TOS = C.IFLA_GENEVE_TOS
IFLA_GENEVE_PORT = C.IFLA_GENEVE_PORT
IFLA_GENEVE_COLLECT_METADATA = C.IFLA_GENEVE_COLLECT_METADATA
IFLA_GENEVE_REMOTE6 = C.IFLA_GENEVE_REMOTE6
IFLA_GENEVE_UDP_CSUM = C.IFLA_GENEVE_UDP_CSUM
IFLA_GENEVE_UDP_ZERO_CSUM6_TX = C.IFLA_GENEVE_UDP_ZERO_CSUM6_TX
IFLA_GENEVE_UDP_ZERO_CSUM6_RX = C.IFLA_GENEVE_UDP_ZERO_CSUM6_RX
IFLA_GENEVE_LABEL = C.IFLA_GENEVE_LABEL
IFLA_GENEVE_TTL_INHERIT = C.IFLA_GENEVE_TTL_INHERIT
IFLA_GENEVE_DF = C.IFLA_GENEVE_DF
IFLA_GENEVE_INNER_PROTO_INHERIT = C.IFLA_GENEVE_INNER_PROTO_INHERIT
IFLA_BAREUDP_UNSPEC = C.IFLA_BAREUDP_UNSPEC
IFLA_BAREUDP_PORT = C.IFLA_BAREUDP_PORT
IFLA_BAREUDP_ETHERTYPE = C.IFLA_BAREUDP_ETHERTYPE
IFLA_BAREUDP_SRCPORT_MIN = C.IFLA_BAREUDP_SRCPORT_MIN
IFLA_BAREUDP_MULTIPROTO_MODE = C.IFLA_BAREUDP_MULTIPROTO_MODE
IFLA_PPP_UNSPEC = C.IFLA_PPP_UNSPEC
IFLA_PPP_DEV_FD = C.IFLA_PPP_DEV_FD
IFLA_GTP_UNSPEC = C.IFLA_GTP_UNSPEC
IFLA_GTP_FD0 = C.IFLA_GTP_FD0
IFLA_GTP_FD1 = C.IFLA_GTP_FD1
IFLA_GTP_PDP_HASHSIZE = C.IFLA_GTP_PDP_HASHSIZE
IFLA_GTP_ROLE = C.IFLA_GTP_ROLE
IFLA_GTP_CREATE_SOCKETS = C.IFLA_GTP_CREATE_SOCKETS
IFLA_GTP_RESTART_COUNT = C.IFLA_GTP_RESTART_COUNT
IFLA_BOND_UNSPEC = C.IFLA_BOND_UNSPEC
IFLA_BOND_MODE = C.IFLA_BOND_MODE
IFLA_BOND_ACTIVE_SLAVE = C.IFLA_BOND_ACTIVE_SLAVE
IFLA_BOND_MIIMON = C.IFLA_BOND_MIIMON
IFLA_BOND_UPDELAY = C.IFLA_BOND_UPDELAY
IFLA_BOND_DOWNDELAY = C.IFLA_BOND_DOWNDELAY
IFLA_BOND_USE_CARRIER = C.IFLA_BOND_USE_CARRIER
IFLA_BOND_ARP_INTERVAL = C.IFLA_BOND_ARP_INTERVAL
IFLA_BOND_ARP_IP_TARGET = C.IFLA_BOND_ARP_IP_TARGET
IFLA_BOND_ARP_VALIDATE = C.IFLA_BOND_ARP_VALIDATE
IFLA_BOND_ARP_ALL_TARGETS = C.IFLA_BOND_ARP_ALL_TARGETS
IFLA_BOND_PRIMARY = C.IFLA_BOND_PRIMARY
IFLA_BOND_PRIMARY_RESELECT = C.IFLA_BOND_PRIMARY_RESELECT
IFLA_BOND_FAIL_OVER_MAC = C.IFLA_BOND_FAIL_OVER_MAC
IFLA_BOND_XMIT_HASH_POLICY = C.IFLA_BOND_XMIT_HASH_POLICY
IFLA_BOND_RESEND_IGMP = C.IFLA_BOND_RESEND_IGMP
IFLA_BOND_NUM_PEER_NOTIF = C.IFLA_BOND_NUM_PEER_NOTIF
IFLA_BOND_ALL_SLAVES_ACTIVE = C.IFLA_BOND_ALL_SLAVES_ACTIVE
IFLA_BOND_MIN_LINKS = C.IFLA_BOND_MIN_LINKS
IFLA_BOND_LP_INTERVAL = C.IFLA_BOND_LP_INTERVAL
IFLA_BOND_PACKETS_PER_SLAVE = C.IFLA_BOND_PACKETS_PER_SLAVE
IFLA_BOND_AD_LACP_RATE = C.IFLA_BOND_AD_LACP_RATE
IFLA_BOND_AD_SELECT = C.IFLA_BOND_AD_SELECT
IFLA_BOND_AD_INFO = C.IFLA_BOND_AD_INFO
IFLA_BOND_AD_ACTOR_SYS_PRIO = C.IFLA_BOND_AD_ACTOR_SYS_PRIO
IFLA_BOND_AD_USER_PORT_KEY = C.IFLA_BOND_AD_USER_PORT_KEY
IFLA_BOND_AD_ACTOR_SYSTEM = C.IFLA_BOND_AD_ACTOR_SYSTEM
IFLA_BOND_TLB_DYNAMIC_LB = C.IFLA_BOND_TLB_DYNAMIC_LB
IFLA_BOND_PEER_NOTIF_DELAY = C.IFLA_BOND_PEER_NOTIF_DELAY
IFLA_BOND_AD_LACP_ACTIVE = C.IFLA_BOND_AD_LACP_ACTIVE
IFLA_BOND_MISSED_MAX = C.IFLA_BOND_MISSED_MAX
IFLA_BOND_NS_IP6_TARGET = C.IFLA_BOND_NS_IP6_TARGET
IFLA_BOND_AD_INFO_UNSPEC = C.IFLA_BOND_AD_INFO_UNSPEC
IFLA_BOND_AD_INFO_AGGREGATOR = C.IFLA_BOND_AD_INFO_AGGREGATOR
IFLA_BOND_AD_INFO_NUM_PORTS = C.IFLA_BOND_AD_INFO_NUM_PORTS
IFLA_BOND_AD_INFO_ACTOR_KEY = C.IFLA_BOND_AD_INFO_ACTOR_KEY
IFLA_BOND_AD_INFO_PARTNER_KEY = C.IFLA_BOND_AD_INFO_PARTNER_KEY
IFLA_BOND_AD_INFO_PARTNER_MAC = C.IFLA_BOND_AD_INFO_PARTNER_MAC
IFLA_BOND_SLAVE_UNSPEC = C.IFLA_BOND_SLAVE_UNSPEC
IFLA_BOND_SLAVE_STATE = C.IFLA_BOND_SLAVE_STATE
IFLA_BOND_SLAVE_MII_STATUS = C.IFLA_BOND_SLAVE_MII_STATUS
IFLA_BOND_SLAVE_LINK_FAILURE_COUNT = C.IFLA_BOND_SLAVE_LINK_FAILURE_COUNT
IFLA_BOND_SLAVE_PERM_HWADDR = C.IFLA_BOND_SLAVE_PERM_HWADDR
IFLA_BOND_SLAVE_QUEUE_ID = C.IFLA_BOND_SLAVE_QUEUE_ID
IFLA_BOND_SLAVE_AD_AGGREGATOR_ID = C.IFLA_BOND_SLAVE_AD_AGGREGATOR_ID
IFLA_BOND_SLAVE_AD_ACTOR_OPER_PORT_STATE = C.IFLA_BOND_SLAVE_AD_ACTOR_OPER_PORT_STATE
IFLA_BOND_SLAVE_AD_PARTNER_OPER_PORT_STATE = C.IFLA_BOND_SLAVE_AD_PARTNER_OPER_PORT_STATE
IFLA_BOND_SLAVE_PRIO = C.IFLA_BOND_SLAVE_PRIO
IFLA_VF_INFO_UNSPEC = C.IFLA_VF_INFO_UNSPEC
IFLA_VF_INFO = C.IFLA_VF_INFO
IFLA_VF_UNSPEC = C.IFLA_VF_UNSPEC
IFLA_VF_MAC = C.IFLA_VF_MAC
IFLA_VF_VLAN = C.IFLA_VF_VLAN
IFLA_VF_TX_RATE = C.IFLA_VF_TX_RATE
IFLA_VF_SPOOFCHK = C.IFLA_VF_SPOOFCHK
IFLA_VF_LINK_STATE = C.IFLA_VF_LINK_STATE
IFLA_VF_RATE = C.IFLA_VF_RATE
IFLA_VF_RSS_QUERY_EN = C.IFLA_VF_RSS_QUERY_EN
IFLA_VF_STATS = C.IFLA_VF_STATS
IFLA_VF_TRUST = C.IFLA_VF_TRUST
IFLA_VF_IB_NODE_GUID = C.IFLA_VF_IB_NODE_GUID
IFLA_VF_IB_PORT_GUID = C.IFLA_VF_IB_PORT_GUID
IFLA_VF_VLAN_LIST = C.IFLA_VF_VLAN_LIST
IFLA_VF_BROADCAST = C.IFLA_VF_BROADCAST
IFLA_VF_VLAN_INFO_UNSPEC = C.IFLA_VF_VLAN_INFO_UNSPEC
IFLA_VF_VLAN_INFO = C.IFLA_VF_VLAN_INFO
IFLA_VF_LINK_STATE_AUTO = C.IFLA_VF_LINK_STATE_AUTO
IFLA_VF_LINK_STATE_ENABLE = C.IFLA_VF_LINK_STATE_ENABLE
IFLA_VF_LINK_STATE_DISABLE = C.IFLA_VF_LINK_STATE_DISABLE
IFLA_VF_STATS_RX_PACKETS = C.IFLA_VF_STATS_RX_PACKETS
IFLA_VF_STATS_TX_PACKETS = C.IFLA_VF_STATS_TX_PACKETS
IFLA_VF_STATS_RX_BYTES = C.IFLA_VF_STATS_RX_BYTES
IFLA_VF_STATS_TX_BYTES = C.IFLA_VF_STATS_TX_BYTES
IFLA_VF_STATS_BROADCAST = C.IFLA_VF_STATS_BROADCAST
IFLA_VF_STATS_MULTICAST = C.IFLA_VF_STATS_MULTICAST
IFLA_VF_STATS_PAD = C.IFLA_VF_STATS_PAD
IFLA_VF_STATS_RX_DROPPED = C.IFLA_VF_STATS_RX_DROPPED
IFLA_VF_STATS_TX_DROPPED = C.IFLA_VF_STATS_TX_DROPPED
IFLA_VF_PORT_UNSPEC = C.IFLA_VF_PORT_UNSPEC
IFLA_VF_PORT = C.IFLA_VF_PORT
IFLA_PORT_UNSPEC = C.IFLA_PORT_UNSPEC
IFLA_PORT_VF = C.IFLA_PORT_VF
IFLA_PORT_PROFILE = C.IFLA_PORT_PROFILE
IFLA_PORT_VSI_TYPE = C.IFLA_PORT_VSI_TYPE
IFLA_PORT_INSTANCE_UUID = C.IFLA_PORT_INSTANCE_UUID
IFLA_PORT_HOST_UUID = C.IFLA_PORT_HOST_UUID
IFLA_PORT_REQUEST = C.IFLA_PORT_REQUEST
IFLA_PORT_RESPONSE = C.IFLA_PORT_RESPONSE
IFLA_IPOIB_UNSPEC = C.IFLA_IPOIB_UNSPEC
IFLA_IPOIB_PKEY = C.IFLA_IPOIB_PKEY
IFLA_IPOIB_MODE = C.IFLA_IPOIB_MODE
IFLA_IPOIB_UMCAST = C.IFLA_IPOIB_UMCAST
IFLA_HSR_UNSPEC = C.IFLA_HSR_UNSPEC
IFLA_HSR_SLAVE1 = C.IFLA_HSR_SLAVE1
IFLA_HSR_SLAVE2 = C.IFLA_HSR_SLAVE2
IFLA_HSR_MULTICAST_SPEC = C.IFLA_HSR_MULTICAST_SPEC
IFLA_HSR_SUPERVISION_ADDR = C.IFLA_HSR_SUPERVISION_ADDR
IFLA_HSR_SEQ_NR = C.IFLA_HSR_SEQ_NR
IFLA_HSR_VERSION = C.IFLA_HSR_VERSION
IFLA_HSR_PROTOCOL = C.IFLA_HSR_PROTOCOL
IFLA_STATS_UNSPEC = C.IFLA_STATS_UNSPEC
IFLA_STATS_LINK_64 = C.IFLA_STATS_LINK_64
IFLA_STATS_LINK_XSTATS = C.IFLA_STATS_LINK_XSTATS
IFLA_STATS_LINK_XSTATS_SLAVE = C.IFLA_STATS_LINK_XSTATS_SLAVE
IFLA_STATS_LINK_OFFLOAD_XSTATS = C.IFLA_STATS_LINK_OFFLOAD_XSTATS
IFLA_STATS_AF_SPEC = C.IFLA_STATS_AF_SPEC
IFLA_STATS_GETSET_UNSPEC = C.IFLA_STATS_GETSET_UNSPEC
IFLA_STATS_GET_FILTERS = C.IFLA_STATS_GET_FILTERS
IFLA_STATS_SET_OFFLOAD_XSTATS_L3_STATS = C.IFLA_STATS_SET_OFFLOAD_XSTATS_L3_STATS
IFLA_OFFLOAD_XSTATS_UNSPEC = C.IFLA_OFFLOAD_XSTATS_UNSPEC
IFLA_OFFLOAD_XSTATS_CPU_HIT = C.IFLA_OFFLOAD_XSTATS_CPU_HIT
IFLA_OFFLOAD_XSTATS_HW_S_INFO = C.IFLA_OFFLOAD_XSTATS_HW_S_INFO
IFLA_OFFLOAD_XSTATS_L3_STATS = C.IFLA_OFFLOAD_XSTATS_L3_STATS
IFLA_OFFLOAD_XSTATS_HW_S_INFO_UNSPEC = C.IFLA_OFFLOAD_XSTATS_HW_S_INFO_UNSPEC
IFLA_OFFLOAD_XSTATS_HW_S_INFO_REQUEST = C.IFLA_OFFLOAD_XSTATS_HW_S_INFO_REQUEST
IFLA_OFFLOAD_XSTATS_HW_S_INFO_USED = C.IFLA_OFFLOAD_XSTATS_HW_S_INFO_USED
IFLA_XDP_UNSPEC = C.IFLA_XDP_UNSPEC
IFLA_XDP_FD = C.IFLA_XDP_FD
IFLA_XDP_ATTACHED = C.IFLA_XDP_ATTACHED
IFLA_XDP_FLAGS = C.IFLA_XDP_FLAGS
IFLA_XDP_PROG_ID = C.IFLA_XDP_PROG_ID
IFLA_XDP_DRV_PROG_ID = C.IFLA_XDP_DRV_PROG_ID
IFLA_XDP_SKB_PROG_ID = C.IFLA_XDP_SKB_PROG_ID
IFLA_XDP_HW_PROG_ID = C.IFLA_XDP_HW_PROG_ID
IFLA_XDP_EXPECTED_FD = C.IFLA_XDP_EXPECTED_FD
IFLA_EVENT_NONE = C.IFLA_EVENT_NONE
IFLA_EVENT_REBOOT = C.IFLA_EVENT_REBOOT
IFLA_EVENT_FEATURES = C.IFLA_EVENT_FEATURES
IFLA_EVENT_BONDING_FAILOVER = C.IFLA_EVENT_BONDING_FAILOVER
IFLA_EVENT_NOTIFY_PEERS = C.IFLA_EVENT_NOTIFY_PEERS
IFLA_EVENT_IGMP_RESEND = C.IFLA_EVENT_IGMP_RESEND
IFLA_EVENT_BONDING_OPTIONS = C.IFLA_EVENT_BONDING_OPTIONS
IFLA_TUN_UNSPEC = C.IFLA_TUN_UNSPEC
IFLA_TUN_OWNER = C.IFLA_TUN_OWNER
IFLA_TUN_GROUP = C.IFLA_TUN_GROUP
IFLA_TUN_TYPE = C.IFLA_TUN_TYPE
IFLA_TUN_PI = C.IFLA_TUN_PI
IFLA_TUN_VNET_HDR = C.IFLA_TUN_VNET_HDR
IFLA_TUN_PERSIST = C.IFLA_TUN_PERSIST
IFLA_TUN_MULTI_QUEUE = C.IFLA_TUN_MULTI_QUEUE
IFLA_TUN_NUM_QUEUES = C.IFLA_TUN_NUM_QUEUES
IFLA_TUN_NUM_DISABLED_QUEUES = C.IFLA_TUN_NUM_DISABLED_QUEUES
IFLA_RMNET_UNSPEC = C.IFLA_RMNET_UNSPEC
IFLA_RMNET_MUX_ID = C.IFLA_RMNET_MUX_ID
IFLA_RMNET_FLAGS = C.IFLA_RMNET_FLAGS
IFLA_MCTP_UNSPEC = C.IFLA_MCTP_UNSPEC
IFLA_MCTP_NET = C.IFLA_MCTP_NET
IFLA_DSA_UNSPEC = C.IFLA_DSA_UNSPEC
IFLA_DSA_CONDUIT = C.IFLA_DSA_CONDUIT
IFLA_DSA_MASTER = C.IFLA_DSA_MASTER
)
// netfilter
// generated using:
// perl -nlE '/^\s*(NF\w+)/ && say "$1 = C.$1"' /usr/include/linux/netfilter.h
const (
NF_INET_PRE_ROUTING = C.NF_INET_PRE_ROUTING
NF_INET_LOCAL_IN = C.NF_INET_LOCAL_IN
NF_INET_FORWARD = C.NF_INET_FORWARD
NF_INET_LOCAL_OUT = C.NF_INET_LOCAL_OUT
NF_INET_POST_ROUTING = C.NF_INET_POST_ROUTING
NF_INET_NUMHOOKS = C.NF_INET_NUMHOOKS
)
const (
NF_NETDEV_INGRESS = C.NF_NETDEV_INGRESS
NF_NETDEV_EGRESS = C.NF_NETDEV_EGRESS
NF_NETDEV_NUMHOOKS = C.NF_NETDEV_NUMHOOKS
)
const (
NFPROTO_UNSPEC = C.NFPROTO_UNSPEC
NFPROTO_INET = C.NFPROTO_INET
NFPROTO_IPV4 = C.NFPROTO_IPV4
NFPROTO_ARP = C.NFPROTO_ARP
NFPROTO_NETDEV = C.NFPROTO_NETDEV
NFPROTO_BRIDGE = C.NFPROTO_BRIDGE
NFPROTO_IPV6 = C.NFPROTO_IPV6
NFPROTO_DECNET = C.NFPROTO_DECNET
NFPROTO_NUMPROTO = C.NFPROTO_NUMPROTO
)
const SO_ORIGINAL_DST = C.SO_ORIGINAL_DST
// netfilter nfnetlink
type Nfgenmsg C.struct_nfgenmsg
const (
NFNL_BATCH_UNSPEC = C.NFNL_BATCH_UNSPEC
NFNL_BATCH_GENID = C.NFNL_BATCH_GENID
)
// netfilter nf_tables
// generated using:
// perl -nlE '/^\s*(NFT\w+)/ && say "$1 = C.$1"' /usr/include/linux/netfilter/nf_tables.h
const (
NFT_REG_VERDICT = C.NFT_REG_VERDICT
NFT_REG_1 = C.NFT_REG_1
NFT_REG_2 = C.NFT_REG_2
NFT_REG_3 = C.NFT_REG_3
NFT_REG_4 = C.NFT_REG_4
NFT_REG32_00 = C.NFT_REG32_00
NFT_REG32_01 = C.NFT_REG32_01
NFT_REG32_02 = C.NFT_REG32_02
NFT_REG32_03 = C.NFT_REG32_03
NFT_REG32_04 = C.NFT_REG32_04
NFT_REG32_05 = C.NFT_REG32_05
NFT_REG32_06 = C.NFT_REG32_06
NFT_REG32_07 = C.NFT_REG32_07
NFT_REG32_08 = C.NFT_REG32_08
NFT_REG32_09 = C.NFT_REG32_09
NFT_REG32_10 = C.NFT_REG32_10
NFT_REG32_11 = C.NFT_REG32_11
NFT_REG32_12 = C.NFT_REG32_12
NFT_REG32_13 = C.NFT_REG32_13
NFT_REG32_14 = C.NFT_REG32_14
NFT_REG32_15 = C.NFT_REG32_15
NFT_CONTINUE = C.NFT_CONTINUE
NFT_BREAK = C.NFT_BREAK
NFT_JUMP = C.NFT_JUMP
NFT_GOTO = C.NFT_GOTO
NFT_RETURN = C.NFT_RETURN
NFT_MSG_NEWTABLE = C.NFT_MSG_NEWTABLE
NFT_MSG_GETTABLE = C.NFT_MSG_GETTABLE
NFT_MSG_DELTABLE = C.NFT_MSG_DELTABLE
NFT_MSG_NEWCHAIN = C.NFT_MSG_NEWCHAIN
NFT_MSG_GETCHAIN = C.NFT_MSG_GETCHAIN
NFT_MSG_DELCHAIN = C.NFT_MSG_DELCHAIN
NFT_MSG_NEWRULE = C.NFT_MSG_NEWRULE
NFT_MSG_GETRULE = C.NFT_MSG_GETRULE
NFT_MSG_DELRULE = C.NFT_MSG_DELRULE
NFT_MSG_NEWSET = C.NFT_MSG_NEWSET
NFT_MSG_GETSET = C.NFT_MSG_GETSET
NFT_MSG_DELSET = C.NFT_MSG_DELSET
NFT_MSG_NEWSETELEM = C.NFT_MSG_NEWSETELEM
NFT_MSG_GETSETELEM = C.NFT_MSG_GETSETELEM
NFT_MSG_DELSETELEM = C.NFT_MSG_DELSETELEM
NFT_MSG_NEWGEN = C.NFT_MSG_NEWGEN
NFT_MSG_GETGEN = C.NFT_MSG_GETGEN
NFT_MSG_TRACE = C.NFT_MSG_TRACE
NFT_MSG_NEWOBJ = C.NFT_MSG_NEWOBJ
NFT_MSG_GETOBJ = C.NFT_MSG_GETOBJ
NFT_MSG_DELOBJ = C.NFT_MSG_DELOBJ
NFT_MSG_GETOBJ_RESET = C.NFT_MSG_GETOBJ_RESET
NFT_MSG_NEWFLOWTABLE = C.NFT_MSG_NEWFLOWTABLE
NFT_MSG_GETFLOWTABLE = C.NFT_MSG_GETFLOWTABLE
NFT_MSG_DELFLOWTABLE = C.NFT_MSG_DELFLOWTABLE
NFT_MSG_GETRULE_RESET = C.NFT_MSG_GETRULE_RESET
NFT_MSG_MAX = C.NFT_MSG_MAX
NFTA_LIST_UNSPEC = C.NFTA_LIST_UNSPEC
NFTA_LIST_ELEM = C.NFTA_LIST_ELEM
NFTA_HOOK_UNSPEC = C.NFTA_HOOK_UNSPEC
NFTA_HOOK_HOOKNUM = C.NFTA_HOOK_HOOKNUM
NFTA_HOOK_PRIORITY = C.NFTA_HOOK_PRIORITY
NFTA_HOOK_DEV = C.NFTA_HOOK_DEV
NFT_TABLE_F_DORMANT = C.NFT_TABLE_F_DORMANT
NFTA_TABLE_UNSPEC = C.NFTA_TABLE_UNSPEC
NFTA_TABLE_NAME = C.NFTA_TABLE_NAME
NFTA_TABLE_FLAGS = C.NFTA_TABLE_FLAGS
NFTA_TABLE_USE = C.NFTA_TABLE_USE
NFTA_CHAIN_UNSPEC = C.NFTA_CHAIN_UNSPEC
NFTA_CHAIN_TABLE = C.NFTA_CHAIN_TABLE
NFTA_CHAIN_HANDLE = C.NFTA_CHAIN_HANDLE
NFTA_CHAIN_NAME = C.NFTA_CHAIN_NAME
NFTA_CHAIN_HOOK = C.NFTA_CHAIN_HOOK
NFTA_CHAIN_POLICY = C.NFTA_CHAIN_POLICY
NFTA_CHAIN_USE = C.NFTA_CHAIN_USE
NFTA_CHAIN_TYPE = C.NFTA_CHAIN_TYPE
NFTA_CHAIN_COUNTERS = C.NFTA_CHAIN_COUNTERS
NFTA_CHAIN_PAD = C.NFTA_CHAIN_PAD
NFTA_RULE_UNSPEC = C.NFTA_RULE_UNSPEC
NFTA_RULE_TABLE = C.NFTA_RULE_TABLE
NFTA_RULE_CHAIN = C.NFTA_RULE_CHAIN
NFTA_RULE_HANDLE = C.NFTA_RULE_HANDLE
NFTA_RULE_EXPRESSIONS = C.NFTA_RULE_EXPRESSIONS
NFTA_RULE_COMPAT = C.NFTA_RULE_COMPAT
NFTA_RULE_POSITION = C.NFTA_RULE_POSITION
NFTA_RULE_USERDATA = C.NFTA_RULE_USERDATA
NFTA_RULE_PAD = C.NFTA_RULE_PAD
NFTA_RULE_ID = C.NFTA_RULE_ID
NFT_RULE_COMPAT_F_INV = C.NFT_RULE_COMPAT_F_INV
NFT_RULE_COMPAT_F_MASK = C.NFT_RULE_COMPAT_F_MASK
NFTA_RULE_COMPAT_UNSPEC = C.NFTA_RULE_COMPAT_UNSPEC
NFTA_RULE_COMPAT_PROTO = C.NFTA_RULE_COMPAT_PROTO
NFTA_RULE_COMPAT_FLAGS = C.NFTA_RULE_COMPAT_FLAGS
NFT_SET_ANONYMOUS = C.NFT_SET_ANONYMOUS
NFT_SET_CONSTANT = C.NFT_SET_CONSTANT
NFT_SET_INTERVAL = C.NFT_SET_INTERVAL
NFT_SET_MAP = C.NFT_SET_MAP
NFT_SET_TIMEOUT = C.NFT_SET_TIMEOUT
NFT_SET_EVAL = C.NFT_SET_EVAL
NFT_SET_OBJECT = C.NFT_SET_OBJECT
NFT_SET_POL_PERFORMANCE = C.NFT_SET_POL_PERFORMANCE
NFT_SET_POL_MEMORY = C.NFT_SET_POL_MEMORY
NFTA_SET_DESC_UNSPEC = C.NFTA_SET_DESC_UNSPEC
NFTA_SET_DESC_SIZE = C.NFTA_SET_DESC_SIZE
NFTA_SET_UNSPEC = C.NFTA_SET_UNSPEC
NFTA_SET_TABLE = C.NFTA_SET_TABLE
NFTA_SET_NAME = C.NFTA_SET_NAME
NFTA_SET_FLAGS = C.NFTA_SET_FLAGS
NFTA_SET_KEY_TYPE = C.NFTA_SET_KEY_TYPE
NFTA_SET_KEY_LEN = C.NFTA_SET_KEY_LEN
NFTA_SET_DATA_TYPE = C.NFTA_SET_DATA_TYPE
NFTA_SET_DATA_LEN = C.NFTA_SET_DATA_LEN
NFTA_SET_POLICY = C.NFTA_SET_POLICY
NFTA_SET_DESC = C.NFTA_SET_DESC
NFTA_SET_ID = C.NFTA_SET_ID
NFTA_SET_TIMEOUT = C.NFTA_SET_TIMEOUT
NFTA_SET_GC_INTERVAL = C.NFTA_SET_GC_INTERVAL
NFTA_SET_USERDATA = C.NFTA_SET_USERDATA
NFTA_SET_PAD = C.NFTA_SET_PAD
NFTA_SET_OBJ_TYPE = C.NFTA_SET_OBJ_TYPE
NFT_SET_ELEM_INTERVAL_END = C.NFT_SET_ELEM_INTERVAL_END
NFTA_SET_ELEM_UNSPEC = C.NFTA_SET_ELEM_UNSPEC
NFTA_SET_ELEM_KEY = C.NFTA_SET_ELEM_KEY
NFTA_SET_ELEM_DATA = C.NFTA_SET_ELEM_DATA
NFTA_SET_ELEM_FLAGS = C.NFTA_SET_ELEM_FLAGS
NFTA_SET_ELEM_TIMEOUT = C.NFTA_SET_ELEM_TIMEOUT
NFTA_SET_ELEM_EXPIRATION = C.NFTA_SET_ELEM_EXPIRATION
NFTA_SET_ELEM_USERDATA = C.NFTA_SET_ELEM_USERDATA
NFTA_SET_ELEM_EXPR = C.NFTA_SET_ELEM_EXPR
NFTA_SET_ELEM_PAD = C.NFTA_SET_ELEM_PAD
NFTA_SET_ELEM_OBJREF = C.NFTA_SET_ELEM_OBJREF
NFTA_SET_ELEM_LIST_UNSPEC = C.NFTA_SET_ELEM_LIST_UNSPEC
NFTA_SET_ELEM_LIST_TABLE = C.NFTA_SET_ELEM_LIST_TABLE
NFTA_SET_ELEM_LIST_SET = C.NFTA_SET_ELEM_LIST_SET
NFTA_SET_ELEM_LIST_ELEMENTS = C.NFTA_SET_ELEM_LIST_ELEMENTS
NFTA_SET_ELEM_LIST_SET_ID = C.NFTA_SET_ELEM_LIST_SET_ID
NFT_DATA_VALUE = C.NFT_DATA_VALUE
NFT_DATA_VERDICT = C.NFT_DATA_VERDICT
NFTA_DATA_UNSPEC = C.NFTA_DATA_UNSPEC
NFTA_DATA_VALUE = C.NFTA_DATA_VALUE
NFTA_DATA_VERDICT = C.NFTA_DATA_VERDICT
NFTA_VERDICT_UNSPEC = C.NFTA_VERDICT_UNSPEC
NFTA_VERDICT_CODE = C.NFTA_VERDICT_CODE
NFTA_VERDICT_CHAIN = C.NFTA_VERDICT_CHAIN
NFTA_EXPR_UNSPEC = C.NFTA_EXPR_UNSPEC
NFTA_EXPR_NAME = C.NFTA_EXPR_NAME
NFTA_EXPR_DATA = C.NFTA_EXPR_DATA
NFTA_IMMEDIATE_UNSPEC = C.NFTA_IMMEDIATE_UNSPEC
NFTA_IMMEDIATE_DREG = C.NFTA_IMMEDIATE_DREG
NFTA_IMMEDIATE_DATA = C.NFTA_IMMEDIATE_DATA
NFTA_BITWISE_UNSPEC = C.NFTA_BITWISE_UNSPEC
NFTA_BITWISE_SREG = C.NFTA_BITWISE_SREG
NFTA_BITWISE_DREG = C.NFTA_BITWISE_DREG
NFTA_BITWISE_LEN = C.NFTA_BITWISE_LEN
NFTA_BITWISE_MASK = C.NFTA_BITWISE_MASK
NFTA_BITWISE_XOR = C.NFTA_BITWISE_XOR
NFT_BYTEORDER_NTOH = C.NFT_BYTEORDER_NTOH
NFT_BYTEORDER_HTON = C.NFT_BYTEORDER_HTON
NFTA_BYTEORDER_UNSPEC = C.NFTA_BYTEORDER_UNSPEC
NFTA_BYTEORDER_SREG = C.NFTA_BYTEORDER_SREG
NFTA_BYTEORDER_DREG = C.NFTA_BYTEORDER_DREG
NFTA_BYTEORDER_OP = C.NFTA_BYTEORDER_OP
NFTA_BYTEORDER_LEN = C.NFTA_BYTEORDER_LEN
NFTA_BYTEORDER_SIZE = C.NFTA_BYTEORDER_SIZE
NFT_CMP_EQ = C.NFT_CMP_EQ
NFT_CMP_NEQ = C.NFT_CMP_NEQ
NFT_CMP_LT = C.NFT_CMP_LT
NFT_CMP_LTE = C.NFT_CMP_LTE
NFT_CMP_GT = C.NFT_CMP_GT
NFT_CMP_GTE = C.NFT_CMP_GTE
NFTA_CMP_UNSPEC = C.NFTA_CMP_UNSPEC
NFTA_CMP_SREG = C.NFTA_CMP_SREG
NFTA_CMP_OP = C.NFTA_CMP_OP
NFTA_CMP_DATA = C.NFTA_CMP_DATA
NFT_RANGE_EQ = C.NFT_RANGE_EQ
NFT_RANGE_NEQ = C.NFT_RANGE_NEQ
NFTA_RANGE_UNSPEC = C.NFTA_RANGE_UNSPEC
NFTA_RANGE_SREG = C.NFTA_RANGE_SREG
NFTA_RANGE_OP = C.NFTA_RANGE_OP
NFTA_RANGE_FROM_DATA = C.NFTA_RANGE_FROM_DATA
NFTA_RANGE_TO_DATA = C.NFTA_RANGE_TO_DATA
NFT_LOOKUP_F_INV = C.NFT_LOOKUP_F_INV
NFTA_LOOKUP_UNSPEC = C.NFTA_LOOKUP_UNSPEC
NFTA_LOOKUP_SET = C.NFTA_LOOKUP_SET
NFTA_LOOKUP_SREG = C.NFTA_LOOKUP_SREG
NFTA_LOOKUP_DREG = C.NFTA_LOOKUP_DREG
NFTA_LOOKUP_SET_ID = C.NFTA_LOOKUP_SET_ID
NFTA_LOOKUP_FLAGS = C.NFTA_LOOKUP_FLAGS
NFT_DYNSET_OP_ADD = C.NFT_DYNSET_OP_ADD
NFT_DYNSET_OP_UPDATE = C.NFT_DYNSET_OP_UPDATE
NFT_DYNSET_F_INV = C.NFT_DYNSET_F_INV
NFTA_DYNSET_UNSPEC = C.NFTA_DYNSET_UNSPEC
NFTA_DYNSET_SET_NAME = C.NFTA_DYNSET_SET_NAME
NFTA_DYNSET_SET_ID = C.NFTA_DYNSET_SET_ID
NFTA_DYNSET_OP = C.NFTA_DYNSET_OP
NFTA_DYNSET_SREG_KEY = C.NFTA_DYNSET_SREG_KEY
NFTA_DYNSET_SREG_DATA = C.NFTA_DYNSET_SREG_DATA
NFTA_DYNSET_TIMEOUT = C.NFTA_DYNSET_TIMEOUT
NFTA_DYNSET_EXPR = C.NFTA_DYNSET_EXPR
NFTA_DYNSET_PAD = C.NFTA_DYNSET_PAD
NFTA_DYNSET_FLAGS = C.NFTA_DYNSET_FLAGS
NFT_PAYLOAD_LL_HEADER = C.NFT_PAYLOAD_LL_HEADER
NFT_PAYLOAD_NETWORK_HEADER = C.NFT_PAYLOAD_NETWORK_HEADER
NFT_PAYLOAD_TRANSPORT_HEADER = C.NFT_PAYLOAD_TRANSPORT_HEADER
NFT_PAYLOAD_CSUM_NONE = C.NFT_PAYLOAD_CSUM_NONE
NFT_PAYLOAD_CSUM_INET = C.NFT_PAYLOAD_CSUM_INET
NFT_PAYLOAD_L4CSUM_PSEUDOHDR = C.NFT_PAYLOAD_L4CSUM_PSEUDOHDR
NFTA_PAYLOAD_UNSPEC = C.NFTA_PAYLOAD_UNSPEC
NFTA_PAYLOAD_DREG = C.NFTA_PAYLOAD_DREG
NFTA_PAYLOAD_BASE = C.NFTA_PAYLOAD_BASE
NFTA_PAYLOAD_OFFSET = C.NFTA_PAYLOAD_OFFSET
NFTA_PAYLOAD_LEN = C.NFTA_PAYLOAD_LEN
NFTA_PAYLOAD_SREG = C.NFTA_PAYLOAD_SREG
NFTA_PAYLOAD_CSUM_TYPE = C.NFTA_PAYLOAD_CSUM_TYPE
NFTA_PAYLOAD_CSUM_OFFSET = C.NFTA_PAYLOAD_CSUM_OFFSET
NFTA_PAYLOAD_CSUM_FLAGS = C.NFTA_PAYLOAD_CSUM_FLAGS
NFT_EXTHDR_F_PRESENT = C.NFT_EXTHDR_F_PRESENT
NFT_EXTHDR_OP_IPV6 = C.NFT_EXTHDR_OP_IPV6
NFT_EXTHDR_OP_TCPOPT = C.NFT_EXTHDR_OP_TCPOPT
NFTA_EXTHDR_UNSPEC = C.NFTA_EXTHDR_UNSPEC
NFTA_EXTHDR_DREG = C.NFTA_EXTHDR_DREG
NFTA_EXTHDR_TYPE = C.NFTA_EXTHDR_TYPE
NFTA_EXTHDR_OFFSET = C.NFTA_EXTHDR_OFFSET
NFTA_EXTHDR_LEN = C.NFTA_EXTHDR_LEN
NFTA_EXTHDR_FLAGS = C.NFTA_EXTHDR_FLAGS
NFTA_EXTHDR_OP = C.NFTA_EXTHDR_OP
NFTA_EXTHDR_SREG = C.NFTA_EXTHDR_SREG
NFT_META_LEN = C.NFT_META_LEN
NFT_META_PROTOCOL = C.NFT_META_PROTOCOL
NFT_META_PRIORITY = C.NFT_META_PRIORITY
NFT_META_MARK = C.NFT_META_MARK
NFT_META_IIF = C.NFT_META_IIF
NFT_META_OIF = C.NFT_META_OIF
NFT_META_IIFNAME = C.NFT_META_IIFNAME
NFT_META_OIFNAME = C.NFT_META_OIFNAME
NFT_META_IIFTYPE = C.NFT_META_IIFTYPE
NFT_META_OIFTYPE = C.NFT_META_OIFTYPE
NFT_META_SKUID = C.NFT_META_SKUID
NFT_META_SKGID = C.NFT_META_SKGID
NFT_META_NFTRACE = C.NFT_META_NFTRACE
NFT_META_RTCLASSID = C.NFT_META_RTCLASSID
NFT_META_SECMARK = C.NFT_META_SECMARK
NFT_META_NFPROTO = C.NFT_META_NFPROTO
NFT_META_L4PROTO = C.NFT_META_L4PROTO
NFT_META_BRI_IIFNAME = C.NFT_META_BRI_IIFNAME
NFT_META_BRI_OIFNAME = C.NFT_META_BRI_OIFNAME
NFT_META_PKTTYPE = C.NFT_META_PKTTYPE
NFT_META_CPU = C.NFT_META_CPU
NFT_META_IIFGROUP = C.NFT_META_IIFGROUP
NFT_META_OIFGROUP = C.NFT_META_OIFGROUP
NFT_META_CGROUP = C.NFT_META_CGROUP
NFT_META_PRANDOM = C.NFT_META_PRANDOM
NFT_RT_CLASSID = C.NFT_RT_CLASSID
NFT_RT_NEXTHOP4 = C.NFT_RT_NEXTHOP4
NFT_RT_NEXTHOP6 = C.NFT_RT_NEXTHOP6
NFT_RT_TCPMSS = C.NFT_RT_TCPMSS
NFT_HASH_JENKINS = C.NFT_HASH_JENKINS
NFT_HASH_SYM = C.NFT_HASH_SYM
NFTA_HASH_UNSPEC = C.NFTA_HASH_UNSPEC
NFTA_HASH_SREG = C.NFTA_HASH_SREG
NFTA_HASH_DREG = C.NFTA_HASH_DREG
NFTA_HASH_LEN = C.NFTA_HASH_LEN
NFTA_HASH_MODULUS = C.NFTA_HASH_MODULUS
NFTA_HASH_SEED = C.NFTA_HASH_SEED
NFTA_HASH_OFFSET = C.NFTA_HASH_OFFSET
NFTA_HASH_TYPE = C.NFTA_HASH_TYPE
NFTA_META_UNSPEC = C.NFTA_META_UNSPEC
NFTA_META_DREG = C.NFTA_META_DREG
NFTA_META_KEY = C.NFTA_META_KEY
NFTA_META_SREG = C.NFTA_META_SREG
NFTA_RT_UNSPEC = C.NFTA_RT_UNSPEC
NFTA_RT_DREG = C.NFTA_RT_DREG
NFTA_RT_KEY = C.NFTA_RT_KEY
NFT_CT_STATE = C.NFT_CT_STATE
NFT_CT_DIRECTION = C.NFT_CT_DIRECTION
NFT_CT_STATUS = C.NFT_CT_STATUS
NFT_CT_MARK = C.NFT_CT_MARK
NFT_CT_SECMARK = C.NFT_CT_SECMARK
NFT_CT_EXPIRATION = C.NFT_CT_EXPIRATION
NFT_CT_HELPER = C.NFT_CT_HELPER
NFT_CT_L3PROTOCOL = C.NFT_CT_L3PROTOCOL
NFT_CT_SRC = C.NFT_CT_SRC
NFT_CT_DST = C.NFT_CT_DST
NFT_CT_PROTOCOL = C.NFT_CT_PROTOCOL
NFT_CT_PROTO_SRC = C.NFT_CT_PROTO_SRC
NFT_CT_PROTO_DST = C.NFT_CT_PROTO_DST
NFT_CT_LABELS = C.NFT_CT_LABELS
NFT_CT_PKTS = C.NFT_CT_PKTS
NFT_CT_BYTES = C.NFT_CT_BYTES
NFT_CT_AVGPKT = C.NFT_CT_AVGPKT
NFT_CT_ZONE = C.NFT_CT_ZONE
NFT_CT_EVENTMASK = C.NFT_CT_EVENTMASK
NFTA_CT_UNSPEC = C.NFTA_CT_UNSPEC
NFTA_CT_DREG = C.NFTA_CT_DREG
NFTA_CT_KEY = C.NFTA_CT_KEY
NFTA_CT_DIRECTION = C.NFTA_CT_DIRECTION
NFTA_CT_SREG = C.NFTA_CT_SREG
NFT_LIMIT_PKTS = C.NFT_LIMIT_PKTS
NFT_LIMIT_PKT_BYTES = C.NFT_LIMIT_PKT_BYTES
NFT_LIMIT_F_INV = C.NFT_LIMIT_F_INV
NFTA_LIMIT_UNSPEC = C.NFTA_LIMIT_UNSPEC
NFTA_LIMIT_RATE = C.NFTA_LIMIT_RATE
NFTA_LIMIT_UNIT = C.NFTA_LIMIT_UNIT
NFTA_LIMIT_BURST = C.NFTA_LIMIT_BURST
NFTA_LIMIT_TYPE = C.NFTA_LIMIT_TYPE
NFTA_LIMIT_FLAGS = C.NFTA_LIMIT_FLAGS
NFTA_LIMIT_PAD = C.NFTA_LIMIT_PAD
NFTA_COUNTER_UNSPEC = C.NFTA_COUNTER_UNSPEC
NFTA_COUNTER_BYTES = C.NFTA_COUNTER_BYTES
NFTA_COUNTER_PACKETS = C.NFTA_COUNTER_PACKETS
NFTA_COUNTER_PAD = C.NFTA_COUNTER_PAD
NFTA_LOG_UNSPEC = C.NFTA_LOG_UNSPEC
NFTA_LOG_GROUP = C.NFTA_LOG_GROUP
NFTA_LOG_PREFIX = C.NFTA_LOG_PREFIX
NFTA_LOG_SNAPLEN = C.NFTA_LOG_SNAPLEN
NFTA_LOG_QTHRESHOLD = C.NFTA_LOG_QTHRESHOLD
NFTA_LOG_LEVEL = C.NFTA_LOG_LEVEL
NFTA_LOG_FLAGS = C.NFTA_LOG_FLAGS
NFTA_QUEUE_UNSPEC = C.NFTA_QUEUE_UNSPEC
NFTA_QUEUE_NUM = C.NFTA_QUEUE_NUM
NFTA_QUEUE_TOTAL = C.NFTA_QUEUE_TOTAL
NFTA_QUEUE_FLAGS = C.NFTA_QUEUE_FLAGS
NFTA_QUEUE_SREG_QNUM = C.NFTA_QUEUE_SREG_QNUM
NFT_QUOTA_F_INV = C.NFT_QUOTA_F_INV
NFT_QUOTA_F_DEPLETED = C.NFT_QUOTA_F_DEPLETED
NFTA_QUOTA_UNSPEC = C.NFTA_QUOTA_UNSPEC
NFTA_QUOTA_BYTES = C.NFTA_QUOTA_BYTES
NFTA_QUOTA_FLAGS = C.NFTA_QUOTA_FLAGS
NFTA_QUOTA_PAD = C.NFTA_QUOTA_PAD
NFTA_QUOTA_CONSUMED = C.NFTA_QUOTA_CONSUMED
NFT_REJECT_ICMP_UNREACH = C.NFT_REJECT_ICMP_UNREACH
NFT_REJECT_TCP_RST = C.NFT_REJECT_TCP_RST
NFT_REJECT_ICMPX_UNREACH = C.NFT_REJECT_ICMPX_UNREACH
NFT_REJECT_ICMPX_NO_ROUTE = C.NFT_REJECT_ICMPX_NO_ROUTE
NFT_REJECT_ICMPX_PORT_UNREACH = C.NFT_REJECT_ICMPX_PORT_UNREACH
NFT_REJECT_ICMPX_HOST_UNREACH = C.NFT_REJECT_ICMPX_HOST_UNREACH
NFT_REJECT_ICMPX_ADMIN_PROHIBITED = C.NFT_REJECT_ICMPX_ADMIN_PROHIBITED
NFTA_REJECT_UNSPEC = C.NFTA_REJECT_UNSPEC
NFTA_REJECT_TYPE = C.NFTA_REJECT_TYPE
NFTA_REJECT_ICMP_CODE = C.NFTA_REJECT_ICMP_CODE
NFT_NAT_SNAT = C.NFT_NAT_SNAT
NFT_NAT_DNAT = C.NFT_NAT_DNAT
NFTA_NAT_UNSPEC = C.NFTA_NAT_UNSPEC
NFTA_NAT_TYPE = C.NFTA_NAT_TYPE
NFTA_NAT_FAMILY = C.NFTA_NAT_FAMILY
NFTA_NAT_REG_ADDR_MIN = C.NFTA_NAT_REG_ADDR_MIN
NFTA_NAT_REG_ADDR_MAX = C.NFTA_NAT_REG_ADDR_MAX
NFTA_NAT_REG_PROTO_MIN = C.NFTA_NAT_REG_PROTO_MIN
NFTA_NAT_REG_PROTO_MAX = C.NFTA_NAT_REG_PROTO_MAX
NFTA_NAT_FLAGS = C.NFTA_NAT_FLAGS
NFTA_MASQ_UNSPEC = C.NFTA_MASQ_UNSPEC
NFTA_MASQ_FLAGS = C.NFTA_MASQ_FLAGS
NFTA_MASQ_REG_PROTO_MIN = C.NFTA_MASQ_REG_PROTO_MIN
NFTA_MASQ_REG_PROTO_MAX = C.NFTA_MASQ_REG_PROTO_MAX
NFTA_REDIR_UNSPEC = C.NFTA_REDIR_UNSPEC
NFTA_REDIR_REG_PROTO_MIN = C.NFTA_REDIR_REG_PROTO_MIN
NFTA_REDIR_REG_PROTO_MAX = C.NFTA_REDIR_REG_PROTO_MAX
NFTA_REDIR_FLAGS = C.NFTA_REDIR_FLAGS
NFTA_DUP_UNSPEC = C.NFTA_DUP_UNSPEC
NFTA_DUP_SREG_ADDR = C.NFTA_DUP_SREG_ADDR
NFTA_DUP_SREG_DEV = C.NFTA_DUP_SREG_DEV
NFTA_FWD_UNSPEC = C.NFTA_FWD_UNSPEC
NFTA_FWD_SREG_DEV = C.NFTA_FWD_SREG_DEV
NFTA_OBJREF_UNSPEC = C.NFTA_OBJREF_UNSPEC
NFTA_OBJREF_IMM_TYPE = C.NFTA_OBJREF_IMM_TYPE
NFTA_OBJREF_IMM_NAME = C.NFTA_OBJREF_IMM_NAME
NFTA_OBJREF_SET_SREG = C.NFTA_OBJREF_SET_SREG
NFTA_OBJREF_SET_NAME = C.NFTA_OBJREF_SET_NAME
NFTA_OBJREF_SET_ID = C.NFTA_OBJREF_SET_ID
NFTA_GEN_UNSPEC = C.NFTA_GEN_UNSPEC
NFTA_GEN_ID = C.NFTA_GEN_ID
NFTA_GEN_PROC_PID = C.NFTA_GEN_PROC_PID
NFTA_GEN_PROC_NAME = C.NFTA_GEN_PROC_NAME
NFTA_FIB_UNSPEC = C.NFTA_FIB_UNSPEC
NFTA_FIB_DREG = C.NFTA_FIB_DREG
NFTA_FIB_RESULT = C.NFTA_FIB_RESULT
NFTA_FIB_FLAGS = C.NFTA_FIB_FLAGS
NFT_FIB_RESULT_UNSPEC = C.NFT_FIB_RESULT_UNSPEC
NFT_FIB_RESULT_OIF = C.NFT_FIB_RESULT_OIF
NFT_FIB_RESULT_OIFNAME = C.NFT_FIB_RESULT_OIFNAME
NFT_FIB_RESULT_ADDRTYPE = C.NFT_FIB_RESULT_ADDRTYPE
NFTA_FIB_F_SADDR = C.NFTA_FIB_F_SADDR
NFTA_FIB_F_DADDR = C.NFTA_FIB_F_DADDR
NFTA_FIB_F_MARK = C.NFTA_FIB_F_MARK
NFTA_FIB_F_IIF = C.NFTA_FIB_F_IIF
NFTA_FIB_F_OIF = C.NFTA_FIB_F_OIF
NFTA_FIB_F_PRESENT = C.NFTA_FIB_F_PRESENT
NFTA_CT_HELPER_UNSPEC = C.NFTA_CT_HELPER_UNSPEC
NFTA_CT_HELPER_NAME = C.NFTA_CT_HELPER_NAME
NFTA_CT_HELPER_L3PROTO = C.NFTA_CT_HELPER_L3PROTO
NFTA_CT_HELPER_L4PROTO = C.NFTA_CT_HELPER_L4PROTO
NFTA_OBJ_UNSPEC = C.NFTA_OBJ_UNSPEC
NFTA_OBJ_TABLE = C.NFTA_OBJ_TABLE
NFTA_OBJ_NAME = C.NFTA_OBJ_NAME
NFTA_OBJ_TYPE = C.NFTA_OBJ_TYPE
NFTA_OBJ_DATA = C.NFTA_OBJ_DATA
NFTA_OBJ_USE = C.NFTA_OBJ_USE
NFTA_TRACE_UNSPEC = C.NFTA_TRACE_UNSPEC
NFTA_TRACE_TABLE = C.NFTA_TRACE_TABLE
NFTA_TRACE_CHAIN = C.NFTA_TRACE_CHAIN
NFTA_TRACE_RULE_HANDLE = C.NFTA_TRACE_RULE_HANDLE
NFTA_TRACE_TYPE = C.NFTA_TRACE_TYPE
NFTA_TRACE_VERDICT = C.NFTA_TRACE_VERDICT
NFTA_TRACE_ID = C.NFTA_TRACE_ID
NFTA_TRACE_LL_HEADER = C.NFTA_TRACE_LL_HEADER
NFTA_TRACE_NETWORK_HEADER = C.NFTA_TRACE_NETWORK_HEADER
NFTA_TRACE_TRANSPORT_HEADER = C.NFTA_TRACE_TRANSPORT_HEADER
NFTA_TRACE_IIF = C.NFTA_TRACE_IIF
NFTA_TRACE_IIFTYPE = C.NFTA_TRACE_IIFTYPE
NFTA_TRACE_OIF = C.NFTA_TRACE_OIF
NFTA_TRACE_OIFTYPE = C.NFTA_TRACE_OIFTYPE
NFTA_TRACE_MARK = C.NFTA_TRACE_MARK
NFTA_TRACE_NFPROTO = C.NFTA_TRACE_NFPROTO
NFTA_TRACE_POLICY = C.NFTA_TRACE_POLICY
NFTA_TRACE_PAD = C.NFTA_TRACE_PAD
NFT_TRACETYPE_UNSPEC = C.NFT_TRACETYPE_UNSPEC
NFT_TRACETYPE_POLICY = C.NFT_TRACETYPE_POLICY
NFT_TRACETYPE_RETURN = C.NFT_TRACETYPE_RETURN
NFT_TRACETYPE_RULE = C.NFT_TRACETYPE_RULE
NFTA_NG_UNSPEC = C.NFTA_NG_UNSPEC
NFTA_NG_DREG = C.NFTA_NG_DREG
NFTA_NG_MODULUS = C.NFTA_NG_MODULUS
NFTA_NG_TYPE = C.NFTA_NG_TYPE
NFTA_NG_OFFSET = C.NFTA_NG_OFFSET
NFT_NG_INCREMENTAL = C.NFT_NG_INCREMENTAL
NFT_NG_RANDOM = C.NFT_NG_RANDOM
)
// netfilter nf_tables_compat
// generated using:
// perl -nlE '/^\s*(NFT\w+)/ && say "$1 = C.$1"' /usr/include/linux/netfilter/nf_tables_compat.h
const (
NFTA_TARGET_UNSPEC = C.NFTA_TARGET_UNSPEC
NFTA_TARGET_NAME = C.NFTA_TARGET_NAME
NFTA_TARGET_REV = C.NFTA_TARGET_REV
NFTA_TARGET_INFO = C.NFTA_TARGET_INFO
NFTA_MATCH_UNSPEC = C.NFTA_MATCH_UNSPEC
NFTA_MATCH_NAME = C.NFTA_MATCH_NAME
NFTA_MATCH_REV = C.NFTA_MATCH_REV
NFTA_MATCH_INFO = C.NFTA_MATCH_INFO
NFTA_COMPAT_UNSPEC = C.NFTA_COMPAT_UNSPEC
NFTA_COMPAT_NAME = C.NFTA_COMPAT_NAME
NFTA_COMPAT_REV = C.NFTA_COMPAT_REV
NFTA_COMPAT_TYPE = C.NFTA_COMPAT_TYPE
)
type RTCTime C.struct_rtc_time
type RTCWkAlrm C.struct_rtc_wkalrm
type RTCPLLInfo C.struct_rtc_pll_info
// BLKPG ioctl:
type BlkpgIoctlArg C.struct_blkpg_ioctl_arg
type BlkpgPartition C.struct_my_blkpg_partition
const (
BLKPG = C.BLKPG
BLKPG_ADD_PARTITION = C.BLKPG_ADD_PARTITION
BLKPG_DEL_PARTITION = C.BLKPG_DEL_PARTITION
BLKPG_RESIZE_PARTITION = C.BLKPG_RESIZE_PARTITION
)
// netlink namespace
// generated from
// perl -nlE '/^\s*(NETNSA\w+)/ && say "$1 = C.$1"' /usr/include/linux/net_namespace.h
const (
NETNSA_NONE = C.NETNSA_NONE
NETNSA_NSID = C.NETNSA_NSID
NETNSA_PID = C.NETNSA_PID
NETNSA_FD = C.NETNSA_FD
NETNSA_TARGET_NSID = C.NETNSA_TARGET_NSID
NETNSA_CURRENT_NSID = C.NETNSA_CURRENT_NSID
)
// AF_XDP:
type XDPRingOffset C.struct_xdp_ring_offset
type XDPMmapOffsets C.struct_xdp_mmap_offsets
type XDPUmemReg C.struct_xdp_umem_reg
type XDPStatistics C.struct_xdp_statistics
type XDPDesc C.struct_xdp_desc
// NCSI generic netlink:
const (
NCSI_CMD_UNSPEC = C.NCSI_CMD_UNSPEC
NCSI_CMD_PKG_INFO = C.NCSI_CMD_PKG_INFO
NCSI_CMD_SET_INTERFACE = C.NCSI_CMD_SET_INTERFACE
NCSI_CMD_CLEAR_INTERFACE = C.NCSI_CMD_CLEAR_INTERFACE
NCSI_ATTR_UNSPEC = C.NCSI_ATTR_UNSPEC
NCSI_ATTR_IFINDEX = C.NCSI_ATTR_IFINDEX
NCSI_ATTR_PACKAGE_LIST = C.NCSI_ATTR_PACKAGE_LIST
NCSI_ATTR_PACKAGE_ID = C.NCSI_ATTR_PACKAGE_ID
NCSI_ATTR_CHANNEL_ID = C.NCSI_ATTR_CHANNEL_ID
NCSI_PKG_ATTR_UNSPEC = C.NCSI_PKG_ATTR_UNSPEC
NCSI_PKG_ATTR = C.NCSI_PKG_ATTR
NCSI_PKG_ATTR_ID = C.NCSI_PKG_ATTR_ID
NCSI_PKG_ATTR_FORCED = C.NCSI_PKG_ATTR_FORCED
NCSI_PKG_ATTR_CHANNEL_LIST = C.NCSI_PKG_ATTR_CHANNEL_LIST
NCSI_CHANNEL_ATTR_UNSPEC = C.NCSI_CHANNEL_ATTR_UNSPEC
NCSI_CHANNEL_ATTR = C.NCSI_CHANNEL_ATTR
NCSI_CHANNEL_ATTR_ID = C.NCSI_CHANNEL_ATTR_ID
NCSI_CHANNEL_ATTR_VERSION_MAJOR = C.NCSI_CHANNEL_ATTR_VERSION_MAJOR
NCSI_CHANNEL_ATTR_VERSION_MINOR = C.NCSI_CHANNEL_ATTR_VERSION_MINOR
NCSI_CHANNEL_ATTR_VERSION_STR = C.NCSI_CHANNEL_ATTR_VERSION_STR
NCSI_CHANNEL_ATTR_LINK_STATE = C.NCSI_CHANNEL_ATTR_LINK_STATE
NCSI_CHANNEL_ATTR_ACTIVE = C.NCSI_CHANNEL_ATTR_ACTIVE
NCSI_CHANNEL_ATTR_FORCED = C.NCSI_CHANNEL_ATTR_FORCED
NCSI_CHANNEL_ATTR_VLAN_LIST = C.NCSI_CHANNEL_ATTR_VLAN_LIST
NCSI_CHANNEL_ATTR_VLAN_ID = C.NCSI_CHANNEL_ATTR_VLAN_ID
)
// Timestamping
type ScmTimestamping C.struct_scm_timestamping
const (
SOF_TIMESTAMPING_TX_HARDWARE = C.SOF_TIMESTAMPING_TX_HARDWARE
SOF_TIMESTAMPING_TX_SOFTWARE = C.SOF_TIMESTAMPING_TX_SOFTWARE
SOF_TIMESTAMPING_RX_HARDWARE = C.SOF_TIMESTAMPING_RX_HARDWARE
SOF_TIMESTAMPING_RX_SOFTWARE = C.SOF_TIMESTAMPING_RX_SOFTWARE
SOF_TIMESTAMPING_SOFTWARE = C.SOF_TIMESTAMPING_SOFTWARE
SOF_TIMESTAMPING_SYS_HARDWARE = C.SOF_TIMESTAMPING_SYS_HARDWARE
SOF_TIMESTAMPING_RAW_HARDWARE = C.SOF_TIMESTAMPING_RAW_HARDWARE
SOF_TIMESTAMPING_OPT_ID = C.SOF_TIMESTAMPING_OPT_ID
SOF_TIMESTAMPING_TX_SCHED = C.SOF_TIMESTAMPING_TX_SCHED
SOF_TIMESTAMPING_TX_ACK = C.SOF_TIMESTAMPING_TX_ACK
SOF_TIMESTAMPING_OPT_CMSG = C.SOF_TIMESTAMPING_OPT_CMSG
SOF_TIMESTAMPING_OPT_TSONLY = C.SOF_TIMESTAMPING_OPT_TSONLY
SOF_TIMESTAMPING_OPT_STATS = C.SOF_TIMESTAMPING_OPT_STATS
SOF_TIMESTAMPING_OPT_PKTINFO = C.SOF_TIMESTAMPING_OPT_PKTINFO
SOF_TIMESTAMPING_OPT_TX_SWHW = C.SOF_TIMESTAMPING_OPT_TX_SWHW
SOF_TIMESTAMPING_BIND_PHC = C.SOF_TIMESTAMPING_BIND_PHC
SOF_TIMESTAMPING_OPT_ID_TCP = C.SOF_TIMESTAMPING_OPT_ID_TCP
SOF_TIMESTAMPING_LAST = C.SOF_TIMESTAMPING_LAST
SOF_TIMESTAMPING_MASK = C.SOF_TIMESTAMPING_MASK
SCM_TSTAMP_SND = C.SCM_TSTAMP_SND
SCM_TSTAMP_SCHED = C.SCM_TSTAMP_SCHED
SCM_TSTAMP_ACK = C.SCM_TSTAMP_ACK
)
// Socket error queue
type SockExtendedErr C.struct_sock_extended_err
// Fanotify
type FanotifyEventMetadata C.struct_fanotify_event_metadata
type FanotifyResponse C.struct_fanotify_response
// Crypto user configuration API.
const (
CRYPTO_MSG_BASE = C.CRYPTO_MSG_BASE
CRYPTO_MSG_NEWALG = C.CRYPTO_MSG_NEWALG
CRYPTO_MSG_DELALG = C.CRYPTO_MSG_DELALG
CRYPTO_MSG_UPDATEALG = C.CRYPTO_MSG_UPDATEALG
CRYPTO_MSG_GETALG = C.CRYPTO_MSG_GETALG
CRYPTO_MSG_DELRNG = C.CRYPTO_MSG_DELRNG
CRYPTO_MSG_GETSTAT = C.CRYPTO_MSG_GETSTAT
)
const (
CRYPTOCFGA_UNSPEC = C.CRYPTOCFGA_UNSPEC
CRYPTOCFGA_PRIORITY_VAL = C.CRYPTOCFGA_PRIORITY_VAL
CRYPTOCFGA_REPORT_LARVAL = C.CRYPTOCFGA_REPORT_LARVAL
CRYPTOCFGA_REPORT_HASH = C.CRYPTOCFGA_REPORT_HASH
CRYPTOCFGA_REPORT_BLKCIPHER = C.CRYPTOCFGA_REPORT_BLKCIPHER
CRYPTOCFGA_REPORT_AEAD = C.CRYPTOCFGA_REPORT_AEAD
CRYPTOCFGA_REPORT_COMPRESS = C.CRYPTOCFGA_REPORT_COMPRESS
CRYPTOCFGA_REPORT_RNG = C.CRYPTOCFGA_REPORT_RNG
CRYPTOCFGA_REPORT_CIPHER = C.CRYPTOCFGA_REPORT_CIPHER
CRYPTOCFGA_REPORT_AKCIPHER = C.CRYPTOCFGA_REPORT_AKCIPHER
CRYPTOCFGA_REPORT_KPP = C.CRYPTOCFGA_REPORT_KPP
CRYPTOCFGA_REPORT_ACOMP = C.CRYPTOCFGA_REPORT_ACOMP
CRYPTOCFGA_STAT_LARVAL = C.CRYPTOCFGA_STAT_LARVAL
CRYPTOCFGA_STAT_HASH = C.CRYPTOCFGA_STAT_HASH
CRYPTOCFGA_STAT_BLKCIPHER = C.CRYPTOCFGA_STAT_BLKCIPHER
CRYPTOCFGA_STAT_AEAD = C.CRYPTOCFGA_STAT_AEAD
CRYPTOCFGA_STAT_COMPRESS = C.CRYPTOCFGA_STAT_COMPRESS
CRYPTOCFGA_STAT_RNG = C.CRYPTOCFGA_STAT_RNG
CRYPTOCFGA_STAT_CIPHER = C.CRYPTOCFGA_STAT_CIPHER
CRYPTOCFGA_STAT_AKCIPHER = C.CRYPTOCFGA_STAT_AKCIPHER
CRYPTOCFGA_STAT_KPP = C.CRYPTOCFGA_STAT_KPP
CRYPTOCFGA_STAT_ACOMP = C.CRYPTOCFGA_STAT_ACOMP
)
type CryptoUserAlg C.struct_crypto_user_alg
type CryptoStatAEAD C.struct_crypto_stat_aead
type CryptoStatAKCipher C.struct_crypto_stat_akcipher
type CryptoStatCipher C.struct_crypto_stat_cipher
type CryptoStatCompress C.struct_crypto_stat_compress
type CryptoStatHash C.struct_crypto_stat_hash
type CryptoStatKPP C.struct_crypto_stat_kpp
type CryptoStatRNG C.struct_crypto_stat_rng
type CryptoStatLarval C.struct_crypto_stat_larval
type CryptoReportLarval C.struct_crypto_report_larval
type CryptoReportHash C.struct_crypto_report_hash
type CryptoReportCipher C.struct_crypto_report_cipher
type CryptoReportBlkCipher C.struct_crypto_report_blkcipher
type CryptoReportAEAD C.struct_crypto_report_aead
type CryptoReportComp C.struct_crypto_report_comp
type CryptoReportRNG C.struct_crypto_report_rng
type CryptoReportAKCipher C.struct_crypto_report_akcipher
type CryptoReportKPP C.struct_crypto_report_kpp
type CryptoReportAcomp C.struct_crypto_report_acomp
// generated by:
// perl -nlE '/^\s*((TCP_)?BPF_\w+)/ && say "$1 = C.$1"' include/uapi/linux/bpf.h
const (
BPF_REG_0 = C.BPF_REG_0
BPF_REG_1 = C.BPF_REG_1
BPF_REG_2 = C.BPF_REG_2
BPF_REG_3 = C.BPF_REG_3
BPF_REG_4 = C.BPF_REG_4
BPF_REG_5 = C.BPF_REG_5
BPF_REG_6 = C.BPF_REG_6
BPF_REG_7 = C.BPF_REG_7
BPF_REG_8 = C.BPF_REG_8
BPF_REG_9 = C.BPF_REG_9
BPF_REG_10 = C.BPF_REG_10
BPF_CGROUP_ITER_ORDER_UNSPEC = C.BPF_CGROUP_ITER_ORDER_UNSPEC
BPF_CGROUP_ITER_SELF_ONLY = C.BPF_CGROUP_ITER_SELF_ONLY
BPF_CGROUP_ITER_DESCENDANTS_PRE = C.BPF_CGROUP_ITER_DESCENDANTS_PRE
BPF_CGROUP_ITER_DESCENDANTS_POST = C.BPF_CGROUP_ITER_DESCENDANTS_POST
BPF_CGROUP_ITER_ANCESTORS_UP = C.BPF_CGROUP_ITER_ANCESTORS_UP
BPF_MAP_CREATE = C.BPF_MAP_CREATE
BPF_MAP_LOOKUP_ELEM = C.BPF_MAP_LOOKUP_ELEM
BPF_MAP_UPDATE_ELEM = C.BPF_MAP_UPDATE_ELEM
BPF_MAP_DELETE_ELEM = C.BPF_MAP_DELETE_ELEM
BPF_MAP_GET_NEXT_KEY = C.BPF_MAP_GET_NEXT_KEY
BPF_PROG_LOAD = C.BPF_PROG_LOAD
BPF_OBJ_PIN = C.BPF_OBJ_PIN
BPF_OBJ_GET = C.BPF_OBJ_GET
BPF_PROG_ATTACH = C.BPF_PROG_ATTACH
BPF_PROG_DETACH = C.BPF_PROG_DETACH
BPF_PROG_TEST_RUN = C.BPF_PROG_TEST_RUN
BPF_PROG_RUN = C.BPF_PROG_RUN
BPF_PROG_GET_NEXT_ID = C.BPF_PROG_GET_NEXT_ID
BPF_MAP_GET_NEXT_ID = C.BPF_MAP_GET_NEXT_ID
BPF_PROG_GET_FD_BY_ID = C.BPF_PROG_GET_FD_BY_ID
BPF_MAP_GET_FD_BY_ID = C.BPF_MAP_GET_FD_BY_ID
BPF_OBJ_GET_INFO_BY_FD = C.BPF_OBJ_GET_INFO_BY_FD
BPF_PROG_QUERY = C.BPF_PROG_QUERY
BPF_RAW_TRACEPOINT_OPEN = C.BPF_RAW_TRACEPOINT_OPEN
BPF_BTF_LOAD = C.BPF_BTF_LOAD
BPF_BTF_GET_FD_BY_ID = C.BPF_BTF_GET_FD_BY_ID
BPF_TASK_FD_QUERY = C.BPF_TASK_FD_QUERY
BPF_MAP_LOOKUP_AND_DELETE_ELEM = C.BPF_MAP_LOOKUP_AND_DELETE_ELEM
BPF_MAP_FREEZE = C.BPF_MAP_FREEZE
BPF_BTF_GET_NEXT_ID = C.BPF_BTF_GET_NEXT_ID
BPF_MAP_LOOKUP_BATCH = C.BPF_MAP_LOOKUP_BATCH
BPF_MAP_LOOKUP_AND_DELETE_BATCH = C.BPF_MAP_LOOKUP_AND_DELETE_BATCH
BPF_MAP_UPDATE_BATCH = C.BPF_MAP_UPDATE_BATCH
BPF_MAP_DELETE_BATCH = C.BPF_MAP_DELETE_BATCH
BPF_LINK_CREATE = C.BPF_LINK_CREATE
BPF_LINK_UPDATE = C.BPF_LINK_UPDATE
BPF_LINK_GET_FD_BY_ID = C.BPF_LINK_GET_FD_BY_ID
BPF_LINK_GET_NEXT_ID = C.BPF_LINK_GET_NEXT_ID
BPF_ENABLE_STATS = C.BPF_ENABLE_STATS
BPF_ITER_CREATE = C.BPF_ITER_CREATE
BPF_LINK_DETACH = C.BPF_LINK_DETACH
BPF_PROG_BIND_MAP = C.BPF_PROG_BIND_MAP
BPF_MAP_TYPE_UNSPEC = C.BPF_MAP_TYPE_UNSPEC
BPF_MAP_TYPE_HASH = C.BPF_MAP_TYPE_HASH
BPF_MAP_TYPE_ARRAY = C.BPF_MAP_TYPE_ARRAY
BPF_MAP_TYPE_PROG_ARRAY = C.BPF_MAP_TYPE_PROG_ARRAY
BPF_MAP_TYPE_PERF_EVENT_ARRAY = C.BPF_MAP_TYPE_PERF_EVENT_ARRAY
BPF_MAP_TYPE_PERCPU_HASH = C.BPF_MAP_TYPE_PERCPU_HASH
BPF_MAP_TYPE_PERCPU_ARRAY = C.BPF_MAP_TYPE_PERCPU_ARRAY
BPF_MAP_TYPE_STACK_TRACE = C.BPF_MAP_TYPE_STACK_TRACE
BPF_MAP_TYPE_CGROUP_ARRAY = C.BPF_MAP_TYPE_CGROUP_ARRAY
BPF_MAP_TYPE_LRU_HASH = C.BPF_MAP_TYPE_LRU_HASH
BPF_MAP_TYPE_LRU_PERCPU_HASH = C.BPF_MAP_TYPE_LRU_PERCPU_HASH
BPF_MAP_TYPE_LPM_TRIE = C.BPF_MAP_TYPE_LPM_TRIE
BPF_MAP_TYPE_ARRAY_OF_MAPS = C.BPF_MAP_TYPE_ARRAY_OF_MAPS
BPF_MAP_TYPE_HASH_OF_MAPS = C.BPF_MAP_TYPE_HASH_OF_MAPS
BPF_MAP_TYPE_DEVMAP = C.BPF_MAP_TYPE_DEVMAP
BPF_MAP_TYPE_SOCKMAP = C.BPF_MAP_TYPE_SOCKMAP
BPF_MAP_TYPE_CPUMAP = C.BPF_MAP_TYPE_CPUMAP
BPF_MAP_TYPE_XSKMAP = C.BPF_MAP_TYPE_XSKMAP
BPF_MAP_TYPE_SOCKHASH = C.BPF_MAP_TYPE_SOCKHASH
BPF_MAP_TYPE_CGROUP_STORAGE_DEPRECATED = C.BPF_MAP_TYPE_CGROUP_STORAGE_DEPRECATED
BPF_MAP_TYPE_CGROUP_STORAGE = C.BPF_MAP_TYPE_CGROUP_STORAGE
BPF_MAP_TYPE_REUSEPORT_SOCKARRAY = C.BPF_MAP_TYPE_REUSEPORT_SOCKARRAY
BPF_MAP_TYPE_PERCPU_CGROUP_STORAGE = C.BPF_MAP_TYPE_PERCPU_CGROUP_STORAGE
BPF_MAP_TYPE_QUEUE = C.BPF_MAP_TYPE_QUEUE
BPF_MAP_TYPE_STACK = C.BPF_MAP_TYPE_STACK
BPF_MAP_TYPE_SK_STORAGE = C.BPF_MAP_TYPE_SK_STORAGE
BPF_MAP_TYPE_DEVMAP_HASH = C.BPF_MAP_TYPE_DEVMAP_HASH
BPF_MAP_TYPE_STRUCT_OPS = C.BPF_MAP_TYPE_STRUCT_OPS
BPF_MAP_TYPE_RINGBUF = C.BPF_MAP_TYPE_RINGBUF
BPF_MAP_TYPE_INODE_STORAGE = C.BPF_MAP_TYPE_INODE_STORAGE
BPF_MAP_TYPE_TASK_STORAGE = C.BPF_MAP_TYPE_TASK_STORAGE
BPF_MAP_TYPE_BLOOM_FILTER = C.BPF_MAP_TYPE_BLOOM_FILTER
BPF_MAP_TYPE_USER_RINGBUF = C.BPF_MAP_TYPE_USER_RINGBUF
BPF_MAP_TYPE_CGRP_STORAGE = C.BPF_MAP_TYPE_CGRP_STORAGE
BPF_PROG_TYPE_UNSPEC = C.BPF_PROG_TYPE_UNSPEC
BPF_PROG_TYPE_SOCKET_FILTER = C.BPF_PROG_TYPE_SOCKET_FILTER
BPF_PROG_TYPE_KPROBE = C.BPF_PROG_TYPE_KPROBE
BPF_PROG_TYPE_SCHED_CLS = C.BPF_PROG_TYPE_SCHED_CLS
BPF_PROG_TYPE_SCHED_ACT = C.BPF_PROG_TYPE_SCHED_ACT
BPF_PROG_TYPE_TRACEPOINT = C.BPF_PROG_TYPE_TRACEPOINT
BPF_PROG_TYPE_XDP = C.BPF_PROG_TYPE_XDP
BPF_PROG_TYPE_PERF_EVENT = C.BPF_PROG_TYPE_PERF_EVENT
BPF_PROG_TYPE_CGROUP_SKB = C.BPF_PROG_TYPE_CGROUP_SKB
BPF_PROG_TYPE_CGROUP_SOCK = C.BPF_PROG_TYPE_CGROUP_SOCK
BPF_PROG_TYPE_LWT_IN = C.BPF_PROG_TYPE_LWT_IN
BPF_PROG_TYPE_LWT_OUT = C.BPF_PROG_TYPE_LWT_OUT
BPF_PROG_TYPE_LWT_XMIT = C.BPF_PROG_TYPE_LWT_XMIT
BPF_PROG_TYPE_SOCK_OPS = C.BPF_PROG_TYPE_SOCK_OPS
BPF_PROG_TYPE_SK_SKB = C.BPF_PROG_TYPE_SK_SKB
BPF_PROG_TYPE_CGROUP_DEVICE = C.BPF_PROG_TYPE_CGROUP_DEVICE
BPF_PROG_TYPE_SK_MSG = C.BPF_PROG_TYPE_SK_MSG
BPF_PROG_TYPE_RAW_TRACEPOINT = C.BPF_PROG_TYPE_RAW_TRACEPOINT
BPF_PROG_TYPE_CGROUP_SOCK_ADDR = C.BPF_PROG_TYPE_CGROUP_SOCK_ADDR
BPF_PROG_TYPE_LWT_SEG6LOCAL = C.BPF_PROG_TYPE_LWT_SEG6LOCAL
BPF_PROG_TYPE_LIRC_MODE2 = C.BPF_PROG_TYPE_LIRC_MODE2
BPF_PROG_TYPE_SK_REUSEPORT = C.BPF_PROG_TYPE_SK_REUSEPORT
BPF_PROG_TYPE_FLOW_DISSECTOR = C.BPF_PROG_TYPE_FLOW_DISSECTOR
BPF_PROG_TYPE_CGROUP_SYSCTL = C.BPF_PROG_TYPE_CGROUP_SYSCTL
BPF_PROG_TYPE_RAW_TRACEPOINT_WRITABLE = C.BPF_PROG_TYPE_RAW_TRACEPOINT_WRITABLE
BPF_PROG_TYPE_CGROUP_SOCKOPT = C.BPF_PROG_TYPE_CGROUP_SOCKOPT
BPF_PROG_TYPE_TRACING = C.BPF_PROG_TYPE_TRACING
BPF_PROG_TYPE_STRUCT_OPS = C.BPF_PROG_TYPE_STRUCT_OPS
BPF_PROG_TYPE_EXT = C.BPF_PROG_TYPE_EXT
BPF_PROG_TYPE_LSM = C.BPF_PROG_TYPE_LSM
BPF_PROG_TYPE_SK_LOOKUP = C.BPF_PROG_TYPE_SK_LOOKUP
BPF_PROG_TYPE_SYSCALL = C.BPF_PROG_TYPE_SYSCALL
BPF_PROG_TYPE_NETFILTER = C.BPF_PROG_TYPE_NETFILTER
BPF_CGROUP_INET_INGRESS = C.BPF_CGROUP_INET_INGRESS
BPF_CGROUP_INET_EGRESS = C.BPF_CGROUP_INET_EGRESS
BPF_CGROUP_INET_SOCK_CREATE = C.BPF_CGROUP_INET_SOCK_CREATE
BPF_CGROUP_SOCK_OPS = C.BPF_CGROUP_SOCK_OPS
BPF_SK_SKB_STREAM_PARSER = C.BPF_SK_SKB_STREAM_PARSER
BPF_SK_SKB_STREAM_VERDICT = C.BPF_SK_SKB_STREAM_VERDICT
BPF_CGROUP_DEVICE = C.BPF_CGROUP_DEVICE
BPF_SK_MSG_VERDICT = C.BPF_SK_MSG_VERDICT
BPF_CGROUP_INET4_BIND = C.BPF_CGROUP_INET4_BIND
BPF_CGROUP_INET6_BIND = C.BPF_CGROUP_INET6_BIND
BPF_CGROUP_INET4_CONNECT = C.BPF_CGROUP_INET4_CONNECT
BPF_CGROUP_INET6_CONNECT = C.BPF_CGROUP_INET6_CONNECT
BPF_CGROUP_INET4_POST_BIND = C.BPF_CGROUP_INET4_POST_BIND
BPF_CGROUP_INET6_POST_BIND = C.BPF_CGROUP_INET6_POST_BIND
BPF_CGROUP_UDP4_SENDMSG = C.BPF_CGROUP_UDP4_SENDMSG
BPF_CGROUP_UDP6_SENDMSG = C.BPF_CGROUP_UDP6_SENDMSG
BPF_LIRC_MODE2 = C.BPF_LIRC_MODE2
BPF_FLOW_DISSECTOR = C.BPF_FLOW_DISSECTOR
BPF_CGROUP_SYSCTL = C.BPF_CGROUP_SYSCTL
BPF_CGROUP_UDP4_RECVMSG = C.BPF_CGROUP_UDP4_RECVMSG
BPF_CGROUP_UDP6_RECVMSG = C.BPF_CGROUP_UDP6_RECVMSG
BPF_CGROUP_GETSOCKOPT = C.BPF_CGROUP_GETSOCKOPT
BPF_CGROUP_SETSOCKOPT = C.BPF_CGROUP_SETSOCKOPT
BPF_TRACE_RAW_TP = C.BPF_TRACE_RAW_TP
BPF_TRACE_FENTRY = C.BPF_TRACE_FENTRY
BPF_TRACE_FEXIT = C.BPF_TRACE_FEXIT
BPF_MODIFY_RETURN = C.BPF_MODIFY_RETURN
BPF_LSM_MAC = C.BPF_LSM_MAC
BPF_TRACE_ITER = C.BPF_TRACE_ITER
BPF_CGROUP_INET4_GETPEERNAME = C.BPF_CGROUP_INET4_GETPEERNAME
BPF_CGROUP_INET6_GETPEERNAME = C.BPF_CGROUP_INET6_GETPEERNAME
BPF_CGROUP_INET4_GETSOCKNAME = C.BPF_CGROUP_INET4_GETSOCKNAME
BPF_CGROUP_INET6_GETSOCKNAME = C.BPF_CGROUP_INET6_GETSOCKNAME
BPF_XDP_DEVMAP = C.BPF_XDP_DEVMAP
BPF_CGROUP_INET_SOCK_RELEASE = C.BPF_CGROUP_INET_SOCK_RELEASE
BPF_XDP_CPUMAP = C.BPF_XDP_CPUMAP
BPF_SK_LOOKUP = C.BPF_SK_LOOKUP
BPF_XDP = C.BPF_XDP
BPF_SK_SKB_VERDICT = C.BPF_SK_SKB_VERDICT
BPF_SK_REUSEPORT_SELECT = C.BPF_SK_REUSEPORT_SELECT
BPF_SK_REUSEPORT_SELECT_OR_MIGRATE = C.BPF_SK_REUSEPORT_SELECT_OR_MIGRATE
BPF_PERF_EVENT = C.BPF_PERF_EVENT
BPF_TRACE_KPROBE_MULTI = C.BPF_TRACE_KPROBE_MULTI
BPF_LSM_CGROUP = C.BPF_LSM_CGROUP
BPF_STRUCT_OPS = C.BPF_STRUCT_OPS
BPF_NETFILTER = C.BPF_NETFILTER
BPF_TCX_INGRESS = C.BPF_TCX_INGRESS
BPF_TCX_EGRESS = C.BPF_TCX_EGRESS
BPF_TRACE_UPROBE_MULTI = C.BPF_TRACE_UPROBE_MULTI
BPF_LINK_TYPE_UNSPEC = C.BPF_LINK_TYPE_UNSPEC
BPF_LINK_TYPE_RAW_TRACEPOINT = C.BPF_LINK_TYPE_RAW_TRACEPOINT
BPF_LINK_TYPE_TRACING = C.BPF_LINK_TYPE_TRACING
BPF_LINK_TYPE_CGROUP = C.BPF_LINK_TYPE_CGROUP
BPF_LINK_TYPE_ITER = C.BPF_LINK_TYPE_ITER
BPF_LINK_TYPE_NETNS = C.BPF_LINK_TYPE_NETNS
BPF_LINK_TYPE_XDP = C.BPF_LINK_TYPE_XDP
BPF_LINK_TYPE_PERF_EVENT = C.BPF_LINK_TYPE_PERF_EVENT
BPF_LINK_TYPE_KPROBE_MULTI = C.BPF_LINK_TYPE_KPROBE_MULTI
BPF_LINK_TYPE_STRUCT_OPS = C.BPF_LINK_TYPE_STRUCT_OPS
BPF_LINK_TYPE_NETFILTER = C.BPF_LINK_TYPE_NETFILTER
BPF_LINK_TYPE_TCX = C.BPF_LINK_TYPE_TCX
BPF_LINK_TYPE_UPROBE_MULTI = C.BPF_LINK_TYPE_UPROBE_MULTI
BPF_PERF_EVENT_UNSPEC = C.BPF_PERF_EVENT_UNSPEC
BPF_PERF_EVENT_UPROBE = C.BPF_PERF_EVENT_UPROBE
BPF_PERF_EVENT_URETPROBE = C.BPF_PERF_EVENT_URETPROBE
BPF_PERF_EVENT_KPROBE = C.BPF_PERF_EVENT_KPROBE
BPF_PERF_EVENT_KRETPROBE = C.BPF_PERF_EVENT_KRETPROBE
BPF_PERF_EVENT_TRACEPOINT = C.BPF_PERF_EVENT_TRACEPOINT
BPF_PERF_EVENT_EVENT = C.BPF_PERF_EVENT_EVENT
BPF_F_KPROBE_MULTI_RETURN = C.BPF_F_KPROBE_MULTI_RETURN
BPF_F_UPROBE_MULTI_RETURN = C.BPF_F_UPROBE_MULTI_RETURN
BPF_ANY = C.BPF_ANY
BPF_NOEXIST = C.BPF_NOEXIST
BPF_EXIST = C.BPF_EXIST
BPF_F_LOCK = C.BPF_F_LOCK
BPF_F_NO_PREALLOC = C.BPF_F_NO_PREALLOC
BPF_F_NO_COMMON_LRU = C.BPF_F_NO_COMMON_LRU
BPF_F_NUMA_NODE = C.BPF_F_NUMA_NODE
BPF_F_RDONLY = C.BPF_F_RDONLY
BPF_F_WRONLY = C.BPF_F_WRONLY
BPF_F_STACK_BUILD_ID = C.BPF_F_STACK_BUILD_ID
BPF_F_ZERO_SEED = C.BPF_F_ZERO_SEED
BPF_F_RDONLY_PROG = C.BPF_F_RDONLY_PROG
BPF_F_WRONLY_PROG = C.BPF_F_WRONLY_PROG
BPF_F_CLONE = C.BPF_F_CLONE
BPF_F_MMAPABLE = C.BPF_F_MMAPABLE
BPF_F_PRESERVE_ELEMS = C.BPF_F_PRESERVE_ELEMS
BPF_F_INNER_MAP = C.BPF_F_INNER_MAP
BPF_F_LINK = C.BPF_F_LINK
BPF_F_PATH_FD = C.BPF_F_PATH_FD
BPF_STATS_RUN_TIME = C.BPF_STATS_RUN_TIME
BPF_STACK_BUILD_ID_EMPTY = C.BPF_STACK_BUILD_ID_EMPTY
BPF_STACK_BUILD_ID_VALID = C.BPF_STACK_BUILD_ID_VALID
BPF_STACK_BUILD_ID_IP = C.BPF_STACK_BUILD_ID_IP
BPF_F_RECOMPUTE_CSUM = C.BPF_F_RECOMPUTE_CSUM
BPF_F_INVALIDATE_HASH = C.BPF_F_INVALIDATE_HASH
BPF_F_HDR_FIELD_MASK = C.BPF_F_HDR_FIELD_MASK
BPF_F_PSEUDO_HDR = C.BPF_F_PSEUDO_HDR
BPF_F_MARK_MANGLED_0 = C.BPF_F_MARK_MANGLED_0
BPF_F_MARK_ENFORCE = C.BPF_F_MARK_ENFORCE
BPF_F_INGRESS = C.BPF_F_INGRESS
BPF_F_TUNINFO_IPV6 = C.BPF_F_TUNINFO_IPV6
BPF_F_SKIP_FIELD_MASK = C.BPF_F_SKIP_FIELD_MASK
BPF_F_USER_STACK = C.BPF_F_USER_STACK
BPF_F_FAST_STACK_CMP = C.BPF_F_FAST_STACK_CMP
BPF_F_REUSE_STACKID = C.BPF_F_REUSE_STACKID
BPF_F_USER_BUILD_ID = C.BPF_F_USER_BUILD_ID
BPF_F_ZERO_CSUM_TX = C.BPF_F_ZERO_CSUM_TX
BPF_F_DONT_FRAGMENT = C.BPF_F_DONT_FRAGMENT
BPF_F_SEQ_NUMBER = C.BPF_F_SEQ_NUMBER
BPF_F_NO_TUNNEL_KEY = C.BPF_F_NO_TUNNEL_KEY
BPF_F_TUNINFO_FLAGS = C.BPF_F_TUNINFO_FLAGS
BPF_F_INDEX_MASK = C.BPF_F_INDEX_MASK
BPF_F_CURRENT_CPU = C.BPF_F_CURRENT_CPU
BPF_F_CTXLEN_MASK = C.BPF_F_CTXLEN_MASK
BPF_F_CURRENT_NETNS = C.BPF_F_CURRENT_NETNS
BPF_CSUM_LEVEL_QUERY = C.BPF_CSUM_LEVEL_QUERY
BPF_CSUM_LEVEL_INC = C.BPF_CSUM_LEVEL_INC
BPF_CSUM_LEVEL_DEC = C.BPF_CSUM_LEVEL_DEC
BPF_CSUM_LEVEL_RESET = C.BPF_CSUM_LEVEL_RESET
BPF_F_ADJ_ROOM_FIXED_GSO = C.BPF_F_ADJ_ROOM_FIXED_GSO
BPF_F_ADJ_ROOM_ENCAP_L3_IPV4 = C.BPF_F_ADJ_ROOM_ENCAP_L3_IPV4
BPF_F_ADJ_ROOM_ENCAP_L3_IPV6 = C.BPF_F_ADJ_ROOM_ENCAP_L3_IPV6
BPF_F_ADJ_ROOM_ENCAP_L4_GRE = C.BPF_F_ADJ_ROOM_ENCAP_L4_GRE
BPF_F_ADJ_ROOM_ENCAP_L4_UDP = C.BPF_F_ADJ_ROOM_ENCAP_L4_UDP
BPF_F_ADJ_ROOM_NO_CSUM_RESET = C.BPF_F_ADJ_ROOM_NO_CSUM_RESET
BPF_F_ADJ_ROOM_ENCAP_L2_ETH = C.BPF_F_ADJ_ROOM_ENCAP_L2_ETH
BPF_F_ADJ_ROOM_DECAP_L3_IPV4 = C.BPF_F_ADJ_ROOM_DECAP_L3_IPV4
BPF_F_ADJ_ROOM_DECAP_L3_IPV6 = C.BPF_F_ADJ_ROOM_DECAP_L3_IPV6
BPF_ADJ_ROOM_ENCAP_L2_MASK = C.BPF_ADJ_ROOM_ENCAP_L2_MASK
BPF_ADJ_ROOM_ENCAP_L2_SHIFT = C.BPF_ADJ_ROOM_ENCAP_L2_SHIFT
BPF_F_SYSCTL_BASE_NAME = C.BPF_F_SYSCTL_BASE_NAME
BPF_LOCAL_STORAGE_GET_F_CREATE = C.BPF_LOCAL_STORAGE_GET_F_CREATE
BPF_SK_STORAGE_GET_F_CREATE = C.BPF_SK_STORAGE_GET_F_CREATE
BPF_F_GET_BRANCH_RECORDS_SIZE = C.BPF_F_GET_BRANCH_RECORDS_SIZE
BPF_RB_NO_WAKEUP = C.BPF_RB_NO_WAKEUP
BPF_RB_FORCE_WAKEUP = C.BPF_RB_FORCE_WAKEUP
BPF_RB_AVAIL_DATA = C.BPF_RB_AVAIL_DATA
BPF_RB_RING_SIZE = C.BPF_RB_RING_SIZE
BPF_RB_CONS_POS = C.BPF_RB_CONS_POS
BPF_RB_PROD_POS = C.BPF_RB_PROD_POS
BPF_RINGBUF_BUSY_BIT = C.BPF_RINGBUF_BUSY_BIT
BPF_RINGBUF_DISCARD_BIT = C.BPF_RINGBUF_DISCARD_BIT
BPF_RINGBUF_HDR_SZ = C.BPF_RINGBUF_HDR_SZ
BPF_SK_LOOKUP_F_REPLACE = C.BPF_SK_LOOKUP_F_REPLACE
BPF_SK_LOOKUP_F_NO_REUSEPORT = C.BPF_SK_LOOKUP_F_NO_REUSEPORT
BPF_ADJ_ROOM_NET = C.BPF_ADJ_ROOM_NET
BPF_ADJ_ROOM_MAC = C.BPF_ADJ_ROOM_MAC
BPF_HDR_START_MAC = C.BPF_HDR_START_MAC
BPF_HDR_START_NET = C.BPF_HDR_START_NET
BPF_LWT_ENCAP_SEG6 = C.BPF_LWT_ENCAP_SEG6
BPF_LWT_ENCAP_SEG6_INLINE = C.BPF_LWT_ENCAP_SEG6_INLINE
BPF_LWT_ENCAP_IP = C.BPF_LWT_ENCAP_IP
BPF_F_BPRM_SECUREEXEC = C.BPF_F_BPRM_SECUREEXEC
BPF_F_BROADCAST = C.BPF_F_BROADCAST
BPF_F_EXCLUDE_INGRESS = C.BPF_F_EXCLUDE_INGRESS
BPF_SKB_TSTAMP_UNSPEC = C.BPF_SKB_TSTAMP_UNSPEC
BPF_SKB_TSTAMP_DELIVERY_MONO = C.BPF_SKB_TSTAMP_DELIVERY_MONO
BPF_OK = C.BPF_OK
BPF_DROP = C.BPF_DROP
BPF_REDIRECT = C.BPF_REDIRECT
BPF_LWT_REROUTE = C.BPF_LWT_REROUTE
BPF_FLOW_DISSECTOR_CONTINUE = C.BPF_FLOW_DISSECTOR_CONTINUE
BPF_SOCK_OPS_RTO_CB_FLAG = C.BPF_SOCK_OPS_RTO_CB_FLAG
BPF_SOCK_OPS_RETRANS_CB_FLAG = C.BPF_SOCK_OPS_RETRANS_CB_FLAG
BPF_SOCK_OPS_STATE_CB_FLAG = C.BPF_SOCK_OPS_STATE_CB_FLAG
BPF_SOCK_OPS_RTT_CB_FLAG = C.BPF_SOCK_OPS_RTT_CB_FLAG
BPF_SOCK_OPS_PARSE_ALL_HDR_OPT_CB_FLAG = C.BPF_SOCK_OPS_PARSE_ALL_HDR_OPT_CB_FLAG
BPF_SOCK_OPS_PARSE_UNKNOWN_HDR_OPT_CB_FLAG = C.BPF_SOCK_OPS_PARSE_UNKNOWN_HDR_OPT_CB_FLAG
BPF_SOCK_OPS_WRITE_HDR_OPT_CB_FLAG = C.BPF_SOCK_OPS_WRITE_HDR_OPT_CB_FLAG
BPF_SOCK_OPS_ALL_CB_FLAGS = C.BPF_SOCK_OPS_ALL_CB_FLAGS
BPF_SOCK_OPS_VOID = C.BPF_SOCK_OPS_VOID
BPF_SOCK_OPS_TIMEOUT_INIT = C.BPF_SOCK_OPS_TIMEOUT_INIT
BPF_SOCK_OPS_RWND_INIT = C.BPF_SOCK_OPS_RWND_INIT
BPF_SOCK_OPS_TCP_CONNECT_CB = C.BPF_SOCK_OPS_TCP_CONNECT_CB
BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB = C.BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB
BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB = C.BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB
BPF_SOCK_OPS_NEEDS_ECN = C.BPF_SOCK_OPS_NEEDS_ECN
BPF_SOCK_OPS_BASE_RTT = C.BPF_SOCK_OPS_BASE_RTT
BPF_SOCK_OPS_RTO_CB = C.BPF_SOCK_OPS_RTO_CB
BPF_SOCK_OPS_RETRANS_CB = C.BPF_SOCK_OPS_RETRANS_CB
BPF_SOCK_OPS_STATE_CB = C.BPF_SOCK_OPS_STATE_CB
BPF_SOCK_OPS_TCP_LISTEN_CB = C.BPF_SOCK_OPS_TCP_LISTEN_CB
BPF_SOCK_OPS_RTT_CB = C.BPF_SOCK_OPS_RTT_CB
BPF_SOCK_OPS_PARSE_HDR_OPT_CB = C.BPF_SOCK_OPS_PARSE_HDR_OPT_CB
BPF_SOCK_OPS_HDR_OPT_LEN_CB = C.BPF_SOCK_OPS_HDR_OPT_LEN_CB
BPF_SOCK_OPS_WRITE_HDR_OPT_CB = C.BPF_SOCK_OPS_WRITE_HDR_OPT_CB
BPF_TCP_ESTABLISHED = C.BPF_TCP_ESTABLISHED
BPF_TCP_SYN_SENT = C.BPF_TCP_SYN_SENT
BPF_TCP_SYN_RECV = C.BPF_TCP_SYN_RECV
BPF_TCP_FIN_WAIT1 = C.BPF_TCP_FIN_WAIT1
BPF_TCP_FIN_WAIT2 = C.BPF_TCP_FIN_WAIT2
BPF_TCP_TIME_WAIT = C.BPF_TCP_TIME_WAIT
BPF_TCP_CLOSE = C.BPF_TCP_CLOSE
BPF_TCP_CLOSE_WAIT = C.BPF_TCP_CLOSE_WAIT
BPF_TCP_LAST_ACK = C.BPF_TCP_LAST_ACK
BPF_TCP_LISTEN = C.BPF_TCP_LISTEN
BPF_TCP_CLOSING = C.BPF_TCP_CLOSING
BPF_TCP_NEW_SYN_RECV = C.BPF_TCP_NEW_SYN_RECV
BPF_TCP_MAX_STATES = C.BPF_TCP_MAX_STATES
TCP_BPF_IW = C.TCP_BPF_IW
TCP_BPF_SNDCWND_CLAMP = C.TCP_BPF_SNDCWND_CLAMP
TCP_BPF_DELACK_MAX = C.TCP_BPF_DELACK_MAX
TCP_BPF_RTO_MIN = C.TCP_BPF_RTO_MIN
TCP_BPF_SYN = C.TCP_BPF_SYN
TCP_BPF_SYN_IP = C.TCP_BPF_SYN_IP
TCP_BPF_SYN_MAC = C.TCP_BPF_SYN_MAC
BPF_LOAD_HDR_OPT_TCP_SYN = C.BPF_LOAD_HDR_OPT_TCP_SYN
BPF_WRITE_HDR_TCP_CURRENT_MSS = C.BPF_WRITE_HDR_TCP_CURRENT_MSS
BPF_WRITE_HDR_TCP_SYNACK_COOKIE = C.BPF_WRITE_HDR_TCP_SYNACK_COOKIE
BPF_DEVCG_ACC_MKNOD = C.BPF_DEVCG_ACC_MKNOD
BPF_DEVCG_ACC_READ = C.BPF_DEVCG_ACC_READ
BPF_DEVCG_ACC_WRITE = C.BPF_DEVCG_ACC_WRITE
BPF_DEVCG_DEV_BLOCK = C.BPF_DEVCG_DEV_BLOCK
BPF_DEVCG_DEV_CHAR = C.BPF_DEVCG_DEV_CHAR
BPF_FIB_LOOKUP_DIRECT = C.BPF_FIB_LOOKUP_DIRECT
BPF_FIB_LOOKUP_OUTPUT = C.BPF_FIB_LOOKUP_OUTPUT
BPF_FIB_LOOKUP_SKIP_NEIGH = C.BPF_FIB_LOOKUP_SKIP_NEIGH
BPF_FIB_LOOKUP_TBID = C.BPF_FIB_LOOKUP_TBID
BPF_FIB_LKUP_RET_SUCCESS = C.BPF_FIB_LKUP_RET_SUCCESS
BPF_FIB_LKUP_RET_BLACKHOLE = C.BPF_FIB_LKUP_RET_BLACKHOLE
BPF_FIB_LKUP_RET_UNREACHABLE = C.BPF_FIB_LKUP_RET_UNREACHABLE
BPF_FIB_LKUP_RET_PROHIBIT = C.BPF_FIB_LKUP_RET_PROHIBIT
BPF_FIB_LKUP_RET_NOT_FWDED = C.BPF_FIB_LKUP_RET_NOT_FWDED
BPF_FIB_LKUP_RET_FWD_DISABLED = C.BPF_FIB_LKUP_RET_FWD_DISABLED
BPF_FIB_LKUP_RET_UNSUPP_LWT = C.BPF_FIB_LKUP_RET_UNSUPP_LWT
BPF_FIB_LKUP_RET_NO_NEIGH = C.BPF_FIB_LKUP_RET_NO_NEIGH
BPF_FIB_LKUP_RET_FRAG_NEEDED = C.BPF_FIB_LKUP_RET_FRAG_NEEDED
BPF_MTU_CHK_SEGS = C.BPF_MTU_CHK_SEGS
BPF_MTU_CHK_RET_SUCCESS = C.BPF_MTU_CHK_RET_SUCCESS
BPF_MTU_CHK_RET_FRAG_NEEDED = C.BPF_MTU_CHK_RET_FRAG_NEEDED
BPF_MTU_CHK_RET_SEGS_TOOBIG = C.BPF_MTU_CHK_RET_SEGS_TOOBIG
BPF_FD_TYPE_RAW_TRACEPOINT = C.BPF_FD_TYPE_RAW_TRACEPOINT
BPF_FD_TYPE_TRACEPOINT = C.BPF_FD_TYPE_TRACEPOINT
BPF_FD_TYPE_KPROBE = C.BPF_FD_TYPE_KPROBE
BPF_FD_TYPE_KRETPROBE = C.BPF_FD_TYPE_KRETPROBE
BPF_FD_TYPE_UPROBE = C.BPF_FD_TYPE_UPROBE
BPF_FD_TYPE_URETPROBE = C.BPF_FD_TYPE_URETPROBE
BPF_FLOW_DISSECTOR_F_PARSE_1ST_FRAG = C.BPF_FLOW_DISSECTOR_F_PARSE_1ST_FRAG
BPF_FLOW_DISSECTOR_F_STOP_AT_FLOW_LABEL = C.BPF_FLOW_DISSECTOR_F_STOP_AT_FLOW_LABEL
BPF_FLOW_DISSECTOR_F_STOP_AT_ENCAP = C.BPF_FLOW_DISSECTOR_F_STOP_AT_ENCAP
BPF_CORE_FIELD_BYTE_OFFSET = C.BPF_CORE_FIELD_BYTE_OFFSET
BPF_CORE_FIELD_BYTE_SIZE = C.BPF_CORE_FIELD_BYTE_SIZE
BPF_CORE_FIELD_EXISTS = C.BPF_CORE_FIELD_EXISTS
BPF_CORE_FIELD_SIGNED = C.BPF_CORE_FIELD_SIGNED
BPF_CORE_FIELD_LSHIFT_U64 = C.BPF_CORE_FIELD_LSHIFT_U64
BPF_CORE_FIELD_RSHIFT_U64 = C.BPF_CORE_FIELD_RSHIFT_U64
BPF_CORE_TYPE_ID_LOCAL = C.BPF_CORE_TYPE_ID_LOCAL
BPF_CORE_TYPE_ID_TARGET = C.BPF_CORE_TYPE_ID_TARGET
BPF_CORE_TYPE_EXISTS = C.BPF_CORE_TYPE_EXISTS
BPF_CORE_TYPE_SIZE = C.BPF_CORE_TYPE_SIZE
BPF_CORE_ENUMVAL_EXISTS = C.BPF_CORE_ENUMVAL_EXISTS
BPF_CORE_ENUMVAL_VALUE = C.BPF_CORE_ENUMVAL_VALUE
BPF_CORE_TYPE_MATCHES = C.BPF_CORE_TYPE_MATCHES
BPF_F_TIMER_ABS = C.BPF_F_TIMER_ABS
)
// generated by:
// perl -nlE '/^\s*(RTNLGRP_\w+)/ && say "$1 = C.$1"' include/uapi/linux/rtnetlink.h
const (
RTNLGRP_NONE = C.RTNLGRP_NONE
RTNLGRP_LINK = C.RTNLGRP_LINK
RTNLGRP_NOTIFY = C.RTNLGRP_NOTIFY
RTNLGRP_NEIGH = C.RTNLGRP_NEIGH
RTNLGRP_TC = C.RTNLGRP_TC
RTNLGRP_IPV4_IFADDR = C.RTNLGRP_IPV4_IFADDR
RTNLGRP_IPV4_MROUTE = C.RTNLGRP_IPV4_MROUTE
RTNLGRP_IPV4_ROUTE = C.RTNLGRP_IPV4_ROUTE
RTNLGRP_IPV4_RULE = C.RTNLGRP_IPV4_RULE
RTNLGRP_IPV6_IFADDR = C.RTNLGRP_IPV6_IFADDR
RTNLGRP_IPV6_MROUTE = C.RTNLGRP_IPV6_MROUTE
RTNLGRP_IPV6_ROUTE = C.RTNLGRP_IPV6_ROUTE
RTNLGRP_IPV6_IFINFO = C.RTNLGRP_IPV6_IFINFO
RTNLGRP_DECnet_IFADDR = C.RTNLGRP_DECnet_IFADDR
RTNLGRP_NOP2 = C.RTNLGRP_NOP2
RTNLGRP_DECnet_ROUTE = C.RTNLGRP_DECnet_ROUTE
RTNLGRP_DECnet_RULE = C.RTNLGRP_DECnet_RULE
RTNLGRP_NOP4 = C.RTNLGRP_NOP4
RTNLGRP_IPV6_PREFIX = C.RTNLGRP_IPV6_PREFIX
RTNLGRP_IPV6_RULE = C.RTNLGRP_IPV6_RULE
RTNLGRP_ND_USEROPT = C.RTNLGRP_ND_USEROPT
RTNLGRP_PHONET_IFADDR = C.RTNLGRP_PHONET_IFADDR
RTNLGRP_PHONET_ROUTE = C.RTNLGRP_PHONET_ROUTE
RTNLGRP_DCB = C.RTNLGRP_DCB
RTNLGRP_IPV4_NETCONF = C.RTNLGRP_IPV4_NETCONF
RTNLGRP_IPV6_NETCONF = C.RTNLGRP_IPV6_NETCONF
RTNLGRP_MDB = C.RTNLGRP_MDB
RTNLGRP_MPLS_ROUTE = C.RTNLGRP_MPLS_ROUTE
RTNLGRP_NSID = C.RTNLGRP_NSID
RTNLGRP_MPLS_NETCONF = C.RTNLGRP_MPLS_NETCONF
RTNLGRP_IPV4_MROUTE_R = C.RTNLGRP_IPV4_MROUTE_R
RTNLGRP_IPV6_MROUTE_R = C.RTNLGRP_IPV6_MROUTE_R
RTNLGRP_NEXTHOP = C.RTNLGRP_NEXTHOP
RTNLGRP_BRVLAN = C.RTNLGRP_BRVLAN
)
// Capabilities
type CapUserHeader C.struct___user_cap_header_struct
type CapUserData C.struct___user_cap_data_struct
const (
LINUX_CAPABILITY_VERSION_1 = C._LINUX_CAPABILITY_VERSION_1
LINUX_CAPABILITY_VERSION_2 = C._LINUX_CAPABILITY_VERSION_2
LINUX_CAPABILITY_VERSION_3 = C._LINUX_CAPABILITY_VERSION_3
)
// Loop devices
const (
LO_FLAGS_READ_ONLY = C.LO_FLAGS_READ_ONLY
LO_FLAGS_AUTOCLEAR = C.LO_FLAGS_AUTOCLEAR
LO_FLAGS_PARTSCAN = C.LO_FLAGS_PARTSCAN
LO_FLAGS_DIRECT_IO = C.LO_FLAGS_DIRECT_IO
)
type LoopInfo C.struct_loop_info
type LoopInfo64 C.struct_loop_info64
type LoopConfig C.struct_loop_config
// AF_TIPC
type TIPCSocketAddr C.struct_tipc_socket_addr
type TIPCServiceRange C.struct_tipc_service_range
type TIPCServiceName C.struct_tipc_service_name
type TIPCSubscr C.struct_tipc_subscr
type TIPCEvent C.struct_tipc_event
type TIPCGroupReq C.struct_tipc_group_req
type TIPCSIOCLNReq C.struct_tipc_sioc_ln_req
type TIPCSIOCNodeIDReq C.struct_tipc_sioc_nodeid_req
const (
TIPC_CLUSTER_SCOPE = C.TIPC_CLUSTER_SCOPE
TIPC_NODE_SCOPE = C.TIPC_NODE_SCOPE
)
const (
SYSLOG_ACTION_CLOSE = 0
SYSLOG_ACTION_OPEN = 1
SYSLOG_ACTION_READ = 2
SYSLOG_ACTION_READ_ALL = 3
SYSLOG_ACTION_READ_CLEAR = 4
SYSLOG_ACTION_CLEAR = 5
SYSLOG_ACTION_CONSOLE_OFF = 6
SYSLOG_ACTION_CONSOLE_ON = 7
SYSLOG_ACTION_CONSOLE_LEVEL = 8
SYSLOG_ACTION_SIZE_UNREAD = 9
SYSLOG_ACTION_SIZE_BUFFER = 10
)
// Devlink generic netlink API, generated using:
// perl -nlE '/^\s*(DEVLINK_\w+)/ && say "$1 = C.$1"' devlink.h
const (
DEVLINK_CMD_UNSPEC = C.DEVLINK_CMD_UNSPEC
DEVLINK_CMD_GET = C.DEVLINK_CMD_GET
DEVLINK_CMD_SET = C.DEVLINK_CMD_SET
DEVLINK_CMD_NEW = C.DEVLINK_CMD_NEW
DEVLINK_CMD_DEL = C.DEVLINK_CMD_DEL
DEVLINK_CMD_PORT_GET = C.DEVLINK_CMD_PORT_GET
DEVLINK_CMD_PORT_SET = C.DEVLINK_CMD_PORT_SET
DEVLINK_CMD_PORT_NEW = C.DEVLINK_CMD_PORT_NEW
DEVLINK_CMD_PORT_DEL = C.DEVLINK_CMD_PORT_DEL
DEVLINK_CMD_PORT_SPLIT = C.DEVLINK_CMD_PORT_SPLIT
DEVLINK_CMD_PORT_UNSPLIT = C.DEVLINK_CMD_PORT_UNSPLIT
DEVLINK_CMD_SB_GET = C.DEVLINK_CMD_SB_GET
DEVLINK_CMD_SB_SET = C.DEVLINK_CMD_SB_SET
DEVLINK_CMD_SB_NEW = C.DEVLINK_CMD_SB_NEW
DEVLINK_CMD_SB_DEL = C.DEVLINK_CMD_SB_DEL
DEVLINK_CMD_SB_POOL_GET = C.DEVLINK_CMD_SB_POOL_GET
DEVLINK_CMD_SB_POOL_SET = C.DEVLINK_CMD_SB_POOL_SET
DEVLINK_CMD_SB_POOL_NEW = C.DEVLINK_CMD_SB_POOL_NEW
DEVLINK_CMD_SB_POOL_DEL = C.DEVLINK_CMD_SB_POOL_DEL
DEVLINK_CMD_SB_PORT_POOL_GET = C.DEVLINK_CMD_SB_PORT_POOL_GET
DEVLINK_CMD_SB_PORT_POOL_SET = C.DEVLINK_CMD_SB_PORT_POOL_SET
DEVLINK_CMD_SB_PORT_POOL_NEW = C.DEVLINK_CMD_SB_PORT_POOL_NEW
DEVLINK_CMD_SB_PORT_POOL_DEL = C.DEVLINK_CMD_SB_PORT_POOL_DEL
DEVLINK_CMD_SB_TC_POOL_BIND_GET = C.DEVLINK_CMD_SB_TC_POOL_BIND_GET
DEVLINK_CMD_SB_TC_POOL_BIND_SET = C.DEVLINK_CMD_SB_TC_POOL_BIND_SET
DEVLINK_CMD_SB_TC_POOL_BIND_NEW = C.DEVLINK_CMD_SB_TC_POOL_BIND_NEW
DEVLINK_CMD_SB_TC_POOL_BIND_DEL = C.DEVLINK_CMD_SB_TC_POOL_BIND_DEL
DEVLINK_CMD_SB_OCC_SNAPSHOT = C.DEVLINK_CMD_SB_OCC_SNAPSHOT
DEVLINK_CMD_SB_OCC_MAX_CLEAR = C.DEVLINK_CMD_SB_OCC_MAX_CLEAR
DEVLINK_CMD_ESWITCH_GET = C.DEVLINK_CMD_ESWITCH_GET
DEVLINK_CMD_ESWITCH_SET = C.DEVLINK_CMD_ESWITCH_SET
DEVLINK_CMD_DPIPE_TABLE_GET = C.DEVLINK_CMD_DPIPE_TABLE_GET
DEVLINK_CMD_DPIPE_ENTRIES_GET = C.DEVLINK_CMD_DPIPE_ENTRIES_GET
DEVLINK_CMD_DPIPE_HEADERS_GET = C.DEVLINK_CMD_DPIPE_HEADERS_GET
DEVLINK_CMD_DPIPE_TABLE_COUNTERS_SET = C.DEVLINK_CMD_DPIPE_TABLE_COUNTERS_SET
DEVLINK_CMD_RESOURCE_SET = C.DEVLINK_CMD_RESOURCE_SET
DEVLINK_CMD_RESOURCE_DUMP = C.DEVLINK_CMD_RESOURCE_DUMP
DEVLINK_CMD_RELOAD = C.DEVLINK_CMD_RELOAD
DEVLINK_CMD_PARAM_GET = C.DEVLINK_CMD_PARAM_GET
DEVLINK_CMD_PARAM_SET = C.DEVLINK_CMD_PARAM_SET
DEVLINK_CMD_PARAM_NEW = C.DEVLINK_CMD_PARAM_NEW
DEVLINK_CMD_PARAM_DEL = C.DEVLINK_CMD_PARAM_DEL
DEVLINK_CMD_REGION_GET = C.DEVLINK_CMD_REGION_GET
DEVLINK_CMD_REGION_SET = C.DEVLINK_CMD_REGION_SET
DEVLINK_CMD_REGION_NEW = C.DEVLINK_CMD_REGION_NEW
DEVLINK_CMD_REGION_DEL = C.DEVLINK_CMD_REGION_DEL
DEVLINK_CMD_REGION_READ = C.DEVLINK_CMD_REGION_READ
DEVLINK_CMD_PORT_PARAM_GET = C.DEVLINK_CMD_PORT_PARAM_GET
DEVLINK_CMD_PORT_PARAM_SET = C.DEVLINK_CMD_PORT_PARAM_SET
DEVLINK_CMD_PORT_PARAM_NEW = C.DEVLINK_CMD_PORT_PARAM_NEW
DEVLINK_CMD_PORT_PARAM_DEL = C.DEVLINK_CMD_PORT_PARAM_DEL
DEVLINK_CMD_INFO_GET = C.DEVLINK_CMD_INFO_GET
DEVLINK_CMD_HEALTH_REPORTER_GET = C.DEVLINK_CMD_HEALTH_REPORTER_GET
DEVLINK_CMD_HEALTH_REPORTER_SET = C.DEVLINK_CMD_HEALTH_REPORTER_SET
DEVLINK_CMD_HEALTH_REPORTER_RECOVER = C.DEVLINK_CMD_HEALTH_REPORTER_RECOVER
DEVLINK_CMD_HEALTH_REPORTER_DIAGNOSE = C.DEVLINK_CMD_HEALTH_REPORTER_DIAGNOSE
DEVLINK_CMD_HEALTH_REPORTER_DUMP_GET = C.DEVLINK_CMD_HEALTH_REPORTER_DUMP_GET
DEVLINK_CMD_HEALTH_REPORTER_DUMP_CLEAR = C.DEVLINK_CMD_HEALTH_REPORTER_DUMP_CLEAR
DEVLINK_CMD_FLASH_UPDATE = C.DEVLINK_CMD_FLASH_UPDATE
DEVLINK_CMD_FLASH_UPDATE_END = C.DEVLINK_CMD_FLASH_UPDATE_END
DEVLINK_CMD_FLASH_UPDATE_STATUS = C.DEVLINK_CMD_FLASH_UPDATE_STATUS
DEVLINK_CMD_TRAP_GET = C.DEVLINK_CMD_TRAP_GET
DEVLINK_CMD_TRAP_SET = C.DEVLINK_CMD_TRAP_SET
DEVLINK_CMD_TRAP_NEW = C.DEVLINK_CMD_TRAP_NEW
DEVLINK_CMD_TRAP_DEL = C.DEVLINK_CMD_TRAP_DEL
DEVLINK_CMD_TRAP_GROUP_GET = C.DEVLINK_CMD_TRAP_GROUP_GET
DEVLINK_CMD_TRAP_GROUP_SET = C.DEVLINK_CMD_TRAP_GROUP_SET
DEVLINK_CMD_TRAP_GROUP_NEW = C.DEVLINK_CMD_TRAP_GROUP_NEW
DEVLINK_CMD_TRAP_GROUP_DEL = C.DEVLINK_CMD_TRAP_GROUP_DEL
DEVLINK_CMD_TRAP_POLICER_GET = C.DEVLINK_CMD_TRAP_POLICER_GET
DEVLINK_CMD_TRAP_POLICER_SET = C.DEVLINK_CMD_TRAP_POLICER_SET
DEVLINK_CMD_TRAP_POLICER_NEW = C.DEVLINK_CMD_TRAP_POLICER_NEW
DEVLINK_CMD_TRAP_POLICER_DEL = C.DEVLINK_CMD_TRAP_POLICER_DEL
DEVLINK_CMD_HEALTH_REPORTER_TEST = C.DEVLINK_CMD_HEALTH_REPORTER_TEST
DEVLINK_CMD_RATE_GET = C.DEVLINK_CMD_RATE_GET
DEVLINK_CMD_RATE_SET = C.DEVLINK_CMD_RATE_SET
DEVLINK_CMD_RATE_NEW = C.DEVLINK_CMD_RATE_NEW
DEVLINK_CMD_RATE_DEL = C.DEVLINK_CMD_RATE_DEL
DEVLINK_CMD_LINECARD_GET = C.DEVLINK_CMD_LINECARD_GET
DEVLINK_CMD_LINECARD_SET = C.DEVLINK_CMD_LINECARD_SET
DEVLINK_CMD_LINECARD_NEW = C.DEVLINK_CMD_LINECARD_NEW
DEVLINK_CMD_LINECARD_DEL = C.DEVLINK_CMD_LINECARD_DEL
DEVLINK_CMD_SELFTESTS_GET = C.DEVLINK_CMD_SELFTESTS_GET
DEVLINK_CMD_MAX = C.DEVLINK_CMD_MAX
DEVLINK_PORT_TYPE_NOTSET = C.DEVLINK_PORT_TYPE_NOTSET
DEVLINK_PORT_TYPE_AUTO = C.DEVLINK_PORT_TYPE_AUTO
DEVLINK_PORT_TYPE_ETH = C.DEVLINK_PORT_TYPE_ETH
DEVLINK_PORT_TYPE_IB = C.DEVLINK_PORT_TYPE_IB
DEVLINK_SB_POOL_TYPE_INGRESS = C.DEVLINK_SB_POOL_TYPE_INGRESS
DEVLINK_SB_POOL_TYPE_EGRESS = C.DEVLINK_SB_POOL_TYPE_EGRESS
DEVLINK_SB_THRESHOLD_TYPE_STATIC = C.DEVLINK_SB_THRESHOLD_TYPE_STATIC
DEVLINK_SB_THRESHOLD_TYPE_DYNAMIC = C.DEVLINK_SB_THRESHOLD_TYPE_DYNAMIC
DEVLINK_ESWITCH_MODE_LEGACY = C.DEVLINK_ESWITCH_MODE_LEGACY
DEVLINK_ESWITCH_MODE_SWITCHDEV = C.DEVLINK_ESWITCH_MODE_SWITCHDEV
DEVLINK_ESWITCH_INLINE_MODE_NONE = C.DEVLINK_ESWITCH_INLINE_MODE_NONE
DEVLINK_ESWITCH_INLINE_MODE_LINK = C.DEVLINK_ESWITCH_INLINE_MODE_LINK
DEVLINK_ESWITCH_INLINE_MODE_NETWORK = C.DEVLINK_ESWITCH_INLINE_MODE_NETWORK
DEVLINK_ESWITCH_INLINE_MODE_TRANSPORT = C.DEVLINK_ESWITCH_INLINE_MODE_TRANSPORT
DEVLINK_ESWITCH_ENCAP_MODE_NONE = C.DEVLINK_ESWITCH_ENCAP_MODE_NONE
DEVLINK_ESWITCH_ENCAP_MODE_BASIC = C.DEVLINK_ESWITCH_ENCAP_MODE_BASIC
DEVLINK_PORT_FLAVOUR_PHYSICAL = C.DEVLINK_PORT_FLAVOUR_PHYSICAL
DEVLINK_PORT_FLAVOUR_CPU = C.DEVLINK_PORT_FLAVOUR_CPU
DEVLINK_PORT_FLAVOUR_DSA = C.DEVLINK_PORT_FLAVOUR_DSA
DEVLINK_PORT_FLAVOUR_PCI_PF = C.DEVLINK_PORT_FLAVOUR_PCI_PF
DEVLINK_PORT_FLAVOUR_PCI_VF = C.DEVLINK_PORT_FLAVOUR_PCI_VF
DEVLINK_PORT_FLAVOUR_VIRTUAL = C.DEVLINK_PORT_FLAVOUR_VIRTUAL
DEVLINK_PORT_FLAVOUR_UNUSED = C.DEVLINK_PORT_FLAVOUR_UNUSED
DEVLINK_PARAM_CMODE_RUNTIME = C.DEVLINK_PARAM_CMODE_RUNTIME
DEVLINK_PARAM_CMODE_DRIVERINIT = C.DEVLINK_PARAM_CMODE_DRIVERINIT
DEVLINK_PARAM_CMODE_PERMANENT = C.DEVLINK_PARAM_CMODE_PERMANENT
DEVLINK_PARAM_CMODE_MAX = C.DEVLINK_PARAM_CMODE_MAX
DEVLINK_PARAM_FW_LOAD_POLICY_VALUE_DRIVER = C.DEVLINK_PARAM_FW_LOAD_POLICY_VALUE_DRIVER
DEVLINK_PARAM_FW_LOAD_POLICY_VALUE_FLASH = C.DEVLINK_PARAM_FW_LOAD_POLICY_VALUE_FLASH
DEVLINK_PARAM_FW_LOAD_POLICY_VALUE_DISK = C.DEVLINK_PARAM_FW_LOAD_POLICY_VALUE_DISK
DEVLINK_PARAM_FW_LOAD_POLICY_VALUE_UNKNOWN = C.DEVLINK_PARAM_FW_LOAD_POLICY_VALUE_UNKNOWN
DEVLINK_PARAM_RESET_DEV_ON_DRV_PROBE_VALUE_UNKNOWN = C.DEVLINK_PARAM_RESET_DEV_ON_DRV_PROBE_VALUE_UNKNOWN
DEVLINK_PARAM_RESET_DEV_ON_DRV_PROBE_VALUE_ALWAYS = C.DEVLINK_PARAM_RESET_DEV_ON_DRV_PROBE_VALUE_ALWAYS
DEVLINK_PARAM_RESET_DEV_ON_DRV_PROBE_VALUE_NEVER = C.DEVLINK_PARAM_RESET_DEV_ON_DRV_PROBE_VALUE_NEVER
DEVLINK_PARAM_RESET_DEV_ON_DRV_PROBE_VALUE_DISK = C.DEVLINK_PARAM_RESET_DEV_ON_DRV_PROBE_VALUE_DISK
DEVLINK_ATTR_STATS_RX_PACKETS = C.DEVLINK_ATTR_STATS_RX_PACKETS
DEVLINK_ATTR_STATS_RX_BYTES = C.DEVLINK_ATTR_STATS_RX_BYTES
DEVLINK_ATTR_STATS_RX_DROPPED = C.DEVLINK_ATTR_STATS_RX_DROPPED
DEVLINK_ATTR_STATS_MAX = C.DEVLINK_ATTR_STATS_MAX
DEVLINK_FLASH_OVERWRITE_SETTINGS_BIT = C.DEVLINK_FLASH_OVERWRITE_SETTINGS_BIT
DEVLINK_FLASH_OVERWRITE_IDENTIFIERS_BIT = C.DEVLINK_FLASH_OVERWRITE_IDENTIFIERS_BIT
DEVLINK_FLASH_OVERWRITE_MAX_BIT = C.DEVLINK_FLASH_OVERWRITE_MAX_BIT
DEVLINK_TRAP_ACTION_DROP = C.DEVLINK_TRAP_ACTION_DROP
DEVLINK_TRAP_ACTION_TRAP = C.DEVLINK_TRAP_ACTION_TRAP
DEVLINK_TRAP_ACTION_MIRROR = C.DEVLINK_TRAP_ACTION_MIRROR
DEVLINK_TRAP_TYPE_DROP = C.DEVLINK_TRAP_TYPE_DROP
DEVLINK_TRAP_TYPE_EXCEPTION = C.DEVLINK_TRAP_TYPE_EXCEPTION
DEVLINK_TRAP_TYPE_CONTROL = C.DEVLINK_TRAP_TYPE_CONTROL
DEVLINK_ATTR_TRAP_METADATA_TYPE_IN_PORT = C.DEVLINK_ATTR_TRAP_METADATA_TYPE_IN_PORT
DEVLINK_ATTR_TRAP_METADATA_TYPE_FA_COOKIE = C.DEVLINK_ATTR_TRAP_METADATA_TYPE_FA_COOKIE
DEVLINK_RELOAD_ACTION_UNSPEC = C.DEVLINK_RELOAD_ACTION_UNSPEC
DEVLINK_RELOAD_ACTION_DRIVER_REINIT = C.DEVLINK_RELOAD_ACTION_DRIVER_REINIT
DEVLINK_RELOAD_ACTION_FW_ACTIVATE = C.DEVLINK_RELOAD_ACTION_FW_ACTIVATE
DEVLINK_RELOAD_ACTION_MAX = C.DEVLINK_RELOAD_ACTION_MAX
DEVLINK_RELOAD_LIMIT_UNSPEC = C.DEVLINK_RELOAD_LIMIT_UNSPEC
DEVLINK_RELOAD_LIMIT_NO_RESET = C.DEVLINK_RELOAD_LIMIT_NO_RESET
DEVLINK_RELOAD_LIMIT_MAX = C.DEVLINK_RELOAD_LIMIT_MAX
DEVLINK_ATTR_UNSPEC = C.DEVLINK_ATTR_UNSPEC
DEVLINK_ATTR_BUS_NAME = C.DEVLINK_ATTR_BUS_NAME
DEVLINK_ATTR_DEV_NAME = C.DEVLINK_ATTR_DEV_NAME
DEVLINK_ATTR_PORT_INDEX = C.DEVLINK_ATTR_PORT_INDEX
DEVLINK_ATTR_PORT_TYPE = C.DEVLINK_ATTR_PORT_TYPE
DEVLINK_ATTR_PORT_DESIRED_TYPE = C.DEVLINK_ATTR_PORT_DESIRED_TYPE
DEVLINK_ATTR_PORT_NETDEV_IFINDEX = C.DEVLINK_ATTR_PORT_NETDEV_IFINDEX
DEVLINK_ATTR_PORT_NETDEV_NAME = C.DEVLINK_ATTR_PORT_NETDEV_NAME
DEVLINK_ATTR_PORT_IBDEV_NAME = C.DEVLINK_ATTR_PORT_IBDEV_NAME
DEVLINK_ATTR_PORT_SPLIT_COUNT = C.DEVLINK_ATTR_PORT_SPLIT_COUNT
DEVLINK_ATTR_PORT_SPLIT_GROUP = C.DEVLINK_ATTR_PORT_SPLIT_GROUP
DEVLINK_ATTR_SB_INDEX = C.DEVLINK_ATTR_SB_INDEX
DEVLINK_ATTR_SB_SIZE = C.DEVLINK_ATTR_SB_SIZE
DEVLINK_ATTR_SB_INGRESS_POOL_COUNT = C.DEVLINK_ATTR_SB_INGRESS_POOL_COUNT
DEVLINK_ATTR_SB_EGRESS_POOL_COUNT = C.DEVLINK_ATTR_SB_EGRESS_POOL_COUNT
DEVLINK_ATTR_SB_INGRESS_TC_COUNT = C.DEVLINK_ATTR_SB_INGRESS_TC_COUNT
DEVLINK_ATTR_SB_EGRESS_TC_COUNT = C.DEVLINK_ATTR_SB_EGRESS_TC_COUNT
DEVLINK_ATTR_SB_POOL_INDEX = C.DEVLINK_ATTR_SB_POOL_INDEX
DEVLINK_ATTR_SB_POOL_TYPE = C.DEVLINK_ATTR_SB_POOL_TYPE
DEVLINK_ATTR_SB_POOL_SIZE = C.DEVLINK_ATTR_SB_POOL_SIZE
DEVLINK_ATTR_SB_POOL_THRESHOLD_TYPE = C.DEVLINK_ATTR_SB_POOL_THRESHOLD_TYPE
DEVLINK_ATTR_SB_THRESHOLD = C.DEVLINK_ATTR_SB_THRESHOLD
DEVLINK_ATTR_SB_TC_INDEX = C.DEVLINK_ATTR_SB_TC_INDEX
DEVLINK_ATTR_SB_OCC_CUR = C.DEVLINK_ATTR_SB_OCC_CUR
DEVLINK_ATTR_SB_OCC_MAX = C.DEVLINK_ATTR_SB_OCC_MAX
DEVLINK_ATTR_ESWITCH_MODE = C.DEVLINK_ATTR_ESWITCH_MODE
DEVLINK_ATTR_ESWITCH_INLINE_MODE = C.DEVLINK_ATTR_ESWITCH_INLINE_MODE
DEVLINK_ATTR_DPIPE_TABLES = C.DEVLINK_ATTR_DPIPE_TABLES
DEVLINK_ATTR_DPIPE_TABLE = C.DEVLINK_ATTR_DPIPE_TABLE
DEVLINK_ATTR_DPIPE_TABLE_NAME = C.DEVLINK_ATTR_DPIPE_TABLE_NAME
DEVLINK_ATTR_DPIPE_TABLE_SIZE = C.DEVLINK_ATTR_DPIPE_TABLE_SIZE
DEVLINK_ATTR_DPIPE_TABLE_MATCHES = C.DEVLINK_ATTR_DPIPE_TABLE_MATCHES
DEVLINK_ATTR_DPIPE_TABLE_ACTIONS = C.DEVLINK_ATTR_DPIPE_TABLE_ACTIONS
DEVLINK_ATTR_DPIPE_TABLE_COUNTERS_ENABLED = C.DEVLINK_ATTR_DPIPE_TABLE_COUNTERS_ENABLED
DEVLINK_ATTR_DPIPE_ENTRIES = C.DEVLINK_ATTR_DPIPE_ENTRIES
DEVLINK_ATTR_DPIPE_ENTRY = C.DEVLINK_ATTR_DPIPE_ENTRY
DEVLINK_ATTR_DPIPE_ENTRY_INDEX = C.DEVLINK_ATTR_DPIPE_ENTRY_INDEX
DEVLINK_ATTR_DPIPE_ENTRY_MATCH_VALUES = C.DEVLINK_ATTR_DPIPE_ENTRY_MATCH_VALUES
DEVLINK_ATTR_DPIPE_ENTRY_ACTION_VALUES = C.DEVLINK_ATTR_DPIPE_ENTRY_ACTION_VALUES
DEVLINK_ATTR_DPIPE_ENTRY_COUNTER = C.DEVLINK_ATTR_DPIPE_ENTRY_COUNTER
DEVLINK_ATTR_DPIPE_MATCH = C.DEVLINK_ATTR_DPIPE_MATCH
DEVLINK_ATTR_DPIPE_MATCH_VALUE = C.DEVLINK_ATTR_DPIPE_MATCH_VALUE
DEVLINK_ATTR_DPIPE_MATCH_TYPE = C.DEVLINK_ATTR_DPIPE_MATCH_TYPE
DEVLINK_ATTR_DPIPE_ACTION = C.DEVLINK_ATTR_DPIPE_ACTION
DEVLINK_ATTR_DPIPE_ACTION_VALUE = C.DEVLINK_ATTR_DPIPE_ACTION_VALUE
DEVLINK_ATTR_DPIPE_ACTION_TYPE = C.DEVLINK_ATTR_DPIPE_ACTION_TYPE
DEVLINK_ATTR_DPIPE_VALUE = C.DEVLINK_ATTR_DPIPE_VALUE
DEVLINK_ATTR_DPIPE_VALUE_MASK = C.DEVLINK_ATTR_DPIPE_VALUE_MASK
DEVLINK_ATTR_DPIPE_VALUE_MAPPING = C.DEVLINK_ATTR_DPIPE_VALUE_MAPPING
DEVLINK_ATTR_DPIPE_HEADERS = C.DEVLINK_ATTR_DPIPE_HEADERS
DEVLINK_ATTR_DPIPE_HEADER = C.DEVLINK_ATTR_DPIPE_HEADER
DEVLINK_ATTR_DPIPE_HEADER_NAME = C.DEVLINK_ATTR_DPIPE_HEADER_NAME
DEVLINK_ATTR_DPIPE_HEADER_ID = C.DEVLINK_ATTR_DPIPE_HEADER_ID
DEVLINK_ATTR_DPIPE_HEADER_FIELDS = C.DEVLINK_ATTR_DPIPE_HEADER_FIELDS
DEVLINK_ATTR_DPIPE_HEADER_GLOBAL = C.DEVLINK_ATTR_DPIPE_HEADER_GLOBAL
DEVLINK_ATTR_DPIPE_HEADER_INDEX = C.DEVLINK_ATTR_DPIPE_HEADER_INDEX
DEVLINK_ATTR_DPIPE_FIELD = C.DEVLINK_ATTR_DPIPE_FIELD
DEVLINK_ATTR_DPIPE_FIELD_NAME = C.DEVLINK_ATTR_DPIPE_FIELD_NAME
DEVLINK_ATTR_DPIPE_FIELD_ID = C.DEVLINK_ATTR_DPIPE_FIELD_ID
DEVLINK_ATTR_DPIPE_FIELD_BITWIDTH = C.DEVLINK_ATTR_DPIPE_FIELD_BITWIDTH
DEVLINK_ATTR_DPIPE_FIELD_MAPPING_TYPE = C.DEVLINK_ATTR_DPIPE_FIELD_MAPPING_TYPE
DEVLINK_ATTR_PAD = C.DEVLINK_ATTR_PAD
DEVLINK_ATTR_ESWITCH_ENCAP_MODE = C.DEVLINK_ATTR_ESWITCH_ENCAP_MODE
DEVLINK_ATTR_RESOURCE_LIST = C.DEVLINK_ATTR_RESOURCE_LIST
DEVLINK_ATTR_RESOURCE = C.DEVLINK_ATTR_RESOURCE
DEVLINK_ATTR_RESOURCE_NAME = C.DEVLINK_ATTR_RESOURCE_NAME
DEVLINK_ATTR_RESOURCE_ID = C.DEVLINK_ATTR_RESOURCE_ID
DEVLINK_ATTR_RESOURCE_SIZE = C.DEVLINK_ATTR_RESOURCE_SIZE
DEVLINK_ATTR_RESOURCE_SIZE_NEW = C.DEVLINK_ATTR_RESOURCE_SIZE_NEW
DEVLINK_ATTR_RESOURCE_SIZE_VALID = C.DEVLINK_ATTR_RESOURCE_SIZE_VALID
DEVLINK_ATTR_RESOURCE_SIZE_MIN = C.DEVLINK_ATTR_RESOURCE_SIZE_MIN
DEVLINK_ATTR_RESOURCE_SIZE_MAX = C.DEVLINK_ATTR_RESOURCE_SIZE_MAX
DEVLINK_ATTR_RESOURCE_SIZE_GRAN = C.DEVLINK_ATTR_RESOURCE_SIZE_GRAN
DEVLINK_ATTR_RESOURCE_UNIT = C.DEVLINK_ATTR_RESOURCE_UNIT
DEVLINK_ATTR_RESOURCE_OCC = C.DEVLINK_ATTR_RESOURCE_OCC
DEVLINK_ATTR_DPIPE_TABLE_RESOURCE_ID = C.DEVLINK_ATTR_DPIPE_TABLE_RESOURCE_ID
DEVLINK_ATTR_DPIPE_TABLE_RESOURCE_UNITS = C.DEVLINK_ATTR_DPIPE_TABLE_RESOURCE_UNITS
DEVLINK_ATTR_PORT_FLAVOUR = C.DEVLINK_ATTR_PORT_FLAVOUR
DEVLINK_ATTR_PORT_NUMBER = C.DEVLINK_ATTR_PORT_NUMBER
DEVLINK_ATTR_PORT_SPLIT_SUBPORT_NUMBER = C.DEVLINK_ATTR_PORT_SPLIT_SUBPORT_NUMBER
DEVLINK_ATTR_PARAM = C.DEVLINK_ATTR_PARAM
DEVLINK_ATTR_PARAM_NAME = C.DEVLINK_ATTR_PARAM_NAME
DEVLINK_ATTR_PARAM_GENERIC = C.DEVLINK_ATTR_PARAM_GENERIC
DEVLINK_ATTR_PARAM_TYPE = C.DEVLINK_ATTR_PARAM_TYPE
DEVLINK_ATTR_PARAM_VALUES_LIST = C.DEVLINK_ATTR_PARAM_VALUES_LIST
DEVLINK_ATTR_PARAM_VALUE = C.DEVLINK_ATTR_PARAM_VALUE
DEVLINK_ATTR_PARAM_VALUE_DATA = C.DEVLINK_ATTR_PARAM_VALUE_DATA
DEVLINK_ATTR_PARAM_VALUE_CMODE = C.DEVLINK_ATTR_PARAM_VALUE_CMODE
DEVLINK_ATTR_REGION_NAME = C.DEVLINK_ATTR_REGION_NAME
DEVLINK_ATTR_REGION_SIZE = C.DEVLINK_ATTR_REGION_SIZE
DEVLINK_ATTR_REGION_SNAPSHOTS = C.DEVLINK_ATTR_REGION_SNAPSHOTS
DEVLINK_ATTR_REGION_SNAPSHOT = C.DEVLINK_ATTR_REGION_SNAPSHOT
DEVLINK_ATTR_REGION_SNAPSHOT_ID = C.DEVLINK_ATTR_REGION_SNAPSHOT_ID
DEVLINK_ATTR_REGION_CHUNKS = C.DEVLINK_ATTR_REGION_CHUNKS
DEVLINK_ATTR_REGION_CHUNK = C.DEVLINK_ATTR_REGION_CHUNK
DEVLINK_ATTR_REGION_CHUNK_DATA = C.DEVLINK_ATTR_REGION_CHUNK_DATA
DEVLINK_ATTR_REGION_CHUNK_ADDR = C.DEVLINK_ATTR_REGION_CHUNK_ADDR
DEVLINK_ATTR_REGION_CHUNK_LEN = C.DEVLINK_ATTR_REGION_CHUNK_LEN
DEVLINK_ATTR_INFO_DRIVER_NAME = C.DEVLINK_ATTR_INFO_DRIVER_NAME
DEVLINK_ATTR_INFO_SERIAL_NUMBER = C.DEVLINK_ATTR_INFO_SERIAL_NUMBER
DEVLINK_ATTR_INFO_VERSION_FIXED = C.DEVLINK_ATTR_INFO_VERSION_FIXED
DEVLINK_ATTR_INFO_VERSION_RUNNING = C.DEVLINK_ATTR_INFO_VERSION_RUNNING
DEVLINK_ATTR_INFO_VERSION_STORED = C.DEVLINK_ATTR_INFO_VERSION_STORED
DEVLINK_ATTR_INFO_VERSION_NAME = C.DEVLINK_ATTR_INFO_VERSION_NAME
DEVLINK_ATTR_INFO_VERSION_VALUE = C.DEVLINK_ATTR_INFO_VERSION_VALUE
DEVLINK_ATTR_SB_POOL_CELL_SIZE = C.DEVLINK_ATTR_SB_POOL_CELL_SIZE
DEVLINK_ATTR_FMSG = C.DEVLINK_ATTR_FMSG
DEVLINK_ATTR_FMSG_OBJ_NEST_START = C.DEVLINK_ATTR_FMSG_OBJ_NEST_START
DEVLINK_ATTR_FMSG_PAIR_NEST_START = C.DEVLINK_ATTR_FMSG_PAIR_NEST_START
DEVLINK_ATTR_FMSG_ARR_NEST_START = C.DEVLINK_ATTR_FMSG_ARR_NEST_START
DEVLINK_ATTR_FMSG_NEST_END = C.DEVLINK_ATTR_FMSG_NEST_END
DEVLINK_ATTR_FMSG_OBJ_NAME = C.DEVLINK_ATTR_FMSG_OBJ_NAME
DEVLINK_ATTR_FMSG_OBJ_VALUE_TYPE = C.DEVLINK_ATTR_FMSG_OBJ_VALUE_TYPE
DEVLINK_ATTR_FMSG_OBJ_VALUE_DATA = C.DEVLINK_ATTR_FMSG_OBJ_VALUE_DATA
DEVLINK_ATTR_HEALTH_REPORTER = C.DEVLINK_ATTR_HEALTH_REPORTER
DEVLINK_ATTR_HEALTH_REPORTER_NAME = C.DEVLINK_ATTR_HEALTH_REPORTER_NAME
DEVLINK_ATTR_HEALTH_REPORTER_STATE = C.DEVLINK_ATTR_HEALTH_REPORTER_STATE
DEVLINK_ATTR_HEALTH_REPORTER_ERR_COUNT = C.DEVLINK_ATTR_HEALTH_REPORTER_ERR_COUNT
DEVLINK_ATTR_HEALTH_REPORTER_RECOVER_COUNT = C.DEVLINK_ATTR_HEALTH_REPORTER_RECOVER_COUNT
DEVLINK_ATTR_HEALTH_REPORTER_DUMP_TS = C.DEVLINK_ATTR_HEALTH_REPORTER_DUMP_TS
DEVLINK_ATTR_HEALTH_REPORTER_GRACEFUL_PERIOD = C.DEVLINK_ATTR_HEALTH_REPORTER_GRACEFUL_PERIOD
DEVLINK_ATTR_HEALTH_REPORTER_AUTO_RECOVER = C.DEVLINK_ATTR_HEALTH_REPORTER_AUTO_RECOVER
DEVLINK_ATTR_FLASH_UPDATE_FILE_NAME = C.DEVLINK_ATTR_FLASH_UPDATE_FILE_NAME
DEVLINK_ATTR_FLASH_UPDATE_COMPONENT = C.DEVLINK_ATTR_FLASH_UPDATE_COMPONENT
DEVLINK_ATTR_FLASH_UPDATE_STATUS_MSG = C.DEVLINK_ATTR_FLASH_UPDATE_STATUS_MSG
DEVLINK_ATTR_FLASH_UPDATE_STATUS_DONE = C.DEVLINK_ATTR_FLASH_UPDATE_STATUS_DONE
DEVLINK_ATTR_FLASH_UPDATE_STATUS_TOTAL = C.DEVLINK_ATTR_FLASH_UPDATE_STATUS_TOTAL
DEVLINK_ATTR_PORT_PCI_PF_NUMBER = C.DEVLINK_ATTR_PORT_PCI_PF_NUMBER
DEVLINK_ATTR_PORT_PCI_VF_NUMBER = C.DEVLINK_ATTR_PORT_PCI_VF_NUMBER
DEVLINK_ATTR_STATS = C.DEVLINK_ATTR_STATS
DEVLINK_ATTR_TRAP_NAME = C.DEVLINK_ATTR_TRAP_NAME
DEVLINK_ATTR_TRAP_ACTION = C.DEVLINK_ATTR_TRAP_ACTION
DEVLINK_ATTR_TRAP_TYPE = C.DEVLINK_ATTR_TRAP_TYPE
DEVLINK_ATTR_TRAP_GENERIC = C.DEVLINK_ATTR_TRAP_GENERIC
DEVLINK_ATTR_TRAP_METADATA = C.DEVLINK_ATTR_TRAP_METADATA
DEVLINK_ATTR_TRAP_GROUP_NAME = C.DEVLINK_ATTR_TRAP_GROUP_NAME
DEVLINK_ATTR_RELOAD_FAILED = C.DEVLINK_ATTR_RELOAD_FAILED
DEVLINK_ATTR_HEALTH_REPORTER_DUMP_TS_NS = C.DEVLINK_ATTR_HEALTH_REPORTER_DUMP_TS_NS
DEVLINK_ATTR_NETNS_FD = C.DEVLINK_ATTR_NETNS_FD
DEVLINK_ATTR_NETNS_PID = C.DEVLINK_ATTR_NETNS_PID
DEVLINK_ATTR_NETNS_ID = C.DEVLINK_ATTR_NETNS_ID
DEVLINK_ATTR_HEALTH_REPORTER_AUTO_DUMP = C.DEVLINK_ATTR_HEALTH_REPORTER_AUTO_DUMP
DEVLINK_ATTR_TRAP_POLICER_ID = C.DEVLINK_ATTR_TRAP_POLICER_ID
DEVLINK_ATTR_TRAP_POLICER_RATE = C.DEVLINK_ATTR_TRAP_POLICER_RATE
DEVLINK_ATTR_TRAP_POLICER_BURST = C.DEVLINK_ATTR_TRAP_POLICER_BURST
DEVLINK_ATTR_PORT_FUNCTION = C.DEVLINK_ATTR_PORT_FUNCTION
DEVLINK_ATTR_INFO_BOARD_SERIAL_NUMBER = C.DEVLINK_ATTR_INFO_BOARD_SERIAL_NUMBER
DEVLINK_ATTR_PORT_LANES = C.DEVLINK_ATTR_PORT_LANES
DEVLINK_ATTR_PORT_SPLITTABLE = C.DEVLINK_ATTR_PORT_SPLITTABLE
DEVLINK_ATTR_PORT_EXTERNAL = C.DEVLINK_ATTR_PORT_EXTERNAL
DEVLINK_ATTR_PORT_CONTROLLER_NUMBER = C.DEVLINK_ATTR_PORT_CONTROLLER_NUMBER
DEVLINK_ATTR_FLASH_UPDATE_STATUS_TIMEOUT = C.DEVLINK_ATTR_FLASH_UPDATE_STATUS_TIMEOUT
DEVLINK_ATTR_FLASH_UPDATE_OVERWRITE_MASK = C.DEVLINK_ATTR_FLASH_UPDATE_OVERWRITE_MASK
DEVLINK_ATTR_RELOAD_ACTION = C.DEVLINK_ATTR_RELOAD_ACTION
DEVLINK_ATTR_RELOAD_ACTIONS_PERFORMED = C.DEVLINK_ATTR_RELOAD_ACTIONS_PERFORMED
DEVLINK_ATTR_RELOAD_LIMITS = C.DEVLINK_ATTR_RELOAD_LIMITS
DEVLINK_ATTR_DEV_STATS = C.DEVLINK_ATTR_DEV_STATS
DEVLINK_ATTR_RELOAD_STATS = C.DEVLINK_ATTR_RELOAD_STATS
DEVLINK_ATTR_RELOAD_STATS_ENTRY = C.DEVLINK_ATTR_RELOAD_STATS_ENTRY
DEVLINK_ATTR_RELOAD_STATS_LIMIT = C.DEVLINK_ATTR_RELOAD_STATS_LIMIT
DEVLINK_ATTR_RELOAD_STATS_VALUE = C.DEVLINK_ATTR_RELOAD_STATS_VALUE
DEVLINK_ATTR_REMOTE_RELOAD_STATS = C.DEVLINK_ATTR_REMOTE_RELOAD_STATS
DEVLINK_ATTR_RELOAD_ACTION_INFO = C.DEVLINK_ATTR_RELOAD_ACTION_INFO
DEVLINK_ATTR_RELOAD_ACTION_STATS = C.DEVLINK_ATTR_RELOAD_ACTION_STATS
DEVLINK_ATTR_PORT_PCI_SF_NUMBER = C.DEVLINK_ATTR_PORT_PCI_SF_NUMBER
DEVLINK_ATTR_RATE_TYPE = C.DEVLINK_ATTR_RATE_TYPE
DEVLINK_ATTR_RATE_TX_SHARE = C.DEVLINK_ATTR_RATE_TX_SHARE
DEVLINK_ATTR_RATE_TX_MAX = C.DEVLINK_ATTR_RATE_TX_MAX
DEVLINK_ATTR_RATE_NODE_NAME = C.DEVLINK_ATTR_RATE_NODE_NAME
DEVLINK_ATTR_RATE_PARENT_NODE_NAME = C.DEVLINK_ATTR_RATE_PARENT_NODE_NAME
DEVLINK_ATTR_REGION_MAX_SNAPSHOTS = C.DEVLINK_ATTR_REGION_MAX_SNAPSHOTS
DEVLINK_ATTR_LINECARD_INDEX = C.DEVLINK_ATTR_LINECARD_INDEX
DEVLINK_ATTR_LINECARD_STATE = C.DEVLINK_ATTR_LINECARD_STATE
DEVLINK_ATTR_LINECARD_TYPE = C.DEVLINK_ATTR_LINECARD_TYPE
DEVLINK_ATTR_LINECARD_SUPPORTED_TYPES = C.DEVLINK_ATTR_LINECARD_SUPPORTED_TYPES
DEVLINK_ATTR_NESTED_DEVLINK = C.DEVLINK_ATTR_NESTED_DEVLINK
DEVLINK_ATTR_SELFTESTS = C.DEVLINK_ATTR_SELFTESTS
DEVLINK_ATTR_MAX = C.DEVLINK_ATTR_MAX
DEVLINK_DPIPE_FIELD_MAPPING_TYPE_NONE = C.DEVLINK_DPIPE_FIELD_MAPPING_TYPE_NONE
DEVLINK_DPIPE_FIELD_MAPPING_TYPE_IFINDEX = C.DEVLINK_DPIPE_FIELD_MAPPING_TYPE_IFINDEX
DEVLINK_DPIPE_MATCH_TYPE_FIELD_EXACT = C.DEVLINK_DPIPE_MATCH_TYPE_FIELD_EXACT
DEVLINK_DPIPE_ACTION_TYPE_FIELD_MODIFY = C.DEVLINK_DPIPE_ACTION_TYPE_FIELD_MODIFY
DEVLINK_DPIPE_FIELD_ETHERNET_DST_MAC = C.DEVLINK_DPIPE_FIELD_ETHERNET_DST_MAC
DEVLINK_DPIPE_FIELD_IPV4_DST_IP = C.DEVLINK_DPIPE_FIELD_IPV4_DST_IP
DEVLINK_DPIPE_FIELD_IPV6_DST_IP = C.DEVLINK_DPIPE_FIELD_IPV6_DST_IP
DEVLINK_DPIPE_HEADER_ETHERNET = C.DEVLINK_DPIPE_HEADER_ETHERNET
DEVLINK_DPIPE_HEADER_IPV4 = C.DEVLINK_DPIPE_HEADER_IPV4
DEVLINK_DPIPE_HEADER_IPV6 = C.DEVLINK_DPIPE_HEADER_IPV6
DEVLINK_RESOURCE_UNIT_ENTRY = C.DEVLINK_RESOURCE_UNIT_ENTRY
DEVLINK_PORT_FUNCTION_ATTR_UNSPEC = C.DEVLINK_PORT_FUNCTION_ATTR_UNSPEC
DEVLINK_PORT_FUNCTION_ATTR_HW_ADDR = C.DEVLINK_PORT_FUNCTION_ATTR_HW_ADDR
DEVLINK_PORT_FN_ATTR_STATE = C.DEVLINK_PORT_FN_ATTR_STATE
DEVLINK_PORT_FN_ATTR_OPSTATE = C.DEVLINK_PORT_FN_ATTR_OPSTATE
DEVLINK_PORT_FN_ATTR_CAPS = C.DEVLINK_PORT_FN_ATTR_CAPS
DEVLINK_PORT_FUNCTION_ATTR_MAX = C.DEVLINK_PORT_FUNCTION_ATTR_MAX
)
// sftpfs-verity
type FsverityDigest C.struct_fsverity_digest
type FsverityEnableArg C.struct_fsverity_enable_arg
// nexthop
type Nhmsg C.struct_nhmsg
type NexthopGrp C.struct_nexthop_grp
const (
NHA_UNSPEC = C.NHA_UNSPEC
NHA_ID = C.NHA_ID
NHA_GROUP = C.NHA_GROUP
NHA_GROUP_TYPE = C.NHA_GROUP_TYPE
NHA_BLACKHOLE = C.NHA_BLACKHOLE
NHA_OIF = C.NHA_OIF
NHA_GATEWAY = C.NHA_GATEWAY
NHA_ENCAP_TYPE = C.NHA_ENCAP_TYPE
NHA_ENCAP = C.NHA_ENCAP
NHA_GROUPS = C.NHA_GROUPS
NHA_MASTER = C.NHA_MASTER
)
// raw CAN sockets
const (
CAN_RAW_FILTER = C.CAN_RAW_FILTER
CAN_RAW_ERR_FILTER = C.CAN_RAW_ERR_FILTER
CAN_RAW_LOOPBACK = C.CAN_RAW_LOOPBACK
CAN_RAW_RECV_OWN_MSGS = C.CAN_RAW_RECV_OWN_MSGS
CAN_RAW_FD_FRAMES = C.CAN_RAW_FD_FRAMES
CAN_RAW_JOIN_FILTERS = C.CAN_RAW_JOIN_FILTERS
)
// Watchdog API
type WatchdogInfo C.struct_watchdog_info
// PPS API
type PPSFData C.struct_pps_fdata
type PPSKParams C.struct_pps_kparams
type PPSKInfo C.struct_pps_kinfo
type PPSKTime C.struct_pps_ktime
const (
PPS_GETPARAMS = C.PPS_GETPARAMS
PPS_SETPARAMS = C.PPS_SETPARAMS
PPS_GETCAP = C.PPS_GETCAP
PPS_FETCH = C.PPS_FETCH
)
// lwtunnel and related APIs
const (
LWTUNNEL_ENCAP_NONE = C.LWTUNNEL_ENCAP_NONE
LWTUNNEL_ENCAP_MPLS = C.LWTUNNEL_ENCAP_MPLS
LWTUNNEL_ENCAP_IP = C.LWTUNNEL_ENCAP_IP
LWTUNNEL_ENCAP_ILA = C.LWTUNNEL_ENCAP_ILA
LWTUNNEL_ENCAP_IP6 = C.LWTUNNEL_ENCAP_IP6
LWTUNNEL_ENCAP_SEG6 = C.LWTUNNEL_ENCAP_SEG6
LWTUNNEL_ENCAP_BPF = C.LWTUNNEL_ENCAP_BPF
LWTUNNEL_ENCAP_SEG6_LOCAL = C.LWTUNNEL_ENCAP_SEG6_LOCAL
LWTUNNEL_ENCAP_RPL = C.LWTUNNEL_ENCAP_RPL
LWTUNNEL_ENCAP_IOAM6 = C.LWTUNNEL_ENCAP_IOAM6
LWTUNNEL_ENCAP_XFRM = C.LWTUNNEL_ENCAP_XFRM
LWTUNNEL_ENCAP_MAX = C.LWTUNNEL_ENCAP_MAX
MPLS_IPTUNNEL_UNSPEC = C.MPLS_IPTUNNEL_UNSPEC
MPLS_IPTUNNEL_DST = C.MPLS_IPTUNNEL_DST
MPLS_IPTUNNEL_TTL = C.MPLS_IPTUNNEL_TTL
MPLS_IPTUNNEL_MAX = C.MPLS_IPTUNNEL_MAX
)
// ethtool and its netlink interface, generated using:
//
// perl -nlE '/^\s*(ETHTOOL_\w+)/ && say "$1 = C.$1"' ethtool.h
// perl -nlE '/^\s*(ETHTOOL_\w+)/ && say "$1 = C.$1"' ethtool_netlink.h
//
// Note that a couple of constants produced by this command will be duplicated
// by mkerrors.sh, so some manual pruning was necessary.
const (
ETHTOOL_ID_UNSPEC = C.ETHTOOL_ID_UNSPEC
ETHTOOL_RX_COPYBREAK = C.ETHTOOL_RX_COPYBREAK
ETHTOOL_TX_COPYBREAK = C.ETHTOOL_TX_COPYBREAK
ETHTOOL_PFC_PREVENTION_TOUT = C.ETHTOOL_PFC_PREVENTION_TOUT
ETHTOOL_TUNABLE_UNSPEC = C.ETHTOOL_TUNABLE_UNSPEC
ETHTOOL_TUNABLE_U8 = C.ETHTOOL_TUNABLE_U8
ETHTOOL_TUNABLE_U16 = C.ETHTOOL_TUNABLE_U16
ETHTOOL_TUNABLE_U32 = C.ETHTOOL_TUNABLE_U32
ETHTOOL_TUNABLE_U64 = C.ETHTOOL_TUNABLE_U64
ETHTOOL_TUNABLE_STRING = C.ETHTOOL_TUNABLE_STRING
ETHTOOL_TUNABLE_S8 = C.ETHTOOL_TUNABLE_S8
ETHTOOL_TUNABLE_S16 = C.ETHTOOL_TUNABLE_S16
ETHTOOL_TUNABLE_S32 = C.ETHTOOL_TUNABLE_S32
ETHTOOL_TUNABLE_S64 = C.ETHTOOL_TUNABLE_S64
ETHTOOL_PHY_ID_UNSPEC = C.ETHTOOL_PHY_ID_UNSPEC
ETHTOOL_PHY_DOWNSHIFT = C.ETHTOOL_PHY_DOWNSHIFT
ETHTOOL_PHY_FAST_LINK_DOWN = C.ETHTOOL_PHY_FAST_LINK_DOWN
ETHTOOL_PHY_EDPD = C.ETHTOOL_PHY_EDPD
ETHTOOL_LINK_EXT_STATE_AUTONEG = C.ETHTOOL_LINK_EXT_STATE_AUTONEG
ETHTOOL_LINK_EXT_STATE_LINK_TRAINING_FAILURE = C.ETHTOOL_LINK_EXT_STATE_LINK_TRAINING_FAILURE
ETHTOOL_LINK_EXT_STATE_LINK_LOGICAL_MISMATCH = C.ETHTOOL_LINK_EXT_STATE_LINK_LOGICAL_MISMATCH
ETHTOOL_LINK_EXT_STATE_BAD_SIGNAL_INTEGRITY = C.ETHTOOL_LINK_EXT_STATE_BAD_SIGNAL_INTEGRITY
ETHTOOL_LINK_EXT_STATE_NO_CABLE = C.ETHTOOL_LINK_EXT_STATE_NO_CABLE
ETHTOOL_LINK_EXT_STATE_CABLE_ISSUE = C.ETHTOOL_LINK_EXT_STATE_CABLE_ISSUE
ETHTOOL_LINK_EXT_STATE_EEPROM_ISSUE = C.ETHTOOL_LINK_EXT_STATE_EEPROM_ISSUE
ETHTOOL_LINK_EXT_STATE_CALIBRATION_FAILURE = C.ETHTOOL_LINK_EXT_STATE_CALIBRATION_FAILURE
ETHTOOL_LINK_EXT_STATE_POWER_BUDGET_EXCEEDED = C.ETHTOOL_LINK_EXT_STATE_POWER_BUDGET_EXCEEDED
ETHTOOL_LINK_EXT_STATE_OVERHEAT = C.ETHTOOL_LINK_EXT_STATE_OVERHEAT
ETHTOOL_LINK_EXT_SUBSTATE_AN_NO_PARTNER_DETECTED = C.ETHTOOL_LINK_EXT_SUBSTATE_AN_NO_PARTNER_DETECTED
ETHTOOL_LINK_EXT_SUBSTATE_AN_ACK_NOT_RECEIVED = C.ETHTOOL_LINK_EXT_SUBSTATE_AN_ACK_NOT_RECEIVED
ETHTOOL_LINK_EXT_SUBSTATE_AN_NEXT_PAGE_EXCHANGE_FAILED = C.ETHTOOL_LINK_EXT_SUBSTATE_AN_NEXT_PAGE_EXCHANGE_FAILED
ETHTOOL_LINK_EXT_SUBSTATE_AN_NO_PARTNER_DETECTED_FORCE_MODE = C.ETHTOOL_LINK_EXT_SUBSTATE_AN_NO_PARTNER_DETECTED_FORCE_MODE
ETHTOOL_LINK_EXT_SUBSTATE_AN_FEC_MISMATCH_DURING_OVERRIDE = C.ETHTOOL_LINK_EXT_SUBSTATE_AN_FEC_MISMATCH_DURING_OVERRIDE
ETHTOOL_LINK_EXT_SUBSTATE_AN_NO_HCD = C.ETHTOOL_LINK_EXT_SUBSTATE_AN_NO_HCD
ETHTOOL_LINK_EXT_SUBSTATE_LT_KR_FRAME_LOCK_NOT_ACQUIRED = C.ETHTOOL_LINK_EXT_SUBSTATE_LT_KR_FRAME_LOCK_NOT_ACQUIRED
ETHTOOL_LINK_EXT_SUBSTATE_LT_KR_LINK_INHIBIT_TIMEOUT = C.ETHTOOL_LINK_EXT_SUBSTATE_LT_KR_LINK_INHIBIT_TIMEOUT
ETHTOOL_LINK_EXT_SUBSTATE_LT_KR_LINK_PARTNER_DID_NOT_SET_RECEIVER_READY = C.ETHTOOL_LINK_EXT_SUBSTATE_LT_KR_LINK_PARTNER_DID_NOT_SET_RECEIVER_READY
ETHTOOL_LINK_EXT_SUBSTATE_LT_REMOTE_FAULT = C.ETHTOOL_LINK_EXT_SUBSTATE_LT_REMOTE_FAULT
ETHTOOL_LINK_EXT_SUBSTATE_LLM_PCS_DID_NOT_ACQUIRE_BLOCK_LOCK = C.ETHTOOL_LINK_EXT_SUBSTATE_LLM_PCS_DID_NOT_ACQUIRE_BLOCK_LOCK
ETHTOOL_LINK_EXT_SUBSTATE_LLM_PCS_DID_NOT_ACQUIRE_AM_LOCK = C.ETHTOOL_LINK_EXT_SUBSTATE_LLM_PCS_DID_NOT_ACQUIRE_AM_LOCK
ETHTOOL_LINK_EXT_SUBSTATE_LLM_PCS_DID_NOT_GET_ALIGN_STATUS = C.ETHTOOL_LINK_EXT_SUBSTATE_LLM_PCS_DID_NOT_GET_ALIGN_STATUS
ETHTOOL_LINK_EXT_SUBSTATE_LLM_FC_FEC_IS_NOT_LOCKED = C.ETHTOOL_LINK_EXT_SUBSTATE_LLM_FC_FEC_IS_NOT_LOCKED
ETHTOOL_LINK_EXT_SUBSTATE_LLM_RS_FEC_IS_NOT_LOCKED = C.ETHTOOL_LINK_EXT_SUBSTATE_LLM_RS_FEC_IS_NOT_LOCKED
ETHTOOL_LINK_EXT_SUBSTATE_BSI_LARGE_NUMBER_OF_PHYSICAL_ERRORS = C.ETHTOOL_LINK_EXT_SUBSTATE_BSI_LARGE_NUMBER_OF_PHYSICAL_ERRORS
ETHTOOL_LINK_EXT_SUBSTATE_BSI_UNSUPPORTED_RATE = C.ETHTOOL_LINK_EXT_SUBSTATE_BSI_UNSUPPORTED_RATE
ETHTOOL_LINK_EXT_SUBSTATE_CI_UNSUPPORTED_CABLE = C.ETHTOOL_LINK_EXT_SUBSTATE_CI_UNSUPPORTED_CABLE
ETHTOOL_LINK_EXT_SUBSTATE_CI_CABLE_TEST_FAILURE = C.ETHTOOL_LINK_EXT_SUBSTATE_CI_CABLE_TEST_FAILURE
ETHTOOL_FLASH_ALL_REGIONS = C.ETHTOOL_FLASH_ALL_REGIONS
ETHTOOL_F_UNSUPPORTED__BIT = C.ETHTOOL_F_UNSUPPORTED__BIT
ETHTOOL_F_WISH__BIT = C.ETHTOOL_F_WISH__BIT
ETHTOOL_F_COMPAT__BIT = C.ETHTOOL_F_COMPAT__BIT
ETHTOOL_FEC_NONE_BIT = C.ETHTOOL_FEC_NONE_BIT
ETHTOOL_FEC_AUTO_BIT = C.ETHTOOL_FEC_AUTO_BIT
ETHTOOL_FEC_OFF_BIT = C.ETHTOOL_FEC_OFF_BIT
ETHTOOL_FEC_RS_BIT = C.ETHTOOL_FEC_RS_BIT
ETHTOOL_FEC_BASER_BIT = C.ETHTOOL_FEC_BASER_BIT
ETHTOOL_FEC_LLRS_BIT = C.ETHTOOL_FEC_LLRS_BIT
ETHTOOL_LINK_MODE_10baseT_Half_BIT = C.ETHTOOL_LINK_MODE_10baseT_Half_BIT
ETHTOOL_LINK_MODE_10baseT_Full_BIT = C.ETHTOOL_LINK_MODE_10baseT_Full_BIT
ETHTOOL_LINK_MODE_100baseT_Half_BIT = C.ETHTOOL_LINK_MODE_100baseT_Half_BIT
ETHTOOL_LINK_MODE_100baseT_Full_BIT = C.ETHTOOL_LINK_MODE_100baseT_Full_BIT
ETHTOOL_LINK_MODE_1000baseT_Half_BIT = C.ETHTOOL_LINK_MODE_1000baseT_Half_BIT
ETHTOOL_LINK_MODE_1000baseT_Full_BIT = C.ETHTOOL_LINK_MODE_1000baseT_Full_BIT
ETHTOOL_LINK_MODE_Autoneg_BIT = C.ETHTOOL_LINK_MODE_Autoneg_BIT
ETHTOOL_LINK_MODE_TP_BIT = C.ETHTOOL_LINK_MODE_TP_BIT
ETHTOOL_LINK_MODE_AUI_BIT = C.ETHTOOL_LINK_MODE_AUI_BIT
ETHTOOL_LINK_MODE_MII_BIT = C.ETHTOOL_LINK_MODE_MII_BIT
ETHTOOL_LINK_MODE_FIBRE_BIT = C.ETHTOOL_LINK_MODE_FIBRE_BIT
ETHTOOL_LINK_MODE_BNC_BIT = C.ETHTOOL_LINK_MODE_BNC_BIT
ETHTOOL_LINK_MODE_10000baseT_Full_BIT = C.ETHTOOL_LINK_MODE_10000baseT_Full_BIT
ETHTOOL_LINK_MODE_Pause_BIT = C.ETHTOOL_LINK_MODE_Pause_BIT
ETHTOOL_LINK_MODE_Asym_Pause_BIT = C.ETHTOOL_LINK_MODE_Asym_Pause_BIT
ETHTOOL_LINK_MODE_2500baseX_Full_BIT = C.ETHTOOL_LINK_MODE_2500baseX_Full_BIT
ETHTOOL_LINK_MODE_Backplane_BIT = C.ETHTOOL_LINK_MODE_Backplane_BIT
ETHTOOL_LINK_MODE_1000baseKX_Full_BIT = C.ETHTOOL_LINK_MODE_1000baseKX_Full_BIT
ETHTOOL_LINK_MODE_10000baseKX4_Full_BIT = C.ETHTOOL_LINK_MODE_10000baseKX4_Full_BIT
ETHTOOL_LINK_MODE_10000baseKR_Full_BIT = C.ETHTOOL_LINK_MODE_10000baseKR_Full_BIT
ETHTOOL_LINK_MODE_10000baseR_FEC_BIT = C.ETHTOOL_LINK_MODE_10000baseR_FEC_BIT
ETHTOOL_LINK_MODE_20000baseMLD2_Full_BIT = C.ETHTOOL_LINK_MODE_20000baseMLD2_Full_BIT
ETHTOOL_LINK_MODE_20000baseKR2_Full_BIT = C.ETHTOOL_LINK_MODE_20000baseKR2_Full_BIT
ETHTOOL_LINK_MODE_40000baseKR4_Full_BIT = C.ETHTOOL_LINK_MODE_40000baseKR4_Full_BIT
ETHTOOL_LINK_MODE_40000baseCR4_Full_BIT = C.ETHTOOL_LINK_MODE_40000baseCR4_Full_BIT
ETHTOOL_LINK_MODE_40000baseSR4_Full_BIT = C.ETHTOOL_LINK_MODE_40000baseSR4_Full_BIT
ETHTOOL_LINK_MODE_40000baseLR4_Full_BIT = C.ETHTOOL_LINK_MODE_40000baseLR4_Full_BIT
ETHTOOL_LINK_MODE_56000baseKR4_Full_BIT = C.ETHTOOL_LINK_MODE_56000baseKR4_Full_BIT
ETHTOOL_LINK_MODE_56000baseCR4_Full_BIT = C.ETHTOOL_LINK_MODE_56000baseCR4_Full_BIT
ETHTOOL_LINK_MODE_56000baseSR4_Full_BIT = C.ETHTOOL_LINK_MODE_56000baseSR4_Full_BIT
ETHTOOL_LINK_MODE_56000baseLR4_Full_BIT = C.ETHTOOL_LINK_MODE_56000baseLR4_Full_BIT
ETHTOOL_LINK_MODE_25000baseCR_Full_BIT = C.ETHTOOL_LINK_MODE_25000baseCR_Full_BIT
ETHTOOL_LINK_MODE_25000baseKR_Full_BIT = C.ETHTOOL_LINK_MODE_25000baseKR_Full_BIT
ETHTOOL_LINK_MODE_25000baseSR_Full_BIT = C.ETHTOOL_LINK_MODE_25000baseSR_Full_BIT
ETHTOOL_LINK_MODE_50000baseCR2_Full_BIT = C.ETHTOOL_LINK_MODE_50000baseCR2_Full_BIT
ETHTOOL_LINK_MODE_50000baseKR2_Full_BIT = C.ETHTOOL_LINK_MODE_50000baseKR2_Full_BIT
ETHTOOL_LINK_MODE_100000baseKR4_Full_BIT = C.ETHTOOL_LINK_MODE_100000baseKR4_Full_BIT
ETHTOOL_LINK_MODE_100000baseSR4_Full_BIT = C.ETHTOOL_LINK_MODE_100000baseSR4_Full_BIT
ETHTOOL_LINK_MODE_100000baseCR4_Full_BIT = C.ETHTOOL_LINK_MODE_100000baseCR4_Full_BIT
ETHTOOL_LINK_MODE_100000baseLR4_ER4_Full_BIT = C.ETHTOOL_LINK_MODE_100000baseLR4_ER4_Full_BIT
ETHTOOL_LINK_MODE_50000baseSR2_Full_BIT = C.ETHTOOL_LINK_MODE_50000baseSR2_Full_BIT
ETHTOOL_LINK_MODE_1000baseX_Full_BIT = C.ETHTOOL_LINK_MODE_1000baseX_Full_BIT
ETHTOOL_LINK_MODE_10000baseCR_Full_BIT = C.ETHTOOL_LINK_MODE_10000baseCR_Full_BIT
ETHTOOL_LINK_MODE_10000baseSR_Full_BIT = C.ETHTOOL_LINK_MODE_10000baseSR_Full_BIT
ETHTOOL_LINK_MODE_10000baseLR_Full_BIT = C.ETHTOOL_LINK_MODE_10000baseLR_Full_BIT
ETHTOOL_LINK_MODE_10000baseLRM_Full_BIT = C.ETHTOOL_LINK_MODE_10000baseLRM_Full_BIT
ETHTOOL_LINK_MODE_10000baseER_Full_BIT = C.ETHTOOL_LINK_MODE_10000baseER_Full_BIT
ETHTOOL_LINK_MODE_2500baseT_Full_BIT = C.ETHTOOL_LINK_MODE_2500baseT_Full_BIT
ETHTOOL_LINK_MODE_5000baseT_Full_BIT = C.ETHTOOL_LINK_MODE_5000baseT_Full_BIT
ETHTOOL_LINK_MODE_FEC_NONE_BIT = C.ETHTOOL_LINK_MODE_FEC_NONE_BIT
ETHTOOL_LINK_MODE_FEC_RS_BIT = C.ETHTOOL_LINK_MODE_FEC_RS_BIT
ETHTOOL_LINK_MODE_FEC_BASER_BIT = C.ETHTOOL_LINK_MODE_FEC_BASER_BIT
ETHTOOL_LINK_MODE_50000baseKR_Full_BIT = C.ETHTOOL_LINK_MODE_50000baseKR_Full_BIT
ETHTOOL_LINK_MODE_50000baseSR_Full_BIT = C.ETHTOOL_LINK_MODE_50000baseSR_Full_BIT
ETHTOOL_LINK_MODE_50000baseCR_Full_BIT = C.ETHTOOL_LINK_MODE_50000baseCR_Full_BIT
ETHTOOL_LINK_MODE_50000baseLR_ER_FR_Full_BIT = C.ETHTOOL_LINK_MODE_50000baseLR_ER_FR_Full_BIT
ETHTOOL_LINK_MODE_50000baseDR_Full_BIT = C.ETHTOOL_LINK_MODE_50000baseDR_Full_BIT
ETHTOOL_LINK_MODE_100000baseKR2_Full_BIT = C.ETHTOOL_LINK_MODE_100000baseKR2_Full_BIT
ETHTOOL_LINK_MODE_100000baseSR2_Full_BIT = C.ETHTOOL_LINK_MODE_100000baseSR2_Full_BIT
ETHTOOL_LINK_MODE_100000baseCR2_Full_BIT = C.ETHTOOL_LINK_MODE_100000baseCR2_Full_BIT
ETHTOOL_LINK_MODE_100000baseLR2_ER2_FR2_Full_BIT = C.ETHTOOL_LINK_MODE_100000baseLR2_ER2_FR2_Full_BIT
ETHTOOL_LINK_MODE_100000baseDR2_Full_BIT = C.ETHTOOL_LINK_MODE_100000baseDR2_Full_BIT
ETHTOOL_LINK_MODE_200000baseKR4_Full_BIT = C.ETHTOOL_LINK_MODE_200000baseKR4_Full_BIT
ETHTOOL_LINK_MODE_200000baseSR4_Full_BIT = C.ETHTOOL_LINK_MODE_200000baseSR4_Full_BIT
ETHTOOL_LINK_MODE_200000baseLR4_ER4_FR4_Full_BIT = C.ETHTOOL_LINK_MODE_200000baseLR4_ER4_FR4_Full_BIT
ETHTOOL_LINK_MODE_200000baseDR4_Full_BIT = C.ETHTOOL_LINK_MODE_200000baseDR4_Full_BIT
ETHTOOL_LINK_MODE_200000baseCR4_Full_BIT = C.ETHTOOL_LINK_MODE_200000baseCR4_Full_BIT
ETHTOOL_LINK_MODE_100baseT1_Full_BIT = C.ETHTOOL_LINK_MODE_100baseT1_Full_BIT
ETHTOOL_LINK_MODE_1000baseT1_Full_BIT = C.ETHTOOL_LINK_MODE_1000baseT1_Full_BIT
ETHTOOL_LINK_MODE_400000baseKR8_Full_BIT = C.ETHTOOL_LINK_MODE_400000baseKR8_Full_BIT
ETHTOOL_LINK_MODE_400000baseSR8_Full_BIT = C.ETHTOOL_LINK_MODE_400000baseSR8_Full_BIT
ETHTOOL_LINK_MODE_400000baseLR8_ER8_FR8_Full_BIT = C.ETHTOOL_LINK_MODE_400000baseLR8_ER8_FR8_Full_BIT
ETHTOOL_LINK_MODE_400000baseDR8_Full_BIT = C.ETHTOOL_LINK_MODE_400000baseDR8_Full_BIT
ETHTOOL_LINK_MODE_400000baseCR8_Full_BIT = C.ETHTOOL_LINK_MODE_400000baseCR8_Full_BIT
ETHTOOL_LINK_MODE_FEC_LLRS_BIT = C.ETHTOOL_LINK_MODE_FEC_LLRS_BIT
ETHTOOL_LINK_MODE_100000baseKR_Full_BIT = C.ETHTOOL_LINK_MODE_100000baseKR_Full_BIT
ETHTOOL_LINK_MODE_100000baseSR_Full_BIT = C.ETHTOOL_LINK_MODE_100000baseSR_Full_BIT
ETHTOOL_LINK_MODE_100000baseLR_ER_FR_Full_BIT = C.ETHTOOL_LINK_MODE_100000baseLR_ER_FR_Full_BIT
ETHTOOL_LINK_MODE_100000baseCR_Full_BIT = C.ETHTOOL_LINK_MODE_100000baseCR_Full_BIT
ETHTOOL_LINK_MODE_100000baseDR_Full_BIT = C.ETHTOOL_LINK_MODE_100000baseDR_Full_BIT
ETHTOOL_LINK_MODE_200000baseKR2_Full_BIT = C.ETHTOOL_LINK_MODE_200000baseKR2_Full_BIT
ETHTOOL_LINK_MODE_200000baseSR2_Full_BIT = C.ETHTOOL_LINK_MODE_200000baseSR2_Full_BIT
ETHTOOL_LINK_MODE_200000baseLR2_ER2_FR2_Full_BIT = C.ETHTOOL_LINK_MODE_200000baseLR2_ER2_FR2_Full_BIT
ETHTOOL_LINK_MODE_200000baseDR2_Full_BIT = C.ETHTOOL_LINK_MODE_200000baseDR2_Full_BIT
ETHTOOL_LINK_MODE_200000baseCR2_Full_BIT = C.ETHTOOL_LINK_MODE_200000baseCR2_Full_BIT
ETHTOOL_LINK_MODE_400000baseKR4_Full_BIT = C.ETHTOOL_LINK_MODE_400000baseKR4_Full_BIT
ETHTOOL_LINK_MODE_400000baseSR4_Full_BIT = C.ETHTOOL_LINK_MODE_400000baseSR4_Full_BIT
ETHTOOL_LINK_MODE_400000baseLR4_ER4_FR4_Full_BIT = C.ETHTOOL_LINK_MODE_400000baseLR4_ER4_FR4_Full_BIT
ETHTOOL_LINK_MODE_400000baseDR4_Full_BIT = C.ETHTOOL_LINK_MODE_400000baseDR4_Full_BIT
ETHTOOL_LINK_MODE_400000baseCR4_Full_BIT = C.ETHTOOL_LINK_MODE_400000baseCR4_Full_BIT
ETHTOOL_LINK_MODE_100baseFX_Half_BIT = C.ETHTOOL_LINK_MODE_100baseFX_Half_BIT
ETHTOOL_LINK_MODE_100baseFX_Full_BIT = C.ETHTOOL_LINK_MODE_100baseFX_Full_BIT
ETHTOOL_MSG_USER_NONE = C.ETHTOOL_MSG_USER_NONE
ETHTOOL_MSG_STRSET_GET = C.ETHTOOL_MSG_STRSET_GET
ETHTOOL_MSG_LINKINFO_GET = C.ETHTOOL_MSG_LINKINFO_GET
ETHTOOL_MSG_LINKINFO_SET = C.ETHTOOL_MSG_LINKINFO_SET
ETHTOOL_MSG_LINKMODES_GET = C.ETHTOOL_MSG_LINKMODES_GET
ETHTOOL_MSG_LINKMODES_SET = C.ETHTOOL_MSG_LINKMODES_SET
ETHTOOL_MSG_LINKSTATE_GET = C.ETHTOOL_MSG_LINKSTATE_GET
ETHTOOL_MSG_DEBUG_GET = C.ETHTOOL_MSG_DEBUG_GET
ETHTOOL_MSG_DEBUG_SET = C.ETHTOOL_MSG_DEBUG_SET
ETHTOOL_MSG_WOL_GET = C.ETHTOOL_MSG_WOL_GET
ETHTOOL_MSG_WOL_SET = C.ETHTOOL_MSG_WOL_SET
ETHTOOL_MSG_FEATURES_GET = C.ETHTOOL_MSG_FEATURES_GET
ETHTOOL_MSG_FEATURES_SET = C.ETHTOOL_MSG_FEATURES_SET
ETHTOOL_MSG_PRIVFLAGS_GET = C.ETHTOOL_MSG_PRIVFLAGS_GET
ETHTOOL_MSG_PRIVFLAGS_SET = C.ETHTOOL_MSG_PRIVFLAGS_SET
ETHTOOL_MSG_RINGS_GET = C.ETHTOOL_MSG_RINGS_GET
ETHTOOL_MSG_RINGS_SET = C.ETHTOOL_MSG_RINGS_SET
ETHTOOL_MSG_CHANNELS_GET = C.ETHTOOL_MSG_CHANNELS_GET
ETHTOOL_MSG_CHANNELS_SET = C.ETHTOOL_MSG_CHANNELS_SET
ETHTOOL_MSG_COALESCE_GET = C.ETHTOOL_MSG_COALESCE_GET
ETHTOOL_MSG_COALESCE_SET = C.ETHTOOL_MSG_COALESCE_SET
ETHTOOL_MSG_PAUSE_GET = C.ETHTOOL_MSG_PAUSE_GET
ETHTOOL_MSG_PAUSE_SET = C.ETHTOOL_MSG_PAUSE_SET
ETHTOOL_MSG_EEE_GET = C.ETHTOOL_MSG_EEE_GET
ETHTOOL_MSG_EEE_SET = C.ETHTOOL_MSG_EEE_SET
ETHTOOL_MSG_TSINFO_GET = C.ETHTOOL_MSG_TSINFO_GET
ETHTOOL_MSG_CABLE_TEST_ACT = C.ETHTOOL_MSG_CABLE_TEST_ACT
ETHTOOL_MSG_CABLE_TEST_TDR_ACT = C.ETHTOOL_MSG_CABLE_TEST_TDR_ACT
ETHTOOL_MSG_TUNNEL_INFO_GET = C.ETHTOOL_MSG_TUNNEL_INFO_GET
ETHTOOL_MSG_FEC_GET = C.ETHTOOL_MSG_FEC_GET
ETHTOOL_MSG_FEC_SET = C.ETHTOOL_MSG_FEC_SET
ETHTOOL_MSG_MODULE_EEPROM_GET = C.ETHTOOL_MSG_MODULE_EEPROM_GET
ETHTOOL_MSG_STATS_GET = C.ETHTOOL_MSG_STATS_GET
ETHTOOL_MSG_PHC_VCLOCKS_GET = C.ETHTOOL_MSG_PHC_VCLOCKS_GET
ETHTOOL_MSG_MODULE_GET = C.ETHTOOL_MSG_MODULE_GET
ETHTOOL_MSG_MODULE_SET = C.ETHTOOL_MSG_MODULE_SET
ETHTOOL_MSG_PSE_GET = C.ETHTOOL_MSG_PSE_GET
ETHTOOL_MSG_PSE_SET = C.ETHTOOL_MSG_PSE_SET
ETHTOOL_MSG_RSS_GET = C.ETHTOOL_MSG_RSS_GET
ETHTOOL_MSG_USER_MAX = C.ETHTOOL_MSG_USER_MAX
ETHTOOL_MSG_KERNEL_NONE = C.ETHTOOL_MSG_KERNEL_NONE
ETHTOOL_MSG_STRSET_GET_REPLY = C.ETHTOOL_MSG_STRSET_GET_REPLY
ETHTOOL_MSG_LINKINFO_GET_REPLY = C.ETHTOOL_MSG_LINKINFO_GET_REPLY
ETHTOOL_MSG_LINKINFO_NTF = C.ETHTOOL_MSG_LINKINFO_NTF
ETHTOOL_MSG_LINKMODES_GET_REPLY = C.ETHTOOL_MSG_LINKMODES_GET_REPLY
ETHTOOL_MSG_LINKMODES_NTF = C.ETHTOOL_MSG_LINKMODES_NTF
ETHTOOL_MSG_LINKSTATE_GET_REPLY = C.ETHTOOL_MSG_LINKSTATE_GET_REPLY
ETHTOOL_MSG_DEBUG_GET_REPLY = C.ETHTOOL_MSG_DEBUG_GET_REPLY
ETHTOOL_MSG_DEBUG_NTF = C.ETHTOOL_MSG_DEBUG_NTF
ETHTOOL_MSG_WOL_GET_REPLY = C.ETHTOOL_MSG_WOL_GET_REPLY
ETHTOOL_MSG_WOL_NTF = C.ETHTOOL_MSG_WOL_NTF
ETHTOOL_MSG_FEATURES_GET_REPLY = C.ETHTOOL_MSG_FEATURES_GET_REPLY
ETHTOOL_MSG_FEATURES_SET_REPLY = C.ETHTOOL_MSG_FEATURES_SET_REPLY
ETHTOOL_MSG_FEATURES_NTF = C.ETHTOOL_MSG_FEATURES_NTF
ETHTOOL_MSG_PRIVFLAGS_GET_REPLY = C.ETHTOOL_MSG_PRIVFLAGS_GET_REPLY
ETHTOOL_MSG_PRIVFLAGS_NTF = C.ETHTOOL_MSG_PRIVFLAGS_NTF
ETHTOOL_MSG_RINGS_GET_REPLY = C.ETHTOOL_MSG_RINGS_GET_REPLY
ETHTOOL_MSG_RINGS_NTF = C.ETHTOOL_MSG_RINGS_NTF
ETHTOOL_MSG_CHANNELS_GET_REPLY = C.ETHTOOL_MSG_CHANNELS_GET_REPLY
ETHTOOL_MSG_CHANNELS_NTF = C.ETHTOOL_MSG_CHANNELS_NTF
ETHTOOL_MSG_COALESCE_GET_REPLY = C.ETHTOOL_MSG_COALESCE_GET_REPLY
ETHTOOL_MSG_COALESCE_NTF = C.ETHTOOL_MSG_COALESCE_NTF
ETHTOOL_MSG_PAUSE_GET_REPLY = C.ETHTOOL_MSG_PAUSE_GET_REPLY
ETHTOOL_MSG_PAUSE_NTF = C.ETHTOOL_MSG_PAUSE_NTF
ETHTOOL_MSG_EEE_GET_REPLY = C.ETHTOOL_MSG_EEE_GET_REPLY
ETHTOOL_MSG_EEE_NTF = C.ETHTOOL_MSG_EEE_NTF
ETHTOOL_MSG_TSINFO_GET_REPLY = C.ETHTOOL_MSG_TSINFO_GET_REPLY
ETHTOOL_MSG_CABLE_TEST_NTF = C.ETHTOOL_MSG_CABLE_TEST_NTF
ETHTOOL_MSG_CABLE_TEST_TDR_NTF = C.ETHTOOL_MSG_CABLE_TEST_TDR_NTF
ETHTOOL_MSG_TUNNEL_INFO_GET_REPLY = C.ETHTOOL_MSG_TUNNEL_INFO_GET_REPLY
ETHTOOL_MSG_FEC_GET_REPLY = C.ETHTOOL_MSG_FEC_GET_REPLY
ETHTOOL_MSG_FEC_NTF = C.ETHTOOL_MSG_FEC_NTF
ETHTOOL_MSG_MODULE_EEPROM_GET_REPLY = C.ETHTOOL_MSG_MODULE_EEPROM_GET_REPLY
ETHTOOL_MSG_STATS_GET_REPLY = C.ETHTOOL_MSG_STATS_GET_REPLY
ETHTOOL_MSG_PHC_VCLOCKS_GET_REPLY = C.ETHTOOL_MSG_PHC_VCLOCKS_GET_REPLY
ETHTOOL_MSG_MODULE_GET_REPLY = C.ETHTOOL_MSG_MODULE_GET_REPLY
ETHTOOL_MSG_MODULE_NTF = C.ETHTOOL_MSG_MODULE_NTF
ETHTOOL_MSG_PSE_GET_REPLY = C.ETHTOOL_MSG_PSE_GET_REPLY
ETHTOOL_MSG_RSS_GET_REPLY = C.ETHTOOL_MSG_RSS_GET_REPLY
ETHTOOL_MSG_KERNEL_MAX = C.ETHTOOL_MSG_KERNEL_MAX
ETHTOOL_A_HEADER_UNSPEC = C.ETHTOOL_A_HEADER_UNSPEC
ETHTOOL_A_HEADER_DEV_INDEX = C.ETHTOOL_A_HEADER_DEV_INDEX
ETHTOOL_A_HEADER_DEV_NAME = C.ETHTOOL_A_HEADER_DEV_NAME
ETHTOOL_A_HEADER_FLAGS = C.ETHTOOL_A_HEADER_FLAGS
ETHTOOL_A_HEADER_MAX = C.ETHTOOL_A_HEADER_MAX
ETHTOOL_A_BITSET_BIT_UNSPEC = C.ETHTOOL_A_BITSET_BIT_UNSPEC
ETHTOOL_A_BITSET_BIT_INDEX = C.ETHTOOL_A_BITSET_BIT_INDEX
ETHTOOL_A_BITSET_BIT_NAME = C.ETHTOOL_A_BITSET_BIT_NAME
ETHTOOL_A_BITSET_BIT_VALUE = C.ETHTOOL_A_BITSET_BIT_VALUE
ETHTOOL_A_BITSET_BIT_MAX = C.ETHTOOL_A_BITSET_BIT_MAX
ETHTOOL_A_BITSET_BITS_UNSPEC = C.ETHTOOL_A_BITSET_BITS_UNSPEC
ETHTOOL_A_BITSET_BITS_BIT = C.ETHTOOL_A_BITSET_BITS_BIT
ETHTOOL_A_BITSET_BITS_MAX = C.ETHTOOL_A_BITSET_BITS_MAX
ETHTOOL_A_BITSET_UNSPEC = C.ETHTOOL_A_BITSET_UNSPEC
ETHTOOL_A_BITSET_NOMASK = C.ETHTOOL_A_BITSET_NOMASK
ETHTOOL_A_BITSET_SIZE = C.ETHTOOL_A_BITSET_SIZE
ETHTOOL_A_BITSET_BITS = C.ETHTOOL_A_BITSET_BITS
ETHTOOL_A_BITSET_VALUE = C.ETHTOOL_A_BITSET_VALUE
ETHTOOL_A_BITSET_MASK = C.ETHTOOL_A_BITSET_MASK
ETHTOOL_A_BITSET_MAX = C.ETHTOOL_A_BITSET_MAX
ETHTOOL_A_STRING_UNSPEC = C.ETHTOOL_A_STRING_UNSPEC
ETHTOOL_A_STRING_INDEX = C.ETHTOOL_A_STRING_INDEX
ETHTOOL_A_STRING_VALUE = C.ETHTOOL_A_STRING_VALUE
ETHTOOL_A_STRING_MAX = C.ETHTOOL_A_STRING_MAX
ETHTOOL_A_STRINGS_UNSPEC = C.ETHTOOL_A_STRINGS_UNSPEC
ETHTOOL_A_STRINGS_STRING = C.ETHTOOL_A_STRINGS_STRING
ETHTOOL_A_STRINGS_MAX = C.ETHTOOL_A_STRINGS_MAX
ETHTOOL_A_STRINGSET_UNSPEC = C.ETHTOOL_A_STRINGSET_UNSPEC
ETHTOOL_A_STRINGSET_ID = C.ETHTOOL_A_STRINGSET_ID
ETHTOOL_A_STRINGSET_COUNT = C.ETHTOOL_A_STRINGSET_COUNT
ETHTOOL_A_STRINGSET_STRINGS = C.ETHTOOL_A_STRINGSET_STRINGS
ETHTOOL_A_STRINGSET_MAX = C.ETHTOOL_A_STRINGSET_MAX
ETHTOOL_A_STRINGSETS_UNSPEC = C.ETHTOOL_A_STRINGSETS_UNSPEC
ETHTOOL_A_STRINGSETS_STRINGSET = C.ETHTOOL_A_STRINGSETS_STRINGSET
ETHTOOL_A_STRINGSETS_MAX = C.ETHTOOL_A_STRINGSETS_MAX
ETHTOOL_A_STRSET_UNSPEC = C.ETHTOOL_A_STRSET_UNSPEC
ETHTOOL_A_STRSET_HEADER = C.ETHTOOL_A_STRSET_HEADER
ETHTOOL_A_STRSET_STRINGSETS = C.ETHTOOL_A_STRSET_STRINGSETS
ETHTOOL_A_STRSET_COUNTS_ONLY = C.ETHTOOL_A_STRSET_COUNTS_ONLY
ETHTOOL_A_STRSET_MAX = C.ETHTOOL_A_STRSET_MAX
ETHTOOL_A_LINKINFO_UNSPEC = C.ETHTOOL_A_LINKINFO_UNSPEC
ETHTOOL_A_LINKINFO_HEADER = C.ETHTOOL_A_LINKINFO_HEADER
ETHTOOL_A_LINKINFO_PORT = C.ETHTOOL_A_LINKINFO_PORT
ETHTOOL_A_LINKINFO_PHYADDR = C.ETHTOOL_A_LINKINFO_PHYADDR
ETHTOOL_A_LINKINFO_TP_MDIX = C.ETHTOOL_A_LINKINFO_TP_MDIX
ETHTOOL_A_LINKINFO_TP_MDIX_CTRL = C.ETHTOOL_A_LINKINFO_TP_MDIX_CTRL
ETHTOOL_A_LINKINFO_TRANSCEIVER = C.ETHTOOL_A_LINKINFO_TRANSCEIVER
ETHTOOL_A_LINKINFO_MAX = C.ETHTOOL_A_LINKINFO_MAX
ETHTOOL_A_LINKMODES_UNSPEC = C.ETHTOOL_A_LINKMODES_UNSPEC
ETHTOOL_A_LINKMODES_HEADER = C.ETHTOOL_A_LINKMODES_HEADER
ETHTOOL_A_LINKMODES_AUTONEG = C.ETHTOOL_A_LINKMODES_AUTONEG
ETHTOOL_A_LINKMODES_OURS = C.ETHTOOL_A_LINKMODES_OURS
ETHTOOL_A_LINKMODES_PEER = C.ETHTOOL_A_LINKMODES_PEER
ETHTOOL_A_LINKMODES_SPEED = C.ETHTOOL_A_LINKMODES_SPEED
ETHTOOL_A_LINKMODES_DUPLEX = C.ETHTOOL_A_LINKMODES_DUPLEX
ETHTOOL_A_LINKMODES_MASTER_SLAVE_CFG = C.ETHTOOL_A_LINKMODES_MASTER_SLAVE_CFG
ETHTOOL_A_LINKMODES_MASTER_SLAVE_STATE = C.ETHTOOL_A_LINKMODES_MASTER_SLAVE_STATE
ETHTOOL_A_LINKMODES_LANES = C.ETHTOOL_A_LINKMODES_LANES
ETHTOOL_A_LINKMODES_RATE_MATCHING = C.ETHTOOL_A_LINKMODES_RATE_MATCHING
ETHTOOL_A_LINKMODES_MAX = C.ETHTOOL_A_LINKMODES_MAX
ETHTOOL_A_LINKSTATE_UNSPEC = C.ETHTOOL_A_LINKSTATE_UNSPEC
ETHTOOL_A_LINKSTATE_HEADER = C.ETHTOOL_A_LINKSTATE_HEADER
ETHTOOL_A_LINKSTATE_LINK = C.ETHTOOL_A_LINKSTATE_LINK
ETHTOOL_A_LINKSTATE_SQI = C.ETHTOOL_A_LINKSTATE_SQI
ETHTOOL_A_LINKSTATE_SQI_MAX = C.ETHTOOL_A_LINKSTATE_SQI_MAX
ETHTOOL_A_LINKSTATE_EXT_STATE = C.ETHTOOL_A_LINKSTATE_EXT_STATE
ETHTOOL_A_LINKSTATE_EXT_SUBSTATE = C.ETHTOOL_A_LINKSTATE_EXT_SUBSTATE
ETHTOOL_A_LINKSTATE_EXT_DOWN_CNT = C.ETHTOOL_A_LINKSTATE_EXT_DOWN_CNT
ETHTOOL_A_LINKSTATE_MAX = C.ETHTOOL_A_LINKSTATE_MAX
ETHTOOL_A_DEBUG_UNSPEC = C.ETHTOOL_A_DEBUG_UNSPEC
ETHTOOL_A_DEBUG_HEADER = C.ETHTOOL_A_DEBUG_HEADER
ETHTOOL_A_DEBUG_MSGMASK = C.ETHTOOL_A_DEBUG_MSGMASK
ETHTOOL_A_DEBUG_MAX = C.ETHTOOL_A_DEBUG_MAX
ETHTOOL_A_WOL_UNSPEC = C.ETHTOOL_A_WOL_UNSPEC
ETHTOOL_A_WOL_HEADER = C.ETHTOOL_A_WOL_HEADER
ETHTOOL_A_WOL_MODES = C.ETHTOOL_A_WOL_MODES
ETHTOOL_A_WOL_SOPASS = C.ETHTOOL_A_WOL_SOPASS
ETHTOOL_A_WOL_MAX = C.ETHTOOL_A_WOL_MAX
ETHTOOL_A_FEATURES_UNSPEC = C.ETHTOOL_A_FEATURES_UNSPEC
ETHTOOL_A_FEATURES_HEADER = C.ETHTOOL_A_FEATURES_HEADER
ETHTOOL_A_FEATURES_HW = C.ETHTOOL_A_FEATURES_HW
ETHTOOL_A_FEATURES_WANTED = C.ETHTOOL_A_FEATURES_WANTED
ETHTOOL_A_FEATURES_ACTIVE = C.ETHTOOL_A_FEATURES_ACTIVE
ETHTOOL_A_FEATURES_NOCHANGE = C.ETHTOOL_A_FEATURES_NOCHANGE
ETHTOOL_A_FEATURES_MAX = C.ETHTOOL_A_FEATURES_MAX
ETHTOOL_A_PRIVFLAGS_UNSPEC = C.ETHTOOL_A_PRIVFLAGS_UNSPEC
ETHTOOL_A_PRIVFLAGS_HEADER = C.ETHTOOL_A_PRIVFLAGS_HEADER
ETHTOOL_A_PRIVFLAGS_FLAGS = C.ETHTOOL_A_PRIVFLAGS_FLAGS
ETHTOOL_A_PRIVFLAGS_MAX = C.ETHTOOL_A_PRIVFLAGS_MAX
ETHTOOL_A_RINGS_UNSPEC = C.ETHTOOL_A_RINGS_UNSPEC
ETHTOOL_A_RINGS_HEADER = C.ETHTOOL_A_RINGS_HEADER
ETHTOOL_A_RINGS_RX_MAX = C.ETHTOOL_A_RINGS_RX_MAX
ETHTOOL_A_RINGS_RX_MINI_MAX = C.ETHTOOL_A_RINGS_RX_MINI_MAX
ETHTOOL_A_RINGS_RX_JUMBO_MAX = C.ETHTOOL_A_RINGS_RX_JUMBO_MAX
ETHTOOL_A_RINGS_TX_MAX = C.ETHTOOL_A_RINGS_TX_MAX
ETHTOOL_A_RINGS_RX = C.ETHTOOL_A_RINGS_RX
ETHTOOL_A_RINGS_RX_MINI = C.ETHTOOL_A_RINGS_RX_MINI
ETHTOOL_A_RINGS_RX_JUMBO = C.ETHTOOL_A_RINGS_RX_JUMBO
ETHTOOL_A_RINGS_TX = C.ETHTOOL_A_RINGS_TX
ETHTOOL_A_RINGS_RX_BUF_LEN = C.ETHTOOL_A_RINGS_RX_BUF_LEN
ETHTOOL_A_RINGS_TCP_DATA_SPLIT = C.ETHTOOL_A_RINGS_TCP_DATA_SPLIT
ETHTOOL_A_RINGS_CQE_SIZE = C.ETHTOOL_A_RINGS_CQE_SIZE
ETHTOOL_A_RINGS_TX_PUSH = C.ETHTOOL_A_RINGS_TX_PUSH
ETHTOOL_A_RINGS_MAX = C.ETHTOOL_A_RINGS_MAX
ETHTOOL_A_CHANNELS_UNSPEC = C.ETHTOOL_A_CHANNELS_UNSPEC
ETHTOOL_A_CHANNELS_HEADER = C.ETHTOOL_A_CHANNELS_HEADER
ETHTOOL_A_CHANNELS_RX_MAX = C.ETHTOOL_A_CHANNELS_RX_MAX
ETHTOOL_A_CHANNELS_TX_MAX = C.ETHTOOL_A_CHANNELS_TX_MAX
ETHTOOL_A_CHANNELS_OTHER_MAX = C.ETHTOOL_A_CHANNELS_OTHER_MAX
ETHTOOL_A_CHANNELS_COMBINED_MAX = C.ETHTOOL_A_CHANNELS_COMBINED_MAX
ETHTOOL_A_CHANNELS_RX_COUNT = C.ETHTOOL_A_CHANNELS_RX_COUNT
ETHTOOL_A_CHANNELS_TX_COUNT = C.ETHTOOL_A_CHANNELS_TX_COUNT
ETHTOOL_A_CHANNELS_OTHER_COUNT = C.ETHTOOL_A_CHANNELS_OTHER_COUNT
ETHTOOL_A_CHANNELS_COMBINED_COUNT = C.ETHTOOL_A_CHANNELS_COMBINED_COUNT
ETHTOOL_A_CHANNELS_MAX = C.ETHTOOL_A_CHANNELS_MAX
ETHTOOL_A_COALESCE_UNSPEC = C.ETHTOOL_A_COALESCE_UNSPEC
ETHTOOL_A_COALESCE_HEADER = C.ETHTOOL_A_COALESCE_HEADER
ETHTOOL_A_COALESCE_RX_USECS = C.ETHTOOL_A_COALESCE_RX_USECS
ETHTOOL_A_COALESCE_RX_MAX_FRAMES = C.ETHTOOL_A_COALESCE_RX_MAX_FRAMES
ETHTOOL_A_COALESCE_RX_USECS_IRQ = C.ETHTOOL_A_COALESCE_RX_USECS_IRQ
ETHTOOL_A_COALESCE_RX_MAX_FRAMES_IRQ = C.ETHTOOL_A_COALESCE_RX_MAX_FRAMES_IRQ
ETHTOOL_A_COALESCE_TX_USECS = C.ETHTOOL_A_COALESCE_TX_USECS
ETHTOOL_A_COALESCE_TX_MAX_FRAMES = C.ETHTOOL_A_COALESCE_TX_MAX_FRAMES
ETHTOOL_A_COALESCE_TX_USECS_IRQ = C.ETHTOOL_A_COALESCE_TX_USECS_IRQ
ETHTOOL_A_COALESCE_TX_MAX_FRAMES_IRQ = C.ETHTOOL_A_COALESCE_TX_MAX_FRAMES_IRQ
ETHTOOL_A_COALESCE_STATS_BLOCK_USECS = C.ETHTOOL_A_COALESCE_STATS_BLOCK_USECS
ETHTOOL_A_COALESCE_USE_ADAPTIVE_RX = C.ETHTOOL_A_COALESCE_USE_ADAPTIVE_RX
ETHTOOL_A_COALESCE_USE_ADAPTIVE_TX = C.ETHTOOL_A_COALESCE_USE_ADAPTIVE_TX
ETHTOOL_A_COALESCE_PKT_RATE_LOW = C.ETHTOOL_A_COALESCE_PKT_RATE_LOW
ETHTOOL_A_COALESCE_RX_USECS_LOW = C.ETHTOOL_A_COALESCE_RX_USECS_LOW
ETHTOOL_A_COALESCE_RX_MAX_FRAMES_LOW = C.ETHTOOL_A_COALESCE_RX_MAX_FRAMES_LOW
ETHTOOL_A_COALESCE_TX_USECS_LOW = C.ETHTOOL_A_COALESCE_TX_USECS_LOW
ETHTOOL_A_COALESCE_TX_MAX_FRAMES_LOW = C.ETHTOOL_A_COALESCE_TX_MAX_FRAMES_LOW
ETHTOOL_A_COALESCE_PKT_RATE_HIGH = C.ETHTOOL_A_COALESCE_PKT_RATE_HIGH
ETHTOOL_A_COALESCE_RX_USECS_HIGH = C.ETHTOOL_A_COALESCE_RX_USECS_HIGH
ETHTOOL_A_COALESCE_RX_MAX_FRAMES_HIGH = C.ETHTOOL_A_COALESCE_RX_MAX_FRAMES_HIGH
ETHTOOL_A_COALESCE_TX_USECS_HIGH = C.ETHTOOL_A_COALESCE_TX_USECS_HIGH
ETHTOOL_A_COALESCE_TX_MAX_FRAMES_HIGH = C.ETHTOOL_A_COALESCE_TX_MAX_FRAMES_HIGH
ETHTOOL_A_COALESCE_RATE_SAMPLE_INTERVAL = C.ETHTOOL_A_COALESCE_RATE_SAMPLE_INTERVAL
ETHTOOL_A_COALESCE_USE_CQE_MODE_TX = C.ETHTOOL_A_COALESCE_USE_CQE_MODE_TX
ETHTOOL_A_COALESCE_USE_CQE_MODE_RX = C.ETHTOOL_A_COALESCE_USE_CQE_MODE_RX
ETHTOOL_A_COALESCE_MAX = C.ETHTOOL_A_COALESCE_MAX
ETHTOOL_A_PAUSE_UNSPEC = C.ETHTOOL_A_PAUSE_UNSPEC
ETHTOOL_A_PAUSE_HEADER = C.ETHTOOL_A_PAUSE_HEADER
ETHTOOL_A_PAUSE_AUTONEG = C.ETHTOOL_A_PAUSE_AUTONEG
ETHTOOL_A_PAUSE_RX = C.ETHTOOL_A_PAUSE_RX
ETHTOOL_A_PAUSE_TX = C.ETHTOOL_A_PAUSE_TX
ETHTOOL_A_PAUSE_STATS = C.ETHTOOL_A_PAUSE_STATS
ETHTOOL_A_PAUSE_MAX = C.ETHTOOL_A_PAUSE_MAX
ETHTOOL_A_PAUSE_STAT_UNSPEC = C.ETHTOOL_A_PAUSE_STAT_UNSPEC
ETHTOOL_A_PAUSE_STAT_PAD = C.ETHTOOL_A_PAUSE_STAT_PAD
ETHTOOL_A_PAUSE_STAT_TX_FRAMES = C.ETHTOOL_A_PAUSE_STAT_TX_FRAMES
ETHTOOL_A_PAUSE_STAT_RX_FRAMES = C.ETHTOOL_A_PAUSE_STAT_RX_FRAMES
ETHTOOL_A_PAUSE_STAT_MAX = C.ETHTOOL_A_PAUSE_STAT_MAX
ETHTOOL_A_EEE_UNSPEC = C.ETHTOOL_A_EEE_UNSPEC
ETHTOOL_A_EEE_HEADER = C.ETHTOOL_A_EEE_HEADER
ETHTOOL_A_EEE_MODES_OURS = C.ETHTOOL_A_EEE_MODES_OURS
ETHTOOL_A_EEE_MODES_PEER = C.ETHTOOL_A_EEE_MODES_PEER
ETHTOOL_A_EEE_ACTIVE = C.ETHTOOL_A_EEE_ACTIVE
ETHTOOL_A_EEE_ENABLED = C.ETHTOOL_A_EEE_ENABLED
ETHTOOL_A_EEE_TX_LPI_ENABLED = C.ETHTOOL_A_EEE_TX_LPI_ENABLED
ETHTOOL_A_EEE_TX_LPI_TIMER = C.ETHTOOL_A_EEE_TX_LPI_TIMER
ETHTOOL_A_EEE_MAX = C.ETHTOOL_A_EEE_MAX
ETHTOOL_A_TSINFO_UNSPEC = C.ETHTOOL_A_TSINFO_UNSPEC
ETHTOOL_A_TSINFO_HEADER = C.ETHTOOL_A_TSINFO_HEADER
ETHTOOL_A_TSINFO_TIMESTAMPING = C.ETHTOOL_A_TSINFO_TIMESTAMPING
ETHTOOL_A_TSINFO_TX_TYPES = C.ETHTOOL_A_TSINFO_TX_TYPES
ETHTOOL_A_TSINFO_RX_FILTERS = C.ETHTOOL_A_TSINFO_RX_FILTERS
ETHTOOL_A_TSINFO_PHC_INDEX = C.ETHTOOL_A_TSINFO_PHC_INDEX
ETHTOOL_A_TSINFO_MAX = C.ETHTOOL_A_TSINFO_MAX
ETHTOOL_A_CABLE_TEST_UNSPEC = C.ETHTOOL_A_CABLE_TEST_UNSPEC
ETHTOOL_A_CABLE_TEST_HEADER = C.ETHTOOL_A_CABLE_TEST_HEADER
ETHTOOL_A_CABLE_TEST_MAX = C.ETHTOOL_A_CABLE_TEST_MAX
ETHTOOL_A_CABLE_RESULT_CODE_UNSPEC = C.ETHTOOL_A_CABLE_RESULT_CODE_UNSPEC
ETHTOOL_A_CABLE_RESULT_CODE_OK = C.ETHTOOL_A_CABLE_RESULT_CODE_OK
ETHTOOL_A_CABLE_RESULT_CODE_OPEN = C.ETHTOOL_A_CABLE_RESULT_CODE_OPEN
ETHTOOL_A_CABLE_RESULT_CODE_SAME_SHORT = C.ETHTOOL_A_CABLE_RESULT_CODE_SAME_SHORT
ETHTOOL_A_CABLE_RESULT_CODE_CROSS_SHORT = C.ETHTOOL_A_CABLE_RESULT_CODE_CROSS_SHORT
ETHTOOL_A_CABLE_PAIR_A = C.ETHTOOL_A_CABLE_PAIR_A
ETHTOOL_A_CABLE_PAIR_B = C.ETHTOOL_A_CABLE_PAIR_B
ETHTOOL_A_CABLE_PAIR_C = C.ETHTOOL_A_CABLE_PAIR_C
ETHTOOL_A_CABLE_PAIR_D = C.ETHTOOL_A_CABLE_PAIR_D
ETHTOOL_A_CABLE_RESULT_UNSPEC = C.ETHTOOL_A_CABLE_RESULT_UNSPEC
ETHTOOL_A_CABLE_RESULT_PAIR = C.ETHTOOL_A_CABLE_RESULT_PAIR
ETHTOOL_A_CABLE_RESULT_CODE = C.ETHTOOL_A_CABLE_RESULT_CODE
ETHTOOL_A_CABLE_RESULT_MAX = C.ETHTOOL_A_CABLE_RESULT_MAX
ETHTOOL_A_CABLE_FAULT_LENGTH_UNSPEC = C.ETHTOOL_A_CABLE_FAULT_LENGTH_UNSPEC
ETHTOOL_A_CABLE_FAULT_LENGTH_PAIR = C.ETHTOOL_A_CABLE_FAULT_LENGTH_PAIR
ETHTOOL_A_CABLE_FAULT_LENGTH_CM = C.ETHTOOL_A_CABLE_FAULT_LENGTH_CM
ETHTOOL_A_CABLE_FAULT_LENGTH_MAX = C.ETHTOOL_A_CABLE_FAULT_LENGTH_MAX
ETHTOOL_A_CABLE_TEST_NTF_STATUS_UNSPEC = C.ETHTOOL_A_CABLE_TEST_NTF_STATUS_UNSPEC
ETHTOOL_A_CABLE_TEST_NTF_STATUS_STARTED = C.ETHTOOL_A_CABLE_TEST_NTF_STATUS_STARTED
ETHTOOL_A_CABLE_TEST_NTF_STATUS_COMPLETED = C.ETHTOOL_A_CABLE_TEST_NTF_STATUS_COMPLETED
ETHTOOL_A_CABLE_NEST_UNSPEC = C.ETHTOOL_A_CABLE_NEST_UNSPEC
ETHTOOL_A_CABLE_NEST_RESULT = C.ETHTOOL_A_CABLE_NEST_RESULT
ETHTOOL_A_CABLE_NEST_FAULT_LENGTH = C.ETHTOOL_A_CABLE_NEST_FAULT_LENGTH
ETHTOOL_A_CABLE_NEST_MAX = C.ETHTOOL_A_CABLE_NEST_MAX
ETHTOOL_A_CABLE_TEST_NTF_UNSPEC = C.ETHTOOL_A_CABLE_TEST_NTF_UNSPEC
ETHTOOL_A_CABLE_TEST_NTF_HEADER = C.ETHTOOL_A_CABLE_TEST_NTF_HEADER
ETHTOOL_A_CABLE_TEST_NTF_STATUS = C.ETHTOOL_A_CABLE_TEST_NTF_STATUS
ETHTOOL_A_CABLE_TEST_NTF_NEST = C.ETHTOOL_A_CABLE_TEST_NTF_NEST
ETHTOOL_A_CABLE_TEST_NTF_MAX = C.ETHTOOL_A_CABLE_TEST_NTF_MAX
ETHTOOL_A_CABLE_TEST_TDR_CFG_UNSPEC = C.ETHTOOL_A_CABLE_TEST_TDR_CFG_UNSPEC
ETHTOOL_A_CABLE_TEST_TDR_CFG_FIRST = C.ETHTOOL_A_CABLE_TEST_TDR_CFG_FIRST
ETHTOOL_A_CABLE_TEST_TDR_CFG_LAST = C.ETHTOOL_A_CABLE_TEST_TDR_CFG_LAST
ETHTOOL_A_CABLE_TEST_TDR_CFG_STEP = C.ETHTOOL_A_CABLE_TEST_TDR_CFG_STEP
ETHTOOL_A_CABLE_TEST_TDR_CFG_PAIR = C.ETHTOOL_A_CABLE_TEST_TDR_CFG_PAIR
ETHTOOL_A_CABLE_TEST_TDR_CFG_MAX = C.ETHTOOL_A_CABLE_TEST_TDR_CFG_MAX
ETHTOOL_A_CABLE_TEST_TDR_UNSPEC = C.ETHTOOL_A_CABLE_TEST_TDR_UNSPEC
ETHTOOL_A_CABLE_TEST_TDR_HEADER = C.ETHTOOL_A_CABLE_TEST_TDR_HEADER
ETHTOOL_A_CABLE_TEST_TDR_CFG = C.ETHTOOL_A_CABLE_TEST_TDR_CFG
ETHTOOL_A_CABLE_TEST_TDR_MAX = C.ETHTOOL_A_CABLE_TEST_TDR_MAX
ETHTOOL_A_CABLE_AMPLITUDE_UNSPEC = C.ETHTOOL_A_CABLE_AMPLITUDE_UNSPEC
ETHTOOL_A_CABLE_AMPLITUDE_PAIR = C.ETHTOOL_A_CABLE_AMPLITUDE_PAIR
ETHTOOL_A_CABLE_AMPLITUDE_mV = C.ETHTOOL_A_CABLE_AMPLITUDE_mV
ETHTOOL_A_CABLE_AMPLITUDE_MAX = C.ETHTOOL_A_CABLE_AMPLITUDE_MAX
ETHTOOL_A_CABLE_PULSE_UNSPEC = C.ETHTOOL_A_CABLE_PULSE_UNSPEC
ETHTOOL_A_CABLE_PULSE_mV = C.ETHTOOL_A_CABLE_PULSE_mV
ETHTOOL_A_CABLE_PULSE_MAX = C.ETHTOOL_A_CABLE_PULSE_MAX
ETHTOOL_A_CABLE_STEP_UNSPEC = C.ETHTOOL_A_CABLE_STEP_UNSPEC
ETHTOOL_A_CABLE_STEP_FIRST_DISTANCE = C.ETHTOOL_A_CABLE_STEP_FIRST_DISTANCE
ETHTOOL_A_CABLE_STEP_LAST_DISTANCE = C.ETHTOOL_A_CABLE_STEP_LAST_DISTANCE
ETHTOOL_A_CABLE_STEP_STEP_DISTANCE = C.ETHTOOL_A_CABLE_STEP_STEP_DISTANCE
ETHTOOL_A_CABLE_STEP_MAX = C.ETHTOOL_A_CABLE_STEP_MAX
ETHTOOL_A_CABLE_TDR_NEST_UNSPEC = C.ETHTOOL_A_CABLE_TDR_NEST_UNSPEC
ETHTOOL_A_CABLE_TDR_NEST_STEP = C.ETHTOOL_A_CABLE_TDR_NEST_STEP
ETHTOOL_A_CABLE_TDR_NEST_AMPLITUDE = C.ETHTOOL_A_CABLE_TDR_NEST_AMPLITUDE
ETHTOOL_A_CABLE_TDR_NEST_PULSE = C.ETHTOOL_A_CABLE_TDR_NEST_PULSE
ETHTOOL_A_CABLE_TDR_NEST_MAX = C.ETHTOOL_A_CABLE_TDR_NEST_MAX
ETHTOOL_A_CABLE_TEST_TDR_NTF_UNSPEC = C.ETHTOOL_A_CABLE_TEST_TDR_NTF_UNSPEC
ETHTOOL_A_CABLE_TEST_TDR_NTF_HEADER = C.ETHTOOL_A_CABLE_TEST_TDR_NTF_HEADER
ETHTOOL_A_CABLE_TEST_TDR_NTF_STATUS = C.ETHTOOL_A_CABLE_TEST_TDR_NTF_STATUS
ETHTOOL_A_CABLE_TEST_TDR_NTF_NEST = C.ETHTOOL_A_CABLE_TEST_TDR_NTF_NEST
ETHTOOL_A_CABLE_TEST_TDR_NTF_MAX = C.ETHTOOL_A_CABLE_TEST_TDR_NTF_MAX
ETHTOOL_UDP_TUNNEL_TYPE_VXLAN = C.ETHTOOL_UDP_TUNNEL_TYPE_VXLAN
ETHTOOL_UDP_TUNNEL_TYPE_GENEVE = C.ETHTOOL_UDP_TUNNEL_TYPE_GENEVE
ETHTOOL_UDP_TUNNEL_TYPE_VXLAN_GPE = C.ETHTOOL_UDP_TUNNEL_TYPE_VXLAN_GPE
ETHTOOL_A_TUNNEL_UDP_ENTRY_UNSPEC = C.ETHTOOL_A_TUNNEL_UDP_ENTRY_UNSPEC
ETHTOOL_A_TUNNEL_UDP_ENTRY_PORT = C.ETHTOOL_A_TUNNEL_UDP_ENTRY_PORT
ETHTOOL_A_TUNNEL_UDP_ENTRY_TYPE = C.ETHTOOL_A_TUNNEL_UDP_ENTRY_TYPE
ETHTOOL_A_TUNNEL_UDP_ENTRY_MAX = C.ETHTOOL_A_TUNNEL_UDP_ENTRY_MAX
ETHTOOL_A_TUNNEL_UDP_TABLE_UNSPEC = C.ETHTOOL_A_TUNNEL_UDP_TABLE_UNSPEC
ETHTOOL_A_TUNNEL_UDP_TABLE_SIZE = C.ETHTOOL_A_TUNNEL_UDP_TABLE_SIZE
ETHTOOL_A_TUNNEL_UDP_TABLE_TYPES = C.ETHTOOL_A_TUNNEL_UDP_TABLE_TYPES
ETHTOOL_A_TUNNEL_UDP_TABLE_ENTRY = C.ETHTOOL_A_TUNNEL_UDP_TABLE_ENTRY
ETHTOOL_A_TUNNEL_UDP_TABLE_MAX = C.ETHTOOL_A_TUNNEL_UDP_TABLE_MAX
ETHTOOL_A_TUNNEL_UDP_UNSPEC = C.ETHTOOL_A_TUNNEL_UDP_UNSPEC
ETHTOOL_A_TUNNEL_UDP_TABLE = C.ETHTOOL_A_TUNNEL_UDP_TABLE
ETHTOOL_A_TUNNEL_UDP_MAX = C.ETHTOOL_A_TUNNEL_UDP_MAX
ETHTOOL_A_TUNNEL_INFO_UNSPEC = C.ETHTOOL_A_TUNNEL_INFO_UNSPEC
ETHTOOL_A_TUNNEL_INFO_HEADER = C.ETHTOOL_A_TUNNEL_INFO_HEADER
ETHTOOL_A_TUNNEL_INFO_UDP_PORTS = C.ETHTOOL_A_TUNNEL_INFO_UDP_PORTS
ETHTOOL_A_TUNNEL_INFO_MAX = C.ETHTOOL_A_TUNNEL_INFO_MAX
)
const SPEED_UNKNOWN = C.SPEED_UNKNOWN
type EthtoolDrvinfo C.struct_ethtool_drvinfo
type (
HIDRawReportDescriptor C.struct_hidraw_report_descriptor
HIDRawDevInfo C.struct_hidraw_devinfo
)
// close_range
const (
CLOSE_RANGE_UNSHARE = C.CLOSE_RANGE_UNSHARE
CLOSE_RANGE_CLOEXEC = C.CLOSE_RANGE_CLOEXEC
)
// Netlink extended acknowledgement TLVs.
const (
NLMSGERR_ATTR_MSG = C.NLMSGERR_ATTR_MSG
NLMSGERR_ATTR_OFFS = C.NLMSGERR_ATTR_OFFS
NLMSGERR_ATTR_COOKIE = C.NLMSGERR_ATTR_COOKIE
)
// MTD
type (
EraseInfo C.struct_erase_info_user
EraseInfo64 C.struct_erase_info_user64
MtdOobBuf C.struct_mtd_oob_buf
MtdOobBuf64 C.struct_mtd_oob_buf64
MtdWriteReq C.struct_mtd_write_req
MtdInfo C.struct_mtd_info_user
RegionInfo C.struct_region_info_user
OtpInfo C.struct_otp_info
NandOobinfo C.struct_nand_oobinfo
NandOobfree C.struct_nand_oobfree
NandEcclayout C.struct_nand_ecclayout_user
MtdEccStats C.struct_mtd_ecc_stats
)
const (
MTD_OPS_PLACE_OOB = C.MTD_OPS_PLACE_OOB
MTD_OPS_AUTO_OOB = C.MTD_OPS_AUTO_OOB
MTD_OPS_RAW = C.MTD_OPS_RAW
)
const (
MTD_FILE_MODE_NORMAL = C.MTD_FILE_MODE_NORMAL
MTD_FILE_MODE_OTP_FACTORY = C.MTD_FILE_MODE_OTP_FACTORY
MTD_FILE_MODE_OTP_USER = C.MTD_FILE_MODE_OTP_USER
MTD_FILE_MODE_RAW = C.MTD_FILE_MODE_RAW
)
// NFC Subsystem enums.
const (
NFC_CMD_UNSPEC = C.NFC_CMD_UNSPEC
NFC_CMD_GET_DEVICE = C.NFC_CMD_GET_DEVICE
NFC_CMD_DEV_UP = C.NFC_CMD_DEV_UP
NFC_CMD_DEV_DOWN = C.NFC_CMD_DEV_DOWN
NFC_CMD_DEP_LINK_UP = C.NFC_CMD_DEP_LINK_UP
NFC_CMD_DEP_LINK_DOWN = C.NFC_CMD_DEP_LINK_DOWN
NFC_CMD_START_POLL = C.NFC_CMD_START_POLL
NFC_CMD_STOP_POLL = C.NFC_CMD_STOP_POLL
NFC_CMD_GET_TARGET = C.NFC_CMD_GET_TARGET
NFC_EVENT_TARGETS_FOUND = C.NFC_EVENT_TARGETS_FOUND
NFC_EVENT_DEVICE_ADDED = C.NFC_EVENT_DEVICE_ADDED
NFC_EVENT_DEVICE_REMOVED = C.NFC_EVENT_DEVICE_REMOVED
NFC_EVENT_TARGET_LOST = C.NFC_EVENT_TARGET_LOST
NFC_EVENT_TM_ACTIVATED = C.NFC_EVENT_TM_ACTIVATED
NFC_EVENT_TM_DEACTIVATED = C.NFC_EVENT_TM_DEACTIVATED
NFC_CMD_LLC_GET_PARAMS = C.NFC_CMD_LLC_GET_PARAMS
NFC_CMD_LLC_SET_PARAMS = C.NFC_CMD_LLC_SET_PARAMS
NFC_CMD_ENABLE_SE = C.NFC_CMD_ENABLE_SE
NFC_CMD_DISABLE_SE = C.NFC_CMD_DISABLE_SE
NFC_CMD_LLC_SDREQ = C.NFC_CMD_LLC_SDREQ
NFC_EVENT_LLC_SDRES = C.NFC_EVENT_LLC_SDRES
NFC_CMD_FW_DOWNLOAD = C.NFC_CMD_FW_DOWNLOAD
NFC_EVENT_SE_ADDED = C.NFC_EVENT_SE_ADDED
NFC_EVENT_SE_REMOVED = C.NFC_EVENT_SE_REMOVED
NFC_EVENT_SE_CONNECTIVITY = C.NFC_EVENT_SE_CONNECTIVITY
NFC_EVENT_SE_TRANSACTION = C.NFC_EVENT_SE_TRANSACTION
NFC_CMD_GET_SE = C.NFC_CMD_GET_SE
NFC_CMD_SE_IO = C.NFC_CMD_SE_IO
NFC_CMD_ACTIVATE_TARGET = C.NFC_CMD_ACTIVATE_TARGET
NFC_CMD_VENDOR = C.NFC_CMD_VENDOR
NFC_CMD_DEACTIVATE_TARGET = C.NFC_CMD_DEACTIVATE_TARGET
NFC_ATTR_UNSPEC = C.NFC_ATTR_UNSPEC
NFC_ATTR_DEVICE_INDEX = C.NFC_ATTR_DEVICE_INDEX
NFC_ATTR_DEVICE_NAME = C.NFC_ATTR_DEVICE_NAME
NFC_ATTR_PROTOCOLS = C.NFC_ATTR_PROTOCOLS
NFC_ATTR_TARGET_INDEX = C.NFC_ATTR_TARGET_INDEX
NFC_ATTR_TARGET_SENS_RES = C.NFC_ATTR_TARGET_SENS_RES
NFC_ATTR_TARGET_SEL_RES = C.NFC_ATTR_TARGET_SEL_RES
NFC_ATTR_TARGET_NFCID1 = C.NFC_ATTR_TARGET_NFCID1
NFC_ATTR_TARGET_SENSB_RES = C.NFC_ATTR_TARGET_SENSB_RES
NFC_ATTR_TARGET_SENSF_RES = C.NFC_ATTR_TARGET_SENSF_RES
NFC_ATTR_COMM_MODE = C.NFC_ATTR_COMM_MODE
NFC_ATTR_RF_MODE = C.NFC_ATTR_RF_MODE
NFC_ATTR_DEVICE_POWERED = C.NFC_ATTR_DEVICE_POWERED
NFC_ATTR_IM_PROTOCOLS = C.NFC_ATTR_IM_PROTOCOLS
NFC_ATTR_TM_PROTOCOLS = C.NFC_ATTR_TM_PROTOCOLS
NFC_ATTR_LLC_PARAM_LTO = C.NFC_ATTR_LLC_PARAM_LTO
NFC_ATTR_LLC_PARAM_RW = C.NFC_ATTR_LLC_PARAM_RW
NFC_ATTR_LLC_PARAM_MIUX = C.NFC_ATTR_LLC_PARAM_MIUX
NFC_ATTR_SE = C.NFC_ATTR_SE
NFC_ATTR_LLC_SDP = C.NFC_ATTR_LLC_SDP
NFC_ATTR_FIRMWARE_NAME = C.NFC_ATTR_FIRMWARE_NAME
NFC_ATTR_SE_INDEX = C.NFC_ATTR_SE_INDEX
NFC_ATTR_SE_TYPE = C.NFC_ATTR_SE_TYPE
NFC_ATTR_SE_AID = C.NFC_ATTR_SE_AID
NFC_ATTR_FIRMWARE_DOWNLOAD_STATUS = C.NFC_ATTR_FIRMWARE_DOWNLOAD_STATUS
NFC_ATTR_SE_APDU = C.NFC_ATTR_SE_APDU
NFC_ATTR_TARGET_ISO15693_DSFID = C.NFC_ATTR_TARGET_ISO15693_DSFID
NFC_ATTR_TARGET_ISO15693_UID = C.NFC_ATTR_TARGET_ISO15693_UID
NFC_ATTR_SE_PARAMS = C.NFC_ATTR_SE_PARAMS
NFC_ATTR_VENDOR_ID = C.NFC_ATTR_VENDOR_ID
NFC_ATTR_VENDOR_SUBCMD = C.NFC_ATTR_VENDOR_SUBCMD
NFC_ATTR_VENDOR_DATA = C.NFC_ATTR_VENDOR_DATA
NFC_SDP_ATTR_UNSPEC = C.NFC_SDP_ATTR_UNSPEC
NFC_SDP_ATTR_URI = C.NFC_SDP_ATTR_URI
NFC_SDP_ATTR_SAP = C.NFC_SDP_ATTR_SAP
)
// Landlock
type LandlockRulesetAttr C.struct_landlock_ruleset_attr
type LandlockPathBeneathAttr C.struct_landlock_path_beneath_attr
const (
LANDLOCK_RULE_PATH_BENEATH = C.LANDLOCK_RULE_PATH_BENEATH
)
// pidfd flags.
const (
PIDFD_NONBLOCK = C.O_NONBLOCK
)
// shm
type SysvIpcPerm C.struct_ipc64_perm
type SysvShmDesc C.struct_shmid64_ds
const (
IPC_CREAT = C.IPC_CREAT
IPC_EXCL = C.IPC_EXCL
IPC_NOWAIT = C.IPC_NOWAIT
IPC_PRIVATE = C.IPC_PRIVATE
ipc_64 = C.IPC_64
)
const (
IPC_RMID = C.IPC_RMID
IPC_SET = C.IPC_SET
IPC_STAT = C.IPC_STAT
)
const (
SHM_RDONLY = C.SHM_RDONLY
SHM_RND = C.SHM_RND
)
// mount_setattr
type MountAttr C.struct_mount_attr
// WireGuard generic netlink interface
// Generated by:
// perl -nlE '/^\s*(WG\w+)/ && say "$1 = C.$1"' /usr/include/linux/wireguard.h
const (
WG_CMD_GET_DEVICE = C.WG_CMD_GET_DEVICE
WG_CMD_SET_DEVICE = C.WG_CMD_SET_DEVICE
WGDEVICE_F_REPLACE_PEERS = C.WGDEVICE_F_REPLACE_PEERS
WGDEVICE_A_UNSPEC = C.WGDEVICE_A_UNSPEC
WGDEVICE_A_IFINDEX = C.WGDEVICE_A_IFINDEX
WGDEVICE_A_IFNAME = C.WGDEVICE_A_IFNAME
WGDEVICE_A_PRIVATE_KEY = C.WGDEVICE_A_PRIVATE_KEY
WGDEVICE_A_PUBLIC_KEY = C.WGDEVICE_A_PUBLIC_KEY
WGDEVICE_A_FLAGS = C.WGDEVICE_A_FLAGS
WGDEVICE_A_LISTEN_PORT = C.WGDEVICE_A_LISTEN_PORT
WGDEVICE_A_FWMARK = C.WGDEVICE_A_FWMARK
WGDEVICE_A_PEERS = C.WGDEVICE_A_PEERS
WGPEER_F_REMOVE_ME = C.WGPEER_F_REMOVE_ME
WGPEER_F_REPLACE_ALLOWEDIPS = C.WGPEER_F_REPLACE_ALLOWEDIPS
WGPEER_F_UPDATE_ONLY = C.WGPEER_F_UPDATE_ONLY
WGPEER_A_UNSPEC = C.WGPEER_A_UNSPEC
WGPEER_A_PUBLIC_KEY = C.WGPEER_A_PUBLIC_KEY
WGPEER_A_PRESHARED_KEY = C.WGPEER_A_PRESHARED_KEY
WGPEER_A_FLAGS = C.WGPEER_A_FLAGS
WGPEER_A_ENDPOINT = C.WGPEER_A_ENDPOINT
WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL = C.WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL
WGPEER_A_LAST_HANDSHAKE_TIME = C.WGPEER_A_LAST_HANDSHAKE_TIME
WGPEER_A_RX_BYTES = C.WGPEER_A_RX_BYTES
WGPEER_A_TX_BYTES = C.WGPEER_A_TX_BYTES
WGPEER_A_ALLOWEDIPS = C.WGPEER_A_ALLOWEDIPS
WGPEER_A_PROTOCOL_VERSION = C.WGPEER_A_PROTOCOL_VERSION
WGALLOWEDIP_A_UNSPEC = C.WGALLOWEDIP_A_UNSPEC
WGALLOWEDIP_A_FAMILY = C.WGALLOWEDIP_A_FAMILY
WGALLOWEDIP_A_IPADDR = C.WGALLOWEDIP_A_IPADDR
WGALLOWEDIP_A_CIDR_MASK = C.WGALLOWEDIP_A_CIDR_MASK
)
// netlink attribute types and policies
// Generated by:
// perl -nlE '/^\s*(NL_ATTR\w+)/ && say "$1 = C.$1"' /usr/include/linux/netlink.h
// perl -nlE '/^\s*(NL_POLICY\w+)/ && say "$1 = C.$1"' /usr/include/linux/netlink.h
const (
NL_ATTR_TYPE_INVALID = C.NL_ATTR_TYPE_INVALID
NL_ATTR_TYPE_FLAG = C.NL_ATTR_TYPE_FLAG
NL_ATTR_TYPE_U8 = C.NL_ATTR_TYPE_U8
NL_ATTR_TYPE_U16 = C.NL_ATTR_TYPE_U16
NL_ATTR_TYPE_U32 = C.NL_ATTR_TYPE_U32
NL_ATTR_TYPE_U64 = C.NL_ATTR_TYPE_U64
NL_ATTR_TYPE_S8 = C.NL_ATTR_TYPE_S8
NL_ATTR_TYPE_S16 = C.NL_ATTR_TYPE_S16
NL_ATTR_TYPE_S32 = C.NL_ATTR_TYPE_S32
NL_ATTR_TYPE_S64 = C.NL_ATTR_TYPE_S64
NL_ATTR_TYPE_BINARY = C.NL_ATTR_TYPE_BINARY
NL_ATTR_TYPE_STRING = C.NL_ATTR_TYPE_STRING
NL_ATTR_TYPE_NUL_STRING = C.NL_ATTR_TYPE_NUL_STRING
NL_ATTR_TYPE_NESTED = C.NL_ATTR_TYPE_NESTED
NL_ATTR_TYPE_NESTED_ARRAY = C.NL_ATTR_TYPE_NESTED_ARRAY
NL_ATTR_TYPE_BITFIELD32 = C.NL_ATTR_TYPE_BITFIELD32
NL_POLICY_TYPE_ATTR_UNSPEC = C.NL_POLICY_TYPE_ATTR_UNSPEC
NL_POLICY_TYPE_ATTR_TYPE = C.NL_POLICY_TYPE_ATTR_TYPE
NL_POLICY_TYPE_ATTR_MIN_VALUE_S = C.NL_POLICY_TYPE_ATTR_MIN_VALUE_S
NL_POLICY_TYPE_ATTR_MAX_VALUE_S = C.NL_POLICY_TYPE_ATTR_MAX_VALUE_S
NL_POLICY_TYPE_ATTR_MIN_VALUE_U = C.NL_POLICY_TYPE_ATTR_MIN_VALUE_U
NL_POLICY_TYPE_ATTR_MAX_VALUE_U = C.NL_POLICY_TYPE_ATTR_MAX_VALUE_U
NL_POLICY_TYPE_ATTR_MIN_LENGTH = C.NL_POLICY_TYPE_ATTR_MIN_LENGTH
NL_POLICY_TYPE_ATTR_MAX_LENGTH = C.NL_POLICY_TYPE_ATTR_MAX_LENGTH
NL_POLICY_TYPE_ATTR_POLICY_IDX = C.NL_POLICY_TYPE_ATTR_POLICY_IDX
NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = C.NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE
NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = C.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK
NL_POLICY_TYPE_ATTR_PAD = C.NL_POLICY_TYPE_ATTR_PAD
NL_POLICY_TYPE_ATTR_MASK = C.NL_POLICY_TYPE_ATTR_MASK
NL_POLICY_TYPE_ATTR_MAX = C.NL_POLICY_TYPE_ATTR_MAX
)
// CAN netlink types
type CANBitTiming C.struct_can_bittiming
type CANBitTimingConst C.struct_my_can_bittiming_const
type CANClock C.struct_can_clock
type CANBusErrorCounters C.struct_can_berr_counter
type CANCtrlMode C.struct_can_ctrlmode
type CANDeviceStats C.struct_can_device_stats
// Generated by:
// perl -nlE '/^\s*(CAN_STATE\w+)/ && say "$1 = C.$1"' /usr/include/linux/can/netlink.h
const (
CAN_STATE_ERROR_ACTIVE = C.CAN_STATE_ERROR_ACTIVE
CAN_STATE_ERROR_WARNING = C.CAN_STATE_ERROR_WARNING
CAN_STATE_ERROR_PASSIVE = C.CAN_STATE_ERROR_PASSIVE
CAN_STATE_BUS_OFF = C.CAN_STATE_BUS_OFF
CAN_STATE_STOPPED = C.CAN_STATE_STOPPED
CAN_STATE_SLEEPING = C.CAN_STATE_SLEEPING
CAN_STATE_MAX = C.CAN_STATE_MAX
)
// Generated by:
// perl -nlE '/^\s*(IFLA_CAN\w+)/ && say "$1 = C.$1"' /usr/include/linux/can/netlink.h
const (
IFLA_CAN_UNSPEC = C.IFLA_CAN_UNSPEC
IFLA_CAN_BITTIMING = C.IFLA_CAN_BITTIMING
IFLA_CAN_BITTIMING_CONST = C.IFLA_CAN_BITTIMING_CONST
IFLA_CAN_CLOCK = C.IFLA_CAN_CLOCK
IFLA_CAN_STATE = C.IFLA_CAN_STATE
IFLA_CAN_CTRLMODE = C.IFLA_CAN_CTRLMODE
IFLA_CAN_RESTART_MS = C.IFLA_CAN_RESTART_MS
IFLA_CAN_RESTART = C.IFLA_CAN_RESTART
IFLA_CAN_BERR_COUNTER = C.IFLA_CAN_BERR_COUNTER
IFLA_CAN_DATA_BITTIMING = C.IFLA_CAN_DATA_BITTIMING
IFLA_CAN_DATA_BITTIMING_CONST = C.IFLA_CAN_DATA_BITTIMING_CONST
IFLA_CAN_TERMINATION = C.IFLA_CAN_TERMINATION
IFLA_CAN_TERMINATION_CONST = C.IFLA_CAN_TERMINATION_CONST
IFLA_CAN_BITRATE_CONST = C.IFLA_CAN_BITRATE_CONST
IFLA_CAN_DATA_BITRATE_CONST = C.IFLA_CAN_DATA_BITRATE_CONST
IFLA_CAN_BITRATE_MAX = C.IFLA_CAN_BITRATE_MAX
)
// Kernel connection multiplexor
type KCMAttach C.struct_kcm_attach
type KCMUnattach C.struct_kcm_unattach
type KCMClone C.struct_kcm_clone
// nl80211 generic netlink
// Generated by the following steps:
//
// perl -nlE '/^#define (NL80211_\w+)/ && say "$1 = C.$1"' /usr/include/linux/nl80211.h > define.txt
// perl -nlE '/^\s*(NL80211_\w+)/ && say "$1 = C.$1"' /usr/include/linux/nl80211.h > enum.txt
// cat define.txt enum.txt | sort --unique
//
// nl80211 has some duplicated values between enums and C preprocessor macros,
// so this list must be sorted and made unique.
const (
NL80211_AC_BE = C.NL80211_AC_BE
NL80211_AC_BK = C.NL80211_AC_BK
NL80211_ACL_POLICY_ACCEPT_UNLESS_LISTED = C.NL80211_ACL_POLICY_ACCEPT_UNLESS_LISTED
NL80211_ACL_POLICY_DENY_UNLESS_LISTED = C.NL80211_ACL_POLICY_DENY_UNLESS_LISTED
NL80211_AC_VI = C.NL80211_AC_VI
NL80211_AC_VO = C.NL80211_AC_VO
NL80211_AP_SETTINGS_EXTERNAL_AUTH_SUPPORT = C.NL80211_AP_SETTINGS_EXTERNAL_AUTH_SUPPORT
NL80211_AP_SETTINGS_SA_QUERY_OFFLOAD_SUPPORT = C.NL80211_AP_SETTINGS_SA_QUERY_OFFLOAD_SUPPORT
NL80211_AP_SME_SA_QUERY_OFFLOAD = C.NL80211_AP_SME_SA_QUERY_OFFLOAD
NL80211_ATTR_4ADDR = C.NL80211_ATTR_4ADDR
NL80211_ATTR_ACK = C.NL80211_ATTR_ACK
NL80211_ATTR_ACK_SIGNAL = C.NL80211_ATTR_ACK_SIGNAL
NL80211_ATTR_ACL_POLICY = C.NL80211_ATTR_ACL_POLICY
NL80211_ATTR_ADMITTED_TIME = C.NL80211_ATTR_ADMITTED_TIME
NL80211_ATTR_AIRTIME_WEIGHT = C.NL80211_ATTR_AIRTIME_WEIGHT
NL80211_ATTR_AKM_SUITES = C.NL80211_ATTR_AKM_SUITES
NL80211_ATTR_AP_ISOLATE = C.NL80211_ATTR_AP_ISOLATE
NL80211_ATTR_AP_SETTINGS_FLAGS = C.NL80211_ATTR_AP_SETTINGS_FLAGS
NL80211_ATTR_AUTH_DATA = C.NL80211_ATTR_AUTH_DATA
NL80211_ATTR_AUTH_TYPE = C.NL80211_ATTR_AUTH_TYPE
NL80211_ATTR_BANDS = C.NL80211_ATTR_BANDS
NL80211_ATTR_BEACON_HEAD = C.NL80211_ATTR_BEACON_HEAD
NL80211_ATTR_BEACON_INTERVAL = C.NL80211_ATTR_BEACON_INTERVAL
NL80211_ATTR_BEACON_TAIL = C.NL80211_ATTR_BEACON_TAIL
NL80211_ATTR_BG_SCAN_PERIOD = C.NL80211_ATTR_BG_SCAN_PERIOD
NL80211_ATTR_BSS_BASIC_RATES = C.NL80211_ATTR_BSS_BASIC_RATES
NL80211_ATTR_BSS = C.NL80211_ATTR_BSS
NL80211_ATTR_BSS_CTS_PROT = C.NL80211_ATTR_BSS_CTS_PROT
NL80211_ATTR_BSS_HT_OPMODE = C.NL80211_ATTR_BSS_HT_OPMODE
NL80211_ATTR_BSSID = C.NL80211_ATTR_BSSID
NL80211_ATTR_BSS_SELECT = C.NL80211_ATTR_BSS_SELECT
NL80211_ATTR_BSS_SHORT_PREAMBLE = C.NL80211_ATTR_BSS_SHORT_PREAMBLE
NL80211_ATTR_BSS_SHORT_SLOT_TIME = C.NL80211_ATTR_BSS_SHORT_SLOT_TIME
NL80211_ATTR_CENTER_FREQ1 = C.NL80211_ATTR_CENTER_FREQ1
NL80211_ATTR_CENTER_FREQ1_OFFSET = C.NL80211_ATTR_CENTER_FREQ1_OFFSET
NL80211_ATTR_CENTER_FREQ2 = C.NL80211_ATTR_CENTER_FREQ2
NL80211_ATTR_CHANNEL_WIDTH = C.NL80211_ATTR_CHANNEL_WIDTH
NL80211_ATTR_CH_SWITCH_BLOCK_TX = C.NL80211_ATTR_CH_SWITCH_BLOCK_TX
NL80211_ATTR_CH_SWITCH_COUNT = C.NL80211_ATTR_CH_SWITCH_COUNT
NL80211_ATTR_CIPHER_SUITE_GROUP = C.NL80211_ATTR_CIPHER_SUITE_GROUP
NL80211_ATTR_CIPHER_SUITES = C.NL80211_ATTR_CIPHER_SUITES
NL80211_ATTR_CIPHER_SUITES_PAIRWISE = C.NL80211_ATTR_CIPHER_SUITES_PAIRWISE
NL80211_ATTR_CNTDWN_OFFS_BEACON = C.NL80211_ATTR_CNTDWN_OFFS_BEACON
NL80211_ATTR_CNTDWN_OFFS_PRESP = C.NL80211_ATTR_CNTDWN_OFFS_PRESP
NL80211_ATTR_COALESCE_RULE = C.NL80211_ATTR_COALESCE_RULE
NL80211_ATTR_COALESCE_RULE_CONDITION = C.NL80211_ATTR_COALESCE_RULE_CONDITION
NL80211_ATTR_COALESCE_RULE_DELAY = C.NL80211_ATTR_COALESCE_RULE_DELAY
NL80211_ATTR_COALESCE_RULE_MAX = C.NL80211_ATTR_COALESCE_RULE_MAX
NL80211_ATTR_COALESCE_RULE_PKT_PATTERN = C.NL80211_ATTR_COALESCE_RULE_PKT_PATTERN
NL80211_ATTR_COLOR_CHANGE_COLOR = C.NL80211_ATTR_COLOR_CHANGE_COLOR
NL80211_ATTR_COLOR_CHANGE_COUNT = C.NL80211_ATTR_COLOR_CHANGE_COUNT
NL80211_ATTR_COLOR_CHANGE_ELEMS = C.NL80211_ATTR_COLOR_CHANGE_ELEMS
NL80211_ATTR_CONN_FAILED_REASON = C.NL80211_ATTR_CONN_FAILED_REASON
NL80211_ATTR_CONTROL_PORT = C.NL80211_ATTR_CONTROL_PORT
NL80211_ATTR_CONTROL_PORT_ETHERTYPE = C.NL80211_ATTR_CONTROL_PORT_ETHERTYPE
NL80211_ATTR_CONTROL_PORT_NO_ENCRYPT = C.NL80211_ATTR_CONTROL_PORT_NO_ENCRYPT
NL80211_ATTR_CONTROL_PORT_NO_PREAUTH = C.NL80211_ATTR_CONTROL_PORT_NO_PREAUTH
NL80211_ATTR_CONTROL_PORT_OVER_NL80211 = C.NL80211_ATTR_CONTROL_PORT_OVER_NL80211
NL80211_ATTR_COOKIE = C.NL80211_ATTR_COOKIE
NL80211_ATTR_CQM_BEACON_LOSS_EVENT = C.NL80211_ATTR_CQM_BEACON_LOSS_EVENT
NL80211_ATTR_CQM = C.NL80211_ATTR_CQM
NL80211_ATTR_CQM_MAX = C.NL80211_ATTR_CQM_MAX
NL80211_ATTR_CQM_PKT_LOSS_EVENT = C.NL80211_ATTR_CQM_PKT_LOSS_EVENT
NL80211_ATTR_CQM_RSSI_HYST = C.NL80211_ATTR_CQM_RSSI_HYST
NL80211_ATTR_CQM_RSSI_LEVEL = C.NL80211_ATTR_CQM_RSSI_LEVEL
NL80211_ATTR_CQM_RSSI_THOLD = C.NL80211_ATTR_CQM_RSSI_THOLD
NL80211_ATTR_CQM_RSSI_THRESHOLD_EVENT = C.NL80211_ATTR_CQM_RSSI_THRESHOLD_EVENT
NL80211_ATTR_CQM_TXE_INTVL = C.NL80211_ATTR_CQM_TXE_INTVL
NL80211_ATTR_CQM_TXE_PKTS = C.NL80211_ATTR_CQM_TXE_PKTS
NL80211_ATTR_CQM_TXE_RATE = C.NL80211_ATTR_CQM_TXE_RATE
NL80211_ATTR_CRIT_PROT_ID = C.NL80211_ATTR_CRIT_PROT_ID
NL80211_ATTR_CSA_C_OFF_BEACON = C.NL80211_ATTR_CSA_C_OFF_BEACON
NL80211_ATTR_CSA_C_OFF_PRESP = C.NL80211_ATTR_CSA_C_OFF_PRESP
NL80211_ATTR_CSA_C_OFFSETS_TX = C.NL80211_ATTR_CSA_C_OFFSETS_TX
NL80211_ATTR_CSA_IES = C.NL80211_ATTR_CSA_IES
NL80211_ATTR_DEVICE_AP_SME = C.NL80211_ATTR_DEVICE_AP_SME
NL80211_ATTR_DFS_CAC_TIME = C.NL80211_ATTR_DFS_CAC_TIME
NL80211_ATTR_DFS_REGION = C.NL80211_ATTR_DFS_REGION
NL80211_ATTR_DISABLE_EHT = C.NL80211_ATTR_DISABLE_EHT
NL80211_ATTR_DISABLE_HE = C.NL80211_ATTR_DISABLE_HE
NL80211_ATTR_DISABLE_HT = C.NL80211_ATTR_DISABLE_HT
NL80211_ATTR_DISABLE_VHT = C.NL80211_ATTR_DISABLE_VHT
NL80211_ATTR_DISCONNECTED_BY_AP = C.NL80211_ATTR_DISCONNECTED_BY_AP
NL80211_ATTR_DONT_WAIT_FOR_ACK = C.NL80211_ATTR_DONT_WAIT_FOR_ACK
NL80211_ATTR_DTIM_PERIOD = C.NL80211_ATTR_DTIM_PERIOD
NL80211_ATTR_DURATION = C.NL80211_ATTR_DURATION
NL80211_ATTR_EHT_CAPABILITY = C.NL80211_ATTR_EHT_CAPABILITY
NL80211_ATTR_EML_CAPABILITY = C.NL80211_ATTR_EML_CAPABILITY
NL80211_ATTR_EXT_CAPA = C.NL80211_ATTR_EXT_CAPA
NL80211_ATTR_EXT_CAPA_MASK = C.NL80211_ATTR_EXT_CAPA_MASK
NL80211_ATTR_EXTERNAL_AUTH_ACTION = C.NL80211_ATTR_EXTERNAL_AUTH_ACTION
NL80211_ATTR_EXTERNAL_AUTH_SUPPORT = C.NL80211_ATTR_EXTERNAL_AUTH_SUPPORT
NL80211_ATTR_EXT_FEATURES = C.NL80211_ATTR_EXT_FEATURES
NL80211_ATTR_FEATURE_FLAGS = C.NL80211_ATTR_FEATURE_FLAGS
NL80211_ATTR_FILS_CACHE_ID = C.NL80211_ATTR_FILS_CACHE_ID
NL80211_ATTR_FILS_DISCOVERY = C.NL80211_ATTR_FILS_DISCOVERY
NL80211_ATTR_FILS_ERP_NEXT_SEQ_NUM = C.NL80211_ATTR_FILS_ERP_NEXT_SEQ_NUM
NL80211_ATTR_FILS_ERP_REALM = C.NL80211_ATTR_FILS_ERP_REALM
NL80211_ATTR_FILS_ERP_RRK = C.NL80211_ATTR_FILS_ERP_RRK
NL80211_ATTR_FILS_ERP_USERNAME = C.NL80211_ATTR_FILS_ERP_USERNAME
NL80211_ATTR_FILS_KEK = C.NL80211_ATTR_FILS_KEK
NL80211_ATTR_FILS_NONCES = C.NL80211_ATTR_FILS_NONCES
NL80211_ATTR_FRAME = C.NL80211_ATTR_FRAME
NL80211_ATTR_FRAME_MATCH = C.NL80211_ATTR_FRAME_MATCH
NL80211_ATTR_FRAME_TYPE = C.NL80211_ATTR_FRAME_TYPE
NL80211_ATTR_FREQ_AFTER = C.NL80211_ATTR_FREQ_AFTER
NL80211_ATTR_FREQ_BEFORE = C.NL80211_ATTR_FREQ_BEFORE
NL80211_ATTR_FREQ_FIXED = C.NL80211_ATTR_FREQ_FIXED
NL80211_ATTR_FREQ_RANGE_END = C.NL80211_ATTR_FREQ_RANGE_END
NL80211_ATTR_FREQ_RANGE_MAX_BW = C.NL80211_ATTR_FREQ_RANGE_MAX_BW
NL80211_ATTR_FREQ_RANGE_START = C.NL80211_ATTR_FREQ_RANGE_START
NL80211_ATTR_FTM_RESPONDER = C.NL80211_ATTR_FTM_RESPONDER
NL80211_ATTR_FTM_RESPONDER_STATS = C.NL80211_ATTR_FTM_RESPONDER_STATS
NL80211_ATTR_GENERATION = C.NL80211_ATTR_GENERATION
NL80211_ATTR_HANDLE_DFS = C.NL80211_ATTR_HANDLE_DFS
NL80211_ATTR_HE_6GHZ_CAPABILITY = C.NL80211_ATTR_HE_6GHZ_CAPABILITY
NL80211_ATTR_HE_BSS_COLOR = C.NL80211_ATTR_HE_BSS_COLOR
NL80211_ATTR_HE_CAPABILITY = C.NL80211_ATTR_HE_CAPABILITY
NL80211_ATTR_HE_OBSS_PD = C.NL80211_ATTR_HE_OBSS_PD
NL80211_ATTR_HIDDEN_SSID = C.NL80211_ATTR_HIDDEN_SSID
NL80211_ATTR_HT_CAPABILITY = C.NL80211_ATTR_HT_CAPABILITY
NL80211_ATTR_HT_CAPABILITY_MASK = C.NL80211_ATTR_HT_CAPABILITY_MASK
NL80211_ATTR_IE_ASSOC_RESP = C.NL80211_ATTR_IE_ASSOC_RESP
NL80211_ATTR_IE = C.NL80211_ATTR_IE
NL80211_ATTR_IE_PROBE_RESP = C.NL80211_ATTR_IE_PROBE_RESP
NL80211_ATTR_IE_RIC = C.NL80211_ATTR_IE_RIC
NL80211_ATTR_IFACE_SOCKET_OWNER = C.NL80211_ATTR_IFACE_SOCKET_OWNER
NL80211_ATTR_IFINDEX = C.NL80211_ATTR_IFINDEX
NL80211_ATTR_IFNAME = C.NL80211_ATTR_IFNAME
NL80211_ATTR_IFTYPE_AKM_SUITES = C.NL80211_ATTR_IFTYPE_AKM_SUITES
NL80211_ATTR_IFTYPE = C.NL80211_ATTR_IFTYPE
NL80211_ATTR_IFTYPE_EXT_CAPA = C.NL80211_ATTR_IFTYPE_EXT_CAPA
NL80211_ATTR_INACTIVITY_TIMEOUT = C.NL80211_ATTR_INACTIVITY_TIMEOUT
NL80211_ATTR_INTERFACE_COMBINATIONS = C.NL80211_ATTR_INTERFACE_COMBINATIONS
NL80211_ATTR_KEY_CIPHER = C.NL80211_ATTR_KEY_CIPHER
NL80211_ATTR_KEY = C.NL80211_ATTR_KEY
NL80211_ATTR_KEY_DATA = C.NL80211_ATTR_KEY_DATA
NL80211_ATTR_KEY_DEFAULT = C.NL80211_ATTR_KEY_DEFAULT
NL80211_ATTR_KEY_DEFAULT_MGMT = C.NL80211_ATTR_KEY_DEFAULT_MGMT
NL80211_ATTR_KEY_DEFAULT_TYPES = C.NL80211_ATTR_KEY_DEFAULT_TYPES
NL80211_ATTR_KEY_IDX = C.NL80211_ATTR_KEY_IDX
NL80211_ATTR_KEYS = C.NL80211_ATTR_KEYS
NL80211_ATTR_KEY_SEQ = C.NL80211_ATTR_KEY_SEQ
NL80211_ATTR_KEY_TYPE = C.NL80211_ATTR_KEY_TYPE
NL80211_ATTR_LOCAL_MESH_POWER_MODE = C.NL80211_ATTR_LOCAL_MESH_POWER_MODE
NL80211_ATTR_LOCAL_STATE_CHANGE = C.NL80211_ATTR_LOCAL_STATE_CHANGE
NL80211_ATTR_MAC_ACL_MAX = C.NL80211_ATTR_MAC_ACL_MAX
NL80211_ATTR_MAC_ADDRS = C.NL80211_ATTR_MAC_ADDRS
NL80211_ATTR_MAC = C.NL80211_ATTR_MAC
NL80211_ATTR_MAC_HINT = C.NL80211_ATTR_MAC_HINT
NL80211_ATTR_MAC_MASK = C.NL80211_ATTR_MAC_MASK
NL80211_ATTR_MAX_AP_ASSOC_STA = C.NL80211_ATTR_MAX_AP_ASSOC_STA
NL80211_ATTR_MAX = C.NL80211_ATTR_MAX
NL80211_ATTR_MAX_CRIT_PROT_DURATION = C.NL80211_ATTR_MAX_CRIT_PROT_DURATION
NL80211_ATTR_MAX_CSA_COUNTERS = C.NL80211_ATTR_MAX_CSA_COUNTERS
NL80211_ATTR_MAX_MATCH_SETS = C.NL80211_ATTR_MAX_MATCH_SETS
NL80211_ATTR_MAX_NUM_AKM_SUITES = C.NL80211_ATTR_MAX_NUM_AKM_SUITES
NL80211_ATTR_MAX_NUM_PMKIDS = C.NL80211_ATTR_MAX_NUM_PMKIDS
NL80211_ATTR_MAX_NUM_SCAN_SSIDS = C.NL80211_ATTR_MAX_NUM_SCAN_SSIDS
NL80211_ATTR_MAX_NUM_SCHED_SCAN_PLANS = C.NL80211_ATTR_MAX_NUM_SCHED_SCAN_PLANS
NL80211_ATTR_MAX_NUM_SCHED_SCAN_SSIDS = C.NL80211_ATTR_MAX_NUM_SCHED_SCAN_SSIDS
NL80211_ATTR_MAX_REMAIN_ON_CHANNEL_DURATION = C.NL80211_ATTR_MAX_REMAIN_ON_CHANNEL_DURATION
NL80211_ATTR_MAX_SCAN_IE_LEN = C.NL80211_ATTR_MAX_SCAN_IE_LEN
NL80211_ATTR_MAX_SCAN_PLAN_INTERVAL = C.NL80211_ATTR_MAX_SCAN_PLAN_INTERVAL
NL80211_ATTR_MAX_SCAN_PLAN_ITERATIONS = C.NL80211_ATTR_MAX_SCAN_PLAN_ITERATIONS
NL80211_ATTR_MAX_SCHED_SCAN_IE_LEN = C.NL80211_ATTR_MAX_SCHED_SCAN_IE_LEN
NL80211_ATTR_MBSSID_CONFIG = C.NL80211_ATTR_MBSSID_CONFIG
NL80211_ATTR_MBSSID_ELEMS = C.NL80211_ATTR_MBSSID_ELEMS
NL80211_ATTR_MCAST_RATE = C.NL80211_ATTR_MCAST_RATE
NL80211_ATTR_MDID = C.NL80211_ATTR_MDID
NL80211_ATTR_MEASUREMENT_DURATION = C.NL80211_ATTR_MEASUREMENT_DURATION
NL80211_ATTR_MEASUREMENT_DURATION_MANDATORY = C.NL80211_ATTR_MEASUREMENT_DURATION_MANDATORY
NL80211_ATTR_MESH_CONFIG = C.NL80211_ATTR_MESH_CONFIG
NL80211_ATTR_MESH_ID = C.NL80211_ATTR_MESH_ID
NL80211_ATTR_MESH_PEER_AID = C.NL80211_ATTR_MESH_PEER_AID
NL80211_ATTR_MESH_SETUP = C.NL80211_ATTR_MESH_SETUP
NL80211_ATTR_MGMT_SUBTYPE = C.NL80211_ATTR_MGMT_SUBTYPE
NL80211_ATTR_MLD_ADDR = C.NL80211_ATTR_MLD_ADDR
NL80211_ATTR_MLD_CAPA_AND_OPS = C.NL80211_ATTR_MLD_CAPA_AND_OPS
NL80211_ATTR_MLO_LINK_ID = C.NL80211_ATTR_MLO_LINK_ID
NL80211_ATTR_MLO_LINKS = C.NL80211_ATTR_MLO_LINKS
NL80211_ATTR_MLO_SUPPORT = C.NL80211_ATTR_MLO_SUPPORT
NL80211_ATTR_MNTR_FLAGS = C.NL80211_ATTR_MNTR_FLAGS
NL80211_ATTR_MPATH_INFO = C.NL80211_ATTR_MPATH_INFO
NL80211_ATTR_MPATH_NEXT_HOP = C.NL80211_ATTR_MPATH_NEXT_HOP
NL80211_ATTR_MULTICAST_TO_UNICAST_ENABLED = C.NL80211_ATTR_MULTICAST_TO_UNICAST_ENABLED
NL80211_ATTR_MU_MIMO_FOLLOW_MAC_ADDR = C.NL80211_ATTR_MU_MIMO_FOLLOW_MAC_ADDR
NL80211_ATTR_MU_MIMO_GROUP_DATA = C.NL80211_ATTR_MU_MIMO_GROUP_DATA
NL80211_ATTR_NAN_FUNC = C.NL80211_ATTR_NAN_FUNC
NL80211_ATTR_NAN_MASTER_PREF = C.NL80211_ATTR_NAN_MASTER_PREF
NL80211_ATTR_NAN_MATCH = C.NL80211_ATTR_NAN_MATCH
NL80211_ATTR_NETNS_FD = C.NL80211_ATTR_NETNS_FD
NL80211_ATTR_NOACK_MAP = C.NL80211_ATTR_NOACK_MAP
NL80211_ATTR_NSS = C.NL80211_ATTR_NSS
NL80211_ATTR_OBSS_COLOR_BITMAP = C.NL80211_ATTR_OBSS_COLOR_BITMAP
NL80211_ATTR_OFFCHANNEL_TX_OK = C.NL80211_ATTR_OFFCHANNEL_TX_OK
NL80211_ATTR_OPER_CLASS = C.NL80211_ATTR_OPER_CLASS
NL80211_ATTR_OPMODE_NOTIF = C.NL80211_ATTR_OPMODE_NOTIF
NL80211_ATTR_P2P_CTWINDOW = C.NL80211_ATTR_P2P_CTWINDOW
NL80211_ATTR_P2P_OPPPS = C.NL80211_ATTR_P2P_OPPPS
NL80211_ATTR_PAD = C.NL80211_ATTR_PAD
NL80211_ATTR_PBSS = C.NL80211_ATTR_PBSS
NL80211_ATTR_PEER_AID = C.NL80211_ATTR_PEER_AID
NL80211_ATTR_PEER_MEASUREMENTS = C.NL80211_ATTR_PEER_MEASUREMENTS
NL80211_ATTR_PID = C.NL80211_ATTR_PID
NL80211_ATTR_PMK = C.NL80211_ATTR_PMK
NL80211_ATTR_PMKID = C.NL80211_ATTR_PMKID
NL80211_ATTR_PMK_LIFETIME = C.NL80211_ATTR_PMK_LIFETIME
NL80211_ATTR_PMKR0_NAME = C.NL80211_ATTR_PMKR0_NAME
NL80211_ATTR_PMK_REAUTH_THRESHOLD = C.NL80211_ATTR_PMK_REAUTH_THRESHOLD
NL80211_ATTR_PMKSA_CANDIDATE = C.NL80211_ATTR_PMKSA_CANDIDATE
NL80211_ATTR_PORT_AUTHORIZED = C.NL80211_ATTR_PORT_AUTHORIZED
NL80211_ATTR_POWER_RULE_MAX_ANT_GAIN = C.NL80211_ATTR_POWER_RULE_MAX_ANT_GAIN
NL80211_ATTR_POWER_RULE_MAX_EIRP = C.NL80211_ATTR_POWER_RULE_MAX_EIRP
NL80211_ATTR_PREV_BSSID = C.NL80211_ATTR_PREV_BSSID
NL80211_ATTR_PRIVACY = C.NL80211_ATTR_PRIVACY
NL80211_ATTR_PROBE_RESP = C.NL80211_ATTR_PROBE_RESP
NL80211_ATTR_PROBE_RESP_OFFLOAD = C.NL80211_ATTR_PROBE_RESP_OFFLOAD
NL80211_ATTR_PROTOCOL_FEATURES = C.NL80211_ATTR_PROTOCOL_FEATURES
NL80211_ATTR_PS_STATE = C.NL80211_ATTR_PS_STATE
NL80211_ATTR_QOS_MAP = C.NL80211_ATTR_QOS_MAP
NL80211_ATTR_RADAR_BACKGROUND = C.NL80211_ATTR_RADAR_BACKGROUND
NL80211_ATTR_RADAR_EVENT = C.NL80211_ATTR_RADAR_EVENT
NL80211_ATTR_REASON_CODE = C.NL80211_ATTR_REASON_CODE
NL80211_ATTR_RECEIVE_MULTICAST = C.NL80211_ATTR_RECEIVE_MULTICAST
NL80211_ATTR_RECONNECT_REQUESTED = C.NL80211_ATTR_RECONNECT_REQUESTED
NL80211_ATTR_REG_ALPHA2 = C.NL80211_ATTR_REG_ALPHA2
NL80211_ATTR_REG_INDOOR = C.NL80211_ATTR_REG_INDOOR
NL80211_ATTR_REG_INITIATOR = C.NL80211_ATTR_REG_INITIATOR
NL80211_ATTR_REG_RULE_FLAGS = C.NL80211_ATTR_REG_RULE_FLAGS
NL80211_ATTR_REG_RULES = C.NL80211_ATTR_REG_RULES
NL80211_ATTR_REG_TYPE = C.NL80211_ATTR_REG_TYPE
NL80211_ATTR_REKEY_DATA = C.NL80211_ATTR_REKEY_DATA
NL80211_ATTR_REQ_IE = C.NL80211_ATTR_REQ_IE
NL80211_ATTR_RESP_IE = C.NL80211_ATTR_RESP_IE
NL80211_ATTR_ROAM_SUPPORT = C.NL80211_ATTR_ROAM_SUPPORT
NL80211_ATTR_RX_FRAME_TYPES = C.NL80211_ATTR_RX_FRAME_TYPES
NL80211_ATTR_RX_HW_TIMESTAMP = C.NL80211_ATTR_RX_HW_TIMESTAMP
NL80211_ATTR_RXMGMT_FLAGS = C.NL80211_ATTR_RXMGMT_FLAGS
NL80211_ATTR_RX_SIGNAL_DBM = C.NL80211_ATTR_RX_SIGNAL_DBM
NL80211_ATTR_S1G_CAPABILITY = C.NL80211_ATTR_S1G_CAPABILITY
NL80211_ATTR_S1G_CAPABILITY_MASK = C.NL80211_ATTR_S1G_CAPABILITY_MASK
NL80211_ATTR_SAE_DATA = C.NL80211_ATTR_SAE_DATA
NL80211_ATTR_SAE_PASSWORD = C.NL80211_ATTR_SAE_PASSWORD
NL80211_ATTR_SAE_PWE = C.NL80211_ATTR_SAE_PWE
NL80211_ATTR_SAR_SPEC = C.NL80211_ATTR_SAR_SPEC
NL80211_ATTR_SCAN_FLAGS = C.NL80211_ATTR_SCAN_FLAGS
NL80211_ATTR_SCAN_FREQ_KHZ = C.NL80211_ATTR_SCAN_FREQ_KHZ
NL80211_ATTR_SCAN_FREQUENCIES = C.NL80211_ATTR_SCAN_FREQUENCIES
NL80211_ATTR_SCAN_GENERATION = C.NL80211_ATTR_SCAN_GENERATION
NL80211_ATTR_SCAN_SSIDS = C.NL80211_ATTR_SCAN_SSIDS
NL80211_ATTR_SCAN_START_TIME_TSF_BSSID = C.NL80211_ATTR_SCAN_START_TIME_TSF_BSSID
NL80211_ATTR_SCAN_START_TIME_TSF = C.NL80211_ATTR_SCAN_START_TIME_TSF
NL80211_ATTR_SCAN_SUPP_RATES = C.NL80211_ATTR_SCAN_SUPP_RATES
NL80211_ATTR_SCHED_SCAN_DELAY = C.NL80211_ATTR_SCHED_SCAN_DELAY
NL80211_ATTR_SCHED_SCAN_INTERVAL = C.NL80211_ATTR_SCHED_SCAN_INTERVAL
NL80211_ATTR_SCHED_SCAN_MATCH = C.NL80211_ATTR_SCHED_SCAN_MATCH
NL80211_ATTR_SCHED_SCAN_MATCH_SSID = C.NL80211_ATTR_SCHED_SCAN_MATCH_SSID
NL80211_ATTR_SCHED_SCAN_MAX_REQS = C.NL80211_ATTR_SCHED_SCAN_MAX_REQS
NL80211_ATTR_SCHED_SCAN_MULTI = C.NL80211_ATTR_SCHED_SCAN_MULTI
NL80211_ATTR_SCHED_SCAN_PLANS = C.NL80211_ATTR_SCHED_SCAN_PLANS
NL80211_ATTR_SCHED_SCAN_RELATIVE_RSSI = C.NL80211_ATTR_SCHED_SCAN_RELATIVE_RSSI
NL80211_ATTR_SCHED_SCAN_RSSI_ADJUST = C.NL80211_ATTR_SCHED_SCAN_RSSI_ADJUST
NL80211_ATTR_SMPS_MODE = C.NL80211_ATTR_SMPS_MODE
NL80211_ATTR_SOCKET_OWNER = C.NL80211_ATTR_SOCKET_OWNER
NL80211_ATTR_SOFTWARE_IFTYPES = C.NL80211_ATTR_SOFTWARE_IFTYPES
NL80211_ATTR_SPLIT_WIPHY_DUMP = C.NL80211_ATTR_SPLIT_WIPHY_DUMP
NL80211_ATTR_SSID = C.NL80211_ATTR_SSID
NL80211_ATTR_STA_AID = C.NL80211_ATTR_STA_AID
NL80211_ATTR_STA_CAPABILITY = C.NL80211_ATTR_STA_CAPABILITY
NL80211_ATTR_STA_EXT_CAPABILITY = C.NL80211_ATTR_STA_EXT_CAPABILITY
NL80211_ATTR_STA_FLAGS2 = C.NL80211_ATTR_STA_FLAGS2
NL80211_ATTR_STA_FLAGS = C.NL80211_ATTR_STA_FLAGS
NL80211_ATTR_STA_INFO = C.NL80211_ATTR_STA_INFO
NL80211_ATTR_STA_LISTEN_INTERVAL = C.NL80211_ATTR_STA_LISTEN_INTERVAL
NL80211_ATTR_STA_PLINK_ACTION = C.NL80211_ATTR_STA_PLINK_ACTION
NL80211_ATTR_STA_PLINK_STATE = C.NL80211_ATTR_STA_PLINK_STATE
NL80211_ATTR_STA_SUPPORTED_CHANNELS = C.NL80211_ATTR_STA_SUPPORTED_CHANNELS
NL80211_ATTR_STA_SUPPORTED_OPER_CLASSES = C.NL80211_ATTR_STA_SUPPORTED_OPER_CLASSES
NL80211_ATTR_STA_SUPPORTED_RATES = C.NL80211_ATTR_STA_SUPPORTED_RATES
NL80211_ATTR_STA_SUPPORT_P2P_PS = C.NL80211_ATTR_STA_SUPPORT_P2P_PS
NL80211_ATTR_STATUS_CODE = C.NL80211_ATTR_STATUS_CODE
NL80211_ATTR_STA_TX_POWER = C.NL80211_ATTR_STA_TX_POWER
NL80211_ATTR_STA_TX_POWER_SETTING = C.NL80211_ATTR_STA_TX_POWER_SETTING
NL80211_ATTR_STA_VLAN = C.NL80211_ATTR_STA_VLAN
NL80211_ATTR_STA_WME = C.NL80211_ATTR_STA_WME
NL80211_ATTR_SUPPORT_10_MHZ = C.NL80211_ATTR_SUPPORT_10_MHZ
NL80211_ATTR_SUPPORT_5_MHZ = C.NL80211_ATTR_SUPPORT_5_MHZ
NL80211_ATTR_SUPPORT_AP_UAPSD = C.NL80211_ATTR_SUPPORT_AP_UAPSD
NL80211_ATTR_SUPPORTED_COMMANDS = C.NL80211_ATTR_SUPPORTED_COMMANDS
NL80211_ATTR_SUPPORTED_IFTYPES = C.NL80211_ATTR_SUPPORTED_IFTYPES
NL80211_ATTR_SUPPORT_IBSS_RSN = C.NL80211_ATTR_SUPPORT_IBSS_RSN
NL80211_ATTR_SUPPORT_MESH_AUTH = C.NL80211_ATTR_SUPPORT_MESH_AUTH
NL80211_ATTR_SURVEY_INFO = C.NL80211_ATTR_SURVEY_INFO
NL80211_ATTR_SURVEY_RADIO_STATS = C.NL80211_ATTR_SURVEY_RADIO_STATS
NL80211_ATTR_TD_BITMAP = C.NL80211_ATTR_TD_BITMAP
NL80211_ATTR_TDLS_ACTION = C.NL80211_ATTR_TDLS_ACTION
NL80211_ATTR_TDLS_DIALOG_TOKEN = C.NL80211_ATTR_TDLS_DIALOG_TOKEN
NL80211_ATTR_TDLS_EXTERNAL_SETUP = C.NL80211_ATTR_TDLS_EXTERNAL_SETUP
NL80211_ATTR_TDLS_INITIATOR = C.NL80211_ATTR_TDLS_INITIATOR
NL80211_ATTR_TDLS_OPERATION = C.NL80211_ATTR_TDLS_OPERATION
NL80211_ATTR_TDLS_PEER_CAPABILITY = C.NL80211_ATTR_TDLS_PEER_CAPABILITY
NL80211_ATTR_TDLS_SUPPORT = C.NL80211_ATTR_TDLS_SUPPORT
NL80211_ATTR_TESTDATA = C.NL80211_ATTR_TESTDATA
NL80211_ATTR_TID_CONFIG = C.NL80211_ATTR_TID_CONFIG
NL80211_ATTR_TIMED_OUT = C.NL80211_ATTR_TIMED_OUT
NL80211_ATTR_TIMEOUT = C.NL80211_ATTR_TIMEOUT
NL80211_ATTR_TIMEOUT_REASON = C.NL80211_ATTR_TIMEOUT_REASON
NL80211_ATTR_TSID = C.NL80211_ATTR_TSID
NL80211_ATTR_TWT_RESPONDER = C.NL80211_ATTR_TWT_RESPONDER
NL80211_ATTR_TX_FRAME_TYPES = C.NL80211_ATTR_TX_FRAME_TYPES
NL80211_ATTR_TX_HW_TIMESTAMP = C.NL80211_ATTR_TX_HW_TIMESTAMP
NL80211_ATTR_TX_NO_CCK_RATE = C.NL80211_ATTR_TX_NO_CCK_RATE
NL80211_ATTR_TXQ_LIMIT = C.NL80211_ATTR_TXQ_LIMIT
NL80211_ATTR_TXQ_MEMORY_LIMIT = C.NL80211_ATTR_TXQ_MEMORY_LIMIT
NL80211_ATTR_TXQ_QUANTUM = C.NL80211_ATTR_TXQ_QUANTUM
NL80211_ATTR_TXQ_STATS = C.NL80211_ATTR_TXQ_STATS
NL80211_ATTR_TX_RATES = C.NL80211_ATTR_TX_RATES
NL80211_ATTR_UNSOL_BCAST_PROBE_RESP = C.NL80211_ATTR_UNSOL_BCAST_PROBE_RESP
NL80211_ATTR_UNSPEC = C.NL80211_ATTR_UNSPEC
NL80211_ATTR_USE_MFP = C.NL80211_ATTR_USE_MFP
NL80211_ATTR_USER_PRIO = C.NL80211_ATTR_USER_PRIO
NL80211_ATTR_USER_REG_HINT_TYPE = C.NL80211_ATTR_USER_REG_HINT_TYPE
NL80211_ATTR_USE_RRM = C.NL80211_ATTR_USE_RRM
NL80211_ATTR_VENDOR_DATA = C.NL80211_ATTR_VENDOR_DATA
NL80211_ATTR_VENDOR_EVENTS = C.NL80211_ATTR_VENDOR_EVENTS
NL80211_ATTR_VENDOR_ID = C.NL80211_ATTR_VENDOR_ID
NL80211_ATTR_VENDOR_SUBCMD = C.NL80211_ATTR_VENDOR_SUBCMD
NL80211_ATTR_VHT_CAPABILITY = C.NL80211_ATTR_VHT_CAPABILITY
NL80211_ATTR_VHT_CAPABILITY_MASK = C.NL80211_ATTR_VHT_CAPABILITY_MASK
NL80211_ATTR_VLAN_ID = C.NL80211_ATTR_VLAN_ID
NL80211_ATTR_WANT_1X_4WAY_HS = C.NL80211_ATTR_WANT_1X_4WAY_HS
NL80211_ATTR_WDEV = C.NL80211_ATTR_WDEV
NL80211_ATTR_WIPHY_ANTENNA_AVAIL_RX = C.NL80211_ATTR_WIPHY_ANTENNA_AVAIL_RX
NL80211_ATTR_WIPHY_ANTENNA_AVAIL_TX = C.NL80211_ATTR_WIPHY_ANTENNA_AVAIL_TX
NL80211_ATTR_WIPHY_ANTENNA_RX = C.NL80211_ATTR_WIPHY_ANTENNA_RX
NL80211_ATTR_WIPHY_ANTENNA_TX = C.NL80211_ATTR_WIPHY_ANTENNA_TX
NL80211_ATTR_WIPHY_BANDS = C.NL80211_ATTR_WIPHY_BANDS
NL80211_ATTR_WIPHY_CHANNEL_TYPE = C.NL80211_ATTR_WIPHY_CHANNEL_TYPE
NL80211_ATTR_WIPHY = C.NL80211_ATTR_WIPHY
NL80211_ATTR_WIPHY_COVERAGE_CLASS = C.NL80211_ATTR_WIPHY_COVERAGE_CLASS
NL80211_ATTR_WIPHY_DYN_ACK = C.NL80211_ATTR_WIPHY_DYN_ACK
NL80211_ATTR_WIPHY_EDMG_BW_CONFIG = C.NL80211_ATTR_WIPHY_EDMG_BW_CONFIG
NL80211_ATTR_WIPHY_EDMG_CHANNELS = C.NL80211_ATTR_WIPHY_EDMG_CHANNELS
NL80211_ATTR_WIPHY_FRAG_THRESHOLD = C.NL80211_ATTR_WIPHY_FRAG_THRESHOLD
NL80211_ATTR_WIPHY_FREQ = C.NL80211_ATTR_WIPHY_FREQ
NL80211_ATTR_WIPHY_FREQ_HINT = C.NL80211_ATTR_WIPHY_FREQ_HINT
NL80211_ATTR_WIPHY_FREQ_OFFSET = C.NL80211_ATTR_WIPHY_FREQ_OFFSET
NL80211_ATTR_WIPHY_NAME = C.NL80211_ATTR_WIPHY_NAME
NL80211_ATTR_WIPHY_RETRY_LONG = C.NL80211_ATTR_WIPHY_RETRY_LONG
NL80211_ATTR_WIPHY_RETRY_SHORT = C.NL80211_ATTR_WIPHY_RETRY_SHORT
NL80211_ATTR_WIPHY_RTS_THRESHOLD = C.NL80211_ATTR_WIPHY_RTS_THRESHOLD
NL80211_ATTR_WIPHY_SELF_MANAGED_REG = C.NL80211_ATTR_WIPHY_SELF_MANAGED_REG
NL80211_ATTR_WIPHY_TX_POWER_LEVEL = C.NL80211_ATTR_WIPHY_TX_POWER_LEVEL
NL80211_ATTR_WIPHY_TX_POWER_SETTING = C.NL80211_ATTR_WIPHY_TX_POWER_SETTING
NL80211_ATTR_WIPHY_TXQ_PARAMS = C.NL80211_ATTR_WIPHY_TXQ_PARAMS
NL80211_ATTR_WOWLAN_TRIGGERS = C.NL80211_ATTR_WOWLAN_TRIGGERS
NL80211_ATTR_WOWLAN_TRIGGERS_SUPPORTED = C.NL80211_ATTR_WOWLAN_TRIGGERS_SUPPORTED
NL80211_ATTR_WPA_VERSIONS = C.NL80211_ATTR_WPA_VERSIONS
NL80211_AUTHTYPE_AUTOMATIC = C.NL80211_AUTHTYPE_AUTOMATIC
NL80211_AUTHTYPE_FILS_PK = C.NL80211_AUTHTYPE_FILS_PK
NL80211_AUTHTYPE_FILS_SK = C.NL80211_AUTHTYPE_FILS_SK
NL80211_AUTHTYPE_FILS_SK_PFS = C.NL80211_AUTHTYPE_FILS_SK_PFS
NL80211_AUTHTYPE_FT = C.NL80211_AUTHTYPE_FT
NL80211_AUTHTYPE_MAX = C.NL80211_AUTHTYPE_MAX
NL80211_AUTHTYPE_NETWORK_EAP = C.NL80211_AUTHTYPE_NETWORK_EAP
NL80211_AUTHTYPE_OPEN_SYSTEM = C.NL80211_AUTHTYPE_OPEN_SYSTEM
NL80211_AUTHTYPE_SAE = C.NL80211_AUTHTYPE_SAE
NL80211_AUTHTYPE_SHARED_KEY = C.NL80211_AUTHTYPE_SHARED_KEY
NL80211_BAND_2GHZ = C.NL80211_BAND_2GHZ
NL80211_BAND_5GHZ = C.NL80211_BAND_5GHZ
NL80211_BAND_60GHZ = C.NL80211_BAND_60GHZ
NL80211_BAND_6GHZ = C.NL80211_BAND_6GHZ
NL80211_BAND_ATTR_EDMG_BW_CONFIG = C.NL80211_BAND_ATTR_EDMG_BW_CONFIG
NL80211_BAND_ATTR_EDMG_CHANNELS = C.NL80211_BAND_ATTR_EDMG_CHANNELS
NL80211_BAND_ATTR_FREQS = C.NL80211_BAND_ATTR_FREQS
NL80211_BAND_ATTR_HT_AMPDU_DENSITY = C.NL80211_BAND_ATTR_HT_AMPDU_DENSITY
NL80211_BAND_ATTR_HT_AMPDU_FACTOR = C.NL80211_BAND_ATTR_HT_AMPDU_FACTOR
NL80211_BAND_ATTR_HT_CAPA = C.NL80211_BAND_ATTR_HT_CAPA
NL80211_BAND_ATTR_HT_MCS_SET = C.NL80211_BAND_ATTR_HT_MCS_SET
NL80211_BAND_ATTR_IFTYPE_DATA = C.NL80211_BAND_ATTR_IFTYPE_DATA
NL80211_BAND_ATTR_MAX = C.NL80211_BAND_ATTR_MAX
NL80211_BAND_ATTR_RATES = C.NL80211_BAND_ATTR_RATES
NL80211_BAND_ATTR_VHT_CAPA = C.NL80211_BAND_ATTR_VHT_CAPA
NL80211_BAND_ATTR_VHT_MCS_SET = C.NL80211_BAND_ATTR_VHT_MCS_SET
NL80211_BAND_IFTYPE_ATTR_EHT_CAP_MAC = C.NL80211_BAND_IFTYPE_ATTR_EHT_CAP_MAC
NL80211_BAND_IFTYPE_ATTR_EHT_CAP_MCS_SET = C.NL80211_BAND_IFTYPE_ATTR_EHT_CAP_MCS_SET
NL80211_BAND_IFTYPE_ATTR_EHT_CAP_PHY = C.NL80211_BAND_IFTYPE_ATTR_EHT_CAP_PHY
NL80211_BAND_IFTYPE_ATTR_EHT_CAP_PPE = C.NL80211_BAND_IFTYPE_ATTR_EHT_CAP_PPE
NL80211_BAND_IFTYPE_ATTR_HE_6GHZ_CAPA = C.NL80211_BAND_IFTYPE_ATTR_HE_6GHZ_CAPA
NL80211_BAND_IFTYPE_ATTR_HE_CAP_MAC = C.NL80211_BAND_IFTYPE_ATTR_HE_CAP_MAC
NL80211_BAND_IFTYPE_ATTR_HE_CAP_MCS_SET = C.NL80211_BAND_IFTYPE_ATTR_HE_CAP_MCS_SET
NL80211_BAND_IFTYPE_ATTR_HE_CAP_PHY = C.NL80211_BAND_IFTYPE_ATTR_HE_CAP_PHY
NL80211_BAND_IFTYPE_ATTR_HE_CAP_PPE = C.NL80211_BAND_IFTYPE_ATTR_HE_CAP_PPE
NL80211_BAND_IFTYPE_ATTR_IFTYPES = C.NL80211_BAND_IFTYPE_ATTR_IFTYPES
NL80211_BAND_IFTYPE_ATTR_MAX = C.NL80211_BAND_IFTYPE_ATTR_MAX
NL80211_BAND_IFTYPE_ATTR_VENDOR_ELEMS = C.NL80211_BAND_IFTYPE_ATTR_VENDOR_ELEMS
NL80211_BAND_LC = C.NL80211_BAND_LC
NL80211_BAND_S1GHZ = C.NL80211_BAND_S1GHZ
NL80211_BITRATE_ATTR_2GHZ_SHORTPREAMBLE = C.NL80211_BITRATE_ATTR_2GHZ_SHORTPREAMBLE
NL80211_BITRATE_ATTR_MAX = C.NL80211_BITRATE_ATTR_MAX
NL80211_BITRATE_ATTR_RATE = C.NL80211_BITRATE_ATTR_RATE
NL80211_BSS_BEACON_IES = C.NL80211_BSS_BEACON_IES
NL80211_BSS_BEACON_INTERVAL = C.NL80211_BSS_BEACON_INTERVAL
NL80211_BSS_BEACON_TSF = C.NL80211_BSS_BEACON_TSF
NL80211_BSS_BSSID = C.NL80211_BSS_BSSID
NL80211_BSS_CAPABILITY = C.NL80211_BSS_CAPABILITY
NL80211_BSS_CHAIN_SIGNAL = C.NL80211_BSS_CHAIN_SIGNAL
NL80211_BSS_CHAN_WIDTH_10 = C.NL80211_BSS_CHAN_WIDTH_10
NL80211_BSS_CHAN_WIDTH_1 = C.NL80211_BSS_CHAN_WIDTH_1
NL80211_BSS_CHAN_WIDTH_20 = C.NL80211_BSS_CHAN_WIDTH_20
NL80211_BSS_CHAN_WIDTH_2 = C.NL80211_BSS_CHAN_WIDTH_2
NL80211_BSS_CHAN_WIDTH_5 = C.NL80211_BSS_CHAN_WIDTH_5
NL80211_BSS_CHAN_WIDTH = C.NL80211_BSS_CHAN_WIDTH
NL80211_BSS_FREQUENCY = C.NL80211_BSS_FREQUENCY
NL80211_BSS_FREQUENCY_OFFSET = C.NL80211_BSS_FREQUENCY_OFFSET
NL80211_BSS_INFORMATION_ELEMENTS = C.NL80211_BSS_INFORMATION_ELEMENTS
NL80211_BSS_LAST_SEEN_BOOTTIME = C.NL80211_BSS_LAST_SEEN_BOOTTIME
NL80211_BSS_MAX = C.NL80211_BSS_MAX
NL80211_BSS_MLD_ADDR = C.NL80211_BSS_MLD_ADDR
NL80211_BSS_MLO_LINK_ID = C.NL80211_BSS_MLO_LINK_ID
NL80211_BSS_PAD = C.NL80211_BSS_PAD
NL80211_BSS_PARENT_BSSID = C.NL80211_BSS_PARENT_BSSID
NL80211_BSS_PARENT_TSF = C.NL80211_BSS_PARENT_TSF
NL80211_BSS_PRESP_DATA = C.NL80211_BSS_PRESP_DATA
NL80211_BSS_SEEN_MS_AGO = C.NL80211_BSS_SEEN_MS_AGO
NL80211_BSS_SELECT_ATTR_BAND_PREF = C.NL80211_BSS_SELECT_ATTR_BAND_PREF
NL80211_BSS_SELECT_ATTR_MAX = C.NL80211_BSS_SELECT_ATTR_MAX
NL80211_BSS_SELECT_ATTR_RSSI_ADJUST = C.NL80211_BSS_SELECT_ATTR_RSSI_ADJUST
NL80211_BSS_SELECT_ATTR_RSSI = C.NL80211_BSS_SELECT_ATTR_RSSI
NL80211_BSS_SIGNAL_MBM = C.NL80211_BSS_SIGNAL_MBM
NL80211_BSS_SIGNAL_UNSPEC = C.NL80211_BSS_SIGNAL_UNSPEC
NL80211_BSS_STATUS_ASSOCIATED = C.NL80211_BSS_STATUS_ASSOCIATED
NL80211_BSS_STATUS_AUTHENTICATED = C.NL80211_BSS_STATUS_AUTHENTICATED
NL80211_BSS_STATUS = C.NL80211_BSS_STATUS
NL80211_BSS_STATUS_IBSS_JOINED = C.NL80211_BSS_STATUS_IBSS_JOINED
NL80211_BSS_TSF = C.NL80211_BSS_TSF
NL80211_CHAN_HT20 = C.NL80211_CHAN_HT20
NL80211_CHAN_HT40MINUS = C.NL80211_CHAN_HT40MINUS
NL80211_CHAN_HT40PLUS = C.NL80211_CHAN_HT40PLUS
NL80211_CHAN_NO_HT = C.NL80211_CHAN_NO_HT
NL80211_CHAN_WIDTH_10 = C.NL80211_CHAN_WIDTH_10
NL80211_CHAN_WIDTH_160 = C.NL80211_CHAN_WIDTH_160
NL80211_CHAN_WIDTH_16 = C.NL80211_CHAN_WIDTH_16
NL80211_CHAN_WIDTH_1 = C.NL80211_CHAN_WIDTH_1
NL80211_CHAN_WIDTH_20 = C.NL80211_CHAN_WIDTH_20
NL80211_CHAN_WIDTH_20_NOHT = C.NL80211_CHAN_WIDTH_20_NOHT
NL80211_CHAN_WIDTH_2 = C.NL80211_CHAN_WIDTH_2
NL80211_CHAN_WIDTH_320 = C.NL80211_CHAN_WIDTH_320
NL80211_CHAN_WIDTH_40 = C.NL80211_CHAN_WIDTH_40
NL80211_CHAN_WIDTH_4 = C.NL80211_CHAN_WIDTH_4
NL80211_CHAN_WIDTH_5 = C.NL80211_CHAN_WIDTH_5
NL80211_CHAN_WIDTH_80 = C.NL80211_CHAN_WIDTH_80
NL80211_CHAN_WIDTH_80P80 = C.NL80211_CHAN_WIDTH_80P80
NL80211_CHAN_WIDTH_8 = C.NL80211_CHAN_WIDTH_8
NL80211_CMD_ABORT_SCAN = C.NL80211_CMD_ABORT_SCAN
NL80211_CMD_ACTION = C.NL80211_CMD_ACTION
NL80211_CMD_ACTION_TX_STATUS = C.NL80211_CMD_ACTION_TX_STATUS
NL80211_CMD_ADD_LINK = C.NL80211_CMD_ADD_LINK
NL80211_CMD_ADD_LINK_STA = C.NL80211_CMD_ADD_LINK_STA
NL80211_CMD_ADD_NAN_FUNCTION = C.NL80211_CMD_ADD_NAN_FUNCTION
NL80211_CMD_ADD_TX_TS = C.NL80211_CMD_ADD_TX_TS
NL80211_CMD_ASSOC_COMEBACK = C.NL80211_CMD_ASSOC_COMEBACK
NL80211_CMD_ASSOCIATE = C.NL80211_CMD_ASSOCIATE
NL80211_CMD_AUTHENTICATE = C.NL80211_CMD_AUTHENTICATE
NL80211_CMD_CANCEL_REMAIN_ON_CHANNEL = C.NL80211_CMD_CANCEL_REMAIN_ON_CHANNEL
NL80211_CMD_CHANGE_NAN_CONFIG = C.NL80211_CMD_CHANGE_NAN_CONFIG
NL80211_CMD_CHANNEL_SWITCH = C.NL80211_CMD_CHANNEL_SWITCH
NL80211_CMD_CH_SWITCH_NOTIFY = C.NL80211_CMD_CH_SWITCH_NOTIFY
NL80211_CMD_CH_SWITCH_STARTED_NOTIFY = C.NL80211_CMD_CH_SWITCH_STARTED_NOTIFY
NL80211_CMD_COLOR_CHANGE_ABORTED = C.NL80211_CMD_COLOR_CHANGE_ABORTED
NL80211_CMD_COLOR_CHANGE_COMPLETED = C.NL80211_CMD_COLOR_CHANGE_COMPLETED
NL80211_CMD_COLOR_CHANGE_REQUEST = C.NL80211_CMD_COLOR_CHANGE_REQUEST
NL80211_CMD_COLOR_CHANGE_STARTED = C.NL80211_CMD_COLOR_CHANGE_STARTED
NL80211_CMD_CONNECT = C.NL80211_CMD_CONNECT
NL80211_CMD_CONN_FAILED = C.NL80211_CMD_CONN_FAILED
NL80211_CMD_CONTROL_PORT_FRAME = C.NL80211_CMD_CONTROL_PORT_FRAME
NL80211_CMD_CONTROL_PORT_FRAME_TX_STATUS = C.NL80211_CMD_CONTROL_PORT_FRAME_TX_STATUS
NL80211_CMD_CRIT_PROTOCOL_START = C.NL80211_CMD_CRIT_PROTOCOL_START
NL80211_CMD_CRIT_PROTOCOL_STOP = C.NL80211_CMD_CRIT_PROTOCOL_STOP
NL80211_CMD_DEAUTHENTICATE = C.NL80211_CMD_DEAUTHENTICATE
NL80211_CMD_DEL_BEACON = C.NL80211_CMD_DEL_BEACON
NL80211_CMD_DEL_INTERFACE = C.NL80211_CMD_DEL_INTERFACE
NL80211_CMD_DEL_KEY = C.NL80211_CMD_DEL_KEY
NL80211_CMD_DEL_MPATH = C.NL80211_CMD_DEL_MPATH
NL80211_CMD_DEL_NAN_FUNCTION = C.NL80211_CMD_DEL_NAN_FUNCTION
NL80211_CMD_DEL_PMK = C.NL80211_CMD_DEL_PMK
NL80211_CMD_DEL_PMKSA = C.NL80211_CMD_DEL_PMKSA
NL80211_CMD_DEL_STATION = C.NL80211_CMD_DEL_STATION
NL80211_CMD_DEL_TX_TS = C.NL80211_CMD_DEL_TX_TS
NL80211_CMD_DEL_WIPHY = C.NL80211_CMD_DEL_WIPHY
NL80211_CMD_DISASSOCIATE = C.NL80211_CMD_DISASSOCIATE
NL80211_CMD_DISCONNECT = C.NL80211_CMD_DISCONNECT
NL80211_CMD_EXTERNAL_AUTH = C.NL80211_CMD_EXTERNAL_AUTH
NL80211_CMD_FLUSH_PMKSA = C.NL80211_CMD_FLUSH_PMKSA
NL80211_CMD_FRAME = C.NL80211_CMD_FRAME
NL80211_CMD_FRAME_TX_STATUS = C.NL80211_CMD_FRAME_TX_STATUS
NL80211_CMD_FRAME_WAIT_CANCEL = C.NL80211_CMD_FRAME_WAIT_CANCEL
NL80211_CMD_FT_EVENT = C.NL80211_CMD_FT_EVENT
NL80211_CMD_GET_BEACON = C.NL80211_CMD_GET_BEACON
NL80211_CMD_GET_COALESCE = C.NL80211_CMD_GET_COALESCE
NL80211_CMD_GET_FTM_RESPONDER_STATS = C.NL80211_CMD_GET_FTM_RESPONDER_STATS
NL80211_CMD_GET_INTERFACE = C.NL80211_CMD_GET_INTERFACE
NL80211_CMD_GET_KEY = C.NL80211_CMD_GET_KEY
NL80211_CMD_GET_MESH_CONFIG = C.NL80211_CMD_GET_MESH_CONFIG
NL80211_CMD_GET_MESH_PARAMS = C.NL80211_CMD_GET_MESH_PARAMS
NL80211_CMD_GET_MPATH = C.NL80211_CMD_GET_MPATH
NL80211_CMD_GET_MPP = C.NL80211_CMD_GET_MPP
NL80211_CMD_GET_POWER_SAVE = C.NL80211_CMD_GET_POWER_SAVE
NL80211_CMD_GET_PROTOCOL_FEATURES = C.NL80211_CMD_GET_PROTOCOL_FEATURES
NL80211_CMD_GET_REG = C.NL80211_CMD_GET_REG
NL80211_CMD_GET_SCAN = C.NL80211_CMD_GET_SCAN
NL80211_CMD_GET_STATION = C.NL80211_CMD_GET_STATION
NL80211_CMD_GET_SURVEY = C.NL80211_CMD_GET_SURVEY
NL80211_CMD_GET_WIPHY = C.NL80211_CMD_GET_WIPHY
NL80211_CMD_GET_WOWLAN = C.NL80211_CMD_GET_WOWLAN
NL80211_CMD_JOIN_IBSS = C.NL80211_CMD_JOIN_IBSS
NL80211_CMD_JOIN_MESH = C.NL80211_CMD_JOIN_MESH
NL80211_CMD_JOIN_OCB = C.NL80211_CMD_JOIN_OCB
NL80211_CMD_LEAVE_IBSS = C.NL80211_CMD_LEAVE_IBSS
NL80211_CMD_LEAVE_MESH = C.NL80211_CMD_LEAVE_MESH
NL80211_CMD_LEAVE_OCB = C.NL80211_CMD_LEAVE_OCB
NL80211_CMD_MAX = C.NL80211_CMD_MAX
NL80211_CMD_MICHAEL_MIC_FAILURE = C.NL80211_CMD_MICHAEL_MIC_FAILURE
NL80211_CMD_MODIFY_LINK_STA = C.NL80211_CMD_MODIFY_LINK_STA
NL80211_CMD_NAN_MATCH = C.NL80211_CMD_NAN_MATCH
NL80211_CMD_NEW_BEACON = C.NL80211_CMD_NEW_BEACON
NL80211_CMD_NEW_INTERFACE = C.NL80211_CMD_NEW_INTERFACE
NL80211_CMD_NEW_KEY = C.NL80211_CMD_NEW_KEY
NL80211_CMD_NEW_MPATH = C.NL80211_CMD_NEW_MPATH
NL80211_CMD_NEW_PEER_CANDIDATE = C.NL80211_CMD_NEW_PEER_CANDIDATE
NL80211_CMD_NEW_SCAN_RESULTS = C.NL80211_CMD_NEW_SCAN_RESULTS
NL80211_CMD_NEW_STATION = C.NL80211_CMD_NEW_STATION
NL80211_CMD_NEW_SURVEY_RESULTS = C.NL80211_CMD_NEW_SURVEY_RESULTS
NL80211_CMD_NEW_WIPHY = C.NL80211_CMD_NEW_WIPHY
NL80211_CMD_NOTIFY_CQM = C.NL80211_CMD_NOTIFY_CQM
NL80211_CMD_NOTIFY_RADAR = C.NL80211_CMD_NOTIFY_RADAR
NL80211_CMD_OBSS_COLOR_COLLISION = C.NL80211_CMD_OBSS_COLOR_COLLISION
NL80211_CMD_PEER_MEASUREMENT_COMPLETE = C.NL80211_CMD_PEER_MEASUREMENT_COMPLETE
NL80211_CMD_PEER_MEASUREMENT_RESULT = C.NL80211_CMD_PEER_MEASUREMENT_RESULT
NL80211_CMD_PEER_MEASUREMENT_START = C.NL80211_CMD_PEER_MEASUREMENT_START
NL80211_CMD_PMKSA_CANDIDATE = C.NL80211_CMD_PMKSA_CANDIDATE
NL80211_CMD_PORT_AUTHORIZED = C.NL80211_CMD_PORT_AUTHORIZED
NL80211_CMD_PROBE_CLIENT = C.NL80211_CMD_PROBE_CLIENT
NL80211_CMD_PROBE_MESH_LINK = C.NL80211_CMD_PROBE_MESH_LINK
NL80211_CMD_RADAR_DETECT = C.NL80211_CMD_RADAR_DETECT
NL80211_CMD_REG_BEACON_HINT = C.NL80211_CMD_REG_BEACON_HINT
NL80211_CMD_REG_CHANGE = C.NL80211_CMD_REG_CHANGE
NL80211_CMD_REGISTER_ACTION = C.NL80211_CMD_REGISTER_ACTION
NL80211_CMD_REGISTER_BEACONS = C.NL80211_CMD_REGISTER_BEACONS
NL80211_CMD_REGISTER_FRAME = C.NL80211_CMD_REGISTER_FRAME
NL80211_CMD_RELOAD_REGDB = C.NL80211_CMD_RELOAD_REGDB
NL80211_CMD_REMAIN_ON_CHANNEL = C.NL80211_CMD_REMAIN_ON_CHANNEL
NL80211_CMD_REMOVE_LINK = C.NL80211_CMD_REMOVE_LINK
NL80211_CMD_REMOVE_LINK_STA = C.NL80211_CMD_REMOVE_LINK_STA
NL80211_CMD_REQ_SET_REG = C.NL80211_CMD_REQ_SET_REG
NL80211_CMD_ROAM = C.NL80211_CMD_ROAM
NL80211_CMD_SCAN_ABORTED = C.NL80211_CMD_SCAN_ABORTED
NL80211_CMD_SCHED_SCAN_RESULTS = C.NL80211_CMD_SCHED_SCAN_RESULTS
NL80211_CMD_SCHED_SCAN_STOPPED = C.NL80211_CMD_SCHED_SCAN_STOPPED
NL80211_CMD_SET_BEACON = C.NL80211_CMD_SET_BEACON
NL80211_CMD_SET_BSS = C.NL80211_CMD_SET_BSS
NL80211_CMD_SET_CHANNEL = C.NL80211_CMD_SET_CHANNEL
NL80211_CMD_SET_COALESCE = C.NL80211_CMD_SET_COALESCE
NL80211_CMD_SET_CQM = C.NL80211_CMD_SET_CQM
NL80211_CMD_SET_FILS_AAD = C.NL80211_CMD_SET_FILS_AAD
NL80211_CMD_SET_INTERFACE = C.NL80211_CMD_SET_INTERFACE
NL80211_CMD_SET_KEY = C.NL80211_CMD_SET_KEY
NL80211_CMD_SET_MAC_ACL = C.NL80211_CMD_SET_MAC_ACL
NL80211_CMD_SET_MCAST_RATE = C.NL80211_CMD_SET_MCAST_RATE
NL80211_CMD_SET_MESH_CONFIG = C.NL80211_CMD_SET_MESH_CONFIG
NL80211_CMD_SET_MESH_PARAMS = C.NL80211_CMD_SET_MESH_PARAMS
NL80211_CMD_SET_MGMT_EXTRA_IE = C.NL80211_CMD_SET_MGMT_EXTRA_IE
NL80211_CMD_SET_MPATH = C.NL80211_CMD_SET_MPATH
NL80211_CMD_SET_MULTICAST_TO_UNICAST = C.NL80211_CMD_SET_MULTICAST_TO_UNICAST
NL80211_CMD_SET_NOACK_MAP = C.NL80211_CMD_SET_NOACK_MAP
NL80211_CMD_SET_PMK = C.NL80211_CMD_SET_PMK
NL80211_CMD_SET_PMKSA = C.NL80211_CMD_SET_PMKSA
NL80211_CMD_SET_POWER_SAVE = C.NL80211_CMD_SET_POWER_SAVE
NL80211_CMD_SET_QOS_MAP = C.NL80211_CMD_SET_QOS_MAP
NL80211_CMD_SET_REG = C.NL80211_CMD_SET_REG
NL80211_CMD_SET_REKEY_OFFLOAD = C.NL80211_CMD_SET_REKEY_OFFLOAD
NL80211_CMD_SET_SAR_SPECS = C.NL80211_CMD_SET_SAR_SPECS
NL80211_CMD_SET_STATION = C.NL80211_CMD_SET_STATION
NL80211_CMD_SET_TID_CONFIG = C.NL80211_CMD_SET_TID_CONFIG
NL80211_CMD_SET_TX_BITRATE_MASK = C.NL80211_CMD_SET_TX_BITRATE_MASK
NL80211_CMD_SET_WDS_PEER = C.NL80211_CMD_SET_WDS_PEER
NL80211_CMD_SET_WIPHY = C.NL80211_CMD_SET_WIPHY
NL80211_CMD_SET_WIPHY_NETNS = C.NL80211_CMD_SET_WIPHY_NETNS
NL80211_CMD_SET_WOWLAN = C.NL80211_CMD_SET_WOWLAN
NL80211_CMD_STA_OPMODE_CHANGED = C.NL80211_CMD_STA_OPMODE_CHANGED
NL80211_CMD_START_AP = C.NL80211_CMD_START_AP
NL80211_CMD_START_NAN = C.NL80211_CMD_START_NAN
NL80211_CMD_START_P2P_DEVICE = C.NL80211_CMD_START_P2P_DEVICE
NL80211_CMD_START_SCHED_SCAN = C.NL80211_CMD_START_SCHED_SCAN
NL80211_CMD_STOP_AP = C.NL80211_CMD_STOP_AP
NL80211_CMD_STOP_NAN = C.NL80211_CMD_STOP_NAN
NL80211_CMD_STOP_P2P_DEVICE = C.NL80211_CMD_STOP_P2P_DEVICE
NL80211_CMD_STOP_SCHED_SCAN = C.NL80211_CMD_STOP_SCHED_SCAN
NL80211_CMD_TDLS_CANCEL_CHANNEL_SWITCH = C.NL80211_CMD_TDLS_CANCEL_CHANNEL_SWITCH
NL80211_CMD_TDLS_CHANNEL_SWITCH = C.NL80211_CMD_TDLS_CHANNEL_SWITCH
NL80211_CMD_TDLS_MGMT = C.NL80211_CMD_TDLS_MGMT
NL80211_CMD_TDLS_OPER = C.NL80211_CMD_TDLS_OPER
NL80211_CMD_TESTMODE = C.NL80211_CMD_TESTMODE
NL80211_CMD_TRIGGER_SCAN = C.NL80211_CMD_TRIGGER_SCAN
NL80211_CMD_UNEXPECTED_4ADDR_FRAME = C.NL80211_CMD_UNEXPECTED_4ADDR_FRAME
NL80211_CMD_UNEXPECTED_FRAME = C.NL80211_CMD_UNEXPECTED_FRAME
NL80211_CMD_UNPROT_BEACON = C.NL80211_CMD_UNPROT_BEACON
NL80211_CMD_UNPROT_DEAUTHENTICATE = C.NL80211_CMD_UNPROT_DEAUTHENTICATE
NL80211_CMD_UNPROT_DISASSOCIATE = C.NL80211_CMD_UNPROT_DISASSOCIATE
NL80211_CMD_UNSPEC = C.NL80211_CMD_UNSPEC
NL80211_CMD_UPDATE_CONNECT_PARAMS = C.NL80211_CMD_UPDATE_CONNECT_PARAMS
NL80211_CMD_UPDATE_FT_IES = C.NL80211_CMD_UPDATE_FT_IES
NL80211_CMD_UPDATE_OWE_INFO = C.NL80211_CMD_UPDATE_OWE_INFO
NL80211_CMD_VENDOR = C.NL80211_CMD_VENDOR
NL80211_CMD_WIPHY_REG_CHANGE = C.NL80211_CMD_WIPHY_REG_CHANGE
NL80211_COALESCE_CONDITION_MATCH = C.NL80211_COALESCE_CONDITION_MATCH
NL80211_COALESCE_CONDITION_NO_MATCH = C.NL80211_COALESCE_CONDITION_NO_MATCH
NL80211_CONN_FAIL_BLOCKED_CLIENT = C.NL80211_CONN_FAIL_BLOCKED_CLIENT
NL80211_CONN_FAIL_MAX_CLIENTS = C.NL80211_CONN_FAIL_MAX_CLIENTS
NL80211_CQM_RSSI_BEACON_LOSS_EVENT = C.NL80211_CQM_RSSI_BEACON_LOSS_EVENT
NL80211_CQM_RSSI_THRESHOLD_EVENT_HIGH = C.NL80211_CQM_RSSI_THRESHOLD_EVENT_HIGH
NL80211_CQM_RSSI_THRESHOLD_EVENT_LOW = C.NL80211_CQM_RSSI_THRESHOLD_EVENT_LOW
NL80211_CQM_TXE_MAX_INTVL = C.NL80211_CQM_TXE_MAX_INTVL
NL80211_CRIT_PROTO_APIPA = C.NL80211_CRIT_PROTO_APIPA
NL80211_CRIT_PROTO_DHCP = C.NL80211_CRIT_PROTO_DHCP
NL80211_CRIT_PROTO_EAPOL = C.NL80211_CRIT_PROTO_EAPOL
NL80211_CRIT_PROTO_MAX_DURATION = C.NL80211_CRIT_PROTO_MAX_DURATION
NL80211_CRIT_PROTO_UNSPEC = C.NL80211_CRIT_PROTO_UNSPEC
NL80211_DFS_AVAILABLE = C.NL80211_DFS_AVAILABLE
NL80211_DFS_ETSI = C.NL80211_DFS_ETSI
NL80211_DFS_FCC = C.NL80211_DFS_FCC
NL80211_DFS_JP = C.NL80211_DFS_JP
NL80211_DFS_UNAVAILABLE = C.NL80211_DFS_UNAVAILABLE
NL80211_DFS_UNSET = C.NL80211_DFS_UNSET
NL80211_DFS_USABLE = C.NL80211_DFS_USABLE
NL80211_EDMG_BW_CONFIG_MAX = C.NL80211_EDMG_BW_CONFIG_MAX
NL80211_EDMG_BW_CONFIG_MIN = C.NL80211_EDMG_BW_CONFIG_MIN
NL80211_EDMG_CHANNELS_MAX = C.NL80211_EDMG_CHANNELS_MAX
NL80211_EDMG_CHANNELS_MIN = C.NL80211_EDMG_CHANNELS_MIN
NL80211_EHT_MAX_CAPABILITY_LEN = C.NL80211_EHT_MAX_CAPABILITY_LEN
NL80211_EHT_MIN_CAPABILITY_LEN = C.NL80211_EHT_MIN_CAPABILITY_LEN
NL80211_EXTERNAL_AUTH_ABORT = C.NL80211_EXTERNAL_AUTH_ABORT
NL80211_EXTERNAL_AUTH_START = C.NL80211_EXTERNAL_AUTH_START
NL80211_EXT_FEATURE_4WAY_HANDSHAKE_AP_PSK = C.NL80211_EXT_FEATURE_4WAY_HANDSHAKE_AP_PSK
NL80211_EXT_FEATURE_4WAY_HANDSHAKE_STA_1X = C.NL80211_EXT_FEATURE_4WAY_HANDSHAKE_STA_1X
NL80211_EXT_FEATURE_4WAY_HANDSHAKE_STA_PSK = C.NL80211_EXT_FEATURE_4WAY_HANDSHAKE_STA_PSK
NL80211_EXT_FEATURE_ACCEPT_BCAST_PROBE_RESP = C.NL80211_EXT_FEATURE_ACCEPT_BCAST_PROBE_RESP
NL80211_EXT_FEATURE_ACK_SIGNAL_SUPPORT = C.NL80211_EXT_FEATURE_ACK_SIGNAL_SUPPORT
NL80211_EXT_FEATURE_AIRTIME_FAIRNESS = C.NL80211_EXT_FEATURE_AIRTIME_FAIRNESS
NL80211_EXT_FEATURE_AP_PMKSA_CACHING = C.NL80211_EXT_FEATURE_AP_PMKSA_CACHING
NL80211_EXT_FEATURE_AQL = C.NL80211_EXT_FEATURE_AQL
NL80211_EXT_FEATURE_BEACON_PROTECTION_CLIENT = C.NL80211_EXT_FEATURE_BEACON_PROTECTION_CLIENT
NL80211_EXT_FEATURE_BEACON_PROTECTION = C.NL80211_EXT_FEATURE_BEACON_PROTECTION
NL80211_EXT_FEATURE_BEACON_RATE_HE = C.NL80211_EXT_FEATURE_BEACON_RATE_HE
NL80211_EXT_FEATURE_BEACON_RATE_HT = C.NL80211_EXT_FEATURE_BEACON_RATE_HT
NL80211_EXT_FEATURE_BEACON_RATE_LEGACY = C.NL80211_EXT_FEATURE_BEACON_RATE_LEGACY
NL80211_EXT_FEATURE_BEACON_RATE_VHT = C.NL80211_EXT_FEATURE_BEACON_RATE_VHT
NL80211_EXT_FEATURE_BSS_COLOR = C.NL80211_EXT_FEATURE_BSS_COLOR
NL80211_EXT_FEATURE_BSS_PARENT_TSF = C.NL80211_EXT_FEATURE_BSS_PARENT_TSF
NL80211_EXT_FEATURE_CAN_REPLACE_PTK0 = C.NL80211_EXT_FEATURE_CAN_REPLACE_PTK0
NL80211_EXT_FEATURE_CONTROL_PORT_NO_PREAUTH = C.NL80211_EXT_FEATURE_CONTROL_PORT_NO_PREAUTH
NL80211_EXT_FEATURE_CONTROL_PORT_OVER_NL80211 = C.NL80211_EXT_FEATURE_CONTROL_PORT_OVER_NL80211
NL80211_EXT_FEATURE_CONTROL_PORT_OVER_NL80211_TX_STATUS = C.NL80211_EXT_FEATURE_CONTROL_PORT_OVER_NL80211_TX_STATUS
NL80211_EXT_FEATURE_CQM_RSSI_LIST = C.NL80211_EXT_FEATURE_CQM_RSSI_LIST
NL80211_EXT_FEATURE_DATA_ACK_SIGNAL_SUPPORT = C.NL80211_EXT_FEATURE_DATA_ACK_SIGNAL_SUPPORT
NL80211_EXT_FEATURE_DEL_IBSS_STA = C.NL80211_EXT_FEATURE_DEL_IBSS_STA
NL80211_EXT_FEATURE_DFS_OFFLOAD = C.NL80211_EXT_FEATURE_DFS_OFFLOAD
NL80211_EXT_FEATURE_ENABLE_FTM_RESPONDER = C.NL80211_EXT_FEATURE_ENABLE_FTM_RESPONDER
NL80211_EXT_FEATURE_EXT_KEY_ID = C.NL80211_EXT_FEATURE_EXT_KEY_ID
NL80211_EXT_FEATURE_FILS_CRYPTO_OFFLOAD = C.NL80211_EXT_FEATURE_FILS_CRYPTO_OFFLOAD
NL80211_EXT_FEATURE_FILS_DISCOVERY = C.NL80211_EXT_FEATURE_FILS_DISCOVERY
NL80211_EXT_FEATURE_FILS_MAX_CHANNEL_TIME = C.NL80211_EXT_FEATURE_FILS_MAX_CHANNEL_TIME
NL80211_EXT_FEATURE_FILS_SK_OFFLOAD = C.NL80211_EXT_FEATURE_FILS_SK_OFFLOAD
NL80211_EXT_FEATURE_FILS_STA = C.NL80211_EXT_FEATURE_FILS_STA
NL80211_EXT_FEATURE_HIGH_ACCURACY_SCAN = C.NL80211_EXT_FEATURE_HIGH_ACCURACY_SCAN
NL80211_EXT_FEATURE_LOW_POWER_SCAN = C.NL80211_EXT_FEATURE_LOW_POWER_SCAN
NL80211_EXT_FEATURE_LOW_SPAN_SCAN = C.NL80211_EXT_FEATURE_LOW_SPAN_SCAN
NL80211_EXT_FEATURE_MFP_OPTIONAL = C.NL80211_EXT_FEATURE_MFP_OPTIONAL
NL80211_EXT_FEATURE_MGMT_TX_RANDOM_TA = C.NL80211_EXT_FEATURE_MGMT_TX_RANDOM_TA
NL80211_EXT_FEATURE_MGMT_TX_RANDOM_TA_CONNECTED = C.NL80211_EXT_FEATURE_MGMT_TX_RANDOM_TA_CONNECTED
NL80211_EXT_FEATURE_MULTICAST_REGISTRATIONS = C.NL80211_EXT_FEATURE_MULTICAST_REGISTRATIONS
NL80211_EXT_FEATURE_MU_MIMO_AIR_SNIFFER = C.NL80211_EXT_FEATURE_MU_MIMO_AIR_SNIFFER
NL80211_EXT_FEATURE_OCE_PROBE_REQ_DEFERRAL_SUPPRESSION = C.NL80211_EXT_FEATURE_OCE_PROBE_REQ_DEFERRAL_SUPPRESSION
NL80211_EXT_FEATURE_OCE_PROBE_REQ_HIGH_TX_RATE = C.NL80211_EXT_FEATURE_OCE_PROBE_REQ_HIGH_TX_RATE
NL80211_EXT_FEATURE_OPERATING_CHANNEL_VALIDATION = C.NL80211_EXT_FEATURE_OPERATING_CHANNEL_VALIDATION
NL80211_EXT_FEATURE_POWERED_ADDR_CHANGE = C.NL80211_EXT_FEATURE_POWERED_ADDR_CHANGE
NL80211_EXT_FEATURE_PROTECTED_TWT = C.NL80211_EXT_FEATURE_PROTECTED_TWT
NL80211_EXT_FEATURE_PROT_RANGE_NEGO_AND_MEASURE = C.NL80211_EXT_FEATURE_PROT_RANGE_NEGO_AND_MEASURE
NL80211_EXT_FEATURE_RADAR_BACKGROUND = C.NL80211_EXT_FEATURE_RADAR_BACKGROUND
NL80211_EXT_FEATURE_RRM = C.NL80211_EXT_FEATURE_RRM
NL80211_EXT_FEATURE_SAE_OFFLOAD_AP = C.NL80211_EXT_FEATURE_SAE_OFFLOAD_AP
NL80211_EXT_FEATURE_SAE_OFFLOAD = C.NL80211_EXT_FEATURE_SAE_OFFLOAD
NL80211_EXT_FEATURE_SCAN_FREQ_KHZ = C.NL80211_EXT_FEATURE_SCAN_FREQ_KHZ
NL80211_EXT_FEATURE_SCAN_MIN_PREQ_CONTENT = C.NL80211_EXT_FEATURE_SCAN_MIN_PREQ_CONTENT
NL80211_EXT_FEATURE_SCAN_RANDOM_SN = C.NL80211_EXT_FEATURE_SCAN_RANDOM_SN
NL80211_EXT_FEATURE_SCAN_START_TIME = C.NL80211_EXT_FEATURE_SCAN_START_TIME
NL80211_EXT_FEATURE_SCHED_SCAN_BAND_SPECIFIC_RSSI_THOLD = C.NL80211_EXT_FEATURE_SCHED_SCAN_BAND_SPECIFIC_RSSI_THOLD
NL80211_EXT_FEATURE_SCHED_SCAN_RELATIVE_RSSI = C.NL80211_EXT_FEATURE_SCHED_SCAN_RELATIVE_RSSI
NL80211_EXT_FEATURE_SECURE_LTF = C.NL80211_EXT_FEATURE_SECURE_LTF
NL80211_EXT_FEATURE_SECURE_RTT = C.NL80211_EXT_FEATURE_SECURE_RTT
NL80211_EXT_FEATURE_SET_SCAN_DWELL = C.NL80211_EXT_FEATURE_SET_SCAN_DWELL
NL80211_EXT_FEATURE_STA_TX_PWR = C.NL80211_EXT_FEATURE_STA_TX_PWR
NL80211_EXT_FEATURE_TXQS = C.NL80211_EXT_FEATURE_TXQS
NL80211_EXT_FEATURE_UNSOL_BCAST_PROBE_RESP = C.NL80211_EXT_FEATURE_UNSOL_BCAST_PROBE_RESP
NL80211_EXT_FEATURE_VHT_IBSS = C.NL80211_EXT_FEATURE_VHT_IBSS
NL80211_EXT_FEATURE_VLAN_OFFLOAD = C.NL80211_EXT_FEATURE_VLAN_OFFLOAD
NL80211_FEATURE_ACKTO_ESTIMATION = C.NL80211_FEATURE_ACKTO_ESTIMATION
NL80211_FEATURE_ACTIVE_MONITOR = C.NL80211_FEATURE_ACTIVE_MONITOR
NL80211_FEATURE_ADVERTISE_CHAN_LIMITS = C.NL80211_FEATURE_ADVERTISE_CHAN_LIMITS
NL80211_FEATURE_AP_MODE_CHAN_WIDTH_CHANGE = C.NL80211_FEATURE_AP_MODE_CHAN_WIDTH_CHANGE
NL80211_FEATURE_AP_SCAN = C.NL80211_FEATURE_AP_SCAN
NL80211_FEATURE_CELL_BASE_REG_HINTS = C.NL80211_FEATURE_CELL_BASE_REG_HINTS
NL80211_FEATURE_DS_PARAM_SET_IE_IN_PROBES = C.NL80211_FEATURE_DS_PARAM_SET_IE_IN_PROBES
NL80211_FEATURE_DYNAMIC_SMPS = C.NL80211_FEATURE_DYNAMIC_SMPS
NL80211_FEATURE_FULL_AP_CLIENT_STATE = C.NL80211_FEATURE_FULL_AP_CLIENT_STATE
NL80211_FEATURE_HT_IBSS = C.NL80211_FEATURE_HT_IBSS
NL80211_FEATURE_INACTIVITY_TIMER = C.NL80211_FEATURE_INACTIVITY_TIMER
NL80211_FEATURE_LOW_PRIORITY_SCAN = C.NL80211_FEATURE_LOW_PRIORITY_SCAN
NL80211_FEATURE_MAC_ON_CREATE = C.NL80211_FEATURE_MAC_ON_CREATE
NL80211_FEATURE_ND_RANDOM_MAC_ADDR = C.NL80211_FEATURE_ND_RANDOM_MAC_ADDR
NL80211_FEATURE_NEED_OBSS_SCAN = C.NL80211_FEATURE_NEED_OBSS_SCAN
NL80211_FEATURE_P2P_DEVICE_NEEDS_CHANNEL = C.NL80211_FEATURE_P2P_DEVICE_NEEDS_CHANNEL
NL80211_FEATURE_P2P_GO_CTWIN = C.NL80211_FEATURE_P2P_GO_CTWIN
NL80211_FEATURE_P2P_GO_OPPPS = C.NL80211_FEATURE_P2P_GO_OPPPS
NL80211_FEATURE_QUIET = C.NL80211_FEATURE_QUIET
NL80211_FEATURE_SAE = C.NL80211_FEATURE_SAE
NL80211_FEATURE_SCAN_FLUSH = C.NL80211_FEATURE_SCAN_FLUSH
NL80211_FEATURE_SCAN_RANDOM_MAC_ADDR = C.NL80211_FEATURE_SCAN_RANDOM_MAC_ADDR
NL80211_FEATURE_SCHED_SCAN_RANDOM_MAC_ADDR = C.NL80211_FEATURE_SCHED_SCAN_RANDOM_MAC_ADDR
NL80211_FEATURE_SK_TX_STATUS = C.NL80211_FEATURE_SK_TX_STATUS
NL80211_FEATURE_STATIC_SMPS = C.NL80211_FEATURE_STATIC_SMPS
NL80211_FEATURE_SUPPORTS_WMM_ADMISSION = C.NL80211_FEATURE_SUPPORTS_WMM_ADMISSION
NL80211_FEATURE_TDLS_CHANNEL_SWITCH = C.NL80211_FEATURE_TDLS_CHANNEL_SWITCH
NL80211_FEATURE_TX_POWER_INSERTION = C.NL80211_FEATURE_TX_POWER_INSERTION
NL80211_FEATURE_USERSPACE_MPM = C.NL80211_FEATURE_USERSPACE_MPM
NL80211_FEATURE_VIF_TXPOWER = C.NL80211_FEATURE_VIF_TXPOWER
NL80211_FEATURE_WFA_TPC_IE_IN_PROBES = C.NL80211_FEATURE_WFA_TPC_IE_IN_PROBES
NL80211_FILS_DISCOVERY_ATTR_INT_MAX = C.NL80211_FILS_DISCOVERY_ATTR_INT_MAX
NL80211_FILS_DISCOVERY_ATTR_INT_MIN = C.NL80211_FILS_DISCOVERY_ATTR_INT_MIN
NL80211_FILS_DISCOVERY_ATTR_MAX = C.NL80211_FILS_DISCOVERY_ATTR_MAX
NL80211_FILS_DISCOVERY_ATTR_TMPL = C.NL80211_FILS_DISCOVERY_ATTR_TMPL
NL80211_FILS_DISCOVERY_TMPL_MIN_LEN = C.NL80211_FILS_DISCOVERY_TMPL_MIN_LEN
NL80211_FREQUENCY_ATTR_16MHZ = C.NL80211_FREQUENCY_ATTR_16MHZ
NL80211_FREQUENCY_ATTR_1MHZ = C.NL80211_FREQUENCY_ATTR_1MHZ
NL80211_FREQUENCY_ATTR_2MHZ = C.NL80211_FREQUENCY_ATTR_2MHZ
NL80211_FREQUENCY_ATTR_4MHZ = C.NL80211_FREQUENCY_ATTR_4MHZ
NL80211_FREQUENCY_ATTR_8MHZ = C.NL80211_FREQUENCY_ATTR_8MHZ
NL80211_FREQUENCY_ATTR_DFS_CAC_TIME = C.NL80211_FREQUENCY_ATTR_DFS_CAC_TIME
NL80211_FREQUENCY_ATTR_DFS_STATE = C.NL80211_FREQUENCY_ATTR_DFS_STATE
NL80211_FREQUENCY_ATTR_DFS_TIME = C.NL80211_FREQUENCY_ATTR_DFS_TIME
NL80211_FREQUENCY_ATTR_DISABLED = C.NL80211_FREQUENCY_ATTR_DISABLED
NL80211_FREQUENCY_ATTR_FREQ = C.NL80211_FREQUENCY_ATTR_FREQ
NL80211_FREQUENCY_ATTR_GO_CONCURRENT = C.NL80211_FREQUENCY_ATTR_GO_CONCURRENT
NL80211_FREQUENCY_ATTR_INDOOR_ONLY = C.NL80211_FREQUENCY_ATTR_INDOOR_ONLY
NL80211_FREQUENCY_ATTR_IR_CONCURRENT = C.NL80211_FREQUENCY_ATTR_IR_CONCURRENT
NL80211_FREQUENCY_ATTR_MAX = C.NL80211_FREQUENCY_ATTR_MAX
NL80211_FREQUENCY_ATTR_MAX_TX_POWER = C.NL80211_FREQUENCY_ATTR_MAX_TX_POWER
NL80211_FREQUENCY_ATTR_NO_10MHZ = C.NL80211_FREQUENCY_ATTR_NO_10MHZ
NL80211_FREQUENCY_ATTR_NO_160MHZ = C.NL80211_FREQUENCY_ATTR_NO_160MHZ
NL80211_FREQUENCY_ATTR_NO_20MHZ = C.NL80211_FREQUENCY_ATTR_NO_20MHZ
NL80211_FREQUENCY_ATTR_NO_320MHZ = C.NL80211_FREQUENCY_ATTR_NO_320MHZ
NL80211_FREQUENCY_ATTR_NO_80MHZ = C.NL80211_FREQUENCY_ATTR_NO_80MHZ
NL80211_FREQUENCY_ATTR_NO_EHT = C.NL80211_FREQUENCY_ATTR_NO_EHT
NL80211_FREQUENCY_ATTR_NO_HE = C.NL80211_FREQUENCY_ATTR_NO_HE
NL80211_FREQUENCY_ATTR_NO_HT40_MINUS = C.NL80211_FREQUENCY_ATTR_NO_HT40_MINUS
NL80211_FREQUENCY_ATTR_NO_HT40_PLUS = C.NL80211_FREQUENCY_ATTR_NO_HT40_PLUS
NL80211_FREQUENCY_ATTR_NO_IBSS = C.NL80211_FREQUENCY_ATTR_NO_IBSS
NL80211_FREQUENCY_ATTR_NO_IR = C.NL80211_FREQUENCY_ATTR_NO_IR
NL80211_FREQUENCY_ATTR_OFFSET = C.NL80211_FREQUENCY_ATTR_OFFSET
NL80211_FREQUENCY_ATTR_PASSIVE_SCAN = C.NL80211_FREQUENCY_ATTR_PASSIVE_SCAN
NL80211_FREQUENCY_ATTR_RADAR = C.NL80211_FREQUENCY_ATTR_RADAR
NL80211_FREQUENCY_ATTR_WMM = C.NL80211_FREQUENCY_ATTR_WMM
NL80211_FTM_RESP_ATTR_CIVICLOC = C.NL80211_FTM_RESP_ATTR_CIVICLOC
NL80211_FTM_RESP_ATTR_ENABLED = C.NL80211_FTM_RESP_ATTR_ENABLED
NL80211_FTM_RESP_ATTR_LCI = C.NL80211_FTM_RESP_ATTR_LCI
NL80211_FTM_RESP_ATTR_MAX = C.NL80211_FTM_RESP_ATTR_MAX
NL80211_FTM_STATS_ASAP_NUM = C.NL80211_FTM_STATS_ASAP_NUM
NL80211_FTM_STATS_FAILED_NUM = C.NL80211_FTM_STATS_FAILED_NUM
NL80211_FTM_STATS_MAX = C.NL80211_FTM_STATS_MAX
NL80211_FTM_STATS_NON_ASAP_NUM = C.NL80211_FTM_STATS_NON_ASAP_NUM
NL80211_FTM_STATS_OUT_OF_WINDOW_TRIGGERS_NUM = C.NL80211_FTM_STATS_OUT_OF_WINDOW_TRIGGERS_NUM
NL80211_FTM_STATS_PAD = C.NL80211_FTM_STATS_PAD
NL80211_FTM_STATS_PARTIAL_NUM = C.NL80211_FTM_STATS_PARTIAL_NUM
NL80211_FTM_STATS_RESCHEDULE_REQUESTS_NUM = C.NL80211_FTM_STATS_RESCHEDULE_REQUESTS_NUM
NL80211_FTM_STATS_SUCCESS_NUM = C.NL80211_FTM_STATS_SUCCESS_NUM
NL80211_FTM_STATS_TOTAL_DURATION_MSEC = C.NL80211_FTM_STATS_TOTAL_DURATION_MSEC
NL80211_FTM_STATS_UNKNOWN_TRIGGERS_NUM = C.NL80211_FTM_STATS_UNKNOWN_TRIGGERS_NUM
NL80211_GENL_NAME = C.NL80211_GENL_NAME
NL80211_HE_BSS_COLOR_ATTR_COLOR = C.NL80211_HE_BSS_COLOR_ATTR_COLOR
NL80211_HE_BSS_COLOR_ATTR_DISABLED = C.NL80211_HE_BSS_COLOR_ATTR_DISABLED
NL80211_HE_BSS_COLOR_ATTR_MAX = C.NL80211_HE_BSS_COLOR_ATTR_MAX
NL80211_HE_BSS_COLOR_ATTR_PARTIAL = C.NL80211_HE_BSS_COLOR_ATTR_PARTIAL
NL80211_HE_MAX_CAPABILITY_LEN = C.NL80211_HE_MAX_CAPABILITY_LEN
NL80211_HE_MIN_CAPABILITY_LEN = C.NL80211_HE_MIN_CAPABILITY_LEN
NL80211_HE_NSS_MAX = C.NL80211_HE_NSS_MAX
NL80211_HE_OBSS_PD_ATTR_BSS_COLOR_BITMAP = C.NL80211_HE_OBSS_PD_ATTR_BSS_COLOR_BITMAP
NL80211_HE_OBSS_PD_ATTR_MAX = C.NL80211_HE_OBSS_PD_ATTR_MAX
NL80211_HE_OBSS_PD_ATTR_MAX_OFFSET = C.NL80211_HE_OBSS_PD_ATTR_MAX_OFFSET
NL80211_HE_OBSS_PD_ATTR_MIN_OFFSET = C.NL80211_HE_OBSS_PD_ATTR_MIN_OFFSET
NL80211_HE_OBSS_PD_ATTR_NON_SRG_MAX_OFFSET = C.NL80211_HE_OBSS_PD_ATTR_NON_SRG_MAX_OFFSET
NL80211_HE_OBSS_PD_ATTR_PARTIAL_BSSID_BITMAP = C.NL80211_HE_OBSS_PD_ATTR_PARTIAL_BSSID_BITMAP
NL80211_HE_OBSS_PD_ATTR_SR_CTRL = C.NL80211_HE_OBSS_PD_ATTR_SR_CTRL
NL80211_HIDDEN_SSID_NOT_IN_USE = C.NL80211_HIDDEN_SSID_NOT_IN_USE
NL80211_HIDDEN_SSID_ZERO_CONTENTS = C.NL80211_HIDDEN_SSID_ZERO_CONTENTS
NL80211_HIDDEN_SSID_ZERO_LEN = C.NL80211_HIDDEN_SSID_ZERO_LEN
NL80211_HT_CAPABILITY_LEN = C.NL80211_HT_CAPABILITY_LEN
NL80211_IFACE_COMB_BI_MIN_GCD = C.NL80211_IFACE_COMB_BI_MIN_GCD
NL80211_IFACE_COMB_LIMITS = C.NL80211_IFACE_COMB_LIMITS
NL80211_IFACE_COMB_MAXNUM = C.NL80211_IFACE_COMB_MAXNUM
NL80211_IFACE_COMB_NUM_CHANNELS = C.NL80211_IFACE_COMB_NUM_CHANNELS
NL80211_IFACE_COMB_RADAR_DETECT_REGIONS = C.NL80211_IFACE_COMB_RADAR_DETECT_REGIONS
NL80211_IFACE_COMB_RADAR_DETECT_WIDTHS = C.NL80211_IFACE_COMB_RADAR_DETECT_WIDTHS
NL80211_IFACE_COMB_STA_AP_BI_MATCH = C.NL80211_IFACE_COMB_STA_AP_BI_MATCH
NL80211_IFACE_COMB_UNSPEC = C.NL80211_IFACE_COMB_UNSPEC
NL80211_IFACE_LIMIT_MAX = C.NL80211_IFACE_LIMIT_MAX
NL80211_IFACE_LIMIT_TYPES = C.NL80211_IFACE_LIMIT_TYPES
NL80211_IFACE_LIMIT_UNSPEC = C.NL80211_IFACE_LIMIT_UNSPEC
NL80211_IFTYPE_ADHOC = C.NL80211_IFTYPE_ADHOC
NL80211_IFTYPE_AKM_ATTR_IFTYPES = C.NL80211_IFTYPE_AKM_ATTR_IFTYPES
NL80211_IFTYPE_AKM_ATTR_MAX = C.NL80211_IFTYPE_AKM_ATTR_MAX
NL80211_IFTYPE_AKM_ATTR_SUITES = C.NL80211_IFTYPE_AKM_ATTR_SUITES
NL80211_IFTYPE_AP = C.NL80211_IFTYPE_AP
NL80211_IFTYPE_AP_VLAN = C.NL80211_IFTYPE_AP_VLAN
NL80211_IFTYPE_MAX = C.NL80211_IFTYPE_MAX
NL80211_IFTYPE_MESH_POINT = C.NL80211_IFTYPE_MESH_POINT
NL80211_IFTYPE_MONITOR = C.NL80211_IFTYPE_MONITOR
NL80211_IFTYPE_NAN = C.NL80211_IFTYPE_NAN
NL80211_IFTYPE_OCB = C.NL80211_IFTYPE_OCB
NL80211_IFTYPE_P2P_CLIENT = C.NL80211_IFTYPE_P2P_CLIENT
NL80211_IFTYPE_P2P_DEVICE = C.NL80211_IFTYPE_P2P_DEVICE
NL80211_IFTYPE_P2P_GO = C.NL80211_IFTYPE_P2P_GO
NL80211_IFTYPE_STATION = C.NL80211_IFTYPE_STATION
NL80211_IFTYPE_UNSPECIFIED = C.NL80211_IFTYPE_UNSPECIFIED
NL80211_IFTYPE_WDS = C.NL80211_IFTYPE_WDS
NL80211_KCK_EXT_LEN = C.NL80211_KCK_EXT_LEN
NL80211_KCK_LEN = C.NL80211_KCK_LEN
NL80211_KEK_EXT_LEN = C.NL80211_KEK_EXT_LEN
NL80211_KEK_LEN = C.NL80211_KEK_LEN
NL80211_KEY_CIPHER = C.NL80211_KEY_CIPHER
NL80211_KEY_DATA = C.NL80211_KEY_DATA
NL80211_KEY_DEFAULT_BEACON = C.NL80211_KEY_DEFAULT_BEACON
NL80211_KEY_DEFAULT = C.NL80211_KEY_DEFAULT
NL80211_KEY_DEFAULT_MGMT = C.NL80211_KEY_DEFAULT_MGMT
NL80211_KEY_DEFAULT_TYPE_MULTICAST = C.NL80211_KEY_DEFAULT_TYPE_MULTICAST
NL80211_KEY_DEFAULT_TYPES = C.NL80211_KEY_DEFAULT_TYPES
NL80211_KEY_DEFAULT_TYPE_UNICAST = C.NL80211_KEY_DEFAULT_TYPE_UNICAST
NL80211_KEY_IDX = C.NL80211_KEY_IDX
NL80211_KEY_MAX = C.NL80211_KEY_MAX
NL80211_KEY_MODE = C.NL80211_KEY_MODE
NL80211_KEY_NO_TX = C.NL80211_KEY_NO_TX
NL80211_KEY_RX_TX = C.NL80211_KEY_RX_TX
NL80211_KEY_SEQ = C.NL80211_KEY_SEQ
NL80211_KEY_SET_TX = C.NL80211_KEY_SET_TX
NL80211_KEY_TYPE = C.NL80211_KEY_TYPE
NL80211_KEYTYPE_GROUP = C.NL80211_KEYTYPE_GROUP
NL80211_KEYTYPE_PAIRWISE = C.NL80211_KEYTYPE_PAIRWISE
NL80211_KEYTYPE_PEERKEY = C.NL80211_KEYTYPE_PEERKEY
NL80211_MAX_NR_AKM_SUITES = C.NL80211_MAX_NR_AKM_SUITES
NL80211_MAX_NR_CIPHER_SUITES = C.NL80211_MAX_NR_CIPHER_SUITES
NL80211_MAX_SUPP_HT_RATES = C.NL80211_MAX_SUPP_HT_RATES
NL80211_MAX_SUPP_RATES = C.NL80211_MAX_SUPP_RATES
NL80211_MAX_SUPP_REG_RULES = C.NL80211_MAX_SUPP_REG_RULES
NL80211_MBSSID_CONFIG_ATTR_EMA = C.NL80211_MBSSID_CONFIG_ATTR_EMA
NL80211_MBSSID_CONFIG_ATTR_INDEX = C.NL80211_MBSSID_CONFIG_ATTR_INDEX
NL80211_MBSSID_CONFIG_ATTR_MAX = C.NL80211_MBSSID_CONFIG_ATTR_MAX
NL80211_MBSSID_CONFIG_ATTR_MAX_EMA_PROFILE_PERIODICITY = C.NL80211_MBSSID_CONFIG_ATTR_MAX_EMA_PROFILE_PERIODICITY
NL80211_MBSSID_CONFIG_ATTR_MAX_INTERFACES = C.NL80211_MBSSID_CONFIG_ATTR_MAX_INTERFACES
NL80211_MBSSID_CONFIG_ATTR_TX_IFINDEX = C.NL80211_MBSSID_CONFIG_ATTR_TX_IFINDEX
NL80211_MESHCONF_ATTR_MAX = C.NL80211_MESHCONF_ATTR_MAX
NL80211_MESHCONF_AUTO_OPEN_PLINKS = C.NL80211_MESHCONF_AUTO_OPEN_PLINKS
NL80211_MESHCONF_AWAKE_WINDOW = C.NL80211_MESHCONF_AWAKE_WINDOW
NL80211_MESHCONF_CONFIRM_TIMEOUT = C.NL80211_MESHCONF_CONFIRM_TIMEOUT
NL80211_MESHCONF_CONNECTED_TO_AS = C.NL80211_MESHCONF_CONNECTED_TO_AS
NL80211_MESHCONF_CONNECTED_TO_GATE = C.NL80211_MESHCONF_CONNECTED_TO_GATE
NL80211_MESHCONF_ELEMENT_TTL = C.NL80211_MESHCONF_ELEMENT_TTL
NL80211_MESHCONF_FORWARDING = C.NL80211_MESHCONF_FORWARDING
NL80211_MESHCONF_GATE_ANNOUNCEMENTS = C.NL80211_MESHCONF_GATE_ANNOUNCEMENTS
NL80211_MESHCONF_HOLDING_TIMEOUT = C.NL80211_MESHCONF_HOLDING_TIMEOUT
NL80211_MESHCONF_HT_OPMODE = C.NL80211_MESHCONF_HT_OPMODE
NL80211_MESHCONF_HWMP_ACTIVE_PATH_TIMEOUT = C.NL80211_MESHCONF_HWMP_ACTIVE_PATH_TIMEOUT
NL80211_MESHCONF_HWMP_CONFIRMATION_INTERVAL = C.NL80211_MESHCONF_HWMP_CONFIRMATION_INTERVAL
NL80211_MESHCONF_HWMP_MAX_PREQ_RETRIES = C.NL80211_MESHCONF_HWMP_MAX_PREQ_RETRIES
NL80211_MESHCONF_HWMP_NET_DIAM_TRVS_TIME = C.NL80211_MESHCONF_HWMP_NET_DIAM_TRVS_TIME
NL80211_MESHCONF_HWMP_PATH_TO_ROOT_TIMEOUT = C.NL80211_MESHCONF_HWMP_PATH_TO_ROOT_TIMEOUT
NL80211_MESHCONF_HWMP_PERR_MIN_INTERVAL = C.NL80211_MESHCONF_HWMP_PERR_MIN_INTERVAL
NL80211_MESHCONF_HWMP_PREQ_MIN_INTERVAL = C.NL80211_MESHCONF_HWMP_PREQ_MIN_INTERVAL
NL80211_MESHCONF_HWMP_RANN_INTERVAL = C.NL80211_MESHCONF_HWMP_RANN_INTERVAL
NL80211_MESHCONF_HWMP_ROOT_INTERVAL = C.NL80211_MESHCONF_HWMP_ROOT_INTERVAL
NL80211_MESHCONF_HWMP_ROOTMODE = C.NL80211_MESHCONF_HWMP_ROOTMODE
NL80211_MESHCONF_MAX_PEER_LINKS = C.NL80211_MESHCONF_MAX_PEER_LINKS
NL80211_MESHCONF_MAX_RETRIES = C.NL80211_MESHCONF_MAX_RETRIES
NL80211_MESHCONF_MIN_DISCOVERY_TIMEOUT = C.NL80211_MESHCONF_MIN_DISCOVERY_TIMEOUT
NL80211_MESHCONF_NOLEARN = C.NL80211_MESHCONF_NOLEARN
NL80211_MESHCONF_PATH_REFRESH_TIME = C.NL80211_MESHCONF_PATH_REFRESH_TIME
NL80211_MESHCONF_PLINK_TIMEOUT = C.NL80211_MESHCONF_PLINK_TIMEOUT
NL80211_MESHCONF_POWER_MODE = C.NL80211_MESHCONF_POWER_MODE
NL80211_MESHCONF_RETRY_TIMEOUT = C.NL80211_MESHCONF_RETRY_TIMEOUT
NL80211_MESHCONF_RSSI_THRESHOLD = C.NL80211_MESHCONF_RSSI_THRESHOLD
NL80211_MESHCONF_SYNC_OFFSET_MAX_NEIGHBOR = C.NL80211_MESHCONF_SYNC_OFFSET_MAX_NEIGHBOR
NL80211_MESHCONF_TTL = C.NL80211_MESHCONF_TTL
NL80211_MESH_POWER_ACTIVE = C.NL80211_MESH_POWER_ACTIVE
NL80211_MESH_POWER_DEEP_SLEEP = C.NL80211_MESH_POWER_DEEP_SLEEP
NL80211_MESH_POWER_LIGHT_SLEEP = C.NL80211_MESH_POWER_LIGHT_SLEEP
NL80211_MESH_POWER_MAX = C.NL80211_MESH_POWER_MAX
NL80211_MESH_POWER_UNKNOWN = C.NL80211_MESH_POWER_UNKNOWN
NL80211_MESH_SETUP_ATTR_MAX = C.NL80211_MESH_SETUP_ATTR_MAX
NL80211_MESH_SETUP_AUTH_PROTOCOL = C.NL80211_MESH_SETUP_AUTH_PROTOCOL
NL80211_MESH_SETUP_ENABLE_VENDOR_METRIC = C.NL80211_MESH_SETUP_ENABLE_VENDOR_METRIC
NL80211_MESH_SETUP_ENABLE_VENDOR_PATH_SEL = C.NL80211_MESH_SETUP_ENABLE_VENDOR_PATH_SEL
NL80211_MESH_SETUP_ENABLE_VENDOR_SYNC = C.NL80211_MESH_SETUP_ENABLE_VENDOR_SYNC
NL80211_MESH_SETUP_IE = C.NL80211_MESH_SETUP_IE
NL80211_MESH_SETUP_USERSPACE_AMPE = C.NL80211_MESH_SETUP_USERSPACE_AMPE
NL80211_MESH_SETUP_USERSPACE_AUTH = C.NL80211_MESH_SETUP_USERSPACE_AUTH
NL80211_MESH_SETUP_USERSPACE_MPM = C.NL80211_MESH_SETUP_USERSPACE_MPM
NL80211_MESH_SETUP_VENDOR_PATH_SEL_IE = C.NL80211_MESH_SETUP_VENDOR_PATH_SEL_IE
NL80211_MFP_NO = C.NL80211_MFP_NO
NL80211_MFP_OPTIONAL = C.NL80211_MFP_OPTIONAL
NL80211_MFP_REQUIRED = C.NL80211_MFP_REQUIRED
NL80211_MIN_REMAIN_ON_CHANNEL_TIME = C.NL80211_MIN_REMAIN_ON_CHANNEL_TIME
NL80211_MNTR_FLAG_ACTIVE = C.NL80211_MNTR_FLAG_ACTIVE
NL80211_MNTR_FLAG_CONTROL = C.NL80211_MNTR_FLAG_CONTROL
NL80211_MNTR_FLAG_COOK_FRAMES = C.NL80211_MNTR_FLAG_COOK_FRAMES
NL80211_MNTR_FLAG_FCSFAIL = C.NL80211_MNTR_FLAG_FCSFAIL
NL80211_MNTR_FLAG_MAX = C.NL80211_MNTR_FLAG_MAX
NL80211_MNTR_FLAG_OTHER_BSS = C.NL80211_MNTR_FLAG_OTHER_BSS
NL80211_MNTR_FLAG_PLCPFAIL = C.NL80211_MNTR_FLAG_PLCPFAIL
NL80211_MPATH_FLAG_ACTIVE = C.NL80211_MPATH_FLAG_ACTIVE
NL80211_MPATH_FLAG_FIXED = C.NL80211_MPATH_FLAG_FIXED
NL80211_MPATH_FLAG_RESOLVED = C.NL80211_MPATH_FLAG_RESOLVED
NL80211_MPATH_FLAG_RESOLVING = C.NL80211_MPATH_FLAG_RESOLVING
NL80211_MPATH_FLAG_SN_VALID = C.NL80211_MPATH_FLAG_SN_VALID
NL80211_MPATH_INFO_DISCOVERY_RETRIES = C.NL80211_MPATH_INFO_DISCOVERY_RETRIES
NL80211_MPATH_INFO_DISCOVERY_TIMEOUT = C.NL80211_MPATH_INFO_DISCOVERY_TIMEOUT
NL80211_MPATH_INFO_EXPTIME = C.NL80211_MPATH_INFO_EXPTIME
NL80211_MPATH_INFO_FLAGS = C.NL80211_MPATH_INFO_FLAGS
NL80211_MPATH_INFO_FRAME_QLEN = C.NL80211_MPATH_INFO_FRAME_QLEN
NL80211_MPATH_INFO_HOP_COUNT = C.NL80211_MPATH_INFO_HOP_COUNT
NL80211_MPATH_INFO_MAX = C.NL80211_MPATH_INFO_MAX
NL80211_MPATH_INFO_METRIC = C.NL80211_MPATH_INFO_METRIC
NL80211_MPATH_INFO_PATH_CHANGE = C.NL80211_MPATH_INFO_PATH_CHANGE
NL80211_MPATH_INFO_SN = C.NL80211_MPATH_INFO_SN
NL80211_MULTICAST_GROUP_CONFIG = C.NL80211_MULTICAST_GROUP_CONFIG
NL80211_MULTICAST_GROUP_MLME = C.NL80211_MULTICAST_GROUP_MLME
NL80211_MULTICAST_GROUP_NAN = C.NL80211_MULTICAST_GROUP_NAN
NL80211_MULTICAST_GROUP_REG = C.NL80211_MULTICAST_GROUP_REG
NL80211_MULTICAST_GROUP_SCAN = C.NL80211_MULTICAST_GROUP_SCAN
NL80211_MULTICAST_GROUP_TESTMODE = C.NL80211_MULTICAST_GROUP_TESTMODE
NL80211_MULTICAST_GROUP_VENDOR = C.NL80211_MULTICAST_GROUP_VENDOR
NL80211_NAN_FUNC_ATTR_MAX = C.NL80211_NAN_FUNC_ATTR_MAX
NL80211_NAN_FUNC_CLOSE_RANGE = C.NL80211_NAN_FUNC_CLOSE_RANGE
NL80211_NAN_FUNC_FOLLOW_UP = C.NL80211_NAN_FUNC_FOLLOW_UP
NL80211_NAN_FUNC_FOLLOW_UP_DEST = C.NL80211_NAN_FUNC_FOLLOW_UP_DEST
NL80211_NAN_FUNC_FOLLOW_UP_ID = C.NL80211_NAN_FUNC_FOLLOW_UP_ID
NL80211_NAN_FUNC_FOLLOW_UP_REQ_ID = C.NL80211_NAN_FUNC_FOLLOW_UP_REQ_ID
NL80211_NAN_FUNC_INSTANCE_ID = C.NL80211_NAN_FUNC_INSTANCE_ID
NL80211_NAN_FUNC_MAX_TYPE = C.NL80211_NAN_FUNC_MAX_TYPE
NL80211_NAN_FUNC_PUBLISH_BCAST = C.NL80211_NAN_FUNC_PUBLISH_BCAST
NL80211_NAN_FUNC_PUBLISH = C.NL80211_NAN_FUNC_PUBLISH
NL80211_NAN_FUNC_PUBLISH_TYPE = C.NL80211_NAN_FUNC_PUBLISH_TYPE
NL80211_NAN_FUNC_RX_MATCH_FILTER = C.NL80211_NAN_FUNC_RX_MATCH_FILTER
NL80211_NAN_FUNC_SERVICE_ID = C.NL80211_NAN_FUNC_SERVICE_ID
NL80211_NAN_FUNC_SERVICE_ID_LEN = C.NL80211_NAN_FUNC_SERVICE_ID_LEN
NL80211_NAN_FUNC_SERVICE_INFO = C.NL80211_NAN_FUNC_SERVICE_INFO
NL80211_NAN_FUNC_SERVICE_SPEC_INFO_MAX_LEN = C.NL80211_NAN_FUNC_SERVICE_SPEC_INFO_MAX_LEN
NL80211_NAN_FUNC_SRF = C.NL80211_NAN_FUNC_SRF
NL80211_NAN_FUNC_SRF_MAX_LEN = C.NL80211_NAN_FUNC_SRF_MAX_LEN
NL80211_NAN_FUNC_SUBSCRIBE_ACTIVE = C.NL80211_NAN_FUNC_SUBSCRIBE_ACTIVE
NL80211_NAN_FUNC_SUBSCRIBE = C.NL80211_NAN_FUNC_SUBSCRIBE
NL80211_NAN_FUNC_TERM_REASON = C.NL80211_NAN_FUNC_TERM_REASON
NL80211_NAN_FUNC_TERM_REASON_ERROR = C.NL80211_NAN_FUNC_TERM_REASON_ERROR
NL80211_NAN_FUNC_TERM_REASON_TTL_EXPIRED = C.NL80211_NAN_FUNC_TERM_REASON_TTL_EXPIRED
NL80211_NAN_FUNC_TERM_REASON_USER_REQUEST = C.NL80211_NAN_FUNC_TERM_REASON_USER_REQUEST
NL80211_NAN_FUNC_TTL = C.NL80211_NAN_FUNC_TTL
NL80211_NAN_FUNC_TX_MATCH_FILTER = C.NL80211_NAN_FUNC_TX_MATCH_FILTER
NL80211_NAN_FUNC_TYPE = C.NL80211_NAN_FUNC_TYPE
NL80211_NAN_MATCH_ATTR_MAX = C.NL80211_NAN_MATCH_ATTR_MAX
NL80211_NAN_MATCH_FUNC_LOCAL = C.NL80211_NAN_MATCH_FUNC_LOCAL
NL80211_NAN_MATCH_FUNC_PEER = C.NL80211_NAN_MATCH_FUNC_PEER
NL80211_NAN_SOLICITED_PUBLISH = C.NL80211_NAN_SOLICITED_PUBLISH
NL80211_NAN_SRF_ATTR_MAX = C.NL80211_NAN_SRF_ATTR_MAX
NL80211_NAN_SRF_BF = C.NL80211_NAN_SRF_BF
NL80211_NAN_SRF_BF_IDX = C.NL80211_NAN_SRF_BF_IDX
NL80211_NAN_SRF_INCLUDE = C.NL80211_NAN_SRF_INCLUDE
NL80211_NAN_SRF_MAC_ADDRS = C.NL80211_NAN_SRF_MAC_ADDRS
NL80211_NAN_UNSOLICITED_PUBLISH = C.NL80211_NAN_UNSOLICITED_PUBLISH
NL80211_NUM_ACS = C.NL80211_NUM_ACS
NL80211_P2P_PS_SUPPORTED = C.NL80211_P2P_PS_SUPPORTED
NL80211_P2P_PS_UNSUPPORTED = C.NL80211_P2P_PS_UNSUPPORTED
NL80211_PKTPAT_MASK = C.NL80211_PKTPAT_MASK
NL80211_PKTPAT_OFFSET = C.NL80211_PKTPAT_OFFSET
NL80211_PKTPAT_PATTERN = C.NL80211_PKTPAT_PATTERN
NL80211_PLINK_ACTION_BLOCK = C.NL80211_PLINK_ACTION_BLOCK
NL80211_PLINK_ACTION_NO_ACTION = C.NL80211_PLINK_ACTION_NO_ACTION
NL80211_PLINK_ACTION_OPEN = C.NL80211_PLINK_ACTION_OPEN
NL80211_PLINK_BLOCKED = C.NL80211_PLINK_BLOCKED
NL80211_PLINK_CNF_RCVD = C.NL80211_PLINK_CNF_RCVD
NL80211_PLINK_ESTAB = C.NL80211_PLINK_ESTAB
NL80211_PLINK_HOLDING = C.NL80211_PLINK_HOLDING
NL80211_PLINK_LISTEN = C.NL80211_PLINK_LISTEN
NL80211_PLINK_OPN_RCVD = C.NL80211_PLINK_OPN_RCVD
NL80211_PLINK_OPN_SNT = C.NL80211_PLINK_OPN_SNT
NL80211_PMKSA_CANDIDATE_BSSID = C.NL80211_PMKSA_CANDIDATE_BSSID
NL80211_PMKSA_CANDIDATE_INDEX = C.NL80211_PMKSA_CANDIDATE_INDEX
NL80211_PMKSA_CANDIDATE_PREAUTH = C.NL80211_PMKSA_CANDIDATE_PREAUTH
NL80211_PMSR_ATTR_MAX = C.NL80211_PMSR_ATTR_MAX
NL80211_PMSR_ATTR_MAX_PEERS = C.NL80211_PMSR_ATTR_MAX_PEERS
NL80211_PMSR_ATTR_PEERS = C.NL80211_PMSR_ATTR_PEERS
NL80211_PMSR_ATTR_RANDOMIZE_MAC_ADDR = C.NL80211_PMSR_ATTR_RANDOMIZE_MAC_ADDR
NL80211_PMSR_ATTR_REPORT_AP_TSF = C.NL80211_PMSR_ATTR_REPORT_AP_TSF
NL80211_PMSR_ATTR_TYPE_CAPA = C.NL80211_PMSR_ATTR_TYPE_CAPA
NL80211_PMSR_FTM_CAPA_ATTR_ASAP = C.NL80211_PMSR_FTM_CAPA_ATTR_ASAP
NL80211_PMSR_FTM_CAPA_ATTR_BANDWIDTHS = C.NL80211_PMSR_FTM_CAPA_ATTR_BANDWIDTHS
NL80211_PMSR_FTM_CAPA_ATTR_MAX_BURSTS_EXPONENT = C.NL80211_PMSR_FTM_CAPA_ATTR_MAX_BURSTS_EXPONENT
NL80211_PMSR_FTM_CAPA_ATTR_MAX = C.NL80211_PMSR_FTM_CAPA_ATTR_MAX
NL80211_PMSR_FTM_CAPA_ATTR_MAX_FTMS_PER_BURST = C.NL80211_PMSR_FTM_CAPA_ATTR_MAX_FTMS_PER_BURST
NL80211_PMSR_FTM_CAPA_ATTR_NON_ASAP = C.NL80211_PMSR_FTM_CAPA_ATTR_NON_ASAP
NL80211_PMSR_FTM_CAPA_ATTR_NON_TRIGGER_BASED = C.NL80211_PMSR_FTM_CAPA_ATTR_NON_TRIGGER_BASED
NL80211_PMSR_FTM_CAPA_ATTR_PREAMBLES = C.NL80211_PMSR_FTM_CAPA_ATTR_PREAMBLES
NL80211_PMSR_FTM_CAPA_ATTR_REQ_CIVICLOC = C.NL80211_PMSR_FTM_CAPA_ATTR_REQ_CIVICLOC
NL80211_PMSR_FTM_CAPA_ATTR_REQ_LCI = C.NL80211_PMSR_FTM_CAPA_ATTR_REQ_LCI
NL80211_PMSR_FTM_CAPA_ATTR_TRIGGER_BASED = C.NL80211_PMSR_FTM_CAPA_ATTR_TRIGGER_BASED
NL80211_PMSR_FTM_FAILURE_BAD_CHANGED_PARAMS = C.NL80211_PMSR_FTM_FAILURE_BAD_CHANGED_PARAMS
NL80211_PMSR_FTM_FAILURE_INVALID_TIMESTAMP = C.NL80211_PMSR_FTM_FAILURE_INVALID_TIMESTAMP
NL80211_PMSR_FTM_FAILURE_NO_RESPONSE = C.NL80211_PMSR_FTM_FAILURE_NO_RESPONSE
NL80211_PMSR_FTM_FAILURE_PEER_BUSY = C.NL80211_PMSR_FTM_FAILURE_PEER_BUSY
NL80211_PMSR_FTM_FAILURE_PEER_NOT_CAPABLE = C.NL80211_PMSR_FTM_FAILURE_PEER_NOT_CAPABLE
NL80211_PMSR_FTM_FAILURE_REJECTED = C.NL80211_PMSR_FTM_FAILURE_REJECTED
NL80211_PMSR_FTM_FAILURE_UNSPECIFIED = C.NL80211_PMSR_FTM_FAILURE_UNSPECIFIED
NL80211_PMSR_FTM_FAILURE_WRONG_CHANNEL = C.NL80211_PMSR_FTM_FAILURE_WRONG_CHANNEL
NL80211_PMSR_FTM_REQ_ATTR_ASAP = C.NL80211_PMSR_FTM_REQ_ATTR_ASAP
NL80211_PMSR_FTM_REQ_ATTR_BSS_COLOR = C.NL80211_PMSR_FTM_REQ_ATTR_BSS_COLOR
NL80211_PMSR_FTM_REQ_ATTR_BURST_DURATION = C.NL80211_PMSR_FTM_REQ_ATTR_BURST_DURATION
NL80211_PMSR_FTM_REQ_ATTR_BURST_PERIOD = C.NL80211_PMSR_FTM_REQ_ATTR_BURST_PERIOD
NL80211_PMSR_FTM_REQ_ATTR_FTMS_PER_BURST = C.NL80211_PMSR_FTM_REQ_ATTR_FTMS_PER_BURST
NL80211_PMSR_FTM_REQ_ATTR_LMR_FEEDBACK = C.NL80211_PMSR_FTM_REQ_ATTR_LMR_FEEDBACK
NL80211_PMSR_FTM_REQ_ATTR_MAX = C.NL80211_PMSR_FTM_REQ_ATTR_MAX
NL80211_PMSR_FTM_REQ_ATTR_NON_TRIGGER_BASED = C.NL80211_PMSR_FTM_REQ_ATTR_NON_TRIGGER_BASED
NL80211_PMSR_FTM_REQ_ATTR_NUM_BURSTS_EXP = C.NL80211_PMSR_FTM_REQ_ATTR_NUM_BURSTS_EXP
NL80211_PMSR_FTM_REQ_ATTR_NUM_FTMR_RETRIES = C.NL80211_PMSR_FTM_REQ_ATTR_NUM_FTMR_RETRIES
NL80211_PMSR_FTM_REQ_ATTR_PREAMBLE = C.NL80211_PMSR_FTM_REQ_ATTR_PREAMBLE
NL80211_PMSR_FTM_REQ_ATTR_REQUEST_CIVICLOC = C.NL80211_PMSR_FTM_REQ_ATTR_REQUEST_CIVICLOC
NL80211_PMSR_FTM_REQ_ATTR_REQUEST_LCI = C.NL80211_PMSR_FTM_REQ_ATTR_REQUEST_LCI
NL80211_PMSR_FTM_REQ_ATTR_TRIGGER_BASED = C.NL80211_PMSR_FTM_REQ_ATTR_TRIGGER_BASED
NL80211_PMSR_FTM_RESP_ATTR_BURST_DURATION = C.NL80211_PMSR_FTM_RESP_ATTR_BURST_DURATION
NL80211_PMSR_FTM_RESP_ATTR_BURST_INDEX = C.NL80211_PMSR_FTM_RESP_ATTR_BURST_INDEX
NL80211_PMSR_FTM_RESP_ATTR_BUSY_RETRY_TIME = C.NL80211_PMSR_FTM_RESP_ATTR_BUSY_RETRY_TIME
NL80211_PMSR_FTM_RESP_ATTR_CIVICLOC = C.NL80211_PMSR_FTM_RESP_ATTR_CIVICLOC
NL80211_PMSR_FTM_RESP_ATTR_DIST_AVG = C.NL80211_PMSR_FTM_RESP_ATTR_DIST_AVG
NL80211_PMSR_FTM_RESP_ATTR_DIST_SPREAD = C.NL80211_PMSR_FTM_RESP_ATTR_DIST_SPREAD
NL80211_PMSR_FTM_RESP_ATTR_DIST_VARIANCE = C.NL80211_PMSR_FTM_RESP_ATTR_DIST_VARIANCE
NL80211_PMSR_FTM_RESP_ATTR_FAIL_REASON = C.NL80211_PMSR_FTM_RESP_ATTR_FAIL_REASON
NL80211_PMSR_FTM_RESP_ATTR_FTMS_PER_BURST = C.NL80211_PMSR_FTM_RESP_ATTR_FTMS_PER_BURST
NL80211_PMSR_FTM_RESP_ATTR_LCI = C.NL80211_PMSR_FTM_RESP_ATTR_LCI
NL80211_PMSR_FTM_RESP_ATTR_MAX = C.NL80211_PMSR_FTM_RESP_ATTR_MAX
NL80211_PMSR_FTM_RESP_ATTR_NUM_BURSTS_EXP = C.NL80211_PMSR_FTM_RESP_ATTR_NUM_BURSTS_EXP
NL80211_PMSR_FTM_RESP_ATTR_NUM_FTMR_ATTEMPTS = C.NL80211_PMSR_FTM_RESP_ATTR_NUM_FTMR_ATTEMPTS
NL80211_PMSR_FTM_RESP_ATTR_NUM_FTMR_SUCCESSES = C.NL80211_PMSR_FTM_RESP_ATTR_NUM_FTMR_SUCCESSES
NL80211_PMSR_FTM_RESP_ATTR_PAD = C.NL80211_PMSR_FTM_RESP_ATTR_PAD
NL80211_PMSR_FTM_RESP_ATTR_RSSI_AVG = C.NL80211_PMSR_FTM_RESP_ATTR_RSSI_AVG
NL80211_PMSR_FTM_RESP_ATTR_RSSI_SPREAD = C.NL80211_PMSR_FTM_RESP_ATTR_RSSI_SPREAD
NL80211_PMSR_FTM_RESP_ATTR_RTT_AVG = C.NL80211_PMSR_FTM_RESP_ATTR_RTT_AVG
NL80211_PMSR_FTM_RESP_ATTR_RTT_SPREAD = C.NL80211_PMSR_FTM_RESP_ATTR_RTT_SPREAD
NL80211_PMSR_FTM_RESP_ATTR_RTT_VARIANCE = C.NL80211_PMSR_FTM_RESP_ATTR_RTT_VARIANCE
NL80211_PMSR_FTM_RESP_ATTR_RX_RATE = C.NL80211_PMSR_FTM_RESP_ATTR_RX_RATE
NL80211_PMSR_FTM_RESP_ATTR_TX_RATE = C.NL80211_PMSR_FTM_RESP_ATTR_TX_RATE
NL80211_PMSR_PEER_ATTR_ADDR = C.NL80211_PMSR_PEER_ATTR_ADDR
NL80211_PMSR_PEER_ATTR_CHAN = C.NL80211_PMSR_PEER_ATTR_CHAN
NL80211_PMSR_PEER_ATTR_MAX = C.NL80211_PMSR_PEER_ATTR_MAX
NL80211_PMSR_PEER_ATTR_REQ = C.NL80211_PMSR_PEER_ATTR_REQ
NL80211_PMSR_PEER_ATTR_RESP = C.NL80211_PMSR_PEER_ATTR_RESP
NL80211_PMSR_REQ_ATTR_DATA = C.NL80211_PMSR_REQ_ATTR_DATA
NL80211_PMSR_REQ_ATTR_GET_AP_TSF = C.NL80211_PMSR_REQ_ATTR_GET_AP_TSF
NL80211_PMSR_REQ_ATTR_MAX = C.NL80211_PMSR_REQ_ATTR_MAX
NL80211_PMSR_RESP_ATTR_AP_TSF = C.NL80211_PMSR_RESP_ATTR_AP_TSF
NL80211_PMSR_RESP_ATTR_DATA = C.NL80211_PMSR_RESP_ATTR_DATA
NL80211_PMSR_RESP_ATTR_FINAL = C.NL80211_PMSR_RESP_ATTR_FINAL
NL80211_PMSR_RESP_ATTR_HOST_TIME = C.NL80211_PMSR_RESP_ATTR_HOST_TIME
NL80211_PMSR_RESP_ATTR_MAX = C.NL80211_PMSR_RESP_ATTR_MAX
NL80211_PMSR_RESP_ATTR_PAD = C.NL80211_PMSR_RESP_ATTR_PAD
NL80211_PMSR_RESP_ATTR_STATUS = C.NL80211_PMSR_RESP_ATTR_STATUS
NL80211_PMSR_STATUS_FAILURE = C.NL80211_PMSR_STATUS_FAILURE
NL80211_PMSR_STATUS_REFUSED = C.NL80211_PMSR_STATUS_REFUSED
NL80211_PMSR_STATUS_SUCCESS = C.NL80211_PMSR_STATUS_SUCCESS
NL80211_PMSR_STATUS_TIMEOUT = C.NL80211_PMSR_STATUS_TIMEOUT
NL80211_PMSR_TYPE_FTM = C.NL80211_PMSR_TYPE_FTM
NL80211_PMSR_TYPE_INVALID = C.NL80211_PMSR_TYPE_INVALID
NL80211_PMSR_TYPE_MAX = C.NL80211_PMSR_TYPE_MAX
NL80211_PREAMBLE_DMG = C.NL80211_PREAMBLE_DMG
NL80211_PREAMBLE_HE = C.NL80211_PREAMBLE_HE
NL80211_PREAMBLE_HT = C.NL80211_PREAMBLE_HT
NL80211_PREAMBLE_LEGACY = C.NL80211_PREAMBLE_LEGACY
NL80211_PREAMBLE_VHT = C.NL80211_PREAMBLE_VHT
NL80211_PROBE_RESP_OFFLOAD_SUPPORT_80211U = C.NL80211_PROBE_RESP_OFFLOAD_SUPPORT_80211U
NL80211_PROBE_RESP_OFFLOAD_SUPPORT_P2P = C.NL80211_PROBE_RESP_OFFLOAD_SUPPORT_P2P
NL80211_PROBE_RESP_OFFLOAD_SUPPORT_WPS2 = C.NL80211_PROBE_RESP_OFFLOAD_SUPPORT_WPS2
NL80211_PROBE_RESP_OFFLOAD_SUPPORT_WPS = C.NL80211_PROBE_RESP_OFFLOAD_SUPPORT_WPS
NL80211_PROTOCOL_FEATURE_SPLIT_WIPHY_DUMP = C.NL80211_PROTOCOL_FEATURE_SPLIT_WIPHY_DUMP
NL80211_PS_DISABLED = C.NL80211_PS_DISABLED
NL80211_PS_ENABLED = C.NL80211_PS_ENABLED
NL80211_RADAR_CAC_ABORTED = C.NL80211_RADAR_CAC_ABORTED
NL80211_RADAR_CAC_FINISHED = C.NL80211_RADAR_CAC_FINISHED
NL80211_RADAR_CAC_STARTED = C.NL80211_RADAR_CAC_STARTED
NL80211_RADAR_DETECTED = C.NL80211_RADAR_DETECTED
NL80211_RADAR_NOP_FINISHED = C.NL80211_RADAR_NOP_FINISHED
NL80211_RADAR_PRE_CAC_EXPIRED = C.NL80211_RADAR_PRE_CAC_EXPIRED
NL80211_RATE_INFO_10_MHZ_WIDTH = C.NL80211_RATE_INFO_10_MHZ_WIDTH
NL80211_RATE_INFO_160_MHZ_WIDTH = C.NL80211_RATE_INFO_160_MHZ_WIDTH
NL80211_RATE_INFO_320_MHZ_WIDTH = C.NL80211_RATE_INFO_320_MHZ_WIDTH
NL80211_RATE_INFO_40_MHZ_WIDTH = C.NL80211_RATE_INFO_40_MHZ_WIDTH
NL80211_RATE_INFO_5_MHZ_WIDTH = C.NL80211_RATE_INFO_5_MHZ_WIDTH
NL80211_RATE_INFO_80_MHZ_WIDTH = C.NL80211_RATE_INFO_80_MHZ_WIDTH
NL80211_RATE_INFO_80P80_MHZ_WIDTH = C.NL80211_RATE_INFO_80P80_MHZ_WIDTH
NL80211_RATE_INFO_BITRATE32 = C.NL80211_RATE_INFO_BITRATE32
NL80211_RATE_INFO_BITRATE = C.NL80211_RATE_INFO_BITRATE
NL80211_RATE_INFO_EHT_GI_0_8 = C.NL80211_RATE_INFO_EHT_GI_0_8
NL80211_RATE_INFO_EHT_GI_1_6 = C.NL80211_RATE_INFO_EHT_GI_1_6
NL80211_RATE_INFO_EHT_GI_3_2 = C.NL80211_RATE_INFO_EHT_GI_3_2
NL80211_RATE_INFO_EHT_GI = C.NL80211_RATE_INFO_EHT_GI
NL80211_RATE_INFO_EHT_MCS = C.NL80211_RATE_INFO_EHT_MCS
NL80211_RATE_INFO_EHT_NSS = C.NL80211_RATE_INFO_EHT_NSS
NL80211_RATE_INFO_EHT_RU_ALLOC_106 = C.NL80211_RATE_INFO_EHT_RU_ALLOC_106
NL80211_RATE_INFO_EHT_RU_ALLOC_106P26 = C.NL80211_RATE_INFO_EHT_RU_ALLOC_106P26
NL80211_RATE_INFO_EHT_RU_ALLOC_242 = C.NL80211_RATE_INFO_EHT_RU_ALLOC_242
NL80211_RATE_INFO_EHT_RU_ALLOC_26 = C.NL80211_RATE_INFO_EHT_RU_ALLOC_26
NL80211_RATE_INFO_EHT_RU_ALLOC_2x996 = C.NL80211_RATE_INFO_EHT_RU_ALLOC_2x996
NL80211_RATE_INFO_EHT_RU_ALLOC_2x996P484 = C.NL80211_RATE_INFO_EHT_RU_ALLOC_2x996P484
NL80211_RATE_INFO_EHT_RU_ALLOC_3x996 = C.NL80211_RATE_INFO_EHT_RU_ALLOC_3x996
NL80211_RATE_INFO_EHT_RU_ALLOC_3x996P484 = C.NL80211_RATE_INFO_EHT_RU_ALLOC_3x996P484
NL80211_RATE_INFO_EHT_RU_ALLOC_484 = C.NL80211_RATE_INFO_EHT_RU_ALLOC_484
NL80211_RATE_INFO_EHT_RU_ALLOC_484P242 = C.NL80211_RATE_INFO_EHT_RU_ALLOC_484P242
NL80211_RATE_INFO_EHT_RU_ALLOC_4x996 = C.NL80211_RATE_INFO_EHT_RU_ALLOC_4x996
NL80211_RATE_INFO_EHT_RU_ALLOC_52 = C.NL80211_RATE_INFO_EHT_RU_ALLOC_52
NL80211_RATE_INFO_EHT_RU_ALLOC_52P26 = C.NL80211_RATE_INFO_EHT_RU_ALLOC_52P26
NL80211_RATE_INFO_EHT_RU_ALLOC_996 = C.NL80211_RATE_INFO_EHT_RU_ALLOC_996
NL80211_RATE_INFO_EHT_RU_ALLOC_996P484 = C.NL80211_RATE_INFO_EHT_RU_ALLOC_996P484
NL80211_RATE_INFO_EHT_RU_ALLOC_996P484P242 = C.NL80211_RATE_INFO_EHT_RU_ALLOC_996P484P242
NL80211_RATE_INFO_EHT_RU_ALLOC = C.NL80211_RATE_INFO_EHT_RU_ALLOC
NL80211_RATE_INFO_HE_1XLTF = C.NL80211_RATE_INFO_HE_1XLTF
NL80211_RATE_INFO_HE_2XLTF = C.NL80211_RATE_INFO_HE_2XLTF
NL80211_RATE_INFO_HE_4XLTF = C.NL80211_RATE_INFO_HE_4XLTF
NL80211_RATE_INFO_HE_DCM = C.NL80211_RATE_INFO_HE_DCM
NL80211_RATE_INFO_HE_GI_0_8 = C.NL80211_RATE_INFO_HE_GI_0_8
NL80211_RATE_INFO_HE_GI_1_6 = C.NL80211_RATE_INFO_HE_GI_1_6
NL80211_RATE_INFO_HE_GI_3_2 = C.NL80211_RATE_INFO_HE_GI_3_2
NL80211_RATE_INFO_HE_GI = C.NL80211_RATE_INFO_HE_GI
NL80211_RATE_INFO_HE_MCS = C.NL80211_RATE_INFO_HE_MCS
NL80211_RATE_INFO_HE_NSS = C.NL80211_RATE_INFO_HE_NSS
NL80211_RATE_INFO_HE_RU_ALLOC_106 = C.NL80211_RATE_INFO_HE_RU_ALLOC_106
NL80211_RATE_INFO_HE_RU_ALLOC_242 = C.NL80211_RATE_INFO_HE_RU_ALLOC_242
NL80211_RATE_INFO_HE_RU_ALLOC_26 = C.NL80211_RATE_INFO_HE_RU_ALLOC_26
NL80211_RATE_INFO_HE_RU_ALLOC_2x996 = C.NL80211_RATE_INFO_HE_RU_ALLOC_2x996
NL80211_RATE_INFO_HE_RU_ALLOC_484 = C.NL80211_RATE_INFO_HE_RU_ALLOC_484
NL80211_RATE_INFO_HE_RU_ALLOC_52 = C.NL80211_RATE_INFO_HE_RU_ALLOC_52
NL80211_RATE_INFO_HE_RU_ALLOC_996 = C.NL80211_RATE_INFO_HE_RU_ALLOC_996
NL80211_RATE_INFO_HE_RU_ALLOC = C.NL80211_RATE_INFO_HE_RU_ALLOC
NL80211_RATE_INFO_MAX = C.NL80211_RATE_INFO_MAX
NL80211_RATE_INFO_MCS = C.NL80211_RATE_INFO_MCS
NL80211_RATE_INFO_SHORT_GI = C.NL80211_RATE_INFO_SHORT_GI
NL80211_RATE_INFO_VHT_MCS = C.NL80211_RATE_INFO_VHT_MCS
NL80211_RATE_INFO_VHT_NSS = C.NL80211_RATE_INFO_VHT_NSS
NL80211_REGDOM_SET_BY_CORE = C.NL80211_REGDOM_SET_BY_CORE
NL80211_REGDOM_SET_BY_COUNTRY_IE = C.NL80211_REGDOM_SET_BY_COUNTRY_IE
NL80211_REGDOM_SET_BY_DRIVER = C.NL80211_REGDOM_SET_BY_DRIVER
NL80211_REGDOM_SET_BY_USER = C.NL80211_REGDOM_SET_BY_USER
NL80211_REGDOM_TYPE_COUNTRY = C.NL80211_REGDOM_TYPE_COUNTRY
NL80211_REGDOM_TYPE_CUSTOM_WORLD = C.NL80211_REGDOM_TYPE_CUSTOM_WORLD
NL80211_REGDOM_TYPE_INTERSECTION = C.NL80211_REGDOM_TYPE_INTERSECTION
NL80211_REGDOM_TYPE_WORLD = C.NL80211_REGDOM_TYPE_WORLD
NL80211_REG_RULE_ATTR_MAX = C.NL80211_REG_RULE_ATTR_MAX
NL80211_REKEY_DATA_AKM = C.NL80211_REKEY_DATA_AKM
NL80211_REKEY_DATA_KCK = C.NL80211_REKEY_DATA_KCK
NL80211_REKEY_DATA_KEK = C.NL80211_REKEY_DATA_KEK
NL80211_REKEY_DATA_REPLAY_CTR = C.NL80211_REKEY_DATA_REPLAY_CTR
NL80211_REPLAY_CTR_LEN = C.NL80211_REPLAY_CTR_LEN
NL80211_RRF_AUTO_BW = C.NL80211_RRF_AUTO_BW
NL80211_RRF_DFS = C.NL80211_RRF_DFS
NL80211_RRF_GO_CONCURRENT = C.NL80211_RRF_GO_CONCURRENT
NL80211_RRF_IR_CONCURRENT = C.NL80211_RRF_IR_CONCURRENT
NL80211_RRF_NO_160MHZ = C.NL80211_RRF_NO_160MHZ
NL80211_RRF_NO_320MHZ = C.NL80211_RRF_NO_320MHZ
NL80211_RRF_NO_80MHZ = C.NL80211_RRF_NO_80MHZ
NL80211_RRF_NO_CCK = C.NL80211_RRF_NO_CCK
NL80211_RRF_NO_HE = C.NL80211_RRF_NO_HE
NL80211_RRF_NO_HT40 = C.NL80211_RRF_NO_HT40
NL80211_RRF_NO_HT40MINUS = C.NL80211_RRF_NO_HT40MINUS
NL80211_RRF_NO_HT40PLUS = C.NL80211_RRF_NO_HT40PLUS
NL80211_RRF_NO_IBSS = C.NL80211_RRF_NO_IBSS
NL80211_RRF_NO_INDOOR = C.NL80211_RRF_NO_INDOOR
NL80211_RRF_NO_IR_ALL = C.NL80211_RRF_NO_IR_ALL
NL80211_RRF_NO_IR = C.NL80211_RRF_NO_IR
NL80211_RRF_NO_OFDM = C.NL80211_RRF_NO_OFDM
NL80211_RRF_NO_OUTDOOR = C.NL80211_RRF_NO_OUTDOOR
NL80211_RRF_PASSIVE_SCAN = C.NL80211_RRF_PASSIVE_SCAN
NL80211_RRF_PTMP_ONLY = C.NL80211_RRF_PTMP_ONLY
NL80211_RRF_PTP_ONLY = C.NL80211_RRF_PTP_ONLY
NL80211_RXMGMT_FLAG_ANSWERED = C.NL80211_RXMGMT_FLAG_ANSWERED
NL80211_RXMGMT_FLAG_EXTERNAL_AUTH = C.NL80211_RXMGMT_FLAG_EXTERNAL_AUTH
NL80211_SAE_PWE_BOTH = C.NL80211_SAE_PWE_BOTH
NL80211_SAE_PWE_HASH_TO_ELEMENT = C.NL80211_SAE_PWE_HASH_TO_ELEMENT
NL80211_SAE_PWE_HUNT_AND_PECK = C.NL80211_SAE_PWE_HUNT_AND_PECK
NL80211_SAE_PWE_UNSPECIFIED = C.NL80211_SAE_PWE_UNSPECIFIED
NL80211_SAR_ATTR_MAX = C.NL80211_SAR_ATTR_MAX
NL80211_SAR_ATTR_SPECS = C.NL80211_SAR_ATTR_SPECS
NL80211_SAR_ATTR_SPECS_END_FREQ = C.NL80211_SAR_ATTR_SPECS_END_FREQ
NL80211_SAR_ATTR_SPECS_MAX = C.NL80211_SAR_ATTR_SPECS_MAX
NL80211_SAR_ATTR_SPECS_POWER = C.NL80211_SAR_ATTR_SPECS_POWER
NL80211_SAR_ATTR_SPECS_RANGE_INDEX = C.NL80211_SAR_ATTR_SPECS_RANGE_INDEX
NL80211_SAR_ATTR_SPECS_START_FREQ = C.NL80211_SAR_ATTR_SPECS_START_FREQ
NL80211_SAR_ATTR_TYPE = C.NL80211_SAR_ATTR_TYPE
NL80211_SAR_TYPE_POWER = C.NL80211_SAR_TYPE_POWER
NL80211_SCAN_FLAG_ACCEPT_BCAST_PROBE_RESP = C.NL80211_SCAN_FLAG_ACCEPT_BCAST_PROBE_RESP
NL80211_SCAN_FLAG_AP = C.NL80211_SCAN_FLAG_AP
NL80211_SCAN_FLAG_COLOCATED_6GHZ = C.NL80211_SCAN_FLAG_COLOCATED_6GHZ
NL80211_SCAN_FLAG_FILS_MAX_CHANNEL_TIME = C.NL80211_SCAN_FLAG_FILS_MAX_CHANNEL_TIME
NL80211_SCAN_FLAG_FLUSH = C.NL80211_SCAN_FLAG_FLUSH
NL80211_SCAN_FLAG_FREQ_KHZ = C.NL80211_SCAN_FLAG_FREQ_KHZ
NL80211_SCAN_FLAG_HIGH_ACCURACY = C.NL80211_SCAN_FLAG_HIGH_ACCURACY
NL80211_SCAN_FLAG_LOW_POWER = C.NL80211_SCAN_FLAG_LOW_POWER
NL80211_SCAN_FLAG_LOW_PRIORITY = C.NL80211_SCAN_FLAG_LOW_PRIORITY
NL80211_SCAN_FLAG_LOW_SPAN = C.NL80211_SCAN_FLAG_LOW_SPAN
NL80211_SCAN_FLAG_MIN_PREQ_CONTENT = C.NL80211_SCAN_FLAG_MIN_PREQ_CONTENT
NL80211_SCAN_FLAG_OCE_PROBE_REQ_DEFERRAL_SUPPRESSION = C.NL80211_SCAN_FLAG_OCE_PROBE_REQ_DEFERRAL_SUPPRESSION
NL80211_SCAN_FLAG_OCE_PROBE_REQ_HIGH_TX_RATE = C.NL80211_SCAN_FLAG_OCE_PROBE_REQ_HIGH_TX_RATE
NL80211_SCAN_FLAG_RANDOM_ADDR = C.NL80211_SCAN_FLAG_RANDOM_ADDR
NL80211_SCAN_FLAG_RANDOM_SN = C.NL80211_SCAN_FLAG_RANDOM_SN
NL80211_SCAN_RSSI_THOLD_OFF = C.NL80211_SCAN_RSSI_THOLD_OFF
NL80211_SCHED_SCAN_MATCH_ATTR_BSSID = C.NL80211_SCHED_SCAN_MATCH_ATTR_BSSID
NL80211_SCHED_SCAN_MATCH_ATTR_MAX = C.NL80211_SCHED_SCAN_MATCH_ATTR_MAX
NL80211_SCHED_SCAN_MATCH_ATTR_RELATIVE_RSSI = C.NL80211_SCHED_SCAN_MATCH_ATTR_RELATIVE_RSSI
NL80211_SCHED_SCAN_MATCH_ATTR_RSSI_ADJUST = C.NL80211_SCHED_SCAN_MATCH_ATTR_RSSI_ADJUST
NL80211_SCHED_SCAN_MATCH_ATTR_RSSI = C.NL80211_SCHED_SCAN_MATCH_ATTR_RSSI
NL80211_SCHED_SCAN_MATCH_ATTR_SSID = C.NL80211_SCHED_SCAN_MATCH_ATTR_SSID
NL80211_SCHED_SCAN_MATCH_PER_BAND_RSSI = C.NL80211_SCHED_SCAN_MATCH_PER_BAND_RSSI
NL80211_SCHED_SCAN_PLAN_INTERVAL = C.NL80211_SCHED_SCAN_PLAN_INTERVAL
NL80211_SCHED_SCAN_PLAN_ITERATIONS = C.NL80211_SCHED_SCAN_PLAN_ITERATIONS
NL80211_SCHED_SCAN_PLAN_MAX = C.NL80211_SCHED_SCAN_PLAN_MAX
NL80211_SMPS_DYNAMIC = C.NL80211_SMPS_DYNAMIC
NL80211_SMPS_MAX = C.NL80211_SMPS_MAX
NL80211_SMPS_OFF = C.NL80211_SMPS_OFF
NL80211_SMPS_STATIC = C.NL80211_SMPS_STATIC
NL80211_STA_BSS_PARAM_BEACON_INTERVAL = C.NL80211_STA_BSS_PARAM_BEACON_INTERVAL
NL80211_STA_BSS_PARAM_CTS_PROT = C.NL80211_STA_BSS_PARAM_CTS_PROT
NL80211_STA_BSS_PARAM_DTIM_PERIOD = C.NL80211_STA_BSS_PARAM_DTIM_PERIOD
NL80211_STA_BSS_PARAM_MAX = C.NL80211_STA_BSS_PARAM_MAX
NL80211_STA_BSS_PARAM_SHORT_PREAMBLE = C.NL80211_STA_BSS_PARAM_SHORT_PREAMBLE
NL80211_STA_BSS_PARAM_SHORT_SLOT_TIME = C.NL80211_STA_BSS_PARAM_SHORT_SLOT_TIME
NL80211_STA_FLAG_ASSOCIATED = C.NL80211_STA_FLAG_ASSOCIATED
NL80211_STA_FLAG_AUTHENTICATED = C.NL80211_STA_FLAG_AUTHENTICATED
NL80211_STA_FLAG_AUTHORIZED = C.NL80211_STA_FLAG_AUTHORIZED
NL80211_STA_FLAG_MAX = C.NL80211_STA_FLAG_MAX
NL80211_STA_FLAG_MAX_OLD_API = C.NL80211_STA_FLAG_MAX_OLD_API
NL80211_STA_FLAG_MFP = C.NL80211_STA_FLAG_MFP
NL80211_STA_FLAG_SHORT_PREAMBLE = C.NL80211_STA_FLAG_SHORT_PREAMBLE
NL80211_STA_FLAG_TDLS_PEER = C.NL80211_STA_FLAG_TDLS_PEER
NL80211_STA_FLAG_WME = C.NL80211_STA_FLAG_WME
NL80211_STA_INFO_ACK_SIGNAL_AVG = C.NL80211_STA_INFO_ACK_SIGNAL_AVG
NL80211_STA_INFO_ACK_SIGNAL = C.NL80211_STA_INFO_ACK_SIGNAL
NL80211_STA_INFO_AIRTIME_LINK_METRIC = C.NL80211_STA_INFO_AIRTIME_LINK_METRIC
NL80211_STA_INFO_AIRTIME_WEIGHT = C.NL80211_STA_INFO_AIRTIME_WEIGHT
NL80211_STA_INFO_ASSOC_AT_BOOTTIME = C.NL80211_STA_INFO_ASSOC_AT_BOOTTIME
NL80211_STA_INFO_BEACON_LOSS = C.NL80211_STA_INFO_BEACON_LOSS
NL80211_STA_INFO_BEACON_RX = C.NL80211_STA_INFO_BEACON_RX
NL80211_STA_INFO_BEACON_SIGNAL_AVG = C.NL80211_STA_INFO_BEACON_SIGNAL_AVG
NL80211_STA_INFO_BSS_PARAM = C.NL80211_STA_INFO_BSS_PARAM
NL80211_STA_INFO_CHAIN_SIGNAL_AVG = C.NL80211_STA_INFO_CHAIN_SIGNAL_AVG
NL80211_STA_INFO_CHAIN_SIGNAL = C.NL80211_STA_INFO_CHAIN_SIGNAL
NL80211_STA_INFO_CONNECTED_TIME = C.NL80211_STA_INFO_CONNECTED_TIME
NL80211_STA_INFO_CONNECTED_TO_AS = C.NL80211_STA_INFO_CONNECTED_TO_AS
NL80211_STA_INFO_CONNECTED_TO_GATE = C.NL80211_STA_INFO_CONNECTED_TO_GATE
NL80211_STA_INFO_DATA_ACK_SIGNAL_AVG = C.NL80211_STA_INFO_DATA_ACK_SIGNAL_AVG
NL80211_STA_INFO_EXPECTED_THROUGHPUT = C.NL80211_STA_INFO_EXPECTED_THROUGHPUT
NL80211_STA_INFO_FCS_ERROR_COUNT = C.NL80211_STA_INFO_FCS_ERROR_COUNT
NL80211_STA_INFO_INACTIVE_TIME = C.NL80211_STA_INFO_INACTIVE_TIME
NL80211_STA_INFO_LLID = C.NL80211_STA_INFO_LLID
NL80211_STA_INFO_LOCAL_PM = C.NL80211_STA_INFO_LOCAL_PM
NL80211_STA_INFO_MAX = C.NL80211_STA_INFO_MAX
NL80211_STA_INFO_NONPEER_PM = C.NL80211_STA_INFO_NONPEER_PM
NL80211_STA_INFO_PAD = C.NL80211_STA_INFO_PAD
NL80211_STA_INFO_PEER_PM = C.NL80211_STA_INFO_PEER_PM
NL80211_STA_INFO_PLID = C.NL80211_STA_INFO_PLID
NL80211_STA_INFO_PLINK_STATE = C.NL80211_STA_INFO_PLINK_STATE
NL80211_STA_INFO_RX_BITRATE = C.NL80211_STA_INFO_RX_BITRATE
NL80211_STA_INFO_RX_BYTES64 = C.NL80211_STA_INFO_RX_BYTES64
NL80211_STA_INFO_RX_BYTES = C.NL80211_STA_INFO_RX_BYTES
NL80211_STA_INFO_RX_DROP_MISC = C.NL80211_STA_INFO_RX_DROP_MISC
NL80211_STA_INFO_RX_DURATION = C.NL80211_STA_INFO_RX_DURATION
NL80211_STA_INFO_RX_MPDUS = C.NL80211_STA_INFO_RX_MPDUS
NL80211_STA_INFO_RX_PACKETS = C.NL80211_STA_INFO_RX_PACKETS
NL80211_STA_INFO_SIGNAL_AVG = C.NL80211_STA_INFO_SIGNAL_AVG
NL80211_STA_INFO_SIGNAL = C.NL80211_STA_INFO_SIGNAL
NL80211_STA_INFO_STA_FLAGS = C.NL80211_STA_INFO_STA_FLAGS
NL80211_STA_INFO_TID_STATS = C.NL80211_STA_INFO_TID_STATS
NL80211_STA_INFO_T_OFFSET = C.NL80211_STA_INFO_T_OFFSET
NL80211_STA_INFO_TX_BITRATE = C.NL80211_STA_INFO_TX_BITRATE
NL80211_STA_INFO_TX_BYTES64 = C.NL80211_STA_INFO_TX_BYTES64
NL80211_STA_INFO_TX_BYTES = C.NL80211_STA_INFO_TX_BYTES
NL80211_STA_INFO_TX_DURATION = C.NL80211_STA_INFO_TX_DURATION
NL80211_STA_INFO_TX_FAILED = C.NL80211_STA_INFO_TX_FAILED
NL80211_STA_INFO_TX_PACKETS = C.NL80211_STA_INFO_TX_PACKETS
NL80211_STA_INFO_TX_RETRIES = C.NL80211_STA_INFO_TX_RETRIES
NL80211_STA_WME_MAX = C.NL80211_STA_WME_MAX
NL80211_STA_WME_MAX_SP = C.NL80211_STA_WME_MAX_SP
NL80211_STA_WME_UAPSD_QUEUES = C.NL80211_STA_WME_UAPSD_QUEUES
NL80211_SURVEY_INFO_CHANNEL_TIME_BUSY = C.NL80211_SURVEY_INFO_CHANNEL_TIME_BUSY
NL80211_SURVEY_INFO_CHANNEL_TIME = C.NL80211_SURVEY_INFO_CHANNEL_TIME
NL80211_SURVEY_INFO_CHANNEL_TIME_EXT_BUSY = C.NL80211_SURVEY_INFO_CHANNEL_TIME_EXT_BUSY
NL80211_SURVEY_INFO_CHANNEL_TIME_RX = C.NL80211_SURVEY_INFO_CHANNEL_TIME_RX
NL80211_SURVEY_INFO_CHANNEL_TIME_TX = C.NL80211_SURVEY_INFO_CHANNEL_TIME_TX
NL80211_SURVEY_INFO_FREQUENCY = C.NL80211_SURVEY_INFO_FREQUENCY
NL80211_SURVEY_INFO_FREQUENCY_OFFSET = C.NL80211_SURVEY_INFO_FREQUENCY_OFFSET
NL80211_SURVEY_INFO_IN_USE = C.NL80211_SURVEY_INFO_IN_USE
NL80211_SURVEY_INFO_MAX = C.NL80211_SURVEY_INFO_MAX
NL80211_SURVEY_INFO_NOISE = C.NL80211_SURVEY_INFO_NOISE
NL80211_SURVEY_INFO_PAD = C.NL80211_SURVEY_INFO_PAD
NL80211_SURVEY_INFO_TIME_BSS_RX = C.NL80211_SURVEY_INFO_TIME_BSS_RX
NL80211_SURVEY_INFO_TIME_BUSY = C.NL80211_SURVEY_INFO_TIME_BUSY
NL80211_SURVEY_INFO_TIME = C.NL80211_SURVEY_INFO_TIME
NL80211_SURVEY_INFO_TIME_EXT_BUSY = C.NL80211_SURVEY_INFO_TIME_EXT_BUSY
NL80211_SURVEY_INFO_TIME_RX = C.NL80211_SURVEY_INFO_TIME_RX
NL80211_SURVEY_INFO_TIME_SCAN = C.NL80211_SURVEY_INFO_TIME_SCAN
NL80211_SURVEY_INFO_TIME_TX = C.NL80211_SURVEY_INFO_TIME_TX
NL80211_TDLS_DISABLE_LINK = C.NL80211_TDLS_DISABLE_LINK
NL80211_TDLS_DISCOVERY_REQ = C.NL80211_TDLS_DISCOVERY_REQ
NL80211_TDLS_ENABLE_LINK = C.NL80211_TDLS_ENABLE_LINK
NL80211_TDLS_PEER_HE = C.NL80211_TDLS_PEER_HE
NL80211_TDLS_PEER_HT = C.NL80211_TDLS_PEER_HT
NL80211_TDLS_PEER_VHT = C.NL80211_TDLS_PEER_VHT
NL80211_TDLS_PEER_WMM = C.NL80211_TDLS_PEER_WMM
NL80211_TDLS_SETUP = C.NL80211_TDLS_SETUP
NL80211_TDLS_TEARDOWN = C.NL80211_TDLS_TEARDOWN
NL80211_TID_CONFIG_ATTR_AMPDU_CTRL = C.NL80211_TID_CONFIG_ATTR_AMPDU_CTRL
NL80211_TID_CONFIG_ATTR_AMSDU_CTRL = C.NL80211_TID_CONFIG_ATTR_AMSDU_CTRL
NL80211_TID_CONFIG_ATTR_MAX = C.NL80211_TID_CONFIG_ATTR_MAX
NL80211_TID_CONFIG_ATTR_NOACK = C.NL80211_TID_CONFIG_ATTR_NOACK
NL80211_TID_CONFIG_ATTR_OVERRIDE = C.NL80211_TID_CONFIG_ATTR_OVERRIDE
NL80211_TID_CONFIG_ATTR_PAD = C.NL80211_TID_CONFIG_ATTR_PAD
NL80211_TID_CONFIG_ATTR_PEER_SUPP = C.NL80211_TID_CONFIG_ATTR_PEER_SUPP
NL80211_TID_CONFIG_ATTR_RETRY_LONG = C.NL80211_TID_CONFIG_ATTR_RETRY_LONG
NL80211_TID_CONFIG_ATTR_RETRY_SHORT = C.NL80211_TID_CONFIG_ATTR_RETRY_SHORT
NL80211_TID_CONFIG_ATTR_RTSCTS_CTRL = C.NL80211_TID_CONFIG_ATTR_RTSCTS_CTRL
NL80211_TID_CONFIG_ATTR_TIDS = C.NL80211_TID_CONFIG_ATTR_TIDS
NL80211_TID_CONFIG_ATTR_TX_RATE = C.NL80211_TID_CONFIG_ATTR_TX_RATE
NL80211_TID_CONFIG_ATTR_TX_RATE_TYPE = C.NL80211_TID_CONFIG_ATTR_TX_RATE_TYPE
NL80211_TID_CONFIG_ATTR_VIF_SUPP = C.NL80211_TID_CONFIG_ATTR_VIF_SUPP
NL80211_TID_CONFIG_DISABLE = C.NL80211_TID_CONFIG_DISABLE
NL80211_TID_CONFIG_ENABLE = C.NL80211_TID_CONFIG_ENABLE
NL80211_TID_STATS_MAX = C.NL80211_TID_STATS_MAX
NL80211_TID_STATS_PAD = C.NL80211_TID_STATS_PAD
NL80211_TID_STATS_RX_MSDU = C.NL80211_TID_STATS_RX_MSDU
NL80211_TID_STATS_TX_MSDU = C.NL80211_TID_STATS_TX_MSDU
NL80211_TID_STATS_TX_MSDU_FAILED = C.NL80211_TID_STATS_TX_MSDU_FAILED
NL80211_TID_STATS_TX_MSDU_RETRIES = C.NL80211_TID_STATS_TX_MSDU_RETRIES
NL80211_TID_STATS_TXQ_STATS = C.NL80211_TID_STATS_TXQ_STATS
NL80211_TIMEOUT_ASSOC = C.NL80211_TIMEOUT_ASSOC
NL80211_TIMEOUT_AUTH = C.NL80211_TIMEOUT_AUTH
NL80211_TIMEOUT_SCAN = C.NL80211_TIMEOUT_SCAN
NL80211_TIMEOUT_UNSPECIFIED = C.NL80211_TIMEOUT_UNSPECIFIED
NL80211_TKIP_DATA_OFFSET_ENCR_KEY = C.NL80211_TKIP_DATA_OFFSET_ENCR_KEY
NL80211_TKIP_DATA_OFFSET_RX_MIC_KEY = C.NL80211_TKIP_DATA_OFFSET_RX_MIC_KEY
NL80211_TKIP_DATA_OFFSET_TX_MIC_KEY = C.NL80211_TKIP_DATA_OFFSET_TX_MIC_KEY
NL80211_TX_POWER_AUTOMATIC = C.NL80211_TX_POWER_AUTOMATIC
NL80211_TX_POWER_FIXED = C.NL80211_TX_POWER_FIXED
NL80211_TX_POWER_LIMITED = C.NL80211_TX_POWER_LIMITED
NL80211_TXQ_ATTR_AC = C.NL80211_TXQ_ATTR_AC
NL80211_TXQ_ATTR_AIFS = C.NL80211_TXQ_ATTR_AIFS
NL80211_TXQ_ATTR_CWMAX = C.NL80211_TXQ_ATTR_CWMAX
NL80211_TXQ_ATTR_CWMIN = C.NL80211_TXQ_ATTR_CWMIN
NL80211_TXQ_ATTR_MAX = C.NL80211_TXQ_ATTR_MAX
NL80211_TXQ_ATTR_QUEUE = C.NL80211_TXQ_ATTR_QUEUE
NL80211_TXQ_ATTR_TXOP = C.NL80211_TXQ_ATTR_TXOP
NL80211_TXQ_Q_BE = C.NL80211_TXQ_Q_BE
NL80211_TXQ_Q_BK = C.NL80211_TXQ_Q_BK
NL80211_TXQ_Q_VI = C.NL80211_TXQ_Q_VI
NL80211_TXQ_Q_VO = C.NL80211_TXQ_Q_VO
NL80211_TXQ_STATS_BACKLOG_BYTES = C.NL80211_TXQ_STATS_BACKLOG_BYTES
NL80211_TXQ_STATS_BACKLOG_PACKETS = C.NL80211_TXQ_STATS_BACKLOG_PACKETS
NL80211_TXQ_STATS_COLLISIONS = C.NL80211_TXQ_STATS_COLLISIONS
NL80211_TXQ_STATS_DROPS = C.NL80211_TXQ_STATS_DROPS
NL80211_TXQ_STATS_ECN_MARKS = C.NL80211_TXQ_STATS_ECN_MARKS
NL80211_TXQ_STATS_FLOWS = C.NL80211_TXQ_STATS_FLOWS
NL80211_TXQ_STATS_MAX = C.NL80211_TXQ_STATS_MAX
NL80211_TXQ_STATS_MAX_FLOWS = C.NL80211_TXQ_STATS_MAX_FLOWS
NL80211_TXQ_STATS_OVERLIMIT = C.NL80211_TXQ_STATS_OVERLIMIT
NL80211_TXQ_STATS_OVERMEMORY = C.NL80211_TXQ_STATS_OVERMEMORY
NL80211_TXQ_STATS_TX_BYTES = C.NL80211_TXQ_STATS_TX_BYTES
NL80211_TXQ_STATS_TX_PACKETS = C.NL80211_TXQ_STATS_TX_PACKETS
NL80211_TX_RATE_AUTOMATIC = C.NL80211_TX_RATE_AUTOMATIC
NL80211_TXRATE_DEFAULT_GI = C.NL80211_TXRATE_DEFAULT_GI
NL80211_TX_RATE_FIXED = C.NL80211_TX_RATE_FIXED
NL80211_TXRATE_FORCE_LGI = C.NL80211_TXRATE_FORCE_LGI
NL80211_TXRATE_FORCE_SGI = C.NL80211_TXRATE_FORCE_SGI
NL80211_TXRATE_GI = C.NL80211_TXRATE_GI
NL80211_TXRATE_HE = C.NL80211_TXRATE_HE
NL80211_TXRATE_HE_GI = C.NL80211_TXRATE_HE_GI
NL80211_TXRATE_HE_LTF = C.NL80211_TXRATE_HE_LTF
NL80211_TXRATE_HT = C.NL80211_TXRATE_HT
NL80211_TXRATE_LEGACY = C.NL80211_TXRATE_LEGACY
NL80211_TX_RATE_LIMITED = C.NL80211_TX_RATE_LIMITED
NL80211_TXRATE_MAX = C.NL80211_TXRATE_MAX
NL80211_TXRATE_MCS = C.NL80211_TXRATE_MCS
NL80211_TXRATE_VHT = C.NL80211_TXRATE_VHT
NL80211_UNSOL_BCAST_PROBE_RESP_ATTR_INT = C.NL80211_UNSOL_BCAST_PROBE_RESP_ATTR_INT
NL80211_UNSOL_BCAST_PROBE_RESP_ATTR_MAX = C.NL80211_UNSOL_BCAST_PROBE_RESP_ATTR_MAX
NL80211_UNSOL_BCAST_PROBE_RESP_ATTR_TMPL = C.NL80211_UNSOL_BCAST_PROBE_RESP_ATTR_TMPL
NL80211_USER_REG_HINT_CELL_BASE = C.NL80211_USER_REG_HINT_CELL_BASE
NL80211_USER_REG_HINT_INDOOR = C.NL80211_USER_REG_HINT_INDOOR
NL80211_USER_REG_HINT_USER = C.NL80211_USER_REG_HINT_USER
NL80211_VENDOR_ID_IS_LINUX = C.NL80211_VENDOR_ID_IS_LINUX
NL80211_VHT_CAPABILITY_LEN = C.NL80211_VHT_CAPABILITY_LEN
NL80211_VHT_NSS_MAX = C.NL80211_VHT_NSS_MAX
NL80211_WIPHY_NAME_MAXLEN = C.NL80211_WIPHY_NAME_MAXLEN
NL80211_WMMR_AIFSN = C.NL80211_WMMR_AIFSN
NL80211_WMMR_CW_MAX = C.NL80211_WMMR_CW_MAX
NL80211_WMMR_CW_MIN = C.NL80211_WMMR_CW_MIN
NL80211_WMMR_MAX = C.NL80211_WMMR_MAX
NL80211_WMMR_TXOP = C.NL80211_WMMR_TXOP
NL80211_WOWLAN_PKTPAT_MASK = C.NL80211_WOWLAN_PKTPAT_MASK
NL80211_WOWLAN_PKTPAT_OFFSET = C.NL80211_WOWLAN_PKTPAT_OFFSET
NL80211_WOWLAN_PKTPAT_PATTERN = C.NL80211_WOWLAN_PKTPAT_PATTERN
NL80211_WOWLAN_TCP_DATA_INTERVAL = C.NL80211_WOWLAN_TCP_DATA_INTERVAL
NL80211_WOWLAN_TCP_DATA_PAYLOAD = C.NL80211_WOWLAN_TCP_DATA_PAYLOAD
NL80211_WOWLAN_TCP_DATA_PAYLOAD_SEQ = C.NL80211_WOWLAN_TCP_DATA_PAYLOAD_SEQ
NL80211_WOWLAN_TCP_DATA_PAYLOAD_TOKEN = C.NL80211_WOWLAN_TCP_DATA_PAYLOAD_TOKEN
NL80211_WOWLAN_TCP_DST_IPV4 = C.NL80211_WOWLAN_TCP_DST_IPV4
NL80211_WOWLAN_TCP_DST_MAC = C.NL80211_WOWLAN_TCP_DST_MAC
NL80211_WOWLAN_TCP_DST_PORT = C.NL80211_WOWLAN_TCP_DST_PORT
NL80211_WOWLAN_TCP_SRC_IPV4 = C.NL80211_WOWLAN_TCP_SRC_IPV4
NL80211_WOWLAN_TCP_SRC_PORT = C.NL80211_WOWLAN_TCP_SRC_PORT
NL80211_WOWLAN_TCP_WAKE_MASK = C.NL80211_WOWLAN_TCP_WAKE_MASK
NL80211_WOWLAN_TCP_WAKE_PAYLOAD = C.NL80211_WOWLAN_TCP_WAKE_PAYLOAD
NL80211_WOWLAN_TRIG_4WAY_HANDSHAKE = C.NL80211_WOWLAN_TRIG_4WAY_HANDSHAKE
NL80211_WOWLAN_TRIG_ANY = C.NL80211_WOWLAN_TRIG_ANY
NL80211_WOWLAN_TRIG_DISCONNECT = C.NL80211_WOWLAN_TRIG_DISCONNECT
NL80211_WOWLAN_TRIG_EAP_IDENT_REQUEST = C.NL80211_WOWLAN_TRIG_EAP_IDENT_REQUEST
NL80211_WOWLAN_TRIG_GTK_REKEY_FAILURE = C.NL80211_WOWLAN_TRIG_GTK_REKEY_FAILURE
NL80211_WOWLAN_TRIG_GTK_REKEY_SUPPORTED = C.NL80211_WOWLAN_TRIG_GTK_REKEY_SUPPORTED
NL80211_WOWLAN_TRIG_MAGIC_PKT = C.NL80211_WOWLAN_TRIG_MAGIC_PKT
NL80211_WOWLAN_TRIG_NET_DETECT = C.NL80211_WOWLAN_TRIG_NET_DETECT
NL80211_WOWLAN_TRIG_NET_DETECT_RESULTS = C.NL80211_WOWLAN_TRIG_NET_DETECT_RESULTS
NL80211_WOWLAN_TRIG_PKT_PATTERN = C.NL80211_WOWLAN_TRIG_PKT_PATTERN
NL80211_WOWLAN_TRIG_RFKILL_RELEASE = C.NL80211_WOWLAN_TRIG_RFKILL_RELEASE
NL80211_WOWLAN_TRIG_TCP_CONNECTION = C.NL80211_WOWLAN_TRIG_TCP_CONNECTION
NL80211_WOWLAN_TRIG_WAKEUP_PKT_80211 = C.NL80211_WOWLAN_TRIG_WAKEUP_PKT_80211
NL80211_WOWLAN_TRIG_WAKEUP_PKT_80211_LEN = C.NL80211_WOWLAN_TRIG_WAKEUP_PKT_80211_LEN
NL80211_WOWLAN_TRIG_WAKEUP_PKT_8023 = C.NL80211_WOWLAN_TRIG_WAKEUP_PKT_8023
NL80211_WOWLAN_TRIG_WAKEUP_PKT_8023_LEN = C.NL80211_WOWLAN_TRIG_WAKEUP_PKT_8023_LEN
NL80211_WOWLAN_TRIG_WAKEUP_TCP_CONNLOST = C.NL80211_WOWLAN_TRIG_WAKEUP_TCP_CONNLOST
NL80211_WOWLAN_TRIG_WAKEUP_TCP_MATCH = C.NL80211_WOWLAN_TRIG_WAKEUP_TCP_MATCH
NL80211_WOWLAN_TRIG_WAKEUP_TCP_NOMORETOKENS = C.NL80211_WOWLAN_TRIG_WAKEUP_TCP_NOMORETOKENS
NL80211_WPA_VERSION_1 = C.NL80211_WPA_VERSION_1
NL80211_WPA_VERSION_2 = C.NL80211_WPA_VERSION_2
NL80211_WPA_VERSION_3 = C.NL80211_WPA_VERSION_3
)
// generated by:
// perl -nlE '/^\s*((FR_ACT|FRA)_\w+)/ && say "$1 = C.$1"' include/uapi/linux/fib_rules.h
const (
FRA_UNSPEC = C.FRA_UNSPEC
FRA_DST = C.FRA_DST
FRA_SRC = C.FRA_SRC
FRA_IIFNAME = C.FRA_IIFNAME
FRA_GOTO = C.FRA_GOTO
FRA_UNUSED2 = C.FRA_UNUSED2
FRA_PRIORITY = C.FRA_PRIORITY
FRA_UNUSED3 = C.FRA_UNUSED3
FRA_UNUSED4 = C.FRA_UNUSED4
FRA_UNUSED5 = C.FRA_UNUSED5
FRA_FWMARK = C.FRA_FWMARK
FRA_FLOW = C.FRA_FLOW
FRA_TUN_ID = C.FRA_TUN_ID
FRA_SUPPRESS_IFGROUP = C.FRA_SUPPRESS_IFGROUP
FRA_SUPPRESS_PREFIXLEN = C.FRA_SUPPRESS_PREFIXLEN
FRA_TABLE = C.FRA_TABLE
FRA_FWMASK = C.FRA_FWMASK
FRA_OIFNAME = C.FRA_OIFNAME
FRA_PAD = C.FRA_PAD
FRA_L3MDEV = C.FRA_L3MDEV
FRA_UID_RANGE = C.FRA_UID_RANGE
FRA_PROTOCOL = C.FRA_PROTOCOL
FRA_IP_PROTO = C.FRA_IP_PROTO
FRA_SPORT_RANGE = C.FRA_SPORT_RANGE
FRA_DPORT_RANGE = C.FRA_DPORT_RANGE
FR_ACT_UNSPEC = C.FR_ACT_UNSPEC
FR_ACT_TO_TBL = C.FR_ACT_TO_TBL
FR_ACT_GOTO = C.FR_ACT_GOTO
FR_ACT_NOP = C.FR_ACT_NOP
FR_ACT_RES3 = C.FR_ACT_RES3
FR_ACT_RES4 = C.FR_ACT_RES4
FR_ACT_BLACKHOLE = C.FR_ACT_BLACKHOLE
FR_ACT_UNREACHABLE = C.FR_ACT_UNREACHABLE
FR_ACT_PROHIBIT = C.FR_ACT_PROHIBIT
)
// generated by:
// perl -nlE '/^\s*(AUDIT_NLGRP_\w+)/ && say "$1 = C.$1"' audit.h
const (
AUDIT_NLGRP_NONE = C.AUDIT_NLGRP_NONE
AUDIT_NLGRP_READLOG = C.AUDIT_NLGRP_READLOG
)
// generated by:
// perl -nlE '/^#define (TUN_F_\w+)/ && say "$1 = C.$1"' include/uapi/linux/if_tun.h
const (
TUN_F_CSUM = C.TUN_F_CSUM
TUN_F_TSO4 = C.TUN_F_TSO4
TUN_F_TSO6 = C.TUN_F_TSO6
TUN_F_TSO_ECN = C.TUN_F_TSO_ECN
TUN_F_UFO = C.TUN_F_UFO
TUN_F_USO4 = C.TUN_F_USO4
TUN_F_USO6 = C.TUN_F_USO6
)
// generated by:
// perl -nlE '/^#define (VIRTIO_NET_HDR_F_\w+)/ && say "$1 = C.$1"' include/uapi/linux/virtio_net.h
const (
VIRTIO_NET_HDR_F_NEEDS_CSUM = C.VIRTIO_NET_HDR_F_NEEDS_CSUM
VIRTIO_NET_HDR_F_DATA_VALID = C.VIRTIO_NET_HDR_F_DATA_VALID
VIRTIO_NET_HDR_F_RSC_INFO = C.VIRTIO_NET_HDR_F_RSC_INFO
)
// generated by:
// perl -nlE '/^#define (VIRTIO_NET_HDR_GSO_\w+)/ && say "$1 = C.$1"' include/uapi/linux/virtio_net.h
const (
VIRTIO_NET_HDR_GSO_NONE = C.VIRTIO_NET_HDR_GSO_NONE
VIRTIO_NET_HDR_GSO_TCPV4 = C.VIRTIO_NET_HDR_GSO_TCPV4
VIRTIO_NET_HDR_GSO_UDP = C.VIRTIO_NET_HDR_GSO_UDP
VIRTIO_NET_HDR_GSO_TCPV6 = C.VIRTIO_NET_HDR_GSO_TCPV6
VIRTIO_NET_HDR_GSO_UDP_L4 = C.VIRTIO_NET_HDR_GSO_UDP_L4
VIRTIO_NET_HDR_GSO_ECN = C.VIRTIO_NET_HDR_GSO_ECN
)
type RISCVHWProbePairs C.struct_riscv_hwprobe
// Filtered out for non RISC-V architectures in mkpost.go
// generated by:
// perl -nlE '/^#define\s+(RISCV_HWPROBE_\w+)/ && say "$1 = C.$1"' /tmp/riscv64/include/asm/hwprobe.h
const (
RISCV_HWPROBE_KEY_MVENDORID = C.RISCV_HWPROBE_KEY_MVENDORID
RISCV_HWPROBE_KEY_MARCHID = C.RISCV_HWPROBE_KEY_MARCHID
RISCV_HWPROBE_KEY_MIMPID = C.RISCV_HWPROBE_KEY_MIMPID
RISCV_HWPROBE_KEY_BASE_BEHAVIOR = C.RISCV_HWPROBE_KEY_BASE_BEHAVIOR
RISCV_HWPROBE_BASE_BEHAVIOR_IMA = C.RISCV_HWPROBE_BASE_BEHAVIOR_IMA
RISCV_HWPROBE_KEY_IMA_EXT_0 = C.RISCV_HWPROBE_KEY_IMA_EXT_0
RISCV_HWPROBE_IMA_FD = C.RISCV_HWPROBE_IMA_FD
RISCV_HWPROBE_IMA_C = C.RISCV_HWPROBE_IMA_C
RISCV_HWPROBE_IMA_V = C.RISCV_HWPROBE_IMA_V
RISCV_HWPROBE_EXT_ZBA = C.RISCV_HWPROBE_EXT_ZBA
RISCV_HWPROBE_EXT_ZBB = C.RISCV_HWPROBE_EXT_ZBB
RISCV_HWPROBE_EXT_ZBS = C.RISCV_HWPROBE_EXT_ZBS
RISCV_HWPROBE_KEY_CPUPERF_0 = C.RISCV_HWPROBE_KEY_CPUPERF_0
RISCV_HWPROBE_MISALIGNED_UNKNOWN = C.RISCV_HWPROBE_MISALIGNED_UNKNOWN
RISCV_HWPROBE_MISALIGNED_EMULATED = C.RISCV_HWPROBE_MISALIGNED_EMULATED
RISCV_HWPROBE_MISALIGNED_SLOW = C.RISCV_HWPROBE_MISALIGNED_SLOW
RISCV_HWPROBE_MISALIGNED_FAST = C.RISCV_HWPROBE_MISALIGNED_FAST
RISCV_HWPROBE_MISALIGNED_UNSUPPORTED = C.RISCV_HWPROBE_MISALIGNED_UNSUPPORTED
RISCV_HWPROBE_MISALIGNED_MASK = C.RISCV_HWPROBE_MISALIGNED_MASK
)
type SchedAttr C.struct_sched_attr
const SizeofSchedAttr = C.sizeof_struct_sched_attr
type Cachestat_t C.struct_cachestat
type CachestatRange C.struct_cachestat_range
================================================
FILE: gossh/sys/unix/mkall.sh
================================================
#!/usr/bin/env bash
# Copyright 2009 The Go Authors. All rights reserved.
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file.
# This script runs or (given -n) prints suggested commands to generate files for
# the Architecture/OS specified by the GOARCH and GOOS environment variables.
# See README.md for more information about how the build system works.
GOOSARCH="${GOOS}_${GOARCH}"
# defaults
mksyscall="go run mksyscall.go"
mkerrors="./mkerrors.sh"
zerrors="zerrors_$GOOSARCH.go"
mksysctl=""
zsysctl="zsysctl_$GOOSARCH.go"
mksysnum=
mktypes=
mkasm=
run="sh"
cmd=""
case "$1" in
-syscalls)
for i in zsyscall*go
do
# Run the command line that appears in the first line
# of the generated file to regenerate it.
sed 1q $i | sed 's;^// ;;' | sh > _$i && gofmt < _$i > $i
rm _$i
done
exit 0
;;
-n)
run="cat"
cmd="echo"
shift
esac
case "$#" in
0)
;;
*)
echo 'usage: mkall.sh [-n]' 1>&2
exit 2
esac
if [[ "$GOOS" = "linux" ]]; then
# Use the Docker-based build system
# Files generated through docker (use $cmd so you can Ctl-C the build or run)
$cmd docker build --tag generate:$GOOS $GOOS
$cmd docker run --interactive --tty --volume $(cd -- "$(dirname -- "$0")/.." && pwd):/build generate:$GOOS
exit
fi
GOOSARCH_in=syscall_$GOOSARCH.go
case "$GOOSARCH" in
_* | *_ | _)
echo 'undefined $GOOS_$GOARCH:' "$GOOSARCH" 1>&2
exit 1
;;
aix_ppc)
mkerrors="$mkerrors -maix32"
mksyscall="go run mksyscall_aix_ppc.go -aix"
mktypes="GOARCH=$GOARCH go tool cgo -godefs"
;;
aix_ppc64)
mkerrors="$mkerrors -maix64"
mksyscall="go run mksyscall_aix_ppc64.go -aix"
mktypes="GOARCH=$GOARCH go tool cgo -godefs"
;;
darwin_amd64)
mkerrors="$mkerrors -m64"
mktypes="GOARCH=$GOARCH go tool cgo -godefs"
mkasm="go run mkasm.go"
;;
darwin_arm64)
mkerrors="$mkerrors -m64"
mktypes="GOARCH=$GOARCH go tool cgo -godefs"
mkasm="go run mkasm.go"
;;
dragonfly_amd64)
mkerrors="$mkerrors -m64"
mksyscall="go run mksyscall.go -dragonfly"
mksysnum="go run mksysnum.go 'https://gitweb.dragonflybsd.org/dragonfly.git/blob_plain/HEAD:/sys/kern/syscalls.master'"
mktypes="GOARCH=$GOARCH go tool cgo -godefs"
;;
freebsd_386)
mkerrors="$mkerrors -m32"
mksyscall="go run mksyscall.go -l32"
mksysnum="go run mksysnum.go 'https://cgit.freebsd.org/src/plain/sys/kern/syscalls.master?h=stable/12'"
mktypes="GOARCH=$GOARCH go tool cgo -godefs"
;;
freebsd_amd64)
mkerrors="$mkerrors -m64"
mksysnum="go run mksysnum.go 'https://cgit.freebsd.org/src/plain/sys/kern/syscalls.master?h=stable/12'"
mktypes="GOARCH=$GOARCH go tool cgo -godefs"
;;
freebsd_arm)
mkerrors="$mkerrors"
mksyscall="go run mksyscall.go -l32 -arm"
mksysnum="go run mksysnum.go 'https://cgit.freebsd.org/src/plain/sys/kern/syscalls.master?h=stable/12'"
# Let the type of C char be signed for making the bare syscall
# API consistent across platforms.
mktypes="GOARCH=$GOARCH go tool cgo -godefs -- -fsigned-char"
;;
freebsd_arm64)
mkerrors="$mkerrors -m64"
mksysnum="go run mksysnum.go 'https://cgit.freebsd.org/src/plain/sys/kern/syscalls.master?h=stable/12'"
mktypes="GOARCH=$GOARCH go tool cgo -godefs -- -fsigned-char"
;;
freebsd_riscv64)
mkerrors="$mkerrors -m64"
mksysnum="go run mksysnum.go 'https://cgit.freebsd.org/src/plain/sys/kern/syscalls.master?h=stable/12'"
mktypes="GOARCH=$GOARCH go tool cgo -godefs -- -fsigned-char"
;;
netbsd_386)
mkerrors="$mkerrors -m32"
mksyscall="go run mksyscall.go -l32 -netbsd"
mksysnum="go run mksysnum.go 'http://cvsweb.netbsd.org/bsdweb.cgi/~checkout~/src/sys/kern/syscalls.master'"
mktypes="GOARCH=$GOARCH go tool cgo -godefs"
;;
netbsd_amd64)
mkerrors="$mkerrors -m64"
mksyscall="go run mksyscall.go -netbsd"
mksysnum="go run mksysnum.go 'http://cvsweb.netbsd.org/bsdweb.cgi/~checkout~/src/sys/kern/syscalls.master'"
mktypes="GOARCH=$GOARCH go tool cgo -godefs"
;;
netbsd_arm)
mkerrors="$mkerrors"
mksyscall="go run mksyscall.go -l32 -netbsd -arm"
mksysnum="go run mksysnum.go 'http://cvsweb.netbsd.org/bsdweb.cgi/~checkout~/src/sys/kern/syscalls.master'"
# Let the type of C char be signed for making the bare syscall
# API consistent across platforms.
mktypes="GOARCH=$GOARCH go tool cgo -godefs -- -fsigned-char"
;;
netbsd_arm64)
mkerrors="$mkerrors -m64"
mksyscall="go run mksyscall.go -netbsd"
mksysnum="go run mksysnum.go 'http://cvsweb.netbsd.org/bsdweb.cgi/~checkout~/src/sys/kern/syscalls.master'"
mktypes="GOARCH=$GOARCH go tool cgo -godefs"
;;
openbsd_386)
mkasm="go run mkasm.go"
mkerrors="$mkerrors -m32"
mksyscall="go run mksyscall.go -l32 -openbsd -libc"
mksysctl="go run mksysctl_openbsd.go"
mktypes="GOARCH=$GOARCH go tool cgo -godefs"
;;
openbsd_amd64)
mkasm="go run mkasm.go"
mkerrors="$mkerrors -m64"
mksyscall="go run mksyscall.go -openbsd -libc"
mksysctl="go run mksysctl_openbsd.go"
mktypes="GOARCH=$GOARCH go tool cgo -godefs"
;;
openbsd_arm)
mkasm="go run mkasm.go"
mkerrors="$mkerrors"
mksyscall="go run mksyscall.go -l32 -openbsd -arm -libc"
mksysctl="go run mksysctl_openbsd.go"
# Let the type of C char be signed for making the bare syscall
# API consistent across platforms.
mktypes="GOARCH=$GOARCH go tool cgo -godefs -- -fsigned-char"
;;
openbsd_arm64)
mkasm="go run mkasm.go"
mkerrors="$mkerrors -m64"
mksyscall="go run mksyscall.go -openbsd -libc"
mksysctl="go run mksysctl_openbsd.go"
# Let the type of C char be signed for making the bare syscall
# API consistent across platforms.
mktypes="GOARCH=$GOARCH go tool cgo -godefs -- -fsigned-char"
;;
openbsd_mips64)
mkasm="go run mkasm.go"
mkerrors="$mkerrors -m64"
mksyscall="go run mksyscall.go -openbsd -libc"
mksysctl="go run mksysctl_openbsd.go"
# Let the type of C char be signed for making the bare syscall
# API consistent across platforms.
mktypes="GOARCH=$GOARCH go tool cgo -godefs -- -fsigned-char"
;;
openbsd_ppc64)
mkasm="go run mkasm.go"
mkerrors="$mkerrors -m64"
mksyscall="go run mksyscall.go -openbsd -libc"
mksysctl="go run mksysctl_openbsd.go"
# Let the type of C char be signed for making the bare syscall
# API consistent across platforms.
mktypes="GOARCH=$GOARCH go tool cgo -godefs -- -fsigned-char"
;;
openbsd_riscv64)
mkasm="go run mkasm.go"
mkerrors="$mkerrors -m64"
mksyscall="go run mksyscall.go -openbsd -libc"
mksysctl="go run mksysctl_openbsd.go"
# Let the type of C char be signed for making the bare syscall
# API consistent across platforms.
mktypes="GOARCH=$GOARCH go tool cgo -godefs -- -fsigned-char"
;;
solaris_amd64)
mksyscall="go run mksyscall_solaris.go"
mkerrors="$mkerrors -m64"
mksysnum=
mktypes="GOARCH=$GOARCH go tool cgo -godefs"
;;
illumos_amd64)
mksyscall="go run mksyscall_solaris.go"
mkerrors=
mksysnum=
mktypes="GOARCH=$GOARCH go tool cgo -godefs"
;;
*)
echo 'unrecognized $GOOS_$GOARCH: ' "$GOOSARCH" 1>&2
exit 1
;;
esac
(
if [ -n "$mkerrors" ]; then echo "$mkerrors |gofmt >$zerrors"; fi
case "$GOOS" in
*)
syscall_goos="syscall_$GOOS.go"
case "$GOOS" in
darwin | dragonfly | freebsd | netbsd | openbsd)
syscall_goos="syscall_bsd.go $syscall_goos"
;;
esac
if [ -n "$mksyscall" ]; then
if [ "$GOOSARCH" == "aix_ppc64" ]; then
# aix/ppc64 script generates files instead of writing to stdin.
echo "$mksyscall -tags $GOOS,$GOARCH $syscall_goos $GOOSARCH_in && gofmt -w zsyscall_$GOOSARCH.go && gofmt -w zsyscall_"$GOOSARCH"_gccgo.go && gofmt -w zsyscall_"$GOOSARCH"_gc.go " ;
elif [ "$GOOS" == "illumos" ]; then
# illumos code generation requires a --illumos switch
echo "$mksyscall -illumos -tags illumos,$GOARCH syscall_illumos.go |gofmt > zsyscall_illumos_$GOARCH.go";
# illumos implies solaris, so solaris code generation is also required
echo "$mksyscall -tags solaris,$GOARCH syscall_solaris.go syscall_solaris_$GOARCH.go |gofmt >zsyscall_solaris_$GOARCH.go";
else
echo "$mksyscall -tags $GOOS,$GOARCH $syscall_goos $GOOSARCH_in |gofmt >zsyscall_$GOOSARCH.go";
fi
fi
esac
if [ -n "$mksysctl" ]; then echo "$mksysctl |gofmt >$zsysctl"; fi
if [ -n "$mksysnum" ]; then echo "$mksysnum |gofmt >zsysnum_$GOOSARCH.go"; fi
if [ -n "$mktypes" ]; then echo "$mktypes types_$GOOS.go | go run mkpost.go > ztypes_$GOOSARCH.go"; fi
if [ -n "$mkasm" ]; then echo "$mkasm $GOOS $GOARCH"; fi
) | $run
================================================
FILE: gossh/sys/unix/mkasm.go
================================================
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build ignore
// mkasm.go generates assembly trampolines to call library routines from Go.
// This program must be run after mksyscall.go.
package main
import (
"bytes"
"fmt"
"log"
"os"
"sort"
"strings"
)
func archPtrSize(arch string) int {
switch arch {
case "386", "arm":
return 4
case "amd64", "arm64", "mips64", "ppc64", "riscv64":
return 8
default:
log.Fatalf("Unknown arch %q", arch)
return 0
}
}
func generateASMFile(goos, arch string, inFileNames []string, outFileName string) map[string]bool {
trampolines := map[string]bool{}
var orderedTrampolines []string
for _, inFileName := range inFileNames {
in, err := os.ReadFile(inFileName)
if err != nil {
log.Fatalf("Failed to read file: %v", err)
}
for _, line := range strings.Split(string(in), "\n") {
const prefix = "var "
const suffix = "_trampoline_addr uintptr"
if !strings.HasPrefix(line, prefix) || !strings.HasSuffix(line, suffix) {
continue
}
fn := strings.TrimSuffix(strings.TrimPrefix(line, prefix), suffix)
if !trampolines[fn] {
orderedTrampolines = append(orderedTrampolines, fn)
trampolines[fn] = true
}
}
}
ptrSize := archPtrSize(arch)
var out bytes.Buffer
fmt.Fprintf(&out, "// go run mkasm.go %s\n", strings.Join(os.Args[1:], " "))
fmt.Fprintf(&out, "// Code generated by the command above; DO NOT EDIT.\n")
fmt.Fprintf(&out, "\n")
fmt.Fprintf(&out, "#include \"textflag.h\"\n")
for _, fn := range orderedTrampolines {
fmt.Fprintf(&out, "\nTEXT %s_trampoline<>(SB),NOSPLIT,$0-0\n", fn)
if goos == "openbsd" && arch == "ppc64" {
fmt.Fprintf(&out, "\tCALL\t%s(SB)\n", fn)
fmt.Fprintf(&out, "\tRET\n")
} else {
fmt.Fprintf(&out, "\tJMP\t%s(SB)\n", fn)
}
fmt.Fprintf(&out, "GLOBL\t·%s_trampoline_addr(SB), RODATA, $%d\n", fn, ptrSize)
fmt.Fprintf(&out, "DATA\t·%s_trampoline_addr(SB)/%d, $%s_trampoline<>(SB)\n", fn, ptrSize, fn)
}
if err := os.WriteFile(outFileName, out.Bytes(), 0644); err != nil {
log.Fatalf("Failed to write assembly file %q: %v", outFileName, err)
}
return trampolines
}
const darwinTestTemplate = `// go run mkasm.go %s
// Code generated by the command above; DO NOT EDIT.
//go:build darwin
package unix
// All the _trampoline functions in zsyscall_darwin_%s.s.
var darwinTests = [...]darwinTest{
%s}
`
func writeDarwinTest(trampolines map[string]bool, fileName, arch string) {
var sortedTrampolines []string
for fn := range trampolines {
sortedTrampolines = append(sortedTrampolines, fn)
}
sort.Strings(sortedTrampolines)
var out bytes.Buffer
const prefix = "libc_"
for _, fn := range sortedTrampolines {
fmt.Fprintf(&out, fmt.Sprintf("\t{%q, %s_trampoline_addr},\n", strings.TrimPrefix(fn, prefix), fn))
}
lines := out.String()
out.Reset()
fmt.Fprintf(&out, darwinTestTemplate, strings.Join(os.Args[1:], " "), arch, lines)
if err := os.WriteFile(fileName, out.Bytes(), 0644); err != nil {
log.Fatalf("Failed to write test file %q: %v", fileName, err)
}
}
func main() {
if len(os.Args) != 3 {
log.Fatalf("Usage: %s ", os.Args[0])
}
goos, arch := os.Args[1], os.Args[2]
syscallFilename := fmt.Sprintf("syscall_%s.go", goos)
syscallArchFilename := fmt.Sprintf("syscall_%s_%s.go", goos, arch)
zsyscallArchFilename := fmt.Sprintf("zsyscall_%s_%s.go", goos, arch)
zsyscallASMFileName := fmt.Sprintf("zsyscall_%s_%s.s", goos, arch)
inFileNames := []string{
syscallFilename,
syscallArchFilename,
zsyscallArchFilename,
}
trampolines := generateASMFile(goos, arch, inFileNames, zsyscallASMFileName)
if goos == "darwin" {
writeDarwinTest(trampolines, fmt.Sprintf("darwin_%s_test.go", arch), arch)
}
}
================================================
FILE: gossh/sys/unix/mkerrors.sh
================================================
#!/usr/bin/env bash
# Copyright 2009 The Go Authors. All rights reserved.
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file.
# Generate Go code listing errors and other #defined constant
# values (ENAMETOOLONG etc.), by asking the preprocessor
# about the definitions.
unset LANG
export LC_ALL=C
export LC_CTYPE=C
if test -z "$GOARCH" -o -z "$GOOS"; then
echo 1>&2 "GOARCH or GOOS not defined in environment"
exit 1
fi
# Check that we are using the new build system if we should
if [[ "$GOOS" = "linux" ]] && [[ "$GOLANG_SYS_BUILD" != "docker" ]]; then
echo 1>&2 "In the Docker based build system, mkerrors should not be called directly."
echo 1>&2 "See README.md"
exit 1
fi
if [[ "$GOOS" = "aix" ]]; then
CC=${CC:-gcc}
else
CC=${CC:-cc}
fi
if [[ "$GOOS" = "solaris" ]]; then
# Assumes GNU versions of utilities in PATH.
export PATH=/usr/gnu/bin:$PATH
fi
uname=$(uname)
includes_AIX='
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#define AF_LOCAL AF_UNIX
'
includes_Darwin='
#define _DARWIN_C_SOURCE
#define KERNEL 1
#define _DARWIN_USE_64_BIT_INODE
#define __APPLE_USE_RFC_3542
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
// for backwards compatibility because moved TIOCREMOTE to Kernel.framework after MacOSX12.0.sdk.
#define TIOCREMOTE 0x80047469
'
includes_DragonFly='
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
'
includes_FreeBSD='
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include