Merge pull request #2293 from TomSweeneyRedHat/dev/tsweeney/cve-jose-1.9

[release-1.9] Bump ocicrypt and go-jose CVE-2024-28180
This commit is contained in:
Miloslav Trmač 2024-05-02 21:14:39 +02:00 committed by GitHub
commit 76adb508ff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
739 changed files with 63563 additions and 22122 deletions

View File

@ -74,33 +74,8 @@ doccheck_task:
"${SKOPEO_PATH}/${SCRIPT_BASE}/runner.sh" build "${SKOPEO_PATH}/${SCRIPT_BASE}/runner.sh" build
"${SKOPEO_PATH}/${SCRIPT_BASE}/runner.sh" doccheck "${SKOPEO_PATH}/${SCRIPT_BASE}/runner.sh" doccheck
osx_task:
# Run for regular PRs and those with [CI:BUILD] but not [CI:DOCS]
only_if: &not_docs_multiarch >-
$CIRRUS_CHANGE_TITLE !=~ '.*CI:DOCS.*' &&
$CIRRUS_CRON != 'multiarch'
depends_on:
- validate
macos_instance:
image: catalina-xcode
setup_script: |
# /usr/local/opt/go@1.18 will be populated by (brew install go@1.18) below
export PATH=$GOPATH/bin:/usr/local/opt/go@1.18/bin:$PATH
brew update
brew install gpgme go@1.18 go-md2man
go install golang.org/x/lint/golint@latest
test_script: |
export PATH=$GOPATH/bin:/usr/local/opt/go@1.18/bin:$PATH
go version
go env
make validate-local test-unit-local bin/skopeo
sudo make install
/usr/local/bin/skopeo -v
cross_task: cross_task:
alias: cross alias: cross
only_if: *not_docs_multiarch
depends_on: depends_on:
- validate - validate
gce_instance: &standardvm gce_instance: &standardvm
@ -241,7 +216,6 @@ success_task:
depends_on: depends_on:
- validate - validate
- doccheck - doccheck
- osx
- cross - cross
- test_skopeo - test_skopeo
- image_build - image_build

View File

@ -242,12 +242,12 @@ test-unit-local: bin/skopeo
$(GO) test $(MOD_VENDOR) -tags "$(BUILDTAGS)" $$($(GO) list $(MOD_VENDOR) -tags "$(BUILDTAGS)" -e ./... | grep -v '^github\.com/containers/skopeo/\(integration\|vendor/.*\)$$') $(GO) test $(MOD_VENDOR) -tags "$(BUILDTAGS)" $$($(GO) list $(MOD_VENDOR) -tags "$(BUILDTAGS)" -e ./... | grep -v '^github\.com/containers/skopeo/\(integration\|vendor/.*\)$$')
vendor: vendor:
$(GO) mod tidy $(GO) mod tidy -compat=1.17
$(GO) mod vendor $(GO) mod vendor
$(GO) mod verify $(GO) mod verify
vendor-in-container: vendor-in-container:
podman run --privileged --rm --env HOME=/root -v $(CURDIR):/src -w /src quay.io/libpod/golang:1.16 $(MAKE) vendor podman run --privileged --rm --env HOME=/root -v $(CURDIR):/src -w /src golang $(MAKE) vendor
# CAUTION: This is not a replacement for RPMs provided by your distro. # CAUTION: This is not a replacement for RPMs provided by your distro.
# Only intended to build and test the latest unreleased changes. # Only intended to build and test the latest unreleased changes.

62
go.mod
View File

@ -7,57 +7,57 @@ require (
github.com/containers/image/v5 v5.22.1 github.com/containers/image/v5 v5.22.1
github.com/containers/ocicrypt v1.1.5 github.com/containers/ocicrypt v1.1.5
github.com/containers/storage v1.42.0 github.com/containers/storage v1.42.0
github.com/docker/docker v20.10.17+incompatible github.com/docker/docker v20.10.20+incompatible
github.com/opencontainers/go-digest v1.0.0 github.com/opencontainers/go-digest v1.0.0
github.com/opencontainers/image-spec v1.0.3-0.20220114050600-8b9d41f48198 github.com/opencontainers/image-spec v1.1.0-rc2
github.com/opencontainers/image-tools v1.0.0-rc3 github.com/opencontainers/image-tools v1.0.0-rc3
github.com/sirupsen/logrus v1.9.0 github.com/sirupsen/logrus v1.9.0
github.com/spf13/cobra v1.5.0 github.com/spf13/cobra v1.6.0
github.com/spf13/pflag v1.0.5 github.com/spf13/pflag v1.0.5
github.com/stretchr/testify v1.8.0 github.com/stretchr/testify v1.8.1
github.com/syndtr/gocapability v0.0.0-20200815063812-42c35b437635 github.com/syndtr/gocapability v0.0.0-20200815063812-42c35b437635
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 golang.org/x/term v0.5.0
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c
gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v2 v2.4.0
) )
require ( require (
github.com/BurntSushi/toml v1.2.0 // indirect github.com/BurntSushi/toml v1.2.0 // indirect
github.com/Microsoft/go-winio v0.5.2 // indirect github.com/Microsoft/go-winio v0.6.0 // indirect
github.com/Microsoft/hcsshim v0.9.3 // indirect github.com/Microsoft/hcsshim v0.9.3 // indirect
github.com/VividCortex/ewma v1.2.0 // indirect github.com/VividCortex/ewma v1.2.0 // indirect
github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d // indirect github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d // indirect
github.com/beorn7/perks v1.0.1 // indirect github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/containerd/cgroups v1.0.3 // indirect github.com/containerd/cgroups v1.0.3 // indirect
github.com/containerd/stargz-snapshotter/estargz v0.12.0 // indirect github.com/containerd/stargz-snapshotter/estargz v0.12.1 // indirect
github.com/containers/libtrust v0.0.0-20200511145503-9c3a6c22cd9a // indirect github.com/containers/libtrust v0.0.0-20200511145503-9c3a6c22cd9a // indirect
github.com/cyphar/filepath-securejoin v0.2.3 // indirect github.com/cyphar/filepath-securejoin v0.2.3 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/docker/distribution v2.8.1+incompatible // indirect github.com/docker/distribution v2.8.1+incompatible // indirect
github.com/docker/docker-credential-helpers v0.6.4 // indirect github.com/docker/docker-credential-helpers v0.7.0 // indirect
github.com/docker/go-connections v0.4.0 // indirect github.com/docker/go-connections v0.4.0 // indirect
github.com/docker/go-metrics v0.0.1 // indirect github.com/docker/go-metrics v0.0.1 // indirect
github.com/docker/go-units v0.4.0 // indirect github.com/docker/go-units v0.5.0 // indirect
github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5 // indirect github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5 // indirect
github.com/ghodss/yaml v1.0.0 // indirect github.com/ghodss/yaml v1.0.0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/protobuf v1.5.2 // indirect github.com/golang/protobuf v1.5.2 // indirect
github.com/google/go-containerregistry v0.10.0 // indirect github.com/google/go-containerregistry v0.13.0 // indirect
github.com/google/go-intervals v0.0.2 // indirect github.com/google/go-intervals v0.0.2 // indirect
github.com/google/uuid v1.3.0 // indirect github.com/google/uuid v1.3.0 // indirect
github.com/gorilla/mux v1.8.0 // indirect github.com/gorilla/mux v1.8.0 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect
github.com/imdario/mergo v0.3.13 // indirect github.com/imdario/mergo v0.3.13 // indirect
github.com/inconshreveable/mousetrap v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.0.1 // indirect
github.com/json-iterator/go v1.1.12 // indirect github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.15.9 // indirect github.com/klauspost/compress v1.15.11 // indirect
github.com/klauspost/pgzip v1.2.5 // indirect github.com/klauspost/pgzip v1.2.5 // indirect
github.com/kr/pretty v0.2.1 // indirect github.com/kr/pretty v0.2.1 // indirect
github.com/kr/text v0.2.0 // indirect github.com/kr/text v0.2.0 // indirect
github.com/letsencrypt/boulder v0.0.0-20220331220046-b23ab962616e // indirect github.com/letsencrypt/boulder v0.0.0-20230130200452-c091e64aa391 // indirect
github.com/mattn/go-runewidth v0.0.13 // indirect github.com/mattn/go-runewidth v0.0.13 // indirect
github.com/mattn/go-shellwords v1.0.12 // indirect github.com/mattn/go-shellwords v1.0.12 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 // indirect
@ -73,17 +73,17 @@ require (
github.com/pkg/errors v0.9.1 // indirect github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/proglottis/gpgme v0.1.3 // indirect github.com/proglottis/gpgme v0.1.3 // indirect
github.com/prometheus/client_golang v1.12.1 // indirect github.com/prometheus/client_golang v1.13.0 // indirect
github.com/prometheus/client_model v0.2.0 // indirect github.com/prometheus/client_model v0.3.0 // indirect
github.com/prometheus/common v0.32.1 // indirect github.com/prometheus/common v0.37.0 // indirect
github.com/prometheus/procfs v0.7.3 // indirect github.com/prometheus/procfs v0.8.0 // indirect
github.com/rivo/uniseg v0.2.0 // indirect github.com/rivo/uniseg v0.2.0 // indirect
github.com/russross/blackfriday v2.0.0+incompatible // indirect github.com/russross/blackfriday v2.0.0+incompatible // indirect
github.com/sigstore/sigstore v1.3.1-0.20220629021053-b95fc0d626c1 // indirect github.com/sigstore/sigstore v1.5.2 // indirect
github.com/stefanberger/go-pkcs11uri v0.0.0-20201008174630-78d3cae3a980 // indirect github.com/stefanberger/go-pkcs11uri v0.0.0-20201008174630-78d3cae3a980 // indirect
github.com/sylabs/sif/v2 v2.7.1 // indirect github.com/sylabs/sif/v2 v2.7.1 // indirect
github.com/tchap/go-patricia v2.3.0+incompatible // indirect github.com/tchap/go-patricia v2.3.0+incompatible // indirect
github.com/theupdateframework/go-tuf v0.3.1 // indirect github.com/theupdateframework/go-tuf v0.5.2-0.20220930112810-3890c1e7ace4 // indirect
github.com/titanous/rocacheck v0.0.0-20171023193734-afe73141d399 // indirect github.com/titanous/rocacheck v0.0.0-20171023193734-afe73141d399 // indirect
github.com/ulikunitz/xz v0.5.10 // indirect github.com/ulikunitz/xz v0.5.10 // indirect
github.com/vbatts/tar-split v0.11.2 // indirect github.com/vbatts/tar-split v0.11.2 // indirect
@ -93,16 +93,18 @@ require (
github.com/xeipuuv/gojsonschema v1.2.0 // indirect github.com/xeipuuv/gojsonschema v1.2.0 // indirect
go.etcd.io/bbolt v1.3.6 // indirect go.etcd.io/bbolt v1.3.6 // indirect
go.mozilla.org/pkcs7 v0.0.0-20200128120323-432b2356ecb1 // indirect go.mozilla.org/pkcs7 v0.0.0-20200128120323-432b2356ecb1 // indirect
go.opencensus.io v0.23.0 // indirect go.opencensus.io v0.24.0 // indirect
go.opentelemetry.io/otel/trace v1.3.0 // indirect golang.org/x/crypto v0.6.0 // indirect
golang.org/x/crypto v0.0.0-20220131195533-30dcbda58838 // indirect golang.org/x/mod v0.6.0 // indirect
golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e // indirect golang.org/x/net v0.7.0 // indirect
golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f // indirect golang.org/x/sync v0.1.0 // indirect
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 // indirect golang.org/x/sys v0.5.0 // indirect
golang.org/x/text v0.3.7 // indirect golang.org/x/text v0.7.0 // indirect
google.golang.org/genproto v0.0.0-20220624142145-8cd45d7dbd1f // indirect golang.org/x/tools v0.2.0 // indirect
google.golang.org/grpc v1.47.0 // indirect google.golang.org/genproto v0.0.0-20230209215440-0dfe4f8abfcc // indirect
google.golang.org/protobuf v1.28.0 // indirect google.golang.org/grpc v1.53.0 // indirect
google.golang.org/protobuf v1.28.1 // indirect
gopkg.in/go-jose/go-jose.v2 v2.6.3 // indirect
gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )

1270
go.sum

File diff suppressed because it is too large Load Diff

View File

@ -31,7 +31,8 @@ const (
v2DockerRegistryURL = "localhost:5555" // Update also policy.json v2DockerRegistryURL = "localhost:5555" // Update also policy.json
v2s1DockerRegistryURL = "localhost:5556" v2s1DockerRegistryURL = "localhost:5556"
knownWindowsOnlyImage = "docker://mcr.microsoft.com/windows/nanoserver:1909" knownWindowsOnlyImage = "docker://mcr.microsoft.com/windows/nanoserver:1909"
knownListImage = "docker://registry.fedoraproject.org/fedora-minimal" // could have either ":latest" or "@sha256:..." appended knownListImageRepo = "docker://registry.fedoraproject.org/fedora-minimal"
knownListImage = knownListImageRepo + ":38"
) )
type CopySuite struct { type CopySuite struct {
@ -196,8 +197,8 @@ func (s *CopySuite) TestCopyWithManifestListDigest(c *check.C) {
manifestDigest, err := manifest.Digest([]byte(m)) manifestDigest, err := manifest.Digest([]byte(m))
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
digest := manifestDigest.String() digest := manifestDigest.String()
assertSkopeoSucceeds(c, "", "copy", knownListImage+"@"+digest, "dir:"+dir1) assertSkopeoSucceeds(c, "", "copy", knownListImageRepo+"@"+digest, "dir:"+dir1)
assertSkopeoSucceeds(c, "", "copy", "--multi-arch=all", knownListImage+"@"+digest, "dir:"+dir2) assertSkopeoSucceeds(c, "", "copy", "--multi-arch=all", knownListImageRepo+"@"+digest, "dir:"+dir2)
assertSkopeoSucceeds(c, "", "copy", "dir:"+dir1, "oci:"+oci1) assertSkopeoSucceeds(c, "", "copy", "dir:"+dir1, "oci:"+oci1)
assertSkopeoSucceeds(c, "", "copy", "dir:"+dir2, "oci:"+oci2) assertSkopeoSucceeds(c, "", "copy", "dir:"+dir2, "oci:"+oci2)
out := combinedOutputOfCommand(c, "diff", "-urN", oci1, oci2) out := combinedOutputOfCommand(c, "diff", "-urN", oci1, oci2)
@ -224,9 +225,9 @@ func (s *CopySuite) TestCopyWithManifestListStorageDigest(c *check.C) {
manifestDigest, err := manifest.Digest([]byte(m)) manifestDigest, err := manifest.Digest([]byte(m))
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
digest := manifestDigest.String() digest := manifestDigest.String()
assertSkopeoSucceeds(c, "", "copy", knownListImage+"@"+digest, "containers-storage:"+storage+"test@"+digest) assertSkopeoSucceeds(c, "", "copy", knownListImageRepo+"@"+digest, "containers-storage:"+storage+"test@"+digest)
assertSkopeoSucceeds(c, "", "copy", "containers-storage:"+storage+"test@"+digest, "dir:"+dir1) assertSkopeoSucceeds(c, "", "copy", "containers-storage:"+storage+"test@"+digest, "dir:"+dir1)
assertSkopeoSucceeds(c, "", "copy", knownListImage+"@"+digest, "dir:"+dir2) assertSkopeoSucceeds(c, "", "copy", knownListImageRepo+"@"+digest, "dir:"+dir2)
runDecompressDirs(c, "", dir1, dir2) runDecompressDirs(c, "", dir1, dir2)
assertDirImagesAreEqual(c, dir1, dir2) assertDirImagesAreEqual(c, dir1, dir2)
} }
@ -240,9 +241,9 @@ func (s *CopySuite) TestCopyWithManifestListStorageDigestMultipleArches(c *check
manifestDigest, err := manifest.Digest([]byte(m)) manifestDigest, err := manifest.Digest([]byte(m))
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
digest := manifestDigest.String() digest := manifestDigest.String()
assertSkopeoSucceeds(c, "", "copy", knownListImage+"@"+digest, "containers-storage:"+storage+"test@"+digest) assertSkopeoSucceeds(c, "", "copy", knownListImageRepo+"@"+digest, "containers-storage:"+storage+"test@"+digest)
assertSkopeoSucceeds(c, "", "copy", "containers-storage:"+storage+"test@"+digest, "dir:"+dir1) assertSkopeoSucceeds(c, "", "copy", "containers-storage:"+storage+"test@"+digest, "dir:"+dir1)
assertSkopeoSucceeds(c, "", "copy", knownListImage+"@"+digest, "dir:"+dir2) assertSkopeoSucceeds(c, "", "copy", knownListImageRepo+"@"+digest, "dir:"+dir2)
runDecompressDirs(c, "", dir1, dir2) runDecompressDirs(c, "", dir1, dir2)
assertDirImagesAreEqual(c, dir1, dir2) assertDirImagesAreEqual(c, dir1, dir2)
} }
@ -256,8 +257,8 @@ func (s *CopySuite) TestCopyWithManifestListStorageDigestMultipleArchesBothUseLi
digest := manifestDigest.String() digest := manifestDigest.String()
_, err = manifest.ListFromBlob([]byte(m), manifest.GuessMIMEType([]byte(m))) _, err = manifest.ListFromBlob([]byte(m), manifest.GuessMIMEType([]byte(m)))
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
assertSkopeoSucceeds(c, "", "--override-arch=amd64", "copy", knownListImage+"@"+digest, "containers-storage:"+storage+"test@"+digest) assertSkopeoSucceeds(c, "", "--override-arch=amd64", "copy", knownListImageRepo+"@"+digest, "containers-storage:"+storage+"test@"+digest)
assertSkopeoSucceeds(c, "", "--override-arch=arm64", "copy", knownListImage+"@"+digest, "containers-storage:"+storage+"test@"+digest) assertSkopeoSucceeds(c, "", "--override-arch=arm64", "copy", knownListImageRepo+"@"+digest, "containers-storage:"+storage+"test@"+digest)
assertSkopeoFails(c, `.*reading manifest for image instance.*does not exist.*`, "--override-arch=amd64", "inspect", "containers-storage:"+storage+"test@"+digest) assertSkopeoFails(c, `.*reading manifest for image instance.*does not exist.*`, "--override-arch=amd64", "inspect", "containers-storage:"+storage+"test@"+digest)
assertSkopeoFails(c, `.*reading manifest for image instance.*does not exist.*`, "--override-arch=amd64", "inspect", "--config", "containers-storage:"+storage+"test@"+digest) assertSkopeoFails(c, `.*reading manifest for image instance.*does not exist.*`, "--override-arch=amd64", "inspect", "--config", "containers-storage:"+storage+"test@"+digest)
i2 := combinedOutputOfCommand(c, skopeoBinary, "--override-arch=arm64", "inspect", "--config", "containers-storage:"+storage+"test@"+digest) i2 := combinedOutputOfCommand(c, skopeoBinary, "--override-arch=arm64", "inspect", "--config", "containers-storage:"+storage+"test@"+digest)
@ -280,8 +281,8 @@ func (s *CopySuite) TestCopyWithManifestListStorageDigestMultipleArchesFirstUses
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
arm64Instance, err := list.ChooseInstance(&types.SystemContext{ArchitectureChoice: "arm64"}) arm64Instance, err := list.ChooseInstance(&types.SystemContext{ArchitectureChoice: "arm64"})
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
assertSkopeoSucceeds(c, "", "--override-arch=amd64", "copy", knownListImage+"@"+digest, "containers-storage:"+storage+"test@"+digest) assertSkopeoSucceeds(c, "", "--override-arch=amd64", "copy", knownListImageRepo+"@"+digest, "containers-storage:"+storage+"test@"+digest)
assertSkopeoSucceeds(c, "", "--override-arch=arm64", "copy", knownListImage+"@"+arm64Instance.String(), "containers-storage:"+storage+"test@"+arm64Instance.String()) assertSkopeoSucceeds(c, "", "--override-arch=arm64", "copy", knownListImageRepo+"@"+arm64Instance.String(), "containers-storage:"+storage+"test@"+arm64Instance.String())
i1 := combinedOutputOfCommand(c, skopeoBinary, "--override-arch=amd64", "inspect", "--config", "containers-storage:"+storage+"test@"+digest) i1 := combinedOutputOfCommand(c, skopeoBinary, "--override-arch=amd64", "inspect", "--config", "containers-storage:"+storage+"test@"+digest)
var image1 imgspecv1.Image var image1 imgspecv1.Image
err = json.Unmarshal([]byte(i1), &image1) err = json.Unmarshal([]byte(i1), &image1)
@ -314,8 +315,8 @@ func (s *CopySuite) TestCopyWithManifestListStorageDigestMultipleArchesSecondUse
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
arm64Instance, err := list.ChooseInstance(&types.SystemContext{ArchitectureChoice: "arm64"}) arm64Instance, err := list.ChooseInstance(&types.SystemContext{ArchitectureChoice: "arm64"})
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
assertSkopeoSucceeds(c, "", "--override-arch=amd64", "copy", knownListImage+"@"+amd64Instance.String(), "containers-storage:"+storage+"test@"+amd64Instance.String()) assertSkopeoSucceeds(c, "", "--override-arch=amd64", "copy", knownListImageRepo+"@"+amd64Instance.String(), "containers-storage:"+storage+"test@"+amd64Instance.String())
assertSkopeoSucceeds(c, "", "--override-arch=arm64", "copy", knownListImage+"@"+digest, "containers-storage:"+storage+"test@"+digest) assertSkopeoSucceeds(c, "", "--override-arch=arm64", "copy", knownListImageRepo+"@"+digest, "containers-storage:"+storage+"test@"+digest)
i1 := combinedOutputOfCommand(c, skopeoBinary, "--override-arch=amd64", "inspect", "--config", "containers-storage:"+storage+"test@"+amd64Instance.String()) i1 := combinedOutputOfCommand(c, skopeoBinary, "--override-arch=amd64", "inspect", "--config", "containers-storage:"+storage+"test@"+amd64Instance.String())
var image1 imgspecv1.Image var image1 imgspecv1.Image
err = json.Unmarshal([]byte(i1), &image1) err = json.Unmarshal([]byte(i1), &image1)
@ -348,9 +349,9 @@ func (s *CopySuite) TestCopyWithManifestListStorageDigestMultipleArchesThirdUses
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
arm64Instance, err := list.ChooseInstance(&types.SystemContext{ArchitectureChoice: "arm64"}) arm64Instance, err := list.ChooseInstance(&types.SystemContext{ArchitectureChoice: "arm64"})
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
assertSkopeoSucceeds(c, "", "--override-arch=amd64", "copy", knownListImage+"@"+amd64Instance.String(), "containers-storage:"+storage+"test@"+amd64Instance.String()) assertSkopeoSucceeds(c, "", "--override-arch=amd64", "copy", knownListImageRepo+"@"+amd64Instance.String(), "containers-storage:"+storage+"test@"+amd64Instance.String())
assertSkopeoSucceeds(c, "", "--override-arch=amd64", "copy", knownListImage+"@"+digest, "containers-storage:"+storage+"test@"+digest) assertSkopeoSucceeds(c, "", "--override-arch=amd64", "copy", knownListImageRepo+"@"+digest, "containers-storage:"+storage+"test@"+digest)
assertSkopeoSucceeds(c, "", "--override-arch=arm64", "copy", knownListImage+"@"+digest, "containers-storage:"+storage+"test@"+digest) assertSkopeoSucceeds(c, "", "--override-arch=arm64", "copy", knownListImageRepo+"@"+digest, "containers-storage:"+storage+"test@"+digest)
assertSkopeoFails(c, `.*reading manifest for image instance.*does not exist.*`, "--override-arch=amd64", "inspect", "--config", "containers-storage:"+storage+"test@"+digest) assertSkopeoFails(c, `.*reading manifest for image instance.*does not exist.*`, "--override-arch=amd64", "inspect", "--config", "containers-storage:"+storage+"test@"+digest)
i1 := combinedOutputOfCommand(c, skopeoBinary, "--override-arch=amd64", "inspect", "--config", "containers-storage:"+storage+"test@"+amd64Instance.String()) i1 := combinedOutputOfCommand(c, skopeoBinary, "--override-arch=amd64", "inspect", "--config", "containers-storage:"+storage+"test@"+amd64Instance.String())
var image1 imgspecv1.Image var image1 imgspecv1.Image
@ -383,7 +384,7 @@ func (s *CopySuite) TestCopyWithManifestListStorageDigestMultipleArchesTagAndDig
arm64Instance, err := list.ChooseInstance(&types.SystemContext{ArchitectureChoice: "arm64"}) arm64Instance, err := list.ChooseInstance(&types.SystemContext{ArchitectureChoice: "arm64"})
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
assertSkopeoSucceeds(c, "", "--override-arch=amd64", "copy", knownListImage, "containers-storage:"+storage+"test:latest") assertSkopeoSucceeds(c, "", "--override-arch=amd64", "copy", knownListImage, "containers-storage:"+storage+"test:latest")
assertSkopeoSucceeds(c, "", "--override-arch=arm64", "copy", knownListImage+"@"+digest, "containers-storage:"+storage+"test@"+digest) assertSkopeoSucceeds(c, "", "--override-arch=arm64", "copy", knownListImageRepo+"@"+digest, "containers-storage:"+storage+"test@"+digest)
assertSkopeoFails(c, `.*reading manifest for image instance.*does not exist.*`, "--override-arch=amd64", "inspect", "--config", "containers-storage:"+storage+"test@"+digest) assertSkopeoFails(c, `.*reading manifest for image instance.*does not exist.*`, "--override-arch=amd64", "inspect", "--config", "containers-storage:"+storage+"test@"+digest)
i1 := combinedOutputOfCommand(c, skopeoBinary, "--override-arch=arm64", "inspect", "--config", "containers-storage:"+storage+"test:latest") i1 := combinedOutputOfCommand(c, skopeoBinary, "--override-arch=arm64", "inspect", "--config", "containers-storage:"+storage+"test:latest")
var image1 imgspecv1.Image var image1 imgspecv1.Image

1
vendor/github.com/Microsoft/go-winio/.gitattributes generated vendored Normal file
View File

@ -0,0 +1 @@
* text=auto eol=lf

View File

@ -1 +1,10 @@
.vscode/
*.exe *.exe
# testing
testdata
# go workspaces
go.work
go.work.sum

144
vendor/github.com/Microsoft/go-winio/.golangci.yml generated vendored Normal file
View File

@ -0,0 +1,144 @@
run:
skip-dirs:
- pkg/etw/sample
linters:
enable:
# style
- containedctx # struct contains a context
- dupl # duplicate code
- errname # erorrs are named correctly
- goconst # strings that should be constants
- godot # comments end in a period
- misspell
- nolintlint # "//nolint" directives are properly explained
- revive # golint replacement
- stylecheck # golint replacement, less configurable than revive
- unconvert # unnecessary conversions
- wastedassign
# bugs, performance, unused, etc ...
- contextcheck # function uses a non-inherited context
- errorlint # errors not wrapped for 1.13
- exhaustive # check exhaustiveness of enum switch statements
- gofmt # files are gofmt'ed
- gosec # security
- nestif # deeply nested ifs
- nilerr # returns nil even with non-nil error
- prealloc # slices that can be pre-allocated
- structcheck # unused struct fields
- unparam # unused function params
issues:
exclude-rules:
# err is very often shadowed in nested scopes
- linters:
- govet
text: '^shadow: declaration of "err" shadows declaration'
# ignore long lines for skip autogen directives
- linters:
- revive
text: "^line-length-limit: "
source: "^//(go:generate|sys) "
# allow unjustified ignores of error checks in defer statements
- linters:
- nolintlint
text: "^directive `//nolint:errcheck` should provide explanation"
source: '^\s*defer '
# allow unjustified ignores of error lints for io.EOF
- linters:
- nolintlint
text: "^directive `//nolint:errorlint` should provide explanation"
source: '[=|!]= io.EOF'
linters-settings:
govet:
enable-all: true
disable:
# struct order is often for Win32 compat
# also, ignore pointer bytes/GC issues for now until performance becomes an issue
- fieldalignment
check-shadowing: true
nolintlint:
allow-leading-space: false
require-explanation: true
require-specific: true
revive:
# revive is more configurable than static check, so likely the preferred alternative to static-check
# (once the perf issue is solved: https://github.com/golangci/golangci-lint/issues/2997)
enable-all-rules:
true
# https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md
rules:
# rules with required arguments
- name: argument-limit
disabled: true
- name: banned-characters
disabled: true
- name: cognitive-complexity
disabled: true
- name: cyclomatic
disabled: true
- name: file-header
disabled: true
- name: function-length
disabled: true
- name: function-result-limit
disabled: true
- name: max-public-structs
disabled: true
# geneally annoying rules
- name: add-constant # complains about any and all strings and integers
disabled: true
- name: confusing-naming # we frequently use "Foo()" and "foo()" together
disabled: true
- name: flag-parameter # excessive, and a common idiom we use
disabled: true
# general config
- name: line-length-limit
arguments:
- 140
- name: var-naming
arguments:
- []
- - CID
- CRI
- CTRD
- DACL
- DLL
- DOS
- ETW
- FSCTL
- GCS
- GMSA
- HCS
- HV
- IO
- LCOW
- LDAP
- LPAC
- LTSC
- MMIO
- NT
- OCI
- PMEM
- PWSH
- RX
- SACl
- SID
- SMB
- TX
- VHD
- VHDX
- VMID
- VPCI
- WCOW
- WIM
stylecheck:
checks:
- "all"
- "-ST1003" # use revive's var naming

View File

@ -13,16 +13,60 @@ Please see the LICENSE file for licensing information.
## Contributing ## Contributing
This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) This project welcomes contributions and suggestions.
declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.microsoft.com. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that
you have the right to, and actually do, grant us the rights to use your contribution.
For details, visit [Microsoft CLA](https://cla.microsoft.com).
When you submit a pull request, a CLA-bot will automatically determine whether you need to provide a CLA and decorate the PR When you submit a pull request, a CLA-bot will automatically determine whether you need to
appropriately (e.g., label, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA. provide a CLA and decorate the PR appropriately (e.g., label, comment).
Simply follow the instructions provided by the bot.
You will only need to do this once across all repos using our CLA.
We also require that contributors sign their commits using git commit -s or git commit --signoff to certify they either authored the work themselves Additionally, the pull request pipeline requires the following steps to be performed before
or otherwise have permission to use it in this project. Please see https://developercertificate.org/ for more info, as well as to make sure that you can mergining.
attest to the rules listed. Our CI uses the DCO Github app to ensure that all commits in a given PR are signed-off.
### Code Sign-Off
We require that contributors sign their commits using [`git commit --signoff`][git-commit-s]
to certify they either authored the work themselves or otherwise have permission to use it in this project.
A range of commits can be signed off using [`git rebase --signoff`][git-rebase-s].
Please see [the developer certificate](https://developercertificate.org) for more info,
as well as to make sure that you can attest to the rules listed.
Our CI uses the DCO Github app to ensure that all commits in a given PR are signed-off.
### Linting
Code must pass a linting stage, which uses [`golangci-lint`][lint].
The linting settings are stored in [`.golangci.yaml`](./.golangci.yaml), and can be run
automatically with VSCode by adding the following to your workspace or folder settings:
```json
"go.lintTool": "golangci-lint",
"go.lintOnSave": "package",
```
Additional editor [integrations options are also available][lint-ide].
Alternatively, `golangci-lint` can be [installed locally][lint-install] and run from the repo root:
```shell
# use . or specify a path to only lint a package
# to show all lint errors, use flags "--max-issues-per-linter=0 --max-same-issues=0"
> golangci-lint run ./...
```
### Go Generate
The pipeline checks that auto-generated code, via `go generate`, are up to date.
This can be done for the entire repo:
```shell
> go generate ./...
```
## Code of Conduct ## Code of Conduct
@ -30,8 +74,16 @@ This project has adopted the [Microsoft Open Source Code of Conduct](https://ope
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
## Special Thanks ## Special Thanks
Thanks to natefinch for the inspiration for this library. See https://github.com/natefinch/npipe
for another named pipe implementation. Thanks to [natefinch][natefinch] for the inspiration for this library.
See [npipe](https://github.com/natefinch/npipe) for another named pipe implementation.
[lint]: https://golangci-lint.run/
[lint-ide]: https://golangci-lint.run/usage/integrations/#editor-integration
[lint-install]: https://golangci-lint.run/usage/install/#local-installation
[git-commit-s]: https://git-scm.com/docs/git-commit#Documentation/git-commit.txt--s
[git-rebase-s]: https://git-scm.com/docs/git-rebase#Documentation/git-rebase.txt---signoff
[natefinch]: https://github.com/natefinch

41
vendor/github.com/Microsoft/go-winio/SECURITY.md generated vendored Normal file
View File

@ -0,0 +1,41 @@
<!-- BEGIN MICROSOFT SECURITY.MD V0.0.7 BLOCK -->
## Security
Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below.
## Reporting Security Issues
**Please do not report security vulnerabilities through public GitHub issues.**
Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report).
If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey).
You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc).
Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
* Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
* Full paths of source file(s) related to the manifestation of the issue
* The location of the affected source code (tag/branch/commit or direct URL)
* Any special configuration required to reproduce the issue
* Step-by-step instructions to reproduce the issue
* Proof-of-concept or exploit code (if possible)
* Impact of the issue, including how an attacker might exploit the issue
This information will help us triage your report more quickly.
If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs.
## Preferred Languages
We prefer all communications to be in English.
## Policy
Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd).
<!-- END MICROSOFT SECURITY.MD BLOCK -->

View File

@ -1,3 +1,4 @@
//go:build windows
// +build windows // +build windows
package winio package winio
@ -7,11 +8,12 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"os" "os"
"runtime" "runtime"
"syscall" "syscall"
"unicode/utf16" "unicode/utf16"
"golang.org/x/sys/windows"
) )
//sys backupRead(h syscall.Handle, b []byte, bytesRead *uint32, abort bool, processSecurity bool, context *uintptr) (err error) = BackupRead //sys backupRead(h syscall.Handle, b []byte, bytesRead *uint32, abort bool, processSecurity bool, context *uintptr) (err error) = BackupRead
@ -24,7 +26,7 @@ const (
BackupAlternateData BackupAlternateData
BackupLink BackupLink
BackupPropertyData BackupPropertyData
BackupObjectId BackupObjectId //revive:disable-line:var-naming ID, not Id
BackupReparseData BackupReparseData
BackupSparseBlock BackupSparseBlock
BackupTxfsData BackupTxfsData
@ -34,14 +36,16 @@ const (
StreamSparseAttributes = uint32(8) StreamSparseAttributes = uint32(8)
) )
//nolint:revive // var-naming: ALL_CAPS
const ( const (
WRITE_DAC = 0x40000 WRITE_DAC = windows.WRITE_DAC
WRITE_OWNER = 0x80000 WRITE_OWNER = windows.WRITE_OWNER
ACCESS_SYSTEM_SECURITY = 0x1000000 ACCESS_SYSTEM_SECURITY = windows.ACCESS_SYSTEM_SECURITY
) )
// BackupHeader represents a backup stream of a file. // BackupHeader represents a backup stream of a file.
type BackupHeader struct { type BackupHeader struct {
//revive:disable-next-line:var-naming ID, not Id
Id uint32 // The backup stream ID Id uint32 // The backup stream ID
Attributes uint32 // Stream attributes Attributes uint32 // Stream attributes
Size int64 // The size of the stream in bytes Size int64 // The size of the stream in bytes
@ -49,8 +53,8 @@ type BackupHeader struct {
Offset int64 // The offset of the stream in the file (for BackupSparseBlock only). Offset int64 // The offset of the stream in the file (for BackupSparseBlock only).
} }
type win32StreamId struct { type win32StreamID struct {
StreamId uint32 StreamID uint32
Attributes uint32 Attributes uint32
Size uint64 Size uint64
NameSize uint32 NameSize uint32
@ -71,7 +75,7 @@ func NewBackupStreamReader(r io.Reader) *BackupStreamReader {
// Next returns the next backup stream and prepares for calls to Read(). It skips the remainder of the current stream if // Next returns the next backup stream and prepares for calls to Read(). It skips the remainder of the current stream if
// it was not completely read. // it was not completely read.
func (r *BackupStreamReader) Next() (*BackupHeader, error) { func (r *BackupStreamReader) Next() (*BackupHeader, error) {
if r.bytesLeft > 0 { if r.bytesLeft > 0 { //nolint:nestif // todo: flatten this
if s, ok := r.r.(io.Seeker); ok { if s, ok := r.r.(io.Seeker); ok {
// Make sure Seek on io.SeekCurrent sometimes succeeds // Make sure Seek on io.SeekCurrent sometimes succeeds
// before trying the actual seek. // before trying the actual seek.
@ -82,16 +86,16 @@ func (r *BackupStreamReader) Next() (*BackupHeader, error) {
r.bytesLeft = 0 r.bytesLeft = 0
} }
} }
if _, err := io.Copy(ioutil.Discard, r); err != nil { if _, err := io.Copy(io.Discard, r); err != nil {
return nil, err return nil, err
} }
} }
var wsi win32StreamId var wsi win32StreamID
if err := binary.Read(r.r, binary.LittleEndian, &wsi); err != nil { if err := binary.Read(r.r, binary.LittleEndian, &wsi); err != nil {
return nil, err return nil, err
} }
hdr := &BackupHeader{ hdr := &BackupHeader{
Id: wsi.StreamId, Id: wsi.StreamID,
Attributes: wsi.Attributes, Attributes: wsi.Attributes,
Size: int64(wsi.Size), Size: int64(wsi.Size),
} }
@ -102,7 +106,7 @@ func (r *BackupStreamReader) Next() (*BackupHeader, error) {
} }
hdr.Name = syscall.UTF16ToString(name) hdr.Name = syscall.UTF16ToString(name)
} }
if wsi.StreamId == BackupSparseBlock { if wsi.StreamID == BackupSparseBlock {
if err := binary.Read(r.r, binary.LittleEndian, &hdr.Offset); err != nil { if err := binary.Read(r.r, binary.LittleEndian, &hdr.Offset); err != nil {
return nil, err return nil, err
} }
@ -147,8 +151,8 @@ func (w *BackupStreamWriter) WriteHeader(hdr *BackupHeader) error {
return fmt.Errorf("missing %d bytes", w.bytesLeft) return fmt.Errorf("missing %d bytes", w.bytesLeft)
} }
name := utf16.Encode([]rune(hdr.Name)) name := utf16.Encode([]rune(hdr.Name))
wsi := win32StreamId{ wsi := win32StreamID{
StreamId: hdr.Id, StreamID: hdr.Id,
Attributes: hdr.Attributes, Attributes: hdr.Attributes,
Size: uint64(hdr.Size), Size: uint64(hdr.Size),
NameSize: uint32(len(name) * 2), NameSize: uint32(len(name) * 2),
@ -203,7 +207,7 @@ func (r *BackupFileReader) Read(b []byte) (int, error) {
var bytesRead uint32 var bytesRead uint32
err := backupRead(syscall.Handle(r.f.Fd()), b, &bytesRead, false, r.includeSecurity, &r.ctx) err := backupRead(syscall.Handle(r.f.Fd()), b, &bytesRead, false, r.includeSecurity, &r.ctx)
if err != nil { if err != nil {
return 0, &os.PathError{"BackupRead", r.f.Name(), err} return 0, &os.PathError{Op: "BackupRead", Path: r.f.Name(), Err: err}
} }
runtime.KeepAlive(r.f) runtime.KeepAlive(r.f)
if bytesRead == 0 { if bytesRead == 0 {
@ -216,7 +220,7 @@ func (r *BackupFileReader) Read(b []byte) (int, error) {
// the underlying file. // the underlying file.
func (r *BackupFileReader) Close() error { func (r *BackupFileReader) Close() error {
if r.ctx != 0 { if r.ctx != 0 {
backupRead(syscall.Handle(r.f.Fd()), nil, nil, true, false, &r.ctx) _ = backupRead(syscall.Handle(r.f.Fd()), nil, nil, true, false, &r.ctx)
runtime.KeepAlive(r.f) runtime.KeepAlive(r.f)
r.ctx = 0 r.ctx = 0
} }
@ -242,7 +246,7 @@ func (w *BackupFileWriter) Write(b []byte) (int, error) {
var bytesWritten uint32 var bytesWritten uint32
err := backupWrite(syscall.Handle(w.f.Fd()), b, &bytesWritten, false, w.includeSecurity, &w.ctx) err := backupWrite(syscall.Handle(w.f.Fd()), b, &bytesWritten, false, w.includeSecurity, &w.ctx)
if err != nil { if err != nil {
return 0, &os.PathError{"BackupWrite", w.f.Name(), err} return 0, &os.PathError{Op: "BackupWrite", Path: w.f.Name(), Err: err}
} }
runtime.KeepAlive(w.f) runtime.KeepAlive(w.f)
if int(bytesWritten) != len(b) { if int(bytesWritten) != len(b) {
@ -255,7 +259,7 @@ func (w *BackupFileWriter) Write(b []byte) (int, error) {
// close the underlying file. // close the underlying file.
func (w *BackupFileWriter) Close() error { func (w *BackupFileWriter) Close() error {
if w.ctx != 0 { if w.ctx != 0 {
backupWrite(syscall.Handle(w.f.Fd()), nil, nil, true, false, &w.ctx) _ = backupWrite(syscall.Handle(w.f.Fd()), nil, nil, true, false, &w.ctx)
runtime.KeepAlive(w.f) runtime.KeepAlive(w.f)
w.ctx = 0 w.ctx = 0
} }
@ -271,7 +275,13 @@ func OpenForBackup(path string, access uint32, share uint32, createmode uint32)
if err != nil { if err != nil {
return nil, err return nil, err
} }
h, err := syscall.CreateFile(&winPath[0], access, share, nil, createmode, syscall.FILE_FLAG_BACKUP_SEMANTICS|syscall.FILE_FLAG_OPEN_REPARSE_POINT, 0) h, err := syscall.CreateFile(&winPath[0],
access,
share,
nil,
createmode,
syscall.FILE_FLAG_BACKUP_SEMANTICS|syscall.FILE_FLAG_OPEN_REPARSE_POINT,
0)
if err != nil { if err != nil {
err = &os.PathError{Op: "open", Path: path, Err: err} err = &os.PathError{Op: "open", Path: path, Err: err}
return nil, err return nil, err

View File

@ -1,4 +1,3 @@
// +build !windows
// This file only exists to allow go get on non-Windows platforms. // This file only exists to allow go get on non-Windows platforms.
package backuptar package backuptar

View File

@ -1,3 +1,5 @@
//go:build windows
package backuptar package backuptar
import ( import (

View File

@ -1,3 +1,4 @@
//go:build windows
// +build windows // +build windows
package backuptar package backuptar
@ -7,7 +8,6 @@ import (
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"path/filepath" "path/filepath"
"strconv" "strconv"
"strings" "strings"
@ -18,17 +18,18 @@ import (
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
) )
//nolint:deadcode,varcheck // keep unused constants for potential future use
const ( const (
c_ISUID = 04000 // Set uid cISUID = 0004000 // Set uid
c_ISGID = 02000 // Set gid cISGID = 0002000 // Set gid
c_ISVTX = 01000 // Save text (sticky bit) cISVTX = 0001000 // Save text (sticky bit)
c_ISDIR = 040000 // Directory cISDIR = 0040000 // Directory
c_ISFIFO = 010000 // FIFO cISFIFO = 0010000 // FIFO
c_ISREG = 0100000 // Regular file cISREG = 0100000 // Regular file
c_ISLNK = 0120000 // Symbolic link cISLNK = 0120000 // Symbolic link
c_ISBLK = 060000 // Block special file cISBLK = 0060000 // Block special file
c_ISCHR = 020000 // Character special file cISCHR = 0020000 // Character special file
c_ISSOCK = 0140000 // Socket cISSOCK = 0140000 // Socket
) )
const ( const (
@ -44,7 +45,7 @@ const (
// zeroReader is an io.Reader that always returns 0s. // zeroReader is an io.Reader that always returns 0s.
type zeroReader struct{} type zeroReader struct{}
func (zr zeroReader) Read(b []byte) (int, error) { func (zeroReader) Read(b []byte) (int, error) {
for i := range b { for i := range b {
b[i] = 0 b[i] = 0
} }
@ -55,7 +56,7 @@ func copySparse(t *tar.Writer, br *winio.BackupStreamReader) error {
curOffset := int64(0) curOffset := int64(0)
for { for {
bhdr, err := br.Next() bhdr, err := br.Next()
if err == io.EOF { if err == io.EOF { //nolint:errorlint
err = io.ErrUnexpectedEOF err = io.ErrUnexpectedEOF
} }
if err != nil { if err != nil {
@ -71,8 +72,8 @@ func copySparse(t *tar.Writer, br *winio.BackupStreamReader) error {
} }
// archive/tar does not support writing sparse files // archive/tar does not support writing sparse files
// so just write zeroes to catch up to the current offset. // so just write zeroes to catch up to the current offset.
if _, err := io.CopyN(t, zeroReader{}, bhdr.Offset-curOffset); err != nil { if _, err = io.CopyN(t, zeroReader{}, bhdr.Offset-curOffset); err != nil {
return fmt.Errorf("seek to offset %d: %s", bhdr.Offset, err) return fmt.Errorf("seek to offset %d: %w", bhdr.Offset, err)
} }
if bhdr.Size == 0 { if bhdr.Size == 0 {
// A sparse block with size = 0 is used to mark the end of the sparse blocks. // A sparse block with size = 0 is used to mark the end of the sparse blocks.
@ -106,7 +107,7 @@ func BasicInfoHeader(name string, size int64, fileInfo *winio.FileBasicInfo) *ta
hdr.PAXRecords[hdrCreationTime] = formatPAXTime(time.Unix(0, fileInfo.CreationTime.Nanoseconds())) hdr.PAXRecords[hdrCreationTime] = formatPAXTime(time.Unix(0, fileInfo.CreationTime.Nanoseconds()))
if (fileInfo.FileAttributes & syscall.FILE_ATTRIBUTE_DIRECTORY) != 0 { if (fileInfo.FileAttributes & syscall.FILE_ATTRIBUTE_DIRECTORY) != 0 {
hdr.Mode |= c_ISDIR hdr.Mode |= cISDIR
hdr.Size = 0 hdr.Size = 0
hdr.Typeflag = tar.TypeDir hdr.Typeflag = tar.TypeDir
} }
@ -116,32 +117,29 @@ func BasicInfoHeader(name string, size int64, fileInfo *winio.FileBasicInfo) *ta
// SecurityDescriptorFromTarHeader reads the SDDL associated with the header of the current file // SecurityDescriptorFromTarHeader reads the SDDL associated with the header of the current file
// from the tar header and returns the security descriptor into a byte slice. // from the tar header and returns the security descriptor into a byte slice.
func SecurityDescriptorFromTarHeader(hdr *tar.Header) ([]byte, error) { func SecurityDescriptorFromTarHeader(hdr *tar.Header) ([]byte, error) {
// Maintaining old SDDL-based behavior for backward
// compatibility. All new tar headers written by this library
// will have raw binary for the security descriptor.
var sd []byte
var err error
if sddl, ok := hdr.PAXRecords[hdrSecurityDescriptor]; ok {
sd, err = winio.SddlToSecurityDescriptor(sddl)
if err != nil {
return nil, err
}
}
if sdraw, ok := hdr.PAXRecords[hdrRawSecurityDescriptor]; ok { if sdraw, ok := hdr.PAXRecords[hdrRawSecurityDescriptor]; ok {
sd, err = base64.StdEncoding.DecodeString(sdraw) sd, err := base64.StdEncoding.DecodeString(sdraw)
if err != nil { if err != nil {
// Not returning sd as-is in the error-case, as base64.DecodeString
// may return partially decoded data (not nil or empty slice) in case
// of a failure: https://github.com/golang/go/blob/go1.17.7/src/encoding/base64/base64.go#L382-L387
return nil, err return nil, err
} }
}
return sd, nil return sd, nil
} }
// Maintaining old SDDL-based behavior for backward compatibility. All new
// tar headers written by this library will have raw binary for the security
// descriptor.
if sddl, ok := hdr.PAXRecords[hdrSecurityDescriptor]; ok {
return winio.SddlToSecurityDescriptor(sddl)
}
return nil, nil
}
// ExtendedAttributesFromTarHeader reads the EAs associated with the header of the // ExtendedAttributesFromTarHeader reads the EAs associated with the header of the
// current file from the tar header and returns it as a byte slice. // current file from the tar header and returns it as a byte slice.
func ExtendedAttributesFromTarHeader(hdr *tar.Header) ([]byte, error) { func ExtendedAttributesFromTarHeader(hdr *tar.Header) ([]byte, error) {
var eas []winio.ExtendedAttribute var eas []winio.ExtendedAttribute //nolint:prealloc // len(eas) <= len(hdr.PAXRecords); prealloc is wasteful
var eadata []byte
var err error
for k, v := range hdr.PAXRecords { for k, v := range hdr.PAXRecords {
if !strings.HasPrefix(k, hdrEaPrefix) { if !strings.HasPrefix(k, hdrEaPrefix) {
continue continue
@ -155,13 +153,15 @@ func ExtendedAttributesFromTarHeader(hdr *tar.Header) ([]byte, error) {
Value: data, Value: data,
}) })
} }
var eaData []byte
var err error
if len(eas) != 0 { if len(eas) != 0 {
eadata, err = winio.EncodeExtendedAttributes(eas) eaData, err = winio.EncodeExtendedAttributes(eas)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
return eadata, nil return eaData, nil
} }
// EncodeReparsePointFromTarHeader reads the ReparsePoint structure from the tar header // EncodeReparsePointFromTarHeader reads the ReparsePoint structure from the tar header
@ -182,11 +182,9 @@ func EncodeReparsePointFromTarHeader(hdr *tar.Header) []byte {
// //
// The additional Win32 metadata is: // The additional Win32 metadata is:
// //
// MSWINDOWS.fileattr: The Win32 file attributes, as a decimal value // - MSWINDOWS.fileattr: The Win32 file attributes, as a decimal value
// // - MSWINDOWS.rawsd: The Win32 security descriptor, in raw binary format
// MSWINDOWS.rawsd: The Win32 security descriptor, in raw binary format // - MSWINDOWS.mountpoint: If present, this is a mount point and not a symlink, even though the type is '2' (symlink)
//
// MSWINDOWS.mountpoint: If present, this is a mount point and not a symlink, even though the type is '2' (symlink)
func WriteTarFileFromBackupStream(t *tar.Writer, r io.Reader, name string, size int64, fileInfo *winio.FileBasicInfo) error { func WriteTarFileFromBackupStream(t *tar.Writer, r io.Reader, name string, size int64, fileInfo *winio.FileBasicInfo) error {
name = filepath.ToSlash(name) name = filepath.ToSlash(name)
hdr := BasicInfoHeader(name, size, fileInfo) hdr := BasicInfoHeader(name, size, fileInfo)
@ -209,7 +207,7 @@ func WriteTarFileFromBackupStream(t *tar.Writer, r io.Reader, name string, size
var dataHdr *winio.BackupHeader var dataHdr *winio.BackupHeader
for dataHdr == nil { for dataHdr == nil {
bhdr, err := br.Next() bhdr, err := br.Next()
if err == io.EOF { if err == io.EOF { //nolint:errorlint
break break
} }
if err != nil { if err != nil {
@ -217,21 +215,21 @@ func WriteTarFileFromBackupStream(t *tar.Writer, r io.Reader, name string, size
} }
switch bhdr.Id { switch bhdr.Id {
case winio.BackupData: case winio.BackupData:
hdr.Mode |= c_ISREG hdr.Mode |= cISREG
if !readTwice { if !readTwice {
dataHdr = bhdr dataHdr = bhdr
} }
case winio.BackupSecurity: case winio.BackupSecurity:
sd, err := ioutil.ReadAll(br) sd, err := io.ReadAll(br)
if err != nil { if err != nil {
return err return err
} }
hdr.PAXRecords[hdrRawSecurityDescriptor] = base64.StdEncoding.EncodeToString(sd) hdr.PAXRecords[hdrRawSecurityDescriptor] = base64.StdEncoding.EncodeToString(sd)
case winio.BackupReparseData: case winio.BackupReparseData:
hdr.Mode |= c_ISLNK hdr.Mode |= cISLNK
hdr.Typeflag = tar.TypeSymlink hdr.Typeflag = tar.TypeSymlink
reparseBuffer, err := ioutil.ReadAll(br) reparseBuffer, _ := io.ReadAll(br)
rp, err := winio.DecodeReparsePoint(reparseBuffer) rp, err := winio.DecodeReparsePoint(reparseBuffer)
if err != nil { if err != nil {
return err return err
@ -242,7 +240,7 @@ func WriteTarFileFromBackupStream(t *tar.Writer, r io.Reader, name string, size
hdr.Linkname = rp.Target hdr.Linkname = rp.Target
case winio.BackupEaData: case winio.BackupEaData:
eab, err := ioutil.ReadAll(br) eab, err := io.ReadAll(br)
if err != nil { if err != nil {
return err return err
} }
@ -276,7 +274,7 @@ func WriteTarFileFromBackupStream(t *tar.Writer, r io.Reader, name string, size
} }
for dataHdr == nil { for dataHdr == nil {
bhdr, err := br.Next() bhdr, err := br.Next()
if err == io.EOF { if err == io.EOF { //nolint:errorlint
break break
} }
if err != nil { if err != nil {
@ -311,7 +309,7 @@ func WriteTarFileFromBackupStream(t *tar.Writer, r io.Reader, name string, size
// range of the file containing the range contents. Finally there is a sparse block stream with // range of the file containing the range contents. Finally there is a sparse block stream with
// size = 0 and offset = <file size>. // size = 0 and offset = <file size>.
if dataHdr != nil { if dataHdr != nil { //nolint:nestif // todo: reduce nesting complexity
// A data stream was found. Copy the data. // A data stream was found. Copy the data.
// We assume that we will either have a data stream size > 0 XOR have sparse block streams. // We assume that we will either have a data stream size > 0 XOR have sparse block streams.
if dataHdr.Size > 0 || (dataHdr.Attributes&winio.StreamSparseAttributes) == 0 { if dataHdr.Size > 0 || (dataHdr.Attributes&winio.StreamSparseAttributes) == 0 {
@ -319,13 +317,13 @@ func WriteTarFileFromBackupStream(t *tar.Writer, r io.Reader, name string, size
return fmt.Errorf("%s: mismatch between file size %d and header size %d", name, size, dataHdr.Size) return fmt.Errorf("%s: mismatch between file size %d and header size %d", name, size, dataHdr.Size)
} }
if _, err = io.Copy(t, br); err != nil { if _, err = io.Copy(t, br); err != nil {
return fmt.Errorf("%s: copying contents from data stream: %s", name, err) return fmt.Errorf("%s: copying contents from data stream: %w", name, err)
} }
} else if size > 0 { } else if size > 0 {
// As of a recent OS change, BackupRead now returns a data stream for empty sparse files. // As of a recent OS change, BackupRead now returns a data stream for empty sparse files.
// These files have no sparse block streams, so skip the copySparse call if file size = 0. // These files have no sparse block streams, so skip the copySparse call if file size = 0.
if err = copySparse(t, br); err != nil { if err = copySparse(t, br); err != nil {
return fmt.Errorf("%s: copying contents from sparse block stream: %s", name, err) return fmt.Errorf("%s: copying contents from sparse block stream: %w", name, err)
} }
} }
} }
@ -335,7 +333,7 @@ func WriteTarFileFromBackupStream(t *tar.Writer, r io.Reader, name string, size
// been written. In practice, this means that we don't get EA or TXF metadata. // been written. In practice, this means that we don't get EA or TXF metadata.
for { for {
bhdr, err := br.Next() bhdr, err := br.Next()
if err == io.EOF { if err == io.EOF { //nolint:errorlint
break break
} }
if err != nil { if err != nil {
@ -343,11 +341,12 @@ func WriteTarFileFromBackupStream(t *tar.Writer, r io.Reader, name string, size
} }
switch bhdr.Id { switch bhdr.Id {
case winio.BackupAlternateData: case winio.BackupAlternateData:
altName := bhdr.Name if (bhdr.Attributes & winio.StreamSparseAttributes) != 0 {
if strings.HasSuffix(altName, ":$DATA") { // Unsupported for now, since the size of the alternate stream is not present
altName = altName[:len(altName)-len(":$DATA")] // in the backup stream until after the data has been read.
return fmt.Errorf("%s: tar of sparse alternate data streams is unsupported", name)
} }
if (bhdr.Attributes & winio.StreamSparseAttributes) == 0 { altName := strings.TrimSuffix(bhdr.Name, ":$DATA")
hdr = &tar.Header{ hdr = &tar.Header{
Format: hdr.Format, Format: hdr.Format,
Name: name + altName, Name: name + altName,
@ -366,12 +365,6 @@ func WriteTarFileFromBackupStream(t *tar.Writer, r io.Reader, name string, size
if err != nil { if err != nil {
return err return err
} }
} else {
// Unsupported for now, since the size of the alternate stream is not present
// in the backup stream until after the data has been read.
return fmt.Errorf("%s: tar of sparse alternate data streams is unsupported", name)
}
case winio.BackupEaData, winio.BackupLink, winio.BackupPropertyData, winio.BackupObjectId, winio.BackupTxfsData: case winio.BackupEaData, winio.BackupLink, winio.BackupPropertyData, winio.BackupObjectId, winio.BackupTxfsData:
// ignore these streams // ignore these streams
default: default:
@ -413,7 +406,7 @@ func FileInfoFromHeader(hdr *tar.Header) (name string, size int64, fileInfo *win
} }
fileInfo.CreationTime = windows.NsecToFiletime(creationTime.UnixNano()) fileInfo.CreationTime = windows.NsecToFiletime(creationTime.UnixNano())
} }
return return name, size, fileInfo, err
} }
// WriteBackupStreamFromTarFile writes a Win32 backup stream from the current tar file. Since this function may process multiple // WriteBackupStreamFromTarFile writes a Win32 backup stream from the current tar file. Since this function may process multiple
@ -474,7 +467,6 @@ func WriteBackupStreamFromTarFile(w io.Writer, t *tar.Reader, hdr *tar.Header) (
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
if hdr.Typeflag == tar.TypeReg || hdr.Typeflag == tar.TypeRegA { if hdr.Typeflag == tar.TypeReg || hdr.Typeflag == tar.TypeRegA {

22
vendor/github.com/Microsoft/go-winio/doc.go generated vendored Normal file
View File

@ -0,0 +1,22 @@
// This package provides utilities for efficiently performing Win32 IO operations in Go.
// Currently, this package is provides support for genreal IO and management of
// - named pipes
// - files
// - [Hyper-V sockets]
//
// This code is similar to Go's [net] package, and uses IO completion ports to avoid
// blocking IO on system threads, allowing Go to reuse the thread to schedule other goroutines.
//
// This limits support to Windows Vista and newer operating systems.
//
// Additionally, this package provides support for:
// - creating and managing GUIDs
// - writing to [ETW]
// - opening and manageing VHDs
// - parsing [Windows Image files]
// - auto-generating Win32 API code
//
// [Hyper-V sockets]: https://docs.microsoft.com/en-us/virtualization/hyper-v-on-windows/user-guide/make-integration-service
// [ETW]: https://docs.microsoft.com/en-us/windows-hardware/drivers/devtest/event-tracing-for-windows--etw-
// [Windows Image files]: https://docs.microsoft.com/en-us/windows-hardware/manufacture/desktop/work-with-windows-images
package winio

View File

@ -33,7 +33,7 @@ func parseEa(b []byte) (ea ExtendedAttribute, nb []byte, err error) {
err = binary.Read(bytes.NewReader(b), binary.LittleEndian, &info) err = binary.Read(bytes.NewReader(b), binary.LittleEndian, &info)
if err != nil { if err != nil {
err = errInvalidEaBuffer err = errInvalidEaBuffer
return return ea, nb, err
} }
nameOffset := fileFullEaInformationSize nameOffset := fileFullEaInformationSize
@ -43,7 +43,7 @@ func parseEa(b []byte) (ea ExtendedAttribute, nb []byte, err error) {
nextOffset := int(info.NextEntryOffset) nextOffset := int(info.NextEntryOffset)
if valueLen+valueOffset > len(b) || nextOffset < 0 || nextOffset > len(b) { if valueLen+valueOffset > len(b) || nextOffset < 0 || nextOffset > len(b) {
err = errInvalidEaBuffer err = errInvalidEaBuffer
return return ea, nb, err
} }
ea.Name = string(b[nameOffset : nameOffset+nameLen]) ea.Name = string(b[nameOffset : nameOffset+nameLen])
@ -52,7 +52,7 @@ func parseEa(b []byte) (ea ExtendedAttribute, nb []byte, err error) {
if info.NextEntryOffset != 0 { if info.NextEntryOffset != 0 {
nb = b[info.NextEntryOffset:] nb = b[info.NextEntryOffset:]
} }
return return ea, nb, err
} }
// DecodeExtendedAttributes decodes a list of EAs from a FILE_FULL_EA_INFORMATION // DecodeExtendedAttributes decodes a list of EAs from a FILE_FULL_EA_INFORMATION
@ -67,7 +67,7 @@ func DecodeExtendedAttributes(b []byte) (eas []ExtendedAttribute, err error) {
eas = append(eas, ea) eas = append(eas, ea)
b = nb b = nb
} }
return return eas, err
} }
func writeEa(buf *bytes.Buffer, ea *ExtendedAttribute, last bool) error { func writeEa(buf *bytes.Buffer, ea *ExtendedAttribute, last bool) error {

View File

@ -11,6 +11,8 @@ import (
"sync/atomic" "sync/atomic"
"syscall" "syscall"
"time" "time"
"golang.org/x/sys/windows"
) )
//sys cancelIoEx(file syscall.Handle, o *syscall.Overlapped) (err error) = CancelIoEx //sys cancelIoEx(file syscall.Handle, o *syscall.Overlapped) (err error) = CancelIoEx
@ -24,6 +26,8 @@ type atomicBool int32
func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 } func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 }
func (b *atomicBool) setFalse() { atomic.StoreInt32((*int32)(b), 0) } func (b *atomicBool) setFalse() { atomic.StoreInt32((*int32)(b), 0) }
func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) } func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) }
//revive:disable-next-line:predeclared Keep "new" to maintain consistency with "atomic" pkg
func (b *atomicBool) swap(new bool) bool { func (b *atomicBool) swap(new bool) bool {
var newInt int32 var newInt int32
if new { if new {
@ -32,11 +36,6 @@ func (b *atomicBool) swap(new bool) bool {
return atomic.SwapInt32((*int32)(b), newInt) == 1 return atomic.SwapInt32((*int32)(b), newInt) == 1
} }
const (
cFILE_SKIP_COMPLETION_PORT_ON_SUCCESS = 1
cFILE_SKIP_SET_EVENT_ON_HANDLE = 2
)
var ( var (
ErrFileClosed = errors.New("file has already been closed") ErrFileClosed = errors.New("file has already been closed")
ErrTimeout = &timeoutError{} ErrTimeout = &timeoutError{}
@ -44,28 +43,28 @@ var (
type timeoutError struct{} type timeoutError struct{}
func (e *timeoutError) Error() string { return "i/o timeout" } func (*timeoutError) Error() string { return "i/o timeout" }
func (e *timeoutError) Timeout() bool { return true } func (*timeoutError) Timeout() bool { return true }
func (e *timeoutError) Temporary() bool { return true } func (*timeoutError) Temporary() bool { return true }
type timeoutChan chan struct{} type timeoutChan chan struct{}
var ioInitOnce sync.Once var ioInitOnce sync.Once
var ioCompletionPort syscall.Handle var ioCompletionPort syscall.Handle
// ioResult contains the result of an asynchronous IO operation // ioResult contains the result of an asynchronous IO operation.
type ioResult struct { type ioResult struct {
bytes uint32 bytes uint32
err error err error
} }
// ioOperation represents an outstanding asynchronous Win32 IO // ioOperation represents an outstanding asynchronous Win32 IO.
type ioOperation struct { type ioOperation struct {
o syscall.Overlapped o syscall.Overlapped
ch chan ioResult ch chan ioResult
} }
func initIo() { func initIO() {
h, err := createIoCompletionPort(syscall.InvalidHandle, 0, 0, 0xffffffff) h, err := createIoCompletionPort(syscall.InvalidHandle, 0, 0, 0xffffffff)
if err != nil { if err != nil {
panic(err) panic(err)
@ -94,15 +93,15 @@ type deadlineHandler struct {
timedout atomicBool timedout atomicBool
} }
// makeWin32File makes a new win32File from an existing file handle // makeWin32File makes a new win32File from an existing file handle.
func makeWin32File(h syscall.Handle) (*win32File, error) { func makeWin32File(h syscall.Handle) (*win32File, error) {
f := &win32File{handle: h} f := &win32File{handle: h}
ioInitOnce.Do(initIo) ioInitOnce.Do(initIO)
_, err := createIoCompletionPort(h, ioCompletionPort, 0, 0xffffffff) _, err := createIoCompletionPort(h, ioCompletionPort, 0, 0xffffffff)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = setFileCompletionNotificationModes(h, cFILE_SKIP_COMPLETION_PORT_ON_SUCCESS|cFILE_SKIP_SET_EVENT_ON_HANDLE) err = setFileCompletionNotificationModes(h, windows.FILE_SKIP_COMPLETION_PORT_ON_SUCCESS|windows.FILE_SKIP_SET_EVENT_ON_HANDLE)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -121,14 +120,14 @@ func MakeOpenFile(h syscall.Handle) (io.ReadWriteCloser, error) {
return f, nil return f, nil
} }
// closeHandle closes the resources associated with a Win32 handle // closeHandle closes the resources associated with a Win32 handle.
func (f *win32File) closeHandle() { func (f *win32File) closeHandle() {
f.wgLock.Lock() f.wgLock.Lock()
// Atomically set that we are closing, releasing the resources only once. // Atomically set that we are closing, releasing the resources only once.
if !f.closing.swap(true) { if !f.closing.swap(true) {
f.wgLock.Unlock() f.wgLock.Unlock()
// cancel all IO and wait for it to complete // cancel all IO and wait for it to complete
cancelIoEx(f.handle, nil) _ = cancelIoEx(f.handle, nil)
f.wg.Wait() f.wg.Wait()
// at this point, no new IO can start // at this point, no new IO can start
syscall.Close(f.handle) syscall.Close(f.handle)
@ -144,14 +143,14 @@ func (f *win32File) Close() error {
return nil return nil
} }
// IsClosed checks if the file has been closed // IsClosed checks if the file has been closed.
func (f *win32File) IsClosed() bool { func (f *win32File) IsClosed() bool {
return f.closing.isSet() return f.closing.isSet()
} }
// prepareIo prepares for a new IO operation. // prepareIO prepares for a new IO operation.
// The caller must call f.wg.Done() when the IO is finished, prior to Close() returning. // The caller must call f.wg.Done() when the IO is finished, prior to Close() returning.
func (f *win32File) prepareIo() (*ioOperation, error) { func (f *win32File) prepareIO() (*ioOperation, error) {
f.wgLock.RLock() f.wgLock.RLock()
if f.closing.isSet() { if f.closing.isSet() {
f.wgLock.RUnlock() f.wgLock.RUnlock()
@ -164,7 +163,7 @@ func (f *win32File) prepareIo() (*ioOperation, error) {
return c, nil return c, nil
} }
// ioCompletionProcessor processes completed async IOs forever // ioCompletionProcessor processes completed async IOs forever.
func ioCompletionProcessor(h syscall.Handle) { func ioCompletionProcessor(h syscall.Handle) {
for { for {
var bytes uint32 var bytes uint32
@ -178,15 +177,17 @@ func ioCompletionProcessor(h syscall.Handle) {
} }
} }
// asyncIo processes the return value from ReadFile or WriteFile, blocking until // todo: helsaawy - create an asyncIO version that takes a context
// asyncIO processes the return value from ReadFile or WriteFile, blocking until
// the operation has actually completed. // the operation has actually completed.
func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) { func (f *win32File) asyncIO(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) {
if err != syscall.ERROR_IO_PENDING { if err != syscall.ERROR_IO_PENDING { //nolint:errorlint // err is Errno
return int(bytes), err return int(bytes), err
} }
if f.closing.isSet() { if f.closing.isSet() {
cancelIoEx(f.handle, &c.o) _ = cancelIoEx(f.handle, &c.o)
} }
var timeout timeoutChan var timeout timeoutChan
@ -200,7 +201,7 @@ func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, er
select { select {
case r = <-c.ch: case r = <-c.ch:
err = r.err err = r.err
if err == syscall.ERROR_OPERATION_ABORTED { if err == syscall.ERROR_OPERATION_ABORTED { //nolint:errorlint // err is Errno
if f.closing.isSet() { if f.closing.isSet() {
err = ErrFileClosed err = ErrFileClosed
} }
@ -210,10 +211,10 @@ func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, er
err = wsaGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags) err = wsaGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags)
} }
case <-timeout: case <-timeout:
cancelIoEx(f.handle, &c.o) _ = cancelIoEx(f.handle, &c.o)
r = <-c.ch r = <-c.ch
err = r.err err = r.err
if err == syscall.ERROR_OPERATION_ABORTED { if err == syscall.ERROR_OPERATION_ABORTED { //nolint:errorlint // err is Errno
err = ErrTimeout err = ErrTimeout
} }
} }
@ -221,13 +222,14 @@ func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, er
// runtime.KeepAlive is needed, as c is passed via native // runtime.KeepAlive is needed, as c is passed via native
// code to ioCompletionProcessor, c must remain alive // code to ioCompletionProcessor, c must remain alive
// until the channel read is complete. // until the channel read is complete.
// todo: (de)allocate *ioOperation via win32 heap functions, instead of needing to KeepAlive?
runtime.KeepAlive(c) runtime.KeepAlive(c)
return int(r.bytes), err return int(r.bytes), err
} }
// Read reads from a file handle. // Read reads from a file handle.
func (f *win32File) Read(b []byte) (int, error) { func (f *win32File) Read(b []byte) (int, error) {
c, err := f.prepareIo() c, err := f.prepareIO()
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -239,13 +241,13 @@ func (f *win32File) Read(b []byte) (int, error) {
var bytes uint32 var bytes uint32
err = syscall.ReadFile(f.handle, b, &bytes, &c.o) err = syscall.ReadFile(f.handle, b, &bytes, &c.o)
n, err := f.asyncIo(c, &f.readDeadline, bytes, err) n, err := f.asyncIO(c, &f.readDeadline, bytes, err)
runtime.KeepAlive(b) runtime.KeepAlive(b)
// Handle EOF conditions. // Handle EOF conditions.
if err == nil && n == 0 && len(b) != 0 { if err == nil && n == 0 && len(b) != 0 {
return 0, io.EOF return 0, io.EOF
} else if err == syscall.ERROR_BROKEN_PIPE { } else if err == syscall.ERROR_BROKEN_PIPE { //nolint:errorlint // err is Errno
return 0, io.EOF return 0, io.EOF
} else { } else {
return n, err return n, err
@ -254,7 +256,7 @@ func (f *win32File) Read(b []byte) (int, error) {
// Write writes to a file handle. // Write writes to a file handle.
func (f *win32File) Write(b []byte) (int, error) { func (f *win32File) Write(b []byte) (int, error) {
c, err := f.prepareIo() c, err := f.prepareIO()
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -266,7 +268,7 @@ func (f *win32File) Write(b []byte) (int, error) {
var bytes uint32 var bytes uint32
err = syscall.WriteFile(f.handle, b, &bytes, &c.o) err = syscall.WriteFile(f.handle, b, &bytes, &c.o)
n, err := f.asyncIo(c, &f.writeDeadline, bytes, err) n, err := f.asyncIO(c, &f.writeDeadline, bytes, err)
runtime.KeepAlive(b) runtime.KeepAlive(b)
return n, err return n, err
} }

View File

@ -1,3 +1,4 @@
//go:build windows
// +build windows // +build windows
package winio package winio
@ -14,13 +15,18 @@ import (
type FileBasicInfo struct { type FileBasicInfo struct {
CreationTime, LastAccessTime, LastWriteTime, ChangeTime windows.Filetime CreationTime, LastAccessTime, LastWriteTime, ChangeTime windows.Filetime
FileAttributes uint32 FileAttributes uint32
pad uint32 // padding _ uint32 // padding
} }
// GetFileBasicInfo retrieves times and attributes for a file. // GetFileBasicInfo retrieves times and attributes for a file.
func GetFileBasicInfo(f *os.File) (*FileBasicInfo, error) { func GetFileBasicInfo(f *os.File) (*FileBasicInfo, error) {
bi := &FileBasicInfo{} bi := &FileBasicInfo{}
if err := windows.GetFileInformationByHandleEx(windows.Handle(f.Fd()), windows.FileBasicInfo, (*byte)(unsafe.Pointer(bi)), uint32(unsafe.Sizeof(*bi))); err != nil { if err := windows.GetFileInformationByHandleEx(
windows.Handle(f.Fd()),
windows.FileBasicInfo,
(*byte)(unsafe.Pointer(bi)),
uint32(unsafe.Sizeof(*bi)),
); err != nil {
return nil, &os.PathError{Op: "GetFileInformationByHandleEx", Path: f.Name(), Err: err} return nil, &os.PathError{Op: "GetFileInformationByHandleEx", Path: f.Name(), Err: err}
} }
runtime.KeepAlive(f) runtime.KeepAlive(f)
@ -29,7 +35,12 @@ func GetFileBasicInfo(f *os.File) (*FileBasicInfo, error) {
// SetFileBasicInfo sets times and attributes for a file. // SetFileBasicInfo sets times and attributes for a file.
func SetFileBasicInfo(f *os.File, bi *FileBasicInfo) error { func SetFileBasicInfo(f *os.File, bi *FileBasicInfo) error {
if err := windows.SetFileInformationByHandle(windows.Handle(f.Fd()), windows.FileBasicInfo, (*byte)(unsafe.Pointer(bi)), uint32(unsafe.Sizeof(*bi))); err != nil { if err := windows.SetFileInformationByHandle(
windows.Handle(f.Fd()),
windows.FileBasicInfo,
(*byte)(unsafe.Pointer(bi)),
uint32(unsafe.Sizeof(*bi)),
); err != nil {
return &os.PathError{Op: "SetFileInformationByHandle", Path: f.Name(), Err: err} return &os.PathError{Op: "SetFileInformationByHandle", Path: f.Name(), Err: err}
} }
runtime.KeepAlive(f) runtime.KeepAlive(f)
@ -48,7 +59,10 @@ type FileStandardInfo struct {
// GetFileStandardInfo retrieves ended information for the file. // GetFileStandardInfo retrieves ended information for the file.
func GetFileStandardInfo(f *os.File) (*FileStandardInfo, error) { func GetFileStandardInfo(f *os.File) (*FileStandardInfo, error) {
si := &FileStandardInfo{} si := &FileStandardInfo{}
if err := windows.GetFileInformationByHandleEx(windows.Handle(f.Fd()), windows.FileStandardInfo, (*byte)(unsafe.Pointer(si)), uint32(unsafe.Sizeof(*si))); err != nil { if err := windows.GetFileInformationByHandleEx(windows.Handle(f.Fd()),
windows.FileStandardInfo,
(*byte)(unsafe.Pointer(si)),
uint32(unsafe.Sizeof(*si))); err != nil {
return nil, &os.PathError{Op: "GetFileInformationByHandleEx", Path: f.Name(), Err: err} return nil, &os.PathError{Op: "GetFileInformationByHandleEx", Path: f.Name(), Err: err}
} }
runtime.KeepAlive(f) runtime.KeepAlive(f)
@ -65,7 +79,12 @@ type FileIDInfo struct {
// GetFileID retrieves the unique (volume, file ID) pair for a file. // GetFileID retrieves the unique (volume, file ID) pair for a file.
func GetFileID(f *os.File) (*FileIDInfo, error) { func GetFileID(f *os.File) (*FileIDInfo, error) {
fileID := &FileIDInfo{} fileID := &FileIDInfo{}
if err := windows.GetFileInformationByHandleEx(windows.Handle(f.Fd()), windows.FileIdInfo, (*byte)(unsafe.Pointer(fileID)), uint32(unsafe.Sizeof(*fileID))); err != nil { if err := windows.GetFileInformationByHandleEx(
windows.Handle(f.Fd()),
windows.FileIdInfo,
(*byte)(unsafe.Pointer(fileID)),
uint32(unsafe.Sizeof(*fileID)),
); err != nil {
return nil, &os.PathError{Op: "GetFileInformationByHandleEx", Path: f.Name(), Err: err} return nil, &os.PathError{Op: "GetFileInformationByHandleEx", Path: f.Name(), Err: err}
} }
runtime.KeepAlive(f) runtime.KeepAlive(f)

View File

@ -4,6 +4,8 @@
package winio package winio
import ( import (
"context"
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -12,16 +14,87 @@ import (
"time" "time"
"unsafe" "unsafe"
"golang.org/x/sys/windows"
"github.com/Microsoft/go-winio/internal/socket"
"github.com/Microsoft/go-winio/pkg/guid" "github.com/Microsoft/go-winio/pkg/guid"
) )
//sys bind(s syscall.Handle, name unsafe.Pointer, namelen int32) (err error) [failretval==socketError] = ws2_32.bind const afHVSock = 34 // AF_HYPERV
const ( // Well known Service and VM IDs
afHvSock = 34 // AF_HYPERV //https://docs.microsoft.com/en-us/virtualization/hyper-v-on-windows/user-guide/make-integration-service#vmid-wildcards
socketError = ^uintptr(0) // HvsockGUIDWildcard is the wildcard VmId for accepting connections from all partitions.
) func HvsockGUIDWildcard() guid.GUID { // 00000000-0000-0000-0000-000000000000
return guid.GUID{}
}
// HvsockGUIDBroadcast is the wildcard VmId for broadcasting sends to all partitions.
func HvsockGUIDBroadcast() guid.GUID { //ffffffff-ffff-ffff-ffff-ffffffffffff
return guid.GUID{
Data1: 0xffffffff,
Data2: 0xffff,
Data3: 0xffff,
Data4: [8]uint8{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
}
}
// HvsockGUIDLoopback is the Loopback VmId for accepting connections to the same partition as the connector.
func HvsockGUIDLoopback() guid.GUID { // e0e16197-dd56-4a10-9195-5ee7a155a838
return guid.GUID{
Data1: 0xe0e16197,
Data2: 0xdd56,
Data3: 0x4a10,
Data4: [8]uint8{0x91, 0x95, 0x5e, 0xe7, 0xa1, 0x55, 0xa8, 0x38},
}
}
// HvsockGUIDSiloHost is the address of a silo's host partition:
// - The silo host of a hosted silo is the utility VM.
// - The silo host of a silo on a physical host is the physical host.
func HvsockGUIDSiloHost() guid.GUID { // 36bd0c5c-7276-4223-88ba-7d03b654c568
return guid.GUID{
Data1: 0x36bd0c5c,
Data2: 0x7276,
Data3: 0x4223,
Data4: [8]byte{0x88, 0xba, 0x7d, 0x03, 0xb6, 0x54, 0xc5, 0x68},
}
}
// HvsockGUIDChildren is the wildcard VmId for accepting connections from the connector's child partitions.
func HvsockGUIDChildren() guid.GUID { // 90db8b89-0d35-4f79-8ce9-49ea0ac8b7cd
return guid.GUID{
Data1: 0x90db8b89,
Data2: 0xd35,
Data3: 0x4f79,
Data4: [8]uint8{0x8c, 0xe9, 0x49, 0xea, 0xa, 0xc8, 0xb7, 0xcd},
}
}
// HvsockGUIDParent is the wildcard VmId for accepting connections from the connector's parent partition.
// Listening on this VmId accepts connection from:
// - Inside silos: silo host partition.
// - Inside hosted silo: host of the VM.
// - Inside VM: VM host.
// - Physical host: Not supported.
func HvsockGUIDParent() guid.GUID { // a42e7cda-d03f-480c-9cc2-a4de20abb878
return guid.GUID{
Data1: 0xa42e7cda,
Data2: 0xd03f,
Data3: 0x480c,
Data4: [8]uint8{0x9c, 0xc2, 0xa4, 0xde, 0x20, 0xab, 0xb8, 0x78},
}
}
// hvsockVsockServiceTemplate is the Service GUID used for the VSOCK protocol.
func hvsockVsockServiceTemplate() guid.GUID { // 00000000-facb-11e6-bd58-64006a7986d3
return guid.GUID{
Data2: 0xfacb,
Data3: 0x11e6,
Data4: [8]uint8{0xbd, 0x58, 0x64, 0x00, 0x6a, 0x79, 0x86, 0xd3},
}
}
// An HvsockAddr is an address for a AF_HYPERV socket. // An HvsockAddr is an address for a AF_HYPERV socket.
type HvsockAddr struct { type HvsockAddr struct {
@ -36,8 +109,10 @@ type rawHvsockAddr struct {
ServiceID guid.GUID ServiceID guid.GUID
} }
var _ socket.RawSockaddr = &rawHvsockAddr{}
// Network returns the address's network name, "hvsock". // Network returns the address's network name, "hvsock".
func (addr *HvsockAddr) Network() string { func (*HvsockAddr) Network() string {
return "hvsock" return "hvsock"
} }
@ -47,14 +122,14 @@ func (addr *HvsockAddr) String() string {
// VsockServiceID returns an hvsock service ID corresponding to the specified AF_VSOCK port. // VsockServiceID returns an hvsock service ID corresponding to the specified AF_VSOCK port.
func VsockServiceID(port uint32) guid.GUID { func VsockServiceID(port uint32) guid.GUID {
g, _ := guid.FromString("00000000-facb-11e6-bd58-64006a7986d3") g := hvsockVsockServiceTemplate() // make a copy
g.Data1 = port g.Data1 = port
return g return g
} }
func (addr *HvsockAddr) raw() rawHvsockAddr { func (addr *HvsockAddr) raw() rawHvsockAddr {
return rawHvsockAddr{ return rawHvsockAddr{
Family: afHvSock, Family: afHVSock,
VMID: addr.VMID, VMID: addr.VMID,
ServiceID: addr.ServiceID, ServiceID: addr.ServiceID,
} }
@ -65,20 +140,48 @@ func (addr *HvsockAddr) fromRaw(raw *rawHvsockAddr) {
addr.ServiceID = raw.ServiceID addr.ServiceID = raw.ServiceID
} }
// Sockaddr returns a pointer to and the size of this struct.
//
// Implements the [socket.RawSockaddr] interface, and allows use in
// [socket.Bind] and [socket.ConnectEx].
func (r *rawHvsockAddr) Sockaddr() (unsafe.Pointer, int32, error) {
return unsafe.Pointer(r), int32(unsafe.Sizeof(rawHvsockAddr{})), nil
}
// Sockaddr interface allows use with `sockets.Bind()` and `.ConnectEx()`.
func (r *rawHvsockAddr) FromBytes(b []byte) error {
n := int(unsafe.Sizeof(rawHvsockAddr{}))
if len(b) < n {
return fmt.Errorf("got %d, want %d: %w", len(b), n, socket.ErrBufferSize)
}
copy(unsafe.Slice((*byte)(unsafe.Pointer(r)), n), b[:n])
if r.Family != afHVSock {
return fmt.Errorf("got %d, want %d: %w", r.Family, afHVSock, socket.ErrAddrFamily)
}
return nil
}
// HvsockListener is a socket listener for the AF_HYPERV address family. // HvsockListener is a socket listener for the AF_HYPERV address family.
type HvsockListener struct { type HvsockListener struct {
sock *win32File sock *win32File
addr HvsockAddr addr HvsockAddr
} }
var _ net.Listener = &HvsockListener{}
// HvsockConn is a connected socket of the AF_HYPERV address family. // HvsockConn is a connected socket of the AF_HYPERV address family.
type HvsockConn struct { type HvsockConn struct {
sock *win32File sock *win32File
local, remote HvsockAddr local, remote HvsockAddr
} }
func newHvSocket() (*win32File, error) { var _ net.Conn = &HvsockConn{}
fd, err := syscall.Socket(afHvSock, syscall.SOCK_STREAM, 1)
func newHVSocket() (*win32File, error) {
fd, err := syscall.Socket(afHVSock, syscall.SOCK_STREAM, 1)
if err != nil { if err != nil {
return nil, os.NewSyscallError("socket", err) return nil, os.NewSyscallError("socket", err)
} }
@ -94,12 +197,12 @@ func newHvSocket() (*win32File, error) {
// ListenHvsock listens for connections on the specified hvsock address. // ListenHvsock listens for connections on the specified hvsock address.
func ListenHvsock(addr *HvsockAddr) (_ *HvsockListener, err error) { func ListenHvsock(addr *HvsockAddr) (_ *HvsockListener, err error) {
l := &HvsockListener{addr: *addr} l := &HvsockListener{addr: *addr}
sock, err := newHvSocket() sock, err := newHVSocket()
if err != nil { if err != nil {
return nil, l.opErr("listen", err) return nil, l.opErr("listen", err)
} }
sa := addr.raw() sa := addr.raw()
err = bind(sock.handle, unsafe.Pointer(&sa), int32(unsafe.Sizeof(sa))) err = socket.Bind(windows.Handle(sock.handle), &sa)
if err != nil { if err != nil {
return nil, l.opErr("listen", os.NewSyscallError("socket", err)) return nil, l.opErr("listen", os.NewSyscallError("socket", err))
} }
@ -121,7 +224,7 @@ func (l *HvsockListener) Addr() net.Addr {
// Accept waits for the next connection and returns it. // Accept waits for the next connection and returns it.
func (l *HvsockListener) Accept() (_ net.Conn, err error) { func (l *HvsockListener) Accept() (_ net.Conn, err error) {
sock, err := newHvSocket() sock, err := newHVSocket()
if err != nil { if err != nil {
return nil, l.opErr("accept", err) return nil, l.opErr("accept", err)
} }
@ -130,27 +233,42 @@ func (l *HvsockListener) Accept() (_ net.Conn, err error) {
sock.Close() sock.Close()
} }
}() }()
c, err := l.sock.prepareIo() c, err := l.sock.prepareIO()
if err != nil { if err != nil {
return nil, l.opErr("accept", err) return nil, l.opErr("accept", err)
} }
defer l.sock.wg.Done() defer l.sock.wg.Done()
// AcceptEx, per documentation, requires an extra 16 bytes per address. // AcceptEx, per documentation, requires an extra 16 bytes per address.
//
// https://docs.microsoft.com/en-us/windows/win32/api/mswsock/nf-mswsock-acceptex
const addrlen = uint32(16 + unsafe.Sizeof(rawHvsockAddr{})) const addrlen = uint32(16 + unsafe.Sizeof(rawHvsockAddr{}))
var addrbuf [addrlen * 2]byte var addrbuf [addrlen * 2]byte
var bytes uint32 var bytes uint32
err = syscall.AcceptEx(l.sock.handle, sock.handle, &addrbuf[0], 0, addrlen, addrlen, &bytes, &c.o) err = syscall.AcceptEx(l.sock.handle, sock.handle, &addrbuf[0], 0 /*rxdatalen*/, addrlen, addrlen, &bytes, &c.o)
_, err = l.sock.asyncIo(c, nil, bytes, err) if _, err = l.sock.asyncIO(c, nil, bytes, err); err != nil {
if err != nil {
return nil, l.opErr("accept", os.NewSyscallError("acceptex", err)) return nil, l.opErr("accept", os.NewSyscallError("acceptex", err))
} }
conn := &HvsockConn{ conn := &HvsockConn{
sock: sock, sock: sock,
} }
// The local address returned in the AcceptEx buffer is the same as the Listener socket's
// address. However, the service GUID reported by GetSockName is different from the Listeners
// socket, and is sometimes the same as the local address of the socket that dialed the
// address, with the service GUID.Data1 incremented, but othertimes is different.
// todo: does the local address matter? is the listener's address or the actual address appropriate?
conn.local.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[0]))) conn.local.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[0])))
conn.remote.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[addrlen]))) conn.remote.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[addrlen])))
// initialize the accepted socket and update its properties with those of the listening socket
if err = windows.Setsockopt(windows.Handle(sock.handle),
windows.SOL_SOCKET, windows.SO_UPDATE_ACCEPT_CONTEXT,
(*byte)(unsafe.Pointer(&l.sock.handle)), int32(unsafe.Sizeof(l.sock.handle))); err != nil {
return nil, conn.opErr("accept", os.NewSyscallError("setsockopt", err))
}
sock = nil sock = nil
return conn, nil return conn, nil
} }
@ -160,43 +278,171 @@ func (l *HvsockListener) Close() error {
return l.sock.Close() return l.sock.Close()
} }
/* Need to finish ConnectEx handling // HvsockDialer configures and dials a Hyper-V Socket (ie, [HvsockConn]).
func DialHvsock(ctx context.Context, addr *HvsockAddr) (*HvsockConn, error) { type HvsockDialer struct {
sock, err := newHvSocket() // Deadline is the time the Dial operation must connect before erroring.
Deadline time.Time
// Retries is the number of additional connects to try if the connection times out, is refused,
// or the host is unreachable
Retries uint
// RetryWait is the time to wait after a connection error to retry
RetryWait time.Duration
rt *time.Timer // redial wait timer
}
// Dial the Hyper-V socket at addr.
//
// See [HvsockDialer.Dial] for more information.
func Dial(ctx context.Context, addr *HvsockAddr) (conn *HvsockConn, err error) {
return (&HvsockDialer{}).Dial(ctx, addr)
}
// Dial attempts to connect to the Hyper-V socket at addr, and returns a connection if successful.
// Will attempt (HvsockDialer).Retries if dialing fails, waiting (HvsockDialer).RetryWait between
// retries.
//
// Dialing can be cancelled either by providing (HvsockDialer).Deadline, or cancelling ctx.
func (d *HvsockDialer) Dial(ctx context.Context, addr *HvsockAddr) (conn *HvsockConn, err error) {
op := "dial"
// create the conn early to use opErr()
conn = &HvsockConn{
remote: *addr,
}
if !d.Deadline.IsZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(ctx, d.Deadline)
defer cancel()
}
// preemptive timeout/cancellation check
if err = ctx.Err(); err != nil {
return nil, conn.opErr(op, err)
}
sock, err := newHVSocket()
if err != nil { if err != nil {
return nil, err return nil, conn.opErr(op, err)
} }
defer func() { defer func() {
if sock != nil { if sock != nil {
sock.Close() sock.Close()
} }
}() }()
c, err := sock.prepareIo()
sa := addr.raw()
err = socket.Bind(windows.Handle(sock.handle), &sa)
if err != nil { if err != nil {
return nil, err return nil, conn.opErr(op, os.NewSyscallError("bind", err))
}
c, err := sock.prepareIO()
if err != nil {
return nil, conn.opErr(op, err)
} }
defer sock.wg.Done() defer sock.wg.Done()
var bytes uint32 var bytes uint32
err = windows.ConnectEx(windows.Handle(sock.handle), sa, nil, 0, &bytes, &c.o) for i := uint(0); i <= d.Retries; i++ {
_, err = sock.asyncIo(ctx, c, nil, bytes, err) err = socket.ConnectEx(
windows.Handle(sock.handle),
&sa,
nil, // sendBuf
0, // sendDataLen
&bytes,
(*windows.Overlapped)(unsafe.Pointer(&c.o)))
_, err = sock.asyncIO(c, nil, bytes, err)
if i < d.Retries && canRedial(err) {
if err = d.redialWait(ctx); err == nil {
continue
}
}
break
}
if err != nil { if err != nil {
return nil, err return nil, conn.opErr(op, os.NewSyscallError("connectex", err))
} }
conn := &HvsockConn{
sock: sock, // update the connection properties, so shutdown can be used
remote: *addr, if err = windows.Setsockopt(
windows.Handle(sock.handle),
windows.SOL_SOCKET,
windows.SO_UPDATE_CONNECT_CONTEXT,
nil, // optvalue
0, // optlen
); err != nil {
return nil, conn.opErr(op, os.NewSyscallError("setsockopt", err))
} }
// get the local name
var sal rawHvsockAddr
err = socket.GetSockName(windows.Handle(sock.handle), &sal)
if err != nil {
return nil, conn.opErr(op, os.NewSyscallError("getsockname", err))
}
conn.local.fromRaw(&sal)
// one last check for timeout, since asyncIO doesn't check the context
if err = ctx.Err(); err != nil {
return nil, conn.opErr(op, err)
}
conn.sock = sock
sock = nil sock = nil
return conn, nil return conn, nil
} }
*/
// redialWait waits before attempting to redial, resetting the timer as appropriate.
func (d *HvsockDialer) redialWait(ctx context.Context) (err error) {
if d.RetryWait == 0 {
return nil
}
if d.rt == nil {
d.rt = time.NewTimer(d.RetryWait)
} else {
// should already be stopped and drained
d.rt.Reset(d.RetryWait)
}
select {
case <-ctx.Done():
case <-d.rt.C:
return nil
}
// stop and drain the timer
if !d.rt.Stop() {
<-d.rt.C
}
return ctx.Err()
}
// assumes error is a plain, unwrapped syscall.Errno provided by direct syscall.
func canRedial(err error) bool {
//nolint:errorlint // guaranteed to be an Errno
switch err {
case windows.WSAECONNREFUSED, windows.WSAENETUNREACH, windows.WSAETIMEDOUT,
windows.ERROR_CONNECTION_REFUSED, windows.ERROR_CONNECTION_UNAVAIL:
return true
default:
return false
}
}
func (conn *HvsockConn) opErr(op string, err error) error { func (conn *HvsockConn) opErr(op string, err error) error {
// translate from "file closed" to "socket closed"
if errors.Is(err, ErrFileClosed) {
err = socket.ErrSocketClosed
}
return &net.OpError{Op: op, Net: "hvsock", Source: &conn.local, Addr: &conn.remote, Err: err} return &net.OpError{Op: op, Net: "hvsock", Source: &conn.local, Addr: &conn.remote, Err: err}
} }
func (conn *HvsockConn) Read(b []byte) (int, error) { func (conn *HvsockConn) Read(b []byte) (int, error) {
c, err := conn.sock.prepareIo() c, err := conn.sock.prepareIO()
if err != nil { if err != nil {
return 0, conn.opErr("read", err) return 0, conn.opErr("read", err)
} }
@ -204,10 +450,11 @@ func (conn *HvsockConn) Read(b []byte) (int, error) {
buf := syscall.WSABuf{Buf: &b[0], Len: uint32(len(b))} buf := syscall.WSABuf{Buf: &b[0], Len: uint32(len(b))}
var flags, bytes uint32 var flags, bytes uint32
err = syscall.WSARecv(conn.sock.handle, &buf, 1, &bytes, &flags, &c.o, nil) err = syscall.WSARecv(conn.sock.handle, &buf, 1, &bytes, &flags, &c.o, nil)
n, err := conn.sock.asyncIo(c, &conn.sock.readDeadline, bytes, err) n, err := conn.sock.asyncIO(c, &conn.sock.readDeadline, bytes, err)
if err != nil { if err != nil {
if _, ok := err.(syscall.Errno); ok { var eno windows.Errno
err = os.NewSyscallError("wsarecv", err) if errors.As(err, &eno) {
err = os.NewSyscallError("wsarecv", eno)
} }
return 0, conn.opErr("read", err) return 0, conn.opErr("read", err)
} else if n == 0 { } else if n == 0 {
@ -230,7 +477,7 @@ func (conn *HvsockConn) Write(b []byte) (int, error) {
} }
func (conn *HvsockConn) write(b []byte) (int, error) { func (conn *HvsockConn) write(b []byte) (int, error) {
c, err := conn.sock.prepareIo() c, err := conn.sock.prepareIO()
if err != nil { if err != nil {
return 0, conn.opErr("write", err) return 0, conn.opErr("write", err)
} }
@ -238,10 +485,11 @@ func (conn *HvsockConn) write(b []byte) (int, error) {
buf := syscall.WSABuf{Buf: &b[0], Len: uint32(len(b))} buf := syscall.WSABuf{Buf: &b[0], Len: uint32(len(b))}
var bytes uint32 var bytes uint32
err = syscall.WSASend(conn.sock.handle, &buf, 1, &bytes, 0, &c.o, nil) err = syscall.WSASend(conn.sock.handle, &buf, 1, &bytes, 0, &c.o, nil)
n, err := conn.sock.asyncIo(c, &conn.sock.writeDeadline, bytes, err) n, err := conn.sock.asyncIO(c, &conn.sock.writeDeadline, bytes, err)
if err != nil { if err != nil {
if _, ok := err.(syscall.Errno); ok { var eno windows.Errno
err = os.NewSyscallError("wsasend", err) if errors.As(err, &eno) {
err = os.NewSyscallError("wsasend", eno)
} }
return 0, conn.opErr("write", err) return 0, conn.opErr("write", err)
} }
@ -257,13 +505,19 @@ func (conn *HvsockConn) IsClosed() bool {
return conn.sock.IsClosed() return conn.sock.IsClosed()
} }
// shutdown disables sending or receiving on a socket.
func (conn *HvsockConn) shutdown(how int) error { func (conn *HvsockConn) shutdown(how int) error {
if conn.IsClosed() { if conn.IsClosed() {
return ErrFileClosed return socket.ErrSocketClosed
} }
err := syscall.Shutdown(conn.sock.handle, how) err := syscall.Shutdown(conn.sock.handle, how)
if err != nil { if err != nil {
// If the connection was closed, shutdowns fail with "not connected"
if errors.Is(err, windows.WSAENOTCONN) ||
errors.Is(err, windows.WSAESHUTDOWN) {
err = socket.ErrSocketClosed
}
return os.NewSyscallError("shutdown", err) return os.NewSyscallError("shutdown", err)
} }
return nil return nil
@ -273,7 +527,7 @@ func (conn *HvsockConn) shutdown(how int) error {
func (conn *HvsockConn) CloseRead() error { func (conn *HvsockConn) CloseRead() error {
err := conn.shutdown(syscall.SHUT_RD) err := conn.shutdown(syscall.SHUT_RD)
if err != nil { if err != nil {
return conn.opErr("close", err) return conn.opErr("closeread", err)
} }
return nil return nil
} }
@ -283,7 +537,7 @@ func (conn *HvsockConn) CloseRead() error {
func (conn *HvsockConn) CloseWrite() error { func (conn *HvsockConn) CloseWrite() error {
err := conn.shutdown(syscall.SHUT_WR) err := conn.shutdown(syscall.SHUT_WR)
if err != nil { if err != nil {
return conn.opErr("close", err) return conn.opErr("closewrite", err)
} }
return nil return nil
} }
@ -300,8 +554,13 @@ func (conn *HvsockConn) RemoteAddr() net.Addr {
// SetDeadline implements the net.Conn SetDeadline method. // SetDeadline implements the net.Conn SetDeadline method.
func (conn *HvsockConn) SetDeadline(t time.Time) error { func (conn *HvsockConn) SetDeadline(t time.Time) error {
conn.SetReadDeadline(t) // todo: implement `SetDeadline` for `win32File`
conn.SetWriteDeadline(t) if err := conn.SetReadDeadline(t); err != nil {
return fmt.Errorf("set read deadline: %w", err)
}
if err := conn.SetWriteDeadline(t); err != nil {
return fmt.Errorf("set write deadline: %w", err)
}
return nil return nil
} }

View File

@ -0,0 +1,20 @@
package socket
import (
"unsafe"
)
// RawSockaddr allows structs to be used with [Bind] and [ConnectEx]. The
// struct must meet the Win32 sockaddr requirements specified here:
// https://docs.microsoft.com/en-us/windows/win32/winsock/sockaddr-2
//
// Specifically, the struct size must be least larger than an int16 (unsigned short)
// for the address family.
type RawSockaddr interface {
// Sockaddr returns a pointer to the RawSockaddr and its struct size, allowing
// for the RawSockaddr's data to be overwritten by syscalls (if necessary).
//
// It is the callers responsibility to validate that the values are valid; invalid
// pointers or size can cause a panic.
Sockaddr() (unsafe.Pointer, int32, error)
}

View File

@ -0,0 +1,179 @@
//go:build windows
package socket
import (
"errors"
"fmt"
"net"
"sync"
"syscall"
"unsafe"
"github.com/Microsoft/go-winio/pkg/guid"
"golang.org/x/sys/windows"
)
//go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go socket.go
//sys getsockname(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) [failretval==socketError] = ws2_32.getsockname
//sys getpeername(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) [failretval==socketError] = ws2_32.getpeername
//sys bind(s windows.Handle, name unsafe.Pointer, namelen int32) (err error) [failretval==socketError] = ws2_32.bind
const socketError = uintptr(^uint32(0))
var (
// todo(helsaawy): create custom error types to store the desired vs actual size and addr family?
ErrBufferSize = errors.New("buffer size")
ErrAddrFamily = errors.New("address family")
ErrInvalidPointer = errors.New("invalid pointer")
ErrSocketClosed = fmt.Errorf("socket closed: %w", net.ErrClosed)
)
// todo(helsaawy): replace these with generics, ie: GetSockName[S RawSockaddr](s windows.Handle) (S, error)
// GetSockName writes the local address of socket s to the [RawSockaddr] rsa.
// If rsa is not large enough, the [windows.WSAEFAULT] is returned.
func GetSockName(s windows.Handle, rsa RawSockaddr) error {
ptr, l, err := rsa.Sockaddr()
if err != nil {
return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
}
// although getsockname returns WSAEFAULT if the buffer is too small, it does not set
// &l to the correct size, so--apart from doubling the buffer repeatedly--there is no remedy
return getsockname(s, ptr, &l)
}
// GetPeerName returns the remote address the socket is connected to.
//
// See [GetSockName] for more information.
func GetPeerName(s windows.Handle, rsa RawSockaddr) error {
ptr, l, err := rsa.Sockaddr()
if err != nil {
return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
}
return getpeername(s, ptr, &l)
}
func Bind(s windows.Handle, rsa RawSockaddr) (err error) {
ptr, l, err := rsa.Sockaddr()
if err != nil {
return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
}
return bind(s, ptr, l)
}
// "golang.org/x/sys/windows".ConnectEx and .Bind only accept internal implementations of the
// their sockaddr interface, so they cannot be used with HvsockAddr
// Replicate functionality here from
// https://cs.opensource.google/go/x/sys/+/master:windows/syscall_windows.go
// The function pointers to `AcceptEx`, `ConnectEx` and `GetAcceptExSockaddrs` must be loaded at
// runtime via a WSAIoctl call:
// https://docs.microsoft.com/en-us/windows/win32/api/Mswsock/nc-mswsock-lpfn_connectex#remarks
type runtimeFunc struct {
id guid.GUID
once sync.Once
addr uintptr
err error
}
func (f *runtimeFunc) Load() error {
f.once.Do(func() {
var s windows.Handle
s, f.err = windows.Socket(windows.AF_INET, windows.SOCK_STREAM, windows.IPPROTO_TCP)
if f.err != nil {
return
}
defer windows.CloseHandle(s) //nolint:errcheck
var n uint32
f.err = windows.WSAIoctl(s,
windows.SIO_GET_EXTENSION_FUNCTION_POINTER,
(*byte)(unsafe.Pointer(&f.id)),
uint32(unsafe.Sizeof(f.id)),
(*byte)(unsafe.Pointer(&f.addr)),
uint32(unsafe.Sizeof(f.addr)),
&n,
nil, //overlapped
0, //completionRoutine
)
})
return f.err
}
var (
// todo: add `AcceptEx` and `GetAcceptExSockaddrs`
WSAID_CONNECTEX = guid.GUID{ //revive:disable-line:var-naming ALL_CAPS
Data1: 0x25a207b9,
Data2: 0xddf3,
Data3: 0x4660,
Data4: [8]byte{0x8e, 0xe9, 0x76, 0xe5, 0x8c, 0x74, 0x06, 0x3e},
}
connectExFunc = runtimeFunc{id: WSAID_CONNECTEX}
)
func ConnectEx(
fd windows.Handle,
rsa RawSockaddr,
sendBuf *byte,
sendDataLen uint32,
bytesSent *uint32,
overlapped *windows.Overlapped,
) error {
if err := connectExFunc.Load(); err != nil {
return fmt.Errorf("failed to load ConnectEx function pointer: %w", err)
}
ptr, n, err := rsa.Sockaddr()
if err != nil {
return err
}
return connectEx(fd, ptr, n, sendBuf, sendDataLen, bytesSent, overlapped)
}
// BOOL LpfnConnectex(
// [in] SOCKET s,
// [in] const sockaddr *name,
// [in] int namelen,
// [in, optional] PVOID lpSendBuffer,
// [in] DWORD dwSendDataLength,
// [out] LPDWORD lpdwBytesSent,
// [in] LPOVERLAPPED lpOverlapped
// )
func connectEx(
s windows.Handle,
name unsafe.Pointer,
namelen int32,
sendBuf *byte,
sendDataLen uint32,
bytesSent *uint32,
overlapped *windows.Overlapped,
) (err error) {
// todo: after upgrading to 1.18, switch from syscall.Syscall9 to syscall.SyscallN
r1, _, e1 := syscall.Syscall9(connectExFunc.addr,
7,
uintptr(s),
uintptr(name),
uintptr(namelen),
uintptr(unsafe.Pointer(sendBuf)),
uintptr(sendDataLen),
uintptr(unsafe.Pointer(bytesSent)),
uintptr(unsafe.Pointer(overlapped)),
0,
0)
if r1 == 0 {
if e1 != 0 {
err = error(e1)
} else {
err = syscall.EINVAL
}
}
return err
}

View File

@ -0,0 +1,72 @@
//go:build windows
// Code generated by 'go generate' using "github.com/Microsoft/go-winio/tools/mkwinsyscall"; DO NOT EDIT.
package socket
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
var _ unsafe.Pointer
// Do the interface allocations only once for common
// Errno values.
const (
errnoERROR_IO_PENDING = 997
)
var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
errERROR_EINVAL error = syscall.EINVAL
)
// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return errERROR_EINVAL
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
// TODO: add more here, after collecting data on the common
// error values see on Windows. (perhaps when running
// all.bat?)
return e
}
var (
modws2_32 = windows.NewLazySystemDLL("ws2_32.dll")
procbind = modws2_32.NewProc("bind")
procgetpeername = modws2_32.NewProc("getpeername")
procgetsockname = modws2_32.NewProc("getsockname")
)
func bind(s windows.Handle, name unsafe.Pointer, namelen int32) (err error) {
r1, _, e1 := syscall.Syscall(procbind.Addr(), 3, uintptr(s), uintptr(name), uintptr(namelen))
if r1 == socketError {
err = errnoErr(e1)
}
return
}
func getpeername(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) {
r1, _, e1 := syscall.Syscall(procgetpeername.Addr(), 3, uintptr(s), uintptr(name), uintptr(unsafe.Pointer(namelen)))
if r1 == socketError {
err = errnoErr(e1)
}
return
}
func getsockname(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) {
r1, _, e1 := syscall.Syscall(procgetsockname.Addr(), 3, uintptr(s), uintptr(name), uintptr(unsafe.Pointer(namelen)))
if r1 == socketError {
err = errnoErr(e1)
}
return
}

View File

@ -1,3 +1,4 @@
//go:build windows
// +build windows // +build windows
package winio package winio
@ -13,6 +14,8 @@ import (
"syscall" "syscall"
"time" "time"
"unsafe" "unsafe"
"golang.org/x/sys/windows"
) )
//sys connectNamedPipe(pipe syscall.Handle, o *syscall.Overlapped) (err error) = ConnectNamedPipe //sys connectNamedPipe(pipe syscall.Handle, o *syscall.Overlapped) (err error) = ConnectNamedPipe
@ -21,10 +24,10 @@ import (
//sys getNamedPipeInfo(pipe syscall.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) = GetNamedPipeInfo //sys getNamedPipeInfo(pipe syscall.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) = GetNamedPipeInfo
//sys getNamedPipeHandleState(pipe syscall.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) = GetNamedPipeHandleStateW //sys getNamedPipeHandleState(pipe syscall.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) = GetNamedPipeHandleStateW
//sys localAlloc(uFlags uint32, length uint32) (ptr uintptr) = LocalAlloc //sys localAlloc(uFlags uint32, length uint32) (ptr uintptr) = LocalAlloc
//sys ntCreateNamedPipeFile(pipe *syscall.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) = ntdll.NtCreateNamedPipeFile //sys ntCreateNamedPipeFile(pipe *syscall.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntStatus) = ntdll.NtCreateNamedPipeFile
//sys rtlNtStatusToDosError(status ntstatus) (winerr error) = ntdll.RtlNtStatusToDosErrorNoTeb //sys rtlNtStatusToDosError(status ntStatus) (winerr error) = ntdll.RtlNtStatusToDosErrorNoTeb
//sys rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntstatus) = ntdll.RtlDosPathNameToNtPathName_U //sys rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntStatus) = ntdll.RtlDosPathNameToNtPathName_U
//sys rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) = ntdll.RtlDefaultNpAcl //sys rtlDefaultNpAcl(dacl *uintptr) (status ntStatus) = ntdll.RtlDefaultNpAcl
type ioStatusBlock struct { type ioStatusBlock struct {
Status, Information uintptr Status, Information uintptr
@ -51,45 +54,22 @@ type securityDescriptor struct {
Control uint16 Control uint16
Owner uintptr Owner uintptr
Group uintptr Group uintptr
Sacl uintptr Sacl uintptr //revive:disable-line:var-naming SACL, not Sacl
Dacl uintptr Dacl uintptr //revive:disable-line:var-naming DACL, not Dacl
} }
type ntstatus int32 type ntStatus int32
func (status ntstatus) Err() error { func (status ntStatus) Err() error {
if status >= 0 { if status >= 0 {
return nil return nil
} }
return rtlNtStatusToDosError(status) return rtlNtStatusToDosError(status)
} }
const (
cERROR_PIPE_BUSY = syscall.Errno(231)
cERROR_NO_DATA = syscall.Errno(232)
cERROR_PIPE_CONNECTED = syscall.Errno(535)
cERROR_SEM_TIMEOUT = syscall.Errno(121)
cSECURITY_SQOS_PRESENT = 0x100000
cSECURITY_ANONYMOUS = 0
cPIPE_TYPE_MESSAGE = 4
cPIPE_READMODE_MESSAGE = 2
cFILE_OPEN = 1
cFILE_CREATE = 2
cFILE_PIPE_MESSAGE_TYPE = 1
cFILE_PIPE_REJECT_REMOTE_CLIENTS = 2
cSE_DACL_PRESENT = 4
)
var ( var (
// ErrPipeListenerClosed is returned for pipe operations on listeners that have been closed. // ErrPipeListenerClosed is returned for pipe operations on listeners that have been closed.
// This error should match net.errClosing since docker takes a dependency on its text. ErrPipeListenerClosed = net.ErrClosed
ErrPipeListenerClosed = errors.New("use of closed network connection")
errPipeWriteClosed = errors.New("pipe has been closed for write") errPipeWriteClosed = errors.New("pipe has been closed for write")
) )
@ -116,9 +96,10 @@ func (f *win32Pipe) RemoteAddr() net.Addr {
} }
func (f *win32Pipe) SetDeadline(t time.Time) error { func (f *win32Pipe) SetDeadline(t time.Time) error {
f.SetReadDeadline(t) if err := f.SetReadDeadline(t); err != nil {
f.SetWriteDeadline(t) return err
return nil }
return f.SetWriteDeadline(t)
} }
// CloseWrite closes the write side of a message pipe in byte mode. // CloseWrite closes the write side of a message pipe in byte mode.
@ -157,14 +138,14 @@ func (f *win32MessageBytePipe) Read(b []byte) (int, error) {
return 0, io.EOF return 0, io.EOF
} }
n, err := f.win32File.Read(b) n, err := f.win32File.Read(b)
if err == io.EOF { if err == io.EOF { //nolint:errorlint
// If this was the result of a zero-byte read, then // If this was the result of a zero-byte read, then
// it is possible that the read was due to a zero-size // it is possible that the read was due to a zero-size
// message. Since we are simulating CloseWrite with a // message. Since we are simulating CloseWrite with a
// zero-byte message, ensure that all future Read() calls // zero-byte message, ensure that all future Read() calls
// also return EOF. // also return EOF.
f.readEOF = true f.readEOF = true
} else if err == syscall.ERROR_MORE_DATA { } else if err == syscall.ERROR_MORE_DATA { //nolint:errorlint // err is Errno
// ERROR_MORE_DATA indicates that the pipe's read mode is message mode // ERROR_MORE_DATA indicates that the pipe's read mode is message mode
// and the message still has more bytes. Treat this as a success, since // and the message still has more bytes. Treat this as a success, since
// this package presents all named pipes as byte streams. // this package presents all named pipes as byte streams.
@ -173,7 +154,7 @@ func (f *win32MessageBytePipe) Read(b []byte) (int, error) {
return n, err return n, err
} }
func (s pipeAddress) Network() string { func (pipeAddress) Network() string {
return "pipe" return "pipe"
} }
@ -184,16 +165,21 @@ func (s pipeAddress) String() string {
// tryDialPipe attempts to dial the pipe at `path` until `ctx` cancellation or timeout. // tryDialPipe attempts to dial the pipe at `path` until `ctx` cancellation or timeout.
func tryDialPipe(ctx context.Context, path *string, access uint32) (syscall.Handle, error) { func tryDialPipe(ctx context.Context, path *string, access uint32) (syscall.Handle, error) {
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return syscall.Handle(0), ctx.Err() return syscall.Handle(0), ctx.Err()
default: default:
h, err := createFile(*path, access, 0, nil, syscall.OPEN_EXISTING, syscall.FILE_FLAG_OVERLAPPED|cSECURITY_SQOS_PRESENT|cSECURITY_ANONYMOUS, 0) h, err := createFile(*path,
access,
0,
nil,
syscall.OPEN_EXISTING,
windows.FILE_FLAG_OVERLAPPED|windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS,
0)
if err == nil { if err == nil {
return h, nil return h, nil
} }
if err != cERROR_PIPE_BUSY { if err != windows.ERROR_PIPE_BUSY { //nolint:errorlint // err is Errno
return h, &os.PathError{Err: err, Op: "open", Path: *path} return h, &os.PathError{Err: err, Op: "open", Path: *path}
} }
// Wait 10 msec and try again. This is a rather simplistic // Wait 10 msec and try again. This is a rather simplistic
@ -213,9 +199,10 @@ func DialPipe(path string, timeout *time.Duration) (net.Conn, error) {
} else { } else {
absTimeout = time.Now().Add(2 * time.Second) absTimeout = time.Now().Add(2 * time.Second)
} }
ctx, _ := context.WithDeadline(context.Background(), absTimeout) ctx, cancel := context.WithDeadline(context.Background(), absTimeout)
defer cancel()
conn, err := DialPipeContext(ctx, path) conn, err := DialPipeContext(ctx, path)
if err == context.DeadlineExceeded { if errors.Is(err, context.DeadlineExceeded) {
return nil, ErrTimeout return nil, ErrTimeout
} }
return conn, err return conn, err
@ -251,7 +238,7 @@ func DialPipeAccess(ctx context.Context, path string, access uint32) (net.Conn,
// If the pipe is in message mode, return a message byte pipe, which // If the pipe is in message mode, return a message byte pipe, which
// supports CloseWrite(). // supports CloseWrite().
if flags&cPIPE_TYPE_MESSAGE != 0 { if flags&windows.PIPE_TYPE_MESSAGE != 0 {
return &win32MessageBytePipe{ return &win32MessageBytePipe{
win32Pipe: win32Pipe{win32File: f, path: path}, win32Pipe: win32Pipe{win32File: f, path: path},
}, nil }, nil
@ -283,7 +270,11 @@ func makeServerPipeHandle(path string, sd []byte, c *PipeConfig, first bool) (sy
oa.Length = unsafe.Sizeof(oa) oa.Length = unsafe.Sizeof(oa)
var ntPath unicodeString var ntPath unicodeString
if err := rtlDosPathNameToNtPathName(&path16[0], &ntPath, 0, 0).Err(); err != nil { if err := rtlDosPathNameToNtPathName(&path16[0],
&ntPath,
0,
0,
).Err(); err != nil {
return 0, &os.PathError{Op: "open", Path: path, Err: err} return 0, &os.PathError{Op: "open", Path: path, Err: err}
} }
defer localFree(ntPath.Buffer) defer localFree(ntPath.Buffer)
@ -292,8 +283,8 @@ func makeServerPipeHandle(path string, sd []byte, c *PipeConfig, first bool) (sy
// The security descriptor is only needed for the first pipe. // The security descriptor is only needed for the first pipe.
if first { if first {
if sd != nil { if sd != nil {
len := uint32(len(sd)) l := uint32(len(sd))
sdb := localAlloc(0, len) sdb := localAlloc(0, l)
defer localFree(sdb) defer localFree(sdb)
copy((*[0xffff]byte)(unsafe.Pointer(sdb))[:], sd) copy((*[0xffff]byte)(unsafe.Pointer(sdb))[:], sd)
oa.SecurityDescriptor = (*securityDescriptor)(unsafe.Pointer(sdb)) oa.SecurityDescriptor = (*securityDescriptor)(unsafe.Pointer(sdb))
@ -301,28 +292,28 @@ func makeServerPipeHandle(path string, sd []byte, c *PipeConfig, first bool) (sy
// Construct the default named pipe security descriptor. // Construct the default named pipe security descriptor.
var dacl uintptr var dacl uintptr
if err := rtlDefaultNpAcl(&dacl).Err(); err != nil { if err := rtlDefaultNpAcl(&dacl).Err(); err != nil {
return 0, fmt.Errorf("getting default named pipe ACL: %s", err) return 0, fmt.Errorf("getting default named pipe ACL: %w", err)
} }
defer localFree(dacl) defer localFree(dacl)
sdb := &securityDescriptor{ sdb := &securityDescriptor{
Revision: 1, Revision: 1,
Control: cSE_DACL_PRESENT, Control: windows.SE_DACL_PRESENT,
Dacl: dacl, Dacl: dacl,
} }
oa.SecurityDescriptor = sdb oa.SecurityDescriptor = sdb
} }
} }
typ := uint32(cFILE_PIPE_REJECT_REMOTE_CLIENTS) typ := uint32(windows.FILE_PIPE_REJECT_REMOTE_CLIENTS)
if c.MessageMode { if c.MessageMode {
typ |= cFILE_PIPE_MESSAGE_TYPE typ |= windows.FILE_PIPE_MESSAGE_TYPE
} }
disposition := uint32(cFILE_OPEN) disposition := uint32(windows.FILE_OPEN)
access := uint32(syscall.GENERIC_READ | syscall.GENERIC_WRITE | syscall.SYNCHRONIZE) access := uint32(syscall.GENERIC_READ | syscall.GENERIC_WRITE | syscall.SYNCHRONIZE)
if first { if first {
disposition = cFILE_CREATE disposition = windows.FILE_CREATE
// By not asking for read or write access, the named pipe file system // By not asking for read or write access, the named pipe file system
// will put this pipe into an initially disconnected state, blocking // will put this pipe into an initially disconnected state, blocking
// client connections until the next call with first == false. // client connections until the next call with first == false.
@ -335,7 +326,20 @@ func makeServerPipeHandle(path string, sd []byte, c *PipeConfig, first bool) (sy
h syscall.Handle h syscall.Handle
iosb ioStatusBlock iosb ioStatusBlock
) )
err = ntCreateNamedPipeFile(&h, access, &oa, &iosb, syscall.FILE_SHARE_READ|syscall.FILE_SHARE_WRITE, disposition, 0, typ, 0, 0, 0xffffffff, uint32(c.InputBufferSize), uint32(c.OutputBufferSize), &timeout).Err() err = ntCreateNamedPipeFile(&h,
access,
&oa,
&iosb,
syscall.FILE_SHARE_READ|syscall.FILE_SHARE_WRITE,
disposition,
0,
typ,
0,
0,
0xffffffff,
uint32(c.InputBufferSize),
uint32(c.OutputBufferSize),
&timeout).Err()
if err != nil { if err != nil {
return 0, &os.PathError{Op: "open", Path: path, Err: err} return 0, &os.PathError{Op: "open", Path: path, Err: err}
} }
@ -380,7 +384,7 @@ func (l *win32PipeListener) makeConnectedServerPipe() (*win32File, error) {
p.Close() p.Close()
p = nil p = nil
err = <-ch err = <-ch
if err == nil || err == ErrFileClosed { if err == nil || err == ErrFileClosed { //nolint:errorlint // err is Errno
err = ErrPipeListenerClosed err = ErrPipeListenerClosed
} }
} }
@ -402,12 +406,12 @@ func (l *win32PipeListener) listenerRoutine() {
p, err = l.makeConnectedServerPipe() p, err = l.makeConnectedServerPipe()
// If the connection was immediately closed by the client, try // If the connection was immediately closed by the client, try
// again. // again.
if err != cERROR_NO_DATA { if err != windows.ERROR_NO_DATA { //nolint:errorlint // err is Errno
break break
} }
} }
responseCh <- acceptResponse{p, err} responseCh <- acceptResponse{p, err}
closed = err == ErrPipeListenerClosed closed = err == ErrPipeListenerClosed //nolint:errorlint // err is Errno
} }
} }
syscall.Close(l.firstHandle) syscall.Close(l.firstHandle)
@ -469,15 +473,15 @@ func ListenPipe(path string, c *PipeConfig) (net.Listener, error) {
} }
func connectPipe(p *win32File) error { func connectPipe(p *win32File) error {
c, err := p.prepareIo() c, err := p.prepareIO()
if err != nil { if err != nil {
return err return err
} }
defer p.wg.Done() defer p.wg.Done()
err = connectNamedPipe(p.handle, &c.o) err = connectNamedPipe(p.handle, &c.o)
_, err = p.asyncIo(c, nil, 0, err) _, err = p.asyncIO(c, nil, 0, err)
if err != nil && err != cERROR_PIPE_CONNECTED { if err != nil && err != windows.ERROR_PIPE_CONNECTED { //nolint:errorlint // err is Errno
return err return err
} }
return nil return nil

View File

@ -1,5 +1,3 @@
// +build windows
// Package guid provides a GUID type. The backing structure for a GUID is // Package guid provides a GUID type. The backing structure for a GUID is
// identical to that used by the golang.org/x/sys/windows GUID type. // identical to that used by the golang.org/x/sys/windows GUID type.
// There are two main binary encodings used for a GUID, the big-endian encoding, // There are two main binary encodings used for a GUID, the big-endian encoding,
@ -9,24 +7,26 @@ package guid
import ( import (
"crypto/rand" "crypto/rand"
"crypto/sha1" "crypto/sha1" //nolint:gosec // not used for secure application
"encoding" "encoding"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"strconv" "strconv"
) )
//go:generate go run golang.org/x/tools/cmd/stringer -type=Variant -trimprefix=Variant -linecomment
// Variant specifies which GUID variant (or "type") of the GUID. It determines // Variant specifies which GUID variant (or "type") of the GUID. It determines
// how the entirety of the rest of the GUID is interpreted. // how the entirety of the rest of the GUID is interpreted.
type Variant uint8 type Variant uint8
// The variants specified by RFC 4122. // The variants specified by RFC 4122 section 4.1.1.
const ( const (
// VariantUnknown specifies a GUID variant which does not conform to one of // VariantUnknown specifies a GUID variant which does not conform to one of
// the variant encodings specified in RFC 4122. // the variant encodings specified in RFC 4122.
VariantUnknown Variant = iota VariantUnknown Variant = iota
VariantNCS VariantNCS
VariantRFC4122 VariantRFC4122 // RFC 4122
VariantMicrosoft VariantMicrosoft
VariantFuture VariantFuture
) )
@ -36,6 +36,10 @@ const (
// hash of an input string. // hash of an input string.
type Version uint8 type Version uint8
func (v Version) String() string {
return strconv.FormatUint(uint64(v), 10)
}
var _ = (encoding.TextMarshaler)(GUID{}) var _ = (encoding.TextMarshaler)(GUID{})
var _ = (encoding.TextUnmarshaler)(&GUID{}) var _ = (encoding.TextUnmarshaler)(&GUID{})
@ -61,7 +65,7 @@ func NewV4() (GUID, error) {
// big-endian UTF16 stream of bytes. If that is desired, the string can be // big-endian UTF16 stream of bytes. If that is desired, the string can be
// encoded as such before being passed to this function. // encoded as such before being passed to this function.
func NewV5(namespace GUID, name []byte) (GUID, error) { func NewV5(namespace GUID, name []byte) (GUID, error) {
b := sha1.New() b := sha1.New() //nolint:gosec // not used for secure application
namespaceBytes := namespace.ToArray() namespaceBytes := namespace.ToArray()
b.Write(namespaceBytes[:]) b.Write(namespaceBytes[:])
b.Write(name) b.Write(name)

View File

@ -1,3 +1,4 @@
//go:build !windows
// +build !windows // +build !windows
package guid package guid

View File

@ -1,3 +1,6 @@
//go:build windows
// +build windows
package guid package guid
import "golang.org/x/sys/windows" import "golang.org/x/sys/windows"

View File

@ -0,0 +1,27 @@
// Code generated by "stringer -type=Variant -trimprefix=Variant -linecomment"; DO NOT EDIT.
package guid
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[VariantUnknown-0]
_ = x[VariantNCS-1]
_ = x[VariantRFC4122-2]
_ = x[VariantMicrosoft-3]
_ = x[VariantFuture-4]
}
const _Variant_name = "UnknownNCSRFC 4122MicrosoftFuture"
var _Variant_index = [...]uint8{0, 7, 10, 18, 27, 33}
func (i Variant) String() string {
if i >= Variant(len(_Variant_index)-1) {
return "Variant(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _Variant_name[_Variant_index[i]:_Variant_index[i+1]]
}

View File

@ -1,3 +1,4 @@
//go:build windows
// +build windows // +build windows
package security package security
@ -20,6 +21,7 @@ type (
trusteeForm uint32 trusteeForm uint32
trusteeType uint32 trusteeType uint32
//nolint:structcheck // structcheck thinks fields are unused, but the are used to pass data to OS
explicitAccess struct { explicitAccess struct {
accessPermissions accessMask accessPermissions accessMask
accessMode accessMode accessMode accessMode
@ -27,6 +29,7 @@ type (
trustee trustee trustee trustee
} }
//nolint:structcheck,unused // structcheck thinks fields are unused, but the are used to pass data to OS
trustee struct { trustee struct {
multipleTrustee *trustee multipleTrustee *trustee
multipleTrusteeOperation int32 multipleTrusteeOperation int32
@ -44,6 +47,7 @@ const (
desiredAccessReadControl desiredAccess = 0x20000 desiredAccessReadControl desiredAccess = 0x20000
desiredAccessWriteDac desiredAccess = 0x40000 desiredAccessWriteDac desiredAccess = 0x40000
//cspell:disable-next-line
gvmga = "GrantVmGroupAccess:" gvmga = "GrantVmGroupAccess:"
inheritModeNoInheritance inheritMode = 0x0 inheritModeNoInheritance inheritMode = 0x0
@ -56,9 +60,9 @@ const (
shareModeRead shareMode = 0x1 shareModeRead shareMode = 0x1
shareModeWrite shareMode = 0x2 shareModeWrite shareMode = 0x2
sidVmGroup = "S-1-5-83-0" sidVMGroup = "S-1-5-83-0"
trusteeFormIsSid trusteeForm = 0 trusteeFormIsSID trusteeForm = 0
trusteeTypeWellKnownGroup trusteeType = 5 trusteeTypeWellKnownGroup trusteeType = 5
) )
@ -67,6 +71,8 @@ const (
// include Grant ACE entries for the VM Group SID. This is a golang re- // include Grant ACE entries for the VM Group SID. This is a golang re-
// implementation of the same function in vmcompute, just not exported in // implementation of the same function in vmcompute, just not exported in
// RS5. Which kind of sucks. Sucks a lot :/ // RS5. Which kind of sucks. Sucks a lot :/
//
//revive:disable-next-line:var-naming VM, not Vm
func GrantVmGroupAccess(name string) error { func GrantVmGroupAccess(name string) error {
// Stat (to determine if `name` is a directory). // Stat (to determine if `name` is a directory).
s, err := os.Stat(name) s, err := os.Stat(name)
@ -79,7 +85,7 @@ func GrantVmGroupAccess(name string) error {
if err != nil { if err != nil {
return err // Already wrapped return err // Already wrapped
} }
defer syscall.CloseHandle(fd) defer syscall.CloseHandle(fd) //nolint:errcheck
// Get the current DACL and Security Descriptor. Must defer LocalFree on success. // Get the current DACL and Security Descriptor. Must defer LocalFree on success.
ot := objectTypeFileObject ot := objectTypeFileObject
@ -89,7 +95,7 @@ func GrantVmGroupAccess(name string) error {
if err := getSecurityInfo(fd, uint32(ot), uint32(si), nil, nil, &origDACL, nil, &sd); err != nil { if err := getSecurityInfo(fd, uint32(ot), uint32(si), nil, nil, &origDACL, nil, &sd); err != nil {
return fmt.Errorf("%s GetSecurityInfo %s: %w", gvmga, name, err) return fmt.Errorf("%s GetSecurityInfo %s: %w", gvmga, name, err)
} }
defer syscall.LocalFree((syscall.Handle)(unsafe.Pointer(sd))) defer syscall.LocalFree((syscall.Handle)(unsafe.Pointer(sd))) //nolint:errcheck
// Generate a new DACL which is the current DACL with the required ACEs added. // Generate a new DACL which is the current DACL with the required ACEs added.
// Must defer LocalFree on success. // Must defer LocalFree on success.
@ -97,7 +103,7 @@ func GrantVmGroupAccess(name string) error {
if err != nil { if err != nil {
return err // Already wrapped return err // Already wrapped
} }
defer syscall.LocalFree((syscall.Handle)(unsafe.Pointer(newDACL))) defer syscall.LocalFree((syscall.Handle)(unsafe.Pointer(newDACL))) //nolint:errcheck
// And finally use SetSecurityInfo to apply the updated DACL. // And finally use SetSecurityInfo to apply the updated DACL.
if err := setSecurityInfo(fd, uint32(ot), uint32(si), uintptr(0), uintptr(0), newDACL, uintptr(0)); err != nil { if err := setSecurityInfo(fd, uint32(ot), uint32(si), uintptr(0), uintptr(0), newDACL, uintptr(0)); err != nil {
@ -110,16 +116,19 @@ func GrantVmGroupAccess(name string) error {
// createFile is a helper function to call [Nt]CreateFile to get a handle to // createFile is a helper function to call [Nt]CreateFile to get a handle to
// the file or directory. // the file or directory.
func createFile(name string, isDir bool) (syscall.Handle, error) { func createFile(name string, isDir bool) (syscall.Handle, error) {
namep := syscall.StringToUTF16(name) namep, err := syscall.UTF16FromString(name)
if err != nil {
return syscall.InvalidHandle, fmt.Errorf("could not convernt name to UTF-16: %w", err)
}
da := uint32(desiredAccessReadControl | desiredAccessWriteDac) da := uint32(desiredAccessReadControl | desiredAccessWriteDac)
sm := uint32(shareModeRead | shareModeWrite) sm := uint32(shareModeRead | shareModeWrite)
fa := uint32(syscall.FILE_ATTRIBUTE_NORMAL) fa := uint32(syscall.FILE_ATTRIBUTE_NORMAL)
if isDir { if isDir {
fa = uint32(fa | syscall.FILE_FLAG_BACKUP_SEMANTICS) fa |= syscall.FILE_FLAG_BACKUP_SEMANTICS
} }
fd, err := syscall.CreateFile(&namep[0], da, sm, nil, syscall.OPEN_EXISTING, fa, 0) fd, err := syscall.CreateFile(&namep[0], da, sm, nil, syscall.OPEN_EXISTING, fa, 0)
if err != nil { if err != nil {
return 0, fmt.Errorf("%s syscall.CreateFile %s: %w", gvmga, name, err) return syscall.InvalidHandle, fmt.Errorf("%s syscall.CreateFile %s: %w", gvmga, name, err)
} }
return fd, nil return fd, nil
} }
@ -128,9 +137,9 @@ func createFile(name string, isDir bool) (syscall.Handle, error) {
// The caller is responsible for LocalFree of the returned DACL on success. // The caller is responsible for LocalFree of the returned DACL on success.
func generateDACLWithAcesAdded(name string, isDir bool, origDACL uintptr) (uintptr, error) { func generateDACLWithAcesAdded(name string, isDir bool, origDACL uintptr) (uintptr, error) {
// Generate pointers to the SIDs based on the string SIDs // Generate pointers to the SIDs based on the string SIDs
sid, err := syscall.StringToSid(sidVmGroup) sid, err := syscall.StringToSid(sidVMGroup)
if err != nil { if err != nil {
return 0, fmt.Errorf("%s syscall.StringToSid %s %s: %w", gvmga, name, sidVmGroup, err) return 0, fmt.Errorf("%s syscall.StringToSid %s %s: %w", gvmga, name, sidVMGroup, err)
} }
inheritance := inheritModeNoInheritance inheritance := inheritModeNoInheritance
@ -139,12 +148,12 @@ func generateDACLWithAcesAdded(name string, isDir bool, origDACL uintptr) (uintp
} }
eaArray := []explicitAccess{ eaArray := []explicitAccess{
explicitAccess{ {
accessPermissions: accessMaskDesiredPermission, accessPermissions: accessMaskDesiredPermission,
accessMode: accessModeGrant, accessMode: accessModeGrant,
inheritance: inheritance, inheritance: inheritance,
trustee: trustee{ trustee: trustee{
trusteeForm: trusteeFormIsSid, trusteeForm: trusteeFormIsSID,
trusteeType: trusteeTypeWellKnownGroup, trusteeType: trusteeTypeWellKnownGroup,
name: uintptr(unsafe.Pointer(sid)), name: uintptr(unsafe.Pointer(sid)),
}, },

View File

@ -1,6 +1,6 @@
package security package security
//go:generate go run mksyscall_windows.go -output zsyscall_windows.go syscall_windows.go //go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go syscall_windows.go
//sys getSecurityInfo(handle syscall.Handle, objectType uint32, si uint32, ppsidOwner **uintptr, ppsidGroup **uintptr, ppDacl *uintptr, ppSacl *uintptr, ppSecurityDescriptor *uintptr) (win32err error) = advapi32.GetSecurityInfo //sys getSecurityInfo(handle syscall.Handle, objectType uint32, si uint32, ppsidOwner **uintptr, ppsidGroup **uintptr, ppDacl *uintptr, ppSacl *uintptr, ppSecurityDescriptor *uintptr) (win32err error) = advapi32.GetSecurityInfo
//sys setSecurityInfo(handle syscall.Handle, objectType uint32, si uint32, psidOwner uintptr, psidGroup uintptr, pDacl uintptr, pSacl uintptr) (win32err error) = advapi32.SetSecurityInfo //sys setSecurityInfo(handle syscall.Handle, objectType uint32, si uint32, psidOwner uintptr, psidGroup uintptr, pDacl uintptr, pSacl uintptr) (win32err error) = advapi32.SetSecurityInfo

View File

@ -1,4 +1,6 @@
// Code generated by 'go generate'; DO NOT EDIT. //go:build windows
// Code generated by 'go generate' using "github.com/Microsoft/go-winio/tools/mkwinsyscall"; DO NOT EDIT.
package security package security

View File

@ -1,3 +1,4 @@
//go:build windows
// +build windows // +build windows
package winio package winio
@ -24,22 +25,17 @@ import (
//sys lookupPrivilegeDisplayName(systemName string, name *uint16, buffer *uint16, size *uint32, languageId *uint32) (err error) = advapi32.LookupPrivilegeDisplayNameW //sys lookupPrivilegeDisplayName(systemName string, name *uint16, buffer *uint16, size *uint32, languageId *uint32) (err error) = advapi32.LookupPrivilegeDisplayNameW
const ( const (
SE_PRIVILEGE_ENABLED = 2 //revive:disable-next-line:var-naming ALL_CAPS
SE_PRIVILEGE_ENABLED = windows.SE_PRIVILEGE_ENABLED
ERROR_NOT_ALL_ASSIGNED syscall.Errno = 1300 //revive:disable-next-line:var-naming ALL_CAPS
ERROR_NOT_ALL_ASSIGNED syscall.Errno = windows.ERROR_NOT_ALL_ASSIGNED
SeBackupPrivilege = "SeBackupPrivilege" SeBackupPrivilege = "SeBackupPrivilege"
SeRestorePrivilege = "SeRestorePrivilege" SeRestorePrivilege = "SeRestorePrivilege"
SeSecurityPrivilege = "SeSecurityPrivilege" SeSecurityPrivilege = "SeSecurityPrivilege"
) )
const (
securityAnonymous = iota
securityIdentification
securityImpersonation
securityDelegation
)
var ( var (
privNames = make(map[string]uint64) privNames = make(map[string]uint64)
privNameMutex sync.Mutex privNameMutex sync.Mutex
@ -51,11 +47,9 @@ type PrivilegeError struct {
} }
func (e *PrivilegeError) Error() string { func (e *PrivilegeError) Error() string {
s := "" s := "Could not enable privilege "
if len(e.privileges) > 1 { if len(e.privileges) > 1 {
s = "Could not enable privileges " s = "Could not enable privileges "
} else {
s = "Could not enable privilege "
} }
for i, p := range e.privileges { for i, p := range e.privileges {
if i != 0 { if i != 0 {
@ -94,7 +88,7 @@ func RunWithPrivileges(names []string, fn func() error) error {
} }
func mapPrivileges(names []string) ([]uint64, error) { func mapPrivileges(names []string) ([]uint64, error) {
var privileges []uint64 privileges := make([]uint64, 0, len(names))
privNameMutex.Lock() privNameMutex.Lock()
defer privNameMutex.Unlock() defer privNameMutex.Unlock()
for _, name := range names { for _, name := range names {
@ -127,7 +121,7 @@ func enableDisableProcessPrivilege(names []string, action uint32) error {
return err return err
} }
p, _ := windows.GetCurrentProcess() p := windows.CurrentProcess()
var token windows.Token var token windows.Token
err = windows.OpenProcessToken(p, windows.TOKEN_ADJUST_PRIVILEGES|windows.TOKEN_QUERY, &token) err = windows.OpenProcessToken(p, windows.TOKEN_ADJUST_PRIVILEGES|windows.TOKEN_QUERY, &token)
if err != nil { if err != nil {
@ -140,10 +134,10 @@ func enableDisableProcessPrivilege(names []string, action uint32) error {
func adjustPrivileges(token windows.Token, privileges []uint64, action uint32) error { func adjustPrivileges(token windows.Token, privileges []uint64, action uint32) error {
var b bytes.Buffer var b bytes.Buffer
binary.Write(&b, binary.LittleEndian, uint32(len(privileges))) _ = binary.Write(&b, binary.LittleEndian, uint32(len(privileges)))
for _, p := range privileges { for _, p := range privileges {
binary.Write(&b, binary.LittleEndian, p) _ = binary.Write(&b, binary.LittleEndian, p)
binary.Write(&b, binary.LittleEndian, action) _ = binary.Write(&b, binary.LittleEndian, action)
} }
prevState := make([]byte, b.Len()) prevState := make([]byte, b.Len())
reqSize := uint32(0) reqSize := uint32(0)
@ -151,7 +145,7 @@ func adjustPrivileges(token windows.Token, privileges []uint64, action uint32) e
if !success { if !success {
return err return err
} }
if err == ERROR_NOT_ALL_ASSIGNED { if err == ERROR_NOT_ALL_ASSIGNED { //nolint:errorlint // err is Errno
return &PrivilegeError{privileges} return &PrivilegeError{privileges}
} }
return nil return nil
@ -177,7 +171,7 @@ func getPrivilegeName(luid uint64) string {
} }
func newThreadToken() (windows.Token, error) { func newThreadToken() (windows.Token, error) {
err := impersonateSelf(securityImpersonation) err := impersonateSelf(windows.SecurityImpersonation)
if err != nil { if err != nil {
return 0, err return 0, err
} }

View File

@ -1,3 +1,6 @@
//go:build windows
// +build windows
package winio package winio
import ( import (
@ -113,16 +116,16 @@ func EncodeReparsePoint(rp *ReparsePoint) []byte {
} }
var b bytes.Buffer var b bytes.Buffer
binary.Write(&b, binary.LittleEndian, &data) _ = binary.Write(&b, binary.LittleEndian, &data)
if !rp.IsMountPoint { if !rp.IsMountPoint {
flags := uint32(0) flags := uint32(0)
if relative { if relative {
flags |= 1 flags |= 1
} }
binary.Write(&b, binary.LittleEndian, flags) _ = binary.Write(&b, binary.LittleEndian, flags)
} }
binary.Write(&b, binary.LittleEndian, ntTarget16) _ = binary.Write(&b, binary.LittleEndian, ntTarget16)
binary.Write(&b, binary.LittleEndian, target16) _ = binary.Write(&b, binary.LittleEndian, target16)
return b.Bytes() return b.Bytes()
} }

View File

@ -1,23 +1,25 @@
//go:build windows
// +build windows // +build windows
package winio package winio
import ( import (
"errors"
"syscall" "syscall"
"unsafe" "unsafe"
"golang.org/x/sys/windows"
) )
//sys lookupAccountName(systemName *uint16, accountName string, sid *byte, sidSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) = advapi32.LookupAccountNameW //sys lookupAccountName(systemName *uint16, accountName string, sid *byte, sidSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) = advapi32.LookupAccountNameW
//sys lookupAccountSid(systemName *uint16, sid *byte, name *uint16, nameSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) = advapi32.LookupAccountSidW
//sys convertSidToStringSid(sid *byte, str **uint16) (err error) = advapi32.ConvertSidToStringSidW //sys convertSidToStringSid(sid *byte, str **uint16) (err error) = advapi32.ConvertSidToStringSidW
//sys convertStringSidToSid(str *uint16, sid **byte) (err error) = advapi32.ConvertStringSidToSidW
//sys convertStringSecurityDescriptorToSecurityDescriptor(str string, revision uint32, sd *uintptr, size *uint32) (err error) = advapi32.ConvertStringSecurityDescriptorToSecurityDescriptorW //sys convertStringSecurityDescriptorToSecurityDescriptor(str string, revision uint32, sd *uintptr, size *uint32) (err error) = advapi32.ConvertStringSecurityDescriptorToSecurityDescriptorW
//sys convertSecurityDescriptorToStringSecurityDescriptor(sd *byte, revision uint32, secInfo uint32, sddl **uint16, sddlSize *uint32) (err error) = advapi32.ConvertSecurityDescriptorToStringSecurityDescriptorW //sys convertSecurityDescriptorToStringSecurityDescriptor(sd *byte, revision uint32, secInfo uint32, sddl **uint16, sddlSize *uint32) (err error) = advapi32.ConvertSecurityDescriptorToStringSecurityDescriptorW
//sys localFree(mem uintptr) = LocalFree //sys localFree(mem uintptr) = LocalFree
//sys getSecurityDescriptorLength(sd uintptr) (len uint32) = advapi32.GetSecurityDescriptorLength //sys getSecurityDescriptorLength(sd uintptr) (len uint32) = advapi32.GetSecurityDescriptorLength
const (
cERROR_NONE_MAPPED = syscall.Errno(1332)
)
type AccountLookupError struct { type AccountLookupError struct {
Name string Name string
Err error Err error
@ -28,8 +30,10 @@ func (e *AccountLookupError) Error() string {
return "lookup account: empty account name specified" return "lookup account: empty account name specified"
} }
var s string var s string
switch e.Err { switch {
case cERROR_NONE_MAPPED: case errors.Is(e.Err, windows.ERROR_INVALID_SID):
s = "the security ID structure is invalid"
case errors.Is(e.Err, windows.ERROR_NONE_MAPPED):
s = "not found" s = "not found"
default: default:
s = e.Err.Error() s = e.Err.Error()
@ -37,6 +41,8 @@ func (e *AccountLookupError) Error() string {
return "lookup account " + e.Name + ": " + s return "lookup account " + e.Name + ": " + s
} }
func (e *AccountLookupError) Unwrap() error { return e.Err }
type SddlConversionError struct { type SddlConversionError struct {
Sddl string Sddl string
Err error Err error
@ -46,15 +52,19 @@ func (e *SddlConversionError) Error() string {
return "convert " + e.Sddl + ": " + e.Err.Error() return "convert " + e.Sddl + ": " + e.Err.Error()
} }
func (e *SddlConversionError) Unwrap() error { return e.Err }
// LookupSidByName looks up the SID of an account by name // LookupSidByName looks up the SID of an account by name
//
//revive:disable-next-line:var-naming SID, not Sid
func LookupSidByName(name string) (sid string, err error) { func LookupSidByName(name string) (sid string, err error) {
if name == "" { if name == "" {
return "", &AccountLookupError{name, cERROR_NONE_MAPPED} return "", &AccountLookupError{name, windows.ERROR_NONE_MAPPED}
} }
var sidSize, sidNameUse, refDomainSize uint32 var sidSize, sidNameUse, refDomainSize uint32
err = lookupAccountName(nil, name, nil, &sidSize, nil, &refDomainSize, &sidNameUse) err = lookupAccountName(nil, name, nil, &sidSize, nil, &refDomainSize, &sidNameUse)
if err != nil && err != syscall.ERROR_INSUFFICIENT_BUFFER { if err != nil && err != syscall.ERROR_INSUFFICIENT_BUFFER { //nolint:errorlint // err is Errno
return "", &AccountLookupError{name, err} return "", &AccountLookupError{name, err}
} }
sidBuffer := make([]byte, sidSize) sidBuffer := make([]byte, sidSize)
@ -73,6 +83,42 @@ func LookupSidByName(name string) (sid string, err error) {
return sid, nil return sid, nil
} }
// LookupNameBySid looks up the name of an account by SID
//
//revive:disable-next-line:var-naming SID, not Sid
func LookupNameBySid(sid string) (name string, err error) {
if sid == "" {
return "", &AccountLookupError{sid, windows.ERROR_NONE_MAPPED}
}
sidBuffer, err := windows.UTF16PtrFromString(sid)
if err != nil {
return "", &AccountLookupError{sid, err}
}
var sidPtr *byte
if err = convertStringSidToSid(sidBuffer, &sidPtr); err != nil {
return "", &AccountLookupError{sid, err}
}
defer localFree(uintptr(unsafe.Pointer(sidPtr)))
var nameSize, refDomainSize, sidNameUse uint32
err = lookupAccountSid(nil, sidPtr, nil, &nameSize, nil, &refDomainSize, &sidNameUse)
if err != nil && err != windows.ERROR_INSUFFICIENT_BUFFER { //nolint:errorlint // err is Errno
return "", &AccountLookupError{sid, err}
}
nameBuffer := make([]uint16, nameSize)
refDomainBuffer := make([]uint16, refDomainSize)
err = lookupAccountSid(nil, sidPtr, &nameBuffer[0], &nameSize, &refDomainBuffer[0], &refDomainSize, &sidNameUse)
if err != nil {
return "", &AccountLookupError{sid, err}
}
name = windows.UTF16ToString(nameBuffer)
return name, nil
}
func SddlToSecurityDescriptor(sddl string) ([]byte, error) { func SddlToSecurityDescriptor(sddl string) ([]byte, error) {
var sdBuffer uintptr var sdBuffer uintptr
err := convertStringSecurityDescriptorToSecurityDescriptor(sddl, 1, &sdBuffer, nil) err := convertStringSecurityDescriptorToSecurityDescriptor(sddl, 1, &sdBuffer, nil)
@ -87,7 +133,7 @@ func SddlToSecurityDescriptor(sddl string) ([]byte, error) {
func SecurityDescriptorToSddl(sd []byte) (string, error) { func SecurityDescriptorToSddl(sd []byte) (string, error) {
var sddl *uint16 var sddl *uint16
// The returned string length seems to including an aribtrary number of terminating NULs. // The returned string length seems to include an arbitrary number of terminating NULs.
// Don't use it. // Don't use it.
err := convertSecurityDescriptorToStringSecurityDescriptor(&sd[0], 1, 0xff, &sddl, nil) err := convertSecurityDescriptorToStringSecurityDescriptor(&sd[0], 1, 0xff, &sddl, nil)
if err != nil { if err != nil {

View File

@ -1,3 +1,5 @@
//go:build windows
package winio package winio
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go file.go pipe.go sd.go fileinfo.go privilege.go backup.go hvsock.go //go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go ./*.go

5
vendor/github.com/Microsoft/go-winio/tools.go generated vendored Normal file
View File

@ -0,0 +1,5 @@
//go:build tools
package winio
import _ "golang.org/x/tools/cmd/stringer"

View File

@ -11,7 +11,7 @@ import (
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
) )
//go:generate go run mksyscall_windows.go -output zvhd_windows.go vhd.go //go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zvhd_windows.go vhd.go
//sys createVirtualDisk(virtualStorageType *VirtualStorageType, path string, virtualDiskAccessMask uint32, securityDescriptor *uintptr, createVirtualDiskFlags uint32, providerSpecificFlags uint32, parameters *CreateVirtualDiskParameters, overlapped *syscall.Overlapped, handle *syscall.Handle) (win32err error) = virtdisk.CreateVirtualDisk //sys createVirtualDisk(virtualStorageType *VirtualStorageType, path string, virtualDiskAccessMask uint32, securityDescriptor *uintptr, createVirtualDiskFlags uint32, providerSpecificFlags uint32, parameters *CreateVirtualDiskParameters, overlapped *syscall.Overlapped, handle *syscall.Handle) (win32err error) = virtdisk.CreateVirtualDisk
//sys openVirtualDisk(virtualStorageType *VirtualStorageType, path string, virtualDiskAccessMask uint32, openVirtualDiskFlags uint32, parameters *openVirtualDiskParameters, handle *syscall.Handle) (win32err error) = virtdisk.OpenVirtualDisk //sys openVirtualDisk(virtualStorageType *VirtualStorageType, path string, virtualDiskAccessMask uint32, openVirtualDiskFlags uint32, parameters *openVirtualDiskParameters, handle *syscall.Handle) (win32err error) = virtdisk.OpenVirtualDisk
@ -62,8 +62,8 @@ type OpenVirtualDiskParameters struct {
Version2 OpenVersion2 Version2 OpenVersion2
} }
// The higher level `OpenVersion2` struct uses bools to refer to `GetInfoOnly` and `ReadOnly` for ease of use. However, // The higher level `OpenVersion2` struct uses `bool`s to refer to `GetInfoOnly` and `ReadOnly` for ease of use. However,
// the internal windows structure uses `BOOLS` aka int32s for these types. `openVersion2` is used for translating // the internal windows structure uses `BOOL`s aka int32s for these types. `openVersion2` is used for translating
// `OpenVersion2` fields to the correct windows internal field types on the `Open____` methods. // `OpenVersion2` fields to the correct windows internal field types on the `Open____` methods.
type openVersion2 struct { type openVersion2 struct {
getInfoOnly int32 getInfoOnly int32
@ -87,9 +87,10 @@ type AttachVirtualDiskParameters struct {
} }
const ( const (
//revive:disable-next-line:var-naming ALL_CAPS
VIRTUAL_STORAGE_TYPE_DEVICE_VHDX = 0x3 VIRTUAL_STORAGE_TYPE_DEVICE_VHDX = 0x3
// Access Mask for opening a VHD // Access Mask for opening a VHD.
VirtualDiskAccessNone VirtualDiskAccessMask = 0x00000000 VirtualDiskAccessNone VirtualDiskAccessMask = 0x00000000
VirtualDiskAccessAttachRO VirtualDiskAccessMask = 0x00010000 VirtualDiskAccessAttachRO VirtualDiskAccessMask = 0x00010000
VirtualDiskAccessAttachRW VirtualDiskAccessMask = 0x00020000 VirtualDiskAccessAttachRW VirtualDiskAccessMask = 0x00020000
@ -101,7 +102,7 @@ const (
VirtualDiskAccessAll VirtualDiskAccessMask = 0x003f0000 VirtualDiskAccessAll VirtualDiskAccessMask = 0x003f0000
VirtualDiskAccessWritable VirtualDiskAccessMask = 0x00320000 VirtualDiskAccessWritable VirtualDiskAccessMask = 0x00320000
// Flags for creating a VHD // Flags for creating a VHD.
CreateVirtualDiskFlagNone CreateVirtualDiskFlag = 0x0 CreateVirtualDiskFlagNone CreateVirtualDiskFlag = 0x0
CreateVirtualDiskFlagFullPhysicalAllocation CreateVirtualDiskFlag = 0x1 CreateVirtualDiskFlagFullPhysicalAllocation CreateVirtualDiskFlag = 0x1
CreateVirtualDiskFlagPreventWritesToSourceDisk CreateVirtualDiskFlag = 0x2 CreateVirtualDiskFlagPreventWritesToSourceDisk CreateVirtualDiskFlag = 0x2
@ -109,12 +110,12 @@ const (
CreateVirtualDiskFlagCreateBackingStorage CreateVirtualDiskFlag = 0x8 CreateVirtualDiskFlagCreateBackingStorage CreateVirtualDiskFlag = 0x8
CreateVirtualDiskFlagUseChangeTrackingSourceLimit CreateVirtualDiskFlag = 0x10 CreateVirtualDiskFlagUseChangeTrackingSourceLimit CreateVirtualDiskFlag = 0x10
CreateVirtualDiskFlagPreserveParentChangeTrackingState CreateVirtualDiskFlag = 0x20 CreateVirtualDiskFlagPreserveParentChangeTrackingState CreateVirtualDiskFlag = 0x20
CreateVirtualDiskFlagVhdSetUseOriginalBackingStorage CreateVirtualDiskFlag = 0x40 CreateVirtualDiskFlagVhdSetUseOriginalBackingStorage CreateVirtualDiskFlag = 0x40 //revive:disable-line:var-naming VHD, not Vhd
CreateVirtualDiskFlagSparseFile CreateVirtualDiskFlag = 0x80 CreateVirtualDiskFlagSparseFile CreateVirtualDiskFlag = 0x80
CreateVirtualDiskFlagPmemCompatible CreateVirtualDiskFlag = 0x100 CreateVirtualDiskFlagPmemCompatible CreateVirtualDiskFlag = 0x100 //revive:disable-line:var-naming PMEM, not Pmem
CreateVirtualDiskFlagSupportCompressedVolumes CreateVirtualDiskFlag = 0x200 CreateVirtualDiskFlagSupportCompressedVolumes CreateVirtualDiskFlag = 0x200
// Flags for opening a VHD // Flags for opening a VHD.
OpenVirtualDiskFlagNone VirtualDiskFlag = 0x00000000 OpenVirtualDiskFlagNone VirtualDiskFlag = 0x00000000
OpenVirtualDiskFlagNoParents VirtualDiskFlag = 0x00000001 OpenVirtualDiskFlagNoParents VirtualDiskFlag = 0x00000001
OpenVirtualDiskFlagBlankFile VirtualDiskFlag = 0x00000002 OpenVirtualDiskFlagBlankFile VirtualDiskFlag = 0x00000002
@ -127,7 +128,7 @@ const (
OpenVirtualDiskFlagNoWriteHardening VirtualDiskFlag = 0x00000100 OpenVirtualDiskFlagNoWriteHardening VirtualDiskFlag = 0x00000100
OpenVirtualDiskFlagSupportCompressedVolumes VirtualDiskFlag = 0x00000200 OpenVirtualDiskFlagSupportCompressedVolumes VirtualDiskFlag = 0x00000200
// Flags for attaching a VHD // Flags for attaching a VHD.
AttachVirtualDiskFlagNone AttachVirtualDiskFlag = 0x00000000 AttachVirtualDiskFlagNone AttachVirtualDiskFlag = 0x00000000
AttachVirtualDiskFlagReadOnly AttachVirtualDiskFlag = 0x00000001 AttachVirtualDiskFlagReadOnly AttachVirtualDiskFlag = 0x00000001
AttachVirtualDiskFlagNoDriveLetter AttachVirtualDiskFlag = 0x00000002 AttachVirtualDiskFlagNoDriveLetter AttachVirtualDiskFlag = 0x00000002
@ -140,12 +141,14 @@ const (
AttachVirtualDiskFlagSinglePartition AttachVirtualDiskFlag = 0x00000100 AttachVirtualDiskFlagSinglePartition AttachVirtualDiskFlag = 0x00000100
AttachVirtualDiskFlagRegisterVolume AttachVirtualDiskFlag = 0x00000200 AttachVirtualDiskFlagRegisterVolume AttachVirtualDiskFlag = 0x00000200
// Flags for detaching a VHD // Flags for detaching a VHD.
DetachVirtualDiskFlagNone DetachVirtualDiskFlag = 0x0 DetachVirtualDiskFlagNone DetachVirtualDiskFlag = 0x0
) )
// CreateVhdx is a helper function to create a simple vhdx file at the given path using // CreateVhdx is a helper function to create a simple vhdx file at the given path using
// default values. // default values.
//
//revive:disable-next-line:var-naming VHDX, not Vhdx
func CreateVhdx(path string, maxSizeInGb, blockSizeInMb uint32) error { func CreateVhdx(path string, maxSizeInGb, blockSizeInMb uint32) error {
params := CreateVirtualDiskParameters{ params := CreateVirtualDiskParameters{
Version: 2, Version: 2,
@ -172,6 +175,8 @@ func DetachVirtualDisk(handle syscall.Handle) (err error) {
} }
// DetachVhd detaches a vhd found at `path`. // DetachVhd detaches a vhd found at `path`.
//
//revive:disable-next-line:var-naming VHD, not Vhd
func DetachVhd(path string) error { func DetachVhd(path string) error {
handle, err := OpenVirtualDisk( handle, err := OpenVirtualDisk(
path, path,
@ -181,12 +186,16 @@ func DetachVhd(path string) error {
if err != nil { if err != nil {
return err return err
} }
defer syscall.CloseHandle(handle) defer syscall.CloseHandle(handle) //nolint:errcheck
return DetachVirtualDisk(handle) return DetachVirtualDisk(handle)
} }
// AttachVirtualDisk attaches a virtual hard disk for use. // AttachVirtualDisk attaches a virtual hard disk for use.
func AttachVirtualDisk(handle syscall.Handle, attachVirtualDiskFlag AttachVirtualDiskFlag, parameters *AttachVirtualDiskParameters) (err error) { func AttachVirtualDisk(
handle syscall.Handle,
attachVirtualDiskFlag AttachVirtualDiskFlag,
parameters *AttachVirtualDiskParameters,
) (err error) {
// Supports both version 1 and 2 of the attach parameters as version 2 wasn't present in RS5. // Supports both version 1 and 2 of the attach parameters as version 2 wasn't present in RS5.
if err := attachVirtualDisk( if err := attachVirtualDisk(
handle, handle,
@ -203,6 +212,8 @@ func AttachVirtualDisk(handle syscall.Handle, attachVirtualDiskFlag AttachVirtua
// AttachVhd attaches a virtual hard disk at `path` for use. Attaches using version 2 // AttachVhd attaches a virtual hard disk at `path` for use. Attaches using version 2
// of the ATTACH_VIRTUAL_DISK_PARAMETERS. // of the ATTACH_VIRTUAL_DISK_PARAMETERS.
//
//revive:disable-next-line:var-naming VHD, not Vhd
func AttachVhd(path string) (err error) { func AttachVhd(path string) (err error) {
handle, err := OpenVirtualDisk( handle, err := OpenVirtualDisk(
path, path,
@ -213,7 +224,7 @@ func AttachVhd(path string) (err error) {
return err return err
} }
defer syscall.CloseHandle(handle) defer syscall.CloseHandle(handle) //nolint:errcheck
params := AttachVirtualDiskParameters{Version: 2} params := AttachVirtualDiskParameters{Version: 2}
if err := AttachVirtualDisk( if err := AttachVirtualDisk(
handle, handle,
@ -226,7 +237,11 @@ func AttachVhd(path string) (err error) {
} }
// OpenVirtualDisk obtains a handle to a VHD opened with supplied access mask and flags. // OpenVirtualDisk obtains a handle to a VHD opened with supplied access mask and flags.
func OpenVirtualDisk(vhdPath string, virtualDiskAccessMask VirtualDiskAccessMask, openVirtualDiskFlags VirtualDiskFlag) (syscall.Handle, error) { func OpenVirtualDisk(
vhdPath string,
virtualDiskAccessMask VirtualDiskAccessMask,
openVirtualDiskFlags VirtualDiskFlag,
) (syscall.Handle, error) {
parameters := OpenVirtualDiskParameters{Version: 2} parameters := OpenVirtualDiskParameters{Version: 2}
handle, err := OpenVirtualDiskWithParameters( handle, err := OpenVirtualDiskWithParameters(
vhdPath, vhdPath,
@ -241,7 +256,12 @@ func OpenVirtualDisk(vhdPath string, virtualDiskAccessMask VirtualDiskAccessMask
} }
// OpenVirtualDiskWithParameters obtains a handle to a VHD opened with supplied access mask, flags and parameters. // OpenVirtualDiskWithParameters obtains a handle to a VHD opened with supplied access mask, flags and parameters.
func OpenVirtualDiskWithParameters(vhdPath string, virtualDiskAccessMask VirtualDiskAccessMask, openVirtualDiskFlags VirtualDiskFlag, parameters *OpenVirtualDiskParameters) (syscall.Handle, error) { func OpenVirtualDiskWithParameters(
vhdPath string,
virtualDiskAccessMask VirtualDiskAccessMask,
openVirtualDiskFlags VirtualDiskFlag,
parameters *OpenVirtualDiskParameters,
) (syscall.Handle, error) {
var ( var (
handle syscall.Handle handle syscall.Handle
defaultType VirtualStorageType defaultType VirtualStorageType
@ -279,7 +299,12 @@ func OpenVirtualDiskWithParameters(vhdPath string, virtualDiskAccessMask Virtual
} }
// CreateVirtualDisk creates a virtual harddisk and returns a handle to the disk. // CreateVirtualDisk creates a virtual harddisk and returns a handle to the disk.
func CreateVirtualDisk(path string, virtualDiskAccessMask VirtualDiskAccessMask, createVirtualDiskFlags CreateVirtualDiskFlag, parameters *CreateVirtualDiskParameters) (syscall.Handle, error) { func CreateVirtualDisk(
path string,
virtualDiskAccessMask VirtualDiskAccessMask,
createVirtualDiskFlags CreateVirtualDiskFlag,
parameters *CreateVirtualDiskParameters,
) (syscall.Handle, error) {
var ( var (
handle syscall.Handle handle syscall.Handle
defaultType VirtualStorageType defaultType VirtualStorageType
@ -323,6 +348,8 @@ func GetVirtualDiskPhysicalPath(handle syscall.Handle) (_ string, err error) {
} }
// CreateDiffVhd is a helper function to create a differencing virtual disk. // CreateDiffVhd is a helper function to create a differencing virtual disk.
//
//revive:disable-next-line:var-naming VHD, not Vhd
func CreateDiffVhd(diffVhdPath, baseVhdPath string, blockSizeInMB uint32) error { func CreateDiffVhd(diffVhdPath, baseVhdPath string, blockSizeInMB uint32) error {
// Setting `ParentPath` is how to signal to create a differencing disk. // Setting `ParentPath` is how to signal to create a differencing disk.
createParams := &CreateVirtualDiskParameters{ createParams := &CreateVirtualDiskParameters{

View File

@ -1,4 +1,6 @@
// Code generated by 'go generate'; DO NOT EDIT. //go:build windows
// Code generated by 'go generate' using "github.com/Microsoft/go-winio/tools/mkwinsyscall"; DO NOT EDIT.
package vhd package vhd

View File

@ -1,4 +1,6 @@
// Code generated by 'go generate'; DO NOT EDIT. //go:build windows
// Code generated by 'go generate' using "github.com/Microsoft/go-winio/tools/mkwinsyscall"; DO NOT EDIT.
package winio package winio
@ -47,9 +49,11 @@ var (
procConvertSecurityDescriptorToStringSecurityDescriptorW = modadvapi32.NewProc("ConvertSecurityDescriptorToStringSecurityDescriptorW") procConvertSecurityDescriptorToStringSecurityDescriptorW = modadvapi32.NewProc("ConvertSecurityDescriptorToStringSecurityDescriptorW")
procConvertSidToStringSidW = modadvapi32.NewProc("ConvertSidToStringSidW") procConvertSidToStringSidW = modadvapi32.NewProc("ConvertSidToStringSidW")
procConvertStringSecurityDescriptorToSecurityDescriptorW = modadvapi32.NewProc("ConvertStringSecurityDescriptorToSecurityDescriptorW") procConvertStringSecurityDescriptorToSecurityDescriptorW = modadvapi32.NewProc("ConvertStringSecurityDescriptorToSecurityDescriptorW")
procConvertStringSidToSidW = modadvapi32.NewProc("ConvertStringSidToSidW")
procGetSecurityDescriptorLength = modadvapi32.NewProc("GetSecurityDescriptorLength") procGetSecurityDescriptorLength = modadvapi32.NewProc("GetSecurityDescriptorLength")
procImpersonateSelf = modadvapi32.NewProc("ImpersonateSelf") procImpersonateSelf = modadvapi32.NewProc("ImpersonateSelf")
procLookupAccountNameW = modadvapi32.NewProc("LookupAccountNameW") procLookupAccountNameW = modadvapi32.NewProc("LookupAccountNameW")
procLookupAccountSidW = modadvapi32.NewProc("LookupAccountSidW")
procLookupPrivilegeDisplayNameW = modadvapi32.NewProc("LookupPrivilegeDisplayNameW") procLookupPrivilegeDisplayNameW = modadvapi32.NewProc("LookupPrivilegeDisplayNameW")
procLookupPrivilegeNameW = modadvapi32.NewProc("LookupPrivilegeNameW") procLookupPrivilegeNameW = modadvapi32.NewProc("LookupPrivilegeNameW")
procLookupPrivilegeValueW = modadvapi32.NewProc("LookupPrivilegeValueW") procLookupPrivilegeValueW = modadvapi32.NewProc("LookupPrivilegeValueW")
@ -74,7 +78,6 @@ var (
procRtlDosPathNameToNtPathName_U = modntdll.NewProc("RtlDosPathNameToNtPathName_U") procRtlDosPathNameToNtPathName_U = modntdll.NewProc("RtlDosPathNameToNtPathName_U")
procRtlNtStatusToDosErrorNoTeb = modntdll.NewProc("RtlNtStatusToDosErrorNoTeb") procRtlNtStatusToDosErrorNoTeb = modntdll.NewProc("RtlNtStatusToDosErrorNoTeb")
procWSAGetOverlappedResult = modws2_32.NewProc("WSAGetOverlappedResult") procWSAGetOverlappedResult = modws2_32.NewProc("WSAGetOverlappedResult")
procbind = modws2_32.NewProc("bind")
) )
func adjustTokenPrivileges(token windows.Token, releaseAll bool, input *byte, outputSize uint32, output *byte, requiredSize *uint32) (success bool, err error) { func adjustTokenPrivileges(token windows.Token, releaseAll bool, input *byte, outputSize uint32, output *byte, requiredSize *uint32) (success bool, err error) {
@ -123,6 +126,14 @@ func _convertStringSecurityDescriptorToSecurityDescriptor(str *uint16, revision
return return
} }
func convertStringSidToSid(str *uint16, sid **byte) (err error) {
r1, _, e1 := syscall.Syscall(procConvertStringSidToSidW.Addr(), 2, uintptr(unsafe.Pointer(str)), uintptr(unsafe.Pointer(sid)), 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func getSecurityDescriptorLength(sd uintptr) (len uint32) { func getSecurityDescriptorLength(sd uintptr) (len uint32) {
r0, _, _ := syscall.Syscall(procGetSecurityDescriptorLength.Addr(), 1, uintptr(sd), 0, 0) r0, _, _ := syscall.Syscall(procGetSecurityDescriptorLength.Addr(), 1, uintptr(sd), 0, 0)
len = uint32(r0) len = uint32(r0)
@ -154,6 +165,14 @@ func _lookupAccountName(systemName *uint16, accountName *uint16, sid *byte, sidS
return return
} }
func lookupAccountSid(systemName *uint16, sid *byte, name *uint16, nameSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) {
r1, _, e1 := syscall.Syscall9(procLookupAccountSidW.Addr(), 7, uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(nameSize)), uintptr(unsafe.Pointer(refDomain)), uintptr(unsafe.Pointer(refDomainSize)), uintptr(unsafe.Pointer(sidNameUse)), 0, 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func lookupPrivilegeDisplayName(systemName string, name *uint16, buffer *uint16, size *uint32, languageId *uint32) (err error) { func lookupPrivilegeDisplayName(systemName string, name *uint16, buffer *uint16, size *uint32, languageId *uint32) (err error) {
var _p0 *uint16 var _p0 *uint16
_p0, err = syscall.UTF16PtrFromString(systemName) _p0, err = syscall.UTF16PtrFromString(systemName)
@ -380,25 +399,25 @@ func setFileCompletionNotificationModes(h syscall.Handle, flags uint8) (err erro
return return
} }
func ntCreateNamedPipeFile(pipe *syscall.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) { func ntCreateNamedPipeFile(pipe *syscall.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntStatus) {
r0, _, _ := syscall.Syscall15(procNtCreateNamedPipeFile.Addr(), 14, uintptr(unsafe.Pointer(pipe)), uintptr(access), uintptr(unsafe.Pointer(oa)), uintptr(unsafe.Pointer(iosb)), uintptr(share), uintptr(disposition), uintptr(options), uintptr(typ), uintptr(readMode), uintptr(completionMode), uintptr(maxInstances), uintptr(inboundQuota), uintptr(outputQuota), uintptr(unsafe.Pointer(timeout)), 0) r0, _, _ := syscall.Syscall15(procNtCreateNamedPipeFile.Addr(), 14, uintptr(unsafe.Pointer(pipe)), uintptr(access), uintptr(unsafe.Pointer(oa)), uintptr(unsafe.Pointer(iosb)), uintptr(share), uintptr(disposition), uintptr(options), uintptr(typ), uintptr(readMode), uintptr(completionMode), uintptr(maxInstances), uintptr(inboundQuota), uintptr(outputQuota), uintptr(unsafe.Pointer(timeout)), 0)
status = ntstatus(r0) status = ntStatus(r0)
return return
} }
func rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) { func rtlDefaultNpAcl(dacl *uintptr) (status ntStatus) {
r0, _, _ := syscall.Syscall(procRtlDefaultNpAcl.Addr(), 1, uintptr(unsafe.Pointer(dacl)), 0, 0) r0, _, _ := syscall.Syscall(procRtlDefaultNpAcl.Addr(), 1, uintptr(unsafe.Pointer(dacl)), 0, 0)
status = ntstatus(r0) status = ntStatus(r0)
return return
} }
func rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntstatus) { func rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntStatus) {
r0, _, _ := syscall.Syscall6(procRtlDosPathNameToNtPathName_U.Addr(), 4, uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(ntName)), uintptr(filePart), uintptr(reserved), 0, 0) r0, _, _ := syscall.Syscall6(procRtlDosPathNameToNtPathName_U.Addr(), 4, uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(ntName)), uintptr(filePart), uintptr(reserved), 0, 0)
status = ntstatus(r0) status = ntStatus(r0)
return return
} }
func rtlNtStatusToDosError(status ntstatus) (winerr error) { func rtlNtStatusToDosError(status ntStatus) (winerr error) {
r0, _, _ := syscall.Syscall(procRtlNtStatusToDosErrorNoTeb.Addr(), 1, uintptr(status), 0, 0) r0, _, _ := syscall.Syscall(procRtlNtStatusToDosErrorNoTeb.Addr(), 1, uintptr(status), 0, 0)
if r0 != 0 { if r0 != 0 {
winerr = syscall.Errno(r0) winerr = syscall.Errno(r0)
@ -417,11 +436,3 @@ func wsaGetOverlappedResult(h syscall.Handle, o *syscall.Overlapped, bytes *uint
} }
return return
} }
func bind(s syscall.Handle, name unsafe.Pointer, namelen int32) (err error) {
r1, _, e1 := syscall.Syscall(procbind.Addr(), 3, uintptr(s), uintptr(name), uintptr(namelen))
if r1 == socketError {
err = errnoErr(e1)
}
return
}

View File

@ -3,8 +3,7 @@
[![Go Reference](https://pkg.go.dev/badge/github.com/cespare/xxhash/v2.svg)](https://pkg.go.dev/github.com/cespare/xxhash/v2) [![Go Reference](https://pkg.go.dev/badge/github.com/cespare/xxhash/v2.svg)](https://pkg.go.dev/github.com/cespare/xxhash/v2)
[![Test](https://github.com/cespare/xxhash/actions/workflows/test.yml/badge.svg)](https://github.com/cespare/xxhash/actions/workflows/test.yml) [![Test](https://github.com/cespare/xxhash/actions/workflows/test.yml/badge.svg)](https://github.com/cespare/xxhash/actions/workflows/test.yml)
xxhash is a Go implementation of the 64-bit xxhash is a Go implementation of the 64-bit [xxHash] algorithm, XXH64. This is a
[xxHash](http://cyan4973.github.io/xxHash/) algorithm, XXH64. This is a
high-quality hashing algorithm that is much faster than anything in the Go high-quality hashing algorithm that is much faster than anything in the Go
standard library. standard library.
@ -25,8 +24,11 @@ func (*Digest) WriteString(string) (int, error)
func (*Digest) Sum64() uint64 func (*Digest) Sum64() uint64
``` ```
This implementation provides a fast pure-Go implementation and an even faster The package is written with optimized pure Go and also contains even faster
assembly implementation for amd64. assembly implementations for amd64 and arm64. If desired, the `purego` build tag
opts into using the Go code even on those architectures.
[xxHash]: http://cyan4973.github.io/xxHash/
## Compatibility ## Compatibility
@ -46,18 +48,19 @@ Here are some quick benchmarks comparing the pure-Go and assembly
implementations of Sum64. implementations of Sum64.
| input size | purego | asm | | input size | purego | asm |
| --- | --- | --- | | ---------- | --------- | --------- |
| 5 B | 979.66 MB/s | 1291.17 MB/s | | 4 B | 1.3 GB/s | 1.2 GB/s |
| 100 B | 7475.26 MB/s | 7973.40 MB/s | | 16 B | 2.9 GB/s | 3.5 GB/s |
| 4 KB | 17573.46 MB/s | 17602.65 MB/s | | 100 B | 6.9 GB/s | 8.1 GB/s |
| 10 MB | 17131.46 MB/s | 17142.16 MB/s | | 4 KB | 11.7 GB/s | 16.7 GB/s |
| 10 MB | 12.0 GB/s | 17.3 GB/s |
These numbers were generated on Ubuntu 18.04 with an Intel i7-8700K CPU using These numbers were generated on Ubuntu 20.04 with an Intel Xeon Platinum 8252C
the following commands under Go 1.11.2: CPU using the following commands under Go 1.19.2:
``` ```
$ go test -tags purego -benchtime 10s -bench '/xxhash,direct,bytes' benchstat <(go test -tags purego -benchtime 500ms -count 15 -bench 'Sum64$')
$ go test -benchtime 10s -bench '/xxhash,direct,bytes' benchstat <(go test -benchtime 500ms -count 15 -bench 'Sum64$')
``` ```
## Projects using this package ## Projects using this package

10
vendor/github.com/cespare/xxhash/v2/testall.sh generated vendored Normal file
View File

@ -0,0 +1,10 @@
#!/bin/bash
set -eu -o pipefail
# Small convenience script for running the tests with various combinations of
# arch/tags. This assumes we're running on amd64 and have qemu available.
go test ./...
go test -tags purego ./...
GOARCH=arm64 go test
GOARCH=arm64 go test -tags purego

View File

@ -16,19 +16,11 @@ const (
prime5 uint64 = 2870177450012600261 prime5 uint64 = 2870177450012600261
) )
// NOTE(caleb): I'm using both consts and vars of the primes. Using consts where // Store the primes in an array as well.
// possible in the Go code is worth a small (but measurable) performance boost //
// by avoiding some MOVQs. Vars are needed for the asm and also are useful for // The consts are used when possible in Go code to avoid MOVs but we need a
// convenience in the Go code in a few places where we need to intentionally // contiguous array of the assembly code.
// avoid constant arithmetic (e.g., v1 := prime1 + prime2 fails because the var primes = [...]uint64{prime1, prime2, prime3, prime4, prime5}
// result overflows a uint64).
var (
prime1v = prime1
prime2v = prime2
prime3v = prime3
prime4v = prime4
prime5v = prime5
)
// Digest implements hash.Hash64. // Digest implements hash.Hash64.
type Digest struct { type Digest struct {
@ -50,10 +42,10 @@ func New() *Digest {
// Reset clears the Digest's state so that it can be reused. // Reset clears the Digest's state so that it can be reused.
func (d *Digest) Reset() { func (d *Digest) Reset() {
d.v1 = prime1v + prime2 d.v1 = primes[0] + prime2
d.v2 = prime2 d.v2 = prime2
d.v3 = 0 d.v3 = 0
d.v4 = -prime1v d.v4 = -primes[0]
d.total = 0 d.total = 0
d.n = 0 d.n = 0
} }
@ -69,21 +61,23 @@ func (d *Digest) Write(b []byte) (n int, err error) {
n = len(b) n = len(b)
d.total += uint64(n) d.total += uint64(n)
memleft := d.mem[d.n&(len(d.mem)-1):]
if d.n+n < 32 { if d.n+n < 32 {
// This new data doesn't even fill the current block. // This new data doesn't even fill the current block.
copy(d.mem[d.n:], b) copy(memleft, b)
d.n += n d.n += n
return return
} }
if d.n > 0 { if d.n > 0 {
// Finish off the partial block. // Finish off the partial block.
copy(d.mem[d.n:], b) c := copy(memleft, b)
d.v1 = round(d.v1, u64(d.mem[0:8])) d.v1 = round(d.v1, u64(d.mem[0:8]))
d.v2 = round(d.v2, u64(d.mem[8:16])) d.v2 = round(d.v2, u64(d.mem[8:16]))
d.v3 = round(d.v3, u64(d.mem[16:24])) d.v3 = round(d.v3, u64(d.mem[16:24]))
d.v4 = round(d.v4, u64(d.mem[24:32])) d.v4 = round(d.v4, u64(d.mem[24:32]))
b = b[32-d.n:] b = b[c:]
d.n = 0 d.n = 0
} }
@ -133,21 +127,20 @@ func (d *Digest) Sum64() uint64 {
h += d.total h += d.total
i, end := 0, d.n b := d.mem[:d.n&(len(d.mem)-1)]
for ; i+8 <= end; i += 8 { for ; len(b) >= 8; b = b[8:] {
k1 := round(0, u64(d.mem[i:i+8])) k1 := round(0, u64(b[:8]))
h ^= k1 h ^= k1
h = rol27(h)*prime1 + prime4 h = rol27(h)*prime1 + prime4
} }
if i+4 <= end { if len(b) >= 4 {
h ^= uint64(u32(d.mem[i:i+4])) * prime1 h ^= uint64(u32(b[:4])) * prime1
h = rol23(h)*prime2 + prime3 h = rol23(h)*prime2 + prime3
i += 4 b = b[4:]
} }
for i < end { for ; len(b) > 0; b = b[1:] {
h ^= uint64(d.mem[i]) * prime5 h ^= uint64(b[0]) * prime5
h = rol11(h) * prime1 h = rol11(h) * prime1
i++
} }
h ^= h >> 33 h ^= h >> 33

View File

@ -1,215 +1,209 @@
//go:build !appengine && gc && !purego
// +build !appengine // +build !appengine
// +build gc // +build gc
// +build !purego // +build !purego
#include "textflag.h" #include "textflag.h"
// Register allocation: // Registers:
// AX h #define h AX
// SI pointer to advance through b #define d AX
// DX n #define p SI // pointer to advance through b
// BX loop end #define n DX
// R8 v1, k1 #define end BX // loop end
// R9 v2 #define v1 R8
// R10 v3 #define v2 R9
// R11 v4 #define v3 R10
// R12 tmp #define v4 R11
// R13 prime1v #define x R12
// R14 prime2v #define prime1 R13
// DI prime4v #define prime2 R14
#define prime4 DI
// round reads from and advances the buffer pointer in SI. #define round(acc, x) \
// It assumes that R13 has prime1v and R14 has prime2v. IMULQ prime2, x \
#define round(r) \ ADDQ x, acc \
MOVQ (SI), R12 \ ROLQ $31, acc \
ADDQ $8, SI \ IMULQ prime1, acc
IMULQ R14, R12 \
ADDQ R12, r \
ROLQ $31, r \
IMULQ R13, r
// mergeRound applies a merge round on the two registers acc and val. // round0 performs the operation x = round(0, x).
// It assumes that R13 has prime1v, R14 has prime2v, and DI has prime4v. #define round0(x) \
#define mergeRound(acc, val) \ IMULQ prime2, x \
IMULQ R14, val \ ROLQ $31, x \
ROLQ $31, val \ IMULQ prime1, x
IMULQ R13, val \
XORQ val, acc \ // mergeRound applies a merge round on the two registers acc and x.
IMULQ R13, acc \ // It assumes that prime1, prime2, and prime4 have been loaded.
ADDQ DI, acc #define mergeRound(acc, x) \
round0(x) \
XORQ x, acc \
IMULQ prime1, acc \
ADDQ prime4, acc
// blockLoop processes as many 32-byte blocks as possible,
// updating v1, v2, v3, and v4. It assumes that there is at least one block
// to process.
#define blockLoop() \
loop: \
MOVQ +0(p), x \
round(v1, x) \
MOVQ +8(p), x \
round(v2, x) \
MOVQ +16(p), x \
round(v3, x) \
MOVQ +24(p), x \
round(v4, x) \
ADDQ $32, p \
CMPQ p, end \
JLE loop
// func Sum64(b []byte) uint64 // func Sum64(b []byte) uint64
TEXT ·Sum64(SB), NOSPLIT, $0-32 TEXT ·Sum64(SB), NOSPLIT|NOFRAME, $0-32
// Load fixed primes. // Load fixed primes.
MOVQ ·prime1v(SB), R13 MOVQ ·primes+0(SB), prime1
MOVQ ·prime2v(SB), R14 MOVQ ·primes+8(SB), prime2
MOVQ ·prime4v(SB), DI MOVQ ·primes+24(SB), prime4
// Load slice. // Load slice.
MOVQ b_base+0(FP), SI MOVQ b_base+0(FP), p
MOVQ b_len+8(FP), DX MOVQ b_len+8(FP), n
LEAQ (SI)(DX*1), BX LEAQ (p)(n*1), end
// The first loop limit will be len(b)-32. // The first loop limit will be len(b)-32.
SUBQ $32, BX SUBQ $32, end
// Check whether we have at least one block. // Check whether we have at least one block.
CMPQ DX, $32 CMPQ n, $32
JLT noBlocks JLT noBlocks
// Set up initial state (v1, v2, v3, v4). // Set up initial state (v1, v2, v3, v4).
MOVQ R13, R8 MOVQ prime1, v1
ADDQ R14, R8 ADDQ prime2, v1
MOVQ R14, R9 MOVQ prime2, v2
XORQ R10, R10 XORQ v3, v3
XORQ R11, R11 XORQ v4, v4
SUBQ R13, R11 SUBQ prime1, v4
// Loop until SI > BX. blockLoop()
blockLoop:
round(R8)
round(R9)
round(R10)
round(R11)
CMPQ SI, BX MOVQ v1, h
JLE blockLoop ROLQ $1, h
MOVQ v2, x
ROLQ $7, x
ADDQ x, h
MOVQ v3, x
ROLQ $12, x
ADDQ x, h
MOVQ v4, x
ROLQ $18, x
ADDQ x, h
MOVQ R8, AX mergeRound(h, v1)
ROLQ $1, AX mergeRound(h, v2)
MOVQ R9, R12 mergeRound(h, v3)
ROLQ $7, R12 mergeRound(h, v4)
ADDQ R12, AX
MOVQ R10, R12
ROLQ $12, R12
ADDQ R12, AX
MOVQ R11, R12
ROLQ $18, R12
ADDQ R12, AX
mergeRound(AX, R8)
mergeRound(AX, R9)
mergeRound(AX, R10)
mergeRound(AX, R11)
JMP afterBlocks JMP afterBlocks
noBlocks: noBlocks:
MOVQ ·prime5v(SB), AX MOVQ ·primes+32(SB), h
afterBlocks: afterBlocks:
ADDQ DX, AX ADDQ n, h
// Right now BX has len(b)-32, and we want to loop until SI > len(b)-8. ADDQ $24, end
ADDQ $24, BX CMPQ p, end
JG try4
CMPQ SI, BX loop8:
JG fourByte MOVQ (p), x
ADDQ $8, p
round0(x)
XORQ x, h
ROLQ $27, h
IMULQ prime1, h
ADDQ prime4, h
wordLoop: CMPQ p, end
// Calculate k1. JLE loop8
MOVQ (SI), R8
ADDQ $8, SI
IMULQ R14, R8
ROLQ $31, R8
IMULQ R13, R8
XORQ R8, AX try4:
ROLQ $27, AX ADDQ $4, end
IMULQ R13, AX CMPQ p, end
ADDQ DI, AX JG try1
CMPQ SI, BX MOVL (p), x
JLE wordLoop ADDQ $4, p
IMULQ prime1, x
XORQ x, h
fourByte: ROLQ $23, h
ADDQ $4, BX IMULQ prime2, h
CMPQ SI, BX ADDQ ·primes+16(SB), h
JG singles
MOVL (SI), R8 try1:
ADDQ $4, SI ADDQ $4, end
IMULQ R13, R8 CMPQ p, end
XORQ R8, AX
ROLQ $23, AX
IMULQ R14, AX
ADDQ ·prime3v(SB), AX
singles:
ADDQ $4, BX
CMPQ SI, BX
JGE finalize JGE finalize
singlesLoop: loop1:
MOVBQZX (SI), R12 MOVBQZX (p), x
ADDQ $1, SI ADDQ $1, p
IMULQ ·prime5v(SB), R12 IMULQ ·primes+32(SB), x
XORQ R12, AX XORQ x, h
ROLQ $11, h
IMULQ prime1, h
ROLQ $11, AX CMPQ p, end
IMULQ R13, AX JL loop1
CMPQ SI, BX
JL singlesLoop
finalize: finalize:
MOVQ AX, R12 MOVQ h, x
SHRQ $33, R12 SHRQ $33, x
XORQ R12, AX XORQ x, h
IMULQ R14, AX IMULQ prime2, h
MOVQ AX, R12 MOVQ h, x
SHRQ $29, R12 SHRQ $29, x
XORQ R12, AX XORQ x, h
IMULQ ·prime3v(SB), AX IMULQ ·primes+16(SB), h
MOVQ AX, R12 MOVQ h, x
SHRQ $32, R12 SHRQ $32, x
XORQ R12, AX XORQ x, h
MOVQ AX, ret+24(FP) MOVQ h, ret+24(FP)
RET RET
// writeBlocks uses the same registers as above except that it uses AX to store
// the d pointer.
// func writeBlocks(d *Digest, b []byte) int // func writeBlocks(d *Digest, b []byte) int
TEXT ·writeBlocks(SB), NOSPLIT, $0-40 TEXT ·writeBlocks(SB), NOSPLIT|NOFRAME, $0-40
// Load fixed primes needed for round. // Load fixed primes needed for round.
MOVQ ·prime1v(SB), R13 MOVQ ·primes+0(SB), prime1
MOVQ ·prime2v(SB), R14 MOVQ ·primes+8(SB), prime2
// Load slice. // Load slice.
MOVQ b_base+8(FP), SI MOVQ b_base+8(FP), p
MOVQ b_len+16(FP), DX MOVQ b_len+16(FP), n
LEAQ (SI)(DX*1), BX LEAQ (p)(n*1), end
SUBQ $32, BX SUBQ $32, end
// Load vN from d. // Load vN from d.
MOVQ d+0(FP), AX MOVQ s+0(FP), d
MOVQ 0(AX), R8 // v1 MOVQ 0(d), v1
MOVQ 8(AX), R9 // v2 MOVQ 8(d), v2
MOVQ 16(AX), R10 // v3 MOVQ 16(d), v3
MOVQ 24(AX), R11 // v4 MOVQ 24(d), v4
// We don't need to check the loop condition here; this function is // We don't need to check the loop condition here; this function is
// always called with at least one block of data to process. // always called with at least one block of data to process.
blockLoop: blockLoop()
round(R8)
round(R9)
round(R10)
round(R11)
CMPQ SI, BX
JLE blockLoop
// Copy vN back to d. // Copy vN back to d.
MOVQ R8, 0(AX) MOVQ v1, 0(d)
MOVQ R9, 8(AX) MOVQ v2, 8(d)
MOVQ R10, 16(AX) MOVQ v3, 16(d)
MOVQ R11, 24(AX) MOVQ v4, 24(d)
// The number of bytes written is SI minus the old base pointer. // The number of bytes written is p minus the old base pointer.
SUBQ b_base+8(FP), SI SUBQ b_base+8(FP), p
MOVQ SI, ret+32(FP) MOVQ p, ret+32(FP)
RET RET

183
vendor/github.com/cespare/xxhash/v2/xxhash_arm64.s generated vendored Normal file
View File

@ -0,0 +1,183 @@
//go:build !appengine && gc && !purego
// +build !appengine
// +build gc
// +build !purego
#include "textflag.h"
// Registers:
#define digest R1
#define h R2 // return value
#define p R3 // input pointer
#define n R4 // input length
#define nblocks R5 // n / 32
#define prime1 R7
#define prime2 R8
#define prime3 R9
#define prime4 R10
#define prime5 R11
#define v1 R12
#define v2 R13
#define v3 R14
#define v4 R15
#define x1 R20
#define x2 R21
#define x3 R22
#define x4 R23
#define round(acc, x) \
MADD prime2, acc, x, acc \
ROR $64-31, acc \
MUL prime1, acc
// round0 performs the operation x = round(0, x).
#define round0(x) \
MUL prime2, x \
ROR $64-31, x \
MUL prime1, x
#define mergeRound(acc, x) \
round0(x) \
EOR x, acc \
MADD acc, prime4, prime1, acc
// blockLoop processes as many 32-byte blocks as possible,
// updating v1, v2, v3, and v4. It assumes that n >= 32.
#define blockLoop() \
LSR $5, n, nblocks \
PCALIGN $16 \
loop: \
LDP.P 16(p), (x1, x2) \
LDP.P 16(p), (x3, x4) \
round(v1, x1) \
round(v2, x2) \
round(v3, x3) \
round(v4, x4) \
SUB $1, nblocks \
CBNZ nblocks, loop
// func Sum64(b []byte) uint64
TEXT ·Sum64(SB), NOSPLIT|NOFRAME, $0-32
LDP b_base+0(FP), (p, n)
LDP ·primes+0(SB), (prime1, prime2)
LDP ·primes+16(SB), (prime3, prime4)
MOVD ·primes+32(SB), prime5
CMP $32, n
CSEL LT, prime5, ZR, h // if n < 32 { h = prime5 } else { h = 0 }
BLT afterLoop
ADD prime1, prime2, v1
MOVD prime2, v2
MOVD $0, v3
NEG prime1, v4
blockLoop()
ROR $64-1, v1, x1
ROR $64-7, v2, x2
ADD x1, x2
ROR $64-12, v3, x3
ROR $64-18, v4, x4
ADD x3, x4
ADD x2, x4, h
mergeRound(h, v1)
mergeRound(h, v2)
mergeRound(h, v3)
mergeRound(h, v4)
afterLoop:
ADD n, h
TBZ $4, n, try8
LDP.P 16(p), (x1, x2)
round0(x1)
// NOTE: here and below, sequencing the EOR after the ROR (using a
// rotated register) is worth a small but measurable speedup for small
// inputs.
ROR $64-27, h
EOR x1 @> 64-27, h, h
MADD h, prime4, prime1, h
round0(x2)
ROR $64-27, h
EOR x2 @> 64-27, h, h
MADD h, prime4, prime1, h
try8:
TBZ $3, n, try4
MOVD.P 8(p), x1
round0(x1)
ROR $64-27, h
EOR x1 @> 64-27, h, h
MADD h, prime4, prime1, h
try4:
TBZ $2, n, try2
MOVWU.P 4(p), x2
MUL prime1, x2
ROR $64-23, h
EOR x2 @> 64-23, h, h
MADD h, prime3, prime2, h
try2:
TBZ $1, n, try1
MOVHU.P 2(p), x3
AND $255, x3, x1
LSR $8, x3, x2
MUL prime5, x1
ROR $64-11, h
EOR x1 @> 64-11, h, h
MUL prime1, h
MUL prime5, x2
ROR $64-11, h
EOR x2 @> 64-11, h, h
MUL prime1, h
try1:
TBZ $0, n, finalize
MOVBU (p), x4
MUL prime5, x4
ROR $64-11, h
EOR x4 @> 64-11, h, h
MUL prime1, h
finalize:
EOR h >> 33, h
MUL prime2, h
EOR h >> 29, h
MUL prime3, h
EOR h >> 32, h
MOVD h, ret+24(FP)
RET
// func writeBlocks(d *Digest, b []byte) int
TEXT ·writeBlocks(SB), NOSPLIT|NOFRAME, $0-40
LDP ·primes+0(SB), (prime1, prime2)
// Load state. Assume v[1-4] are stored contiguously.
MOVD d+0(FP), digest
LDP 0(digest), (v1, v2)
LDP 16(digest), (v3, v4)
LDP b_base+8(FP), (p, n)
blockLoop()
// Store updated state.
STP (v1, v2), 0(digest)
STP (v3, v4), 16(digest)
BIC $31, n
MOVD n, ret+32(FP)
RET

View File

@ -1,3 +1,5 @@
//go:build (amd64 || arm64) && !appengine && gc && !purego
// +build amd64 arm64
// +build !appengine // +build !appengine
// +build gc // +build gc
// +build !purego // +build !purego

View File

@ -1,4 +1,5 @@
// +build !amd64 appengine !gc purego //go:build (!amd64 && !arm64) || appengine || !gc || purego
// +build !amd64,!arm64 appengine !gc purego
package xxhash package xxhash
@ -14,10 +15,10 @@ func Sum64(b []byte) uint64 {
var h uint64 var h uint64
if n >= 32 { if n >= 32 {
v1 := prime1v + prime2 v1 := primes[0] + prime2
v2 := prime2 v2 := prime2
v3 := uint64(0) v3 := uint64(0)
v4 := -prime1v v4 := -primes[0]
for len(b) >= 32 { for len(b) >= 32 {
v1 = round(v1, u64(b[0:8:len(b)])) v1 = round(v1, u64(b[0:8:len(b)]))
v2 = round(v2, u64(b[8:16:len(b)])) v2 = round(v2, u64(b[8:16:len(b)]))
@ -36,19 +37,18 @@ func Sum64(b []byte) uint64 {
h += uint64(n) h += uint64(n)
i, end := 0, len(b) for ; len(b) >= 8; b = b[8:] {
for ; i+8 <= end; i += 8 { k1 := round(0, u64(b[:8]))
k1 := round(0, u64(b[i:i+8:len(b)]))
h ^= k1 h ^= k1
h = rol27(h)*prime1 + prime4 h = rol27(h)*prime1 + prime4
} }
if i+4 <= end { if len(b) >= 4 {
h ^= uint64(u32(b[i:i+4:len(b)])) * prime1 h ^= uint64(u32(b[:4])) * prime1
h = rol23(h)*prime2 + prime3 h = rol23(h)*prime2 + prime3
i += 4 b = b[4:]
} }
for ; i < end; i++ { for ; len(b) > 0; b = b[1:] {
h ^= uint64(b[i]) * prime5 h ^= uint64(b[0]) * prime5
h = rol11(h) * prime1 h = rol11(h) * prime1
} }

View File

@ -1,3 +1,4 @@
//go:build appengine
// +build appengine // +build appengine
// This file contains the safe implementations of otherwise unsafe-using code. // This file contains the safe implementations of otherwise unsafe-using code.

View File

@ -1,3 +1,4 @@
//go:build !appengine
// +build !appengine // +build !appengine
// This file encapsulates usage of unsafe. // This file encapsulates usage of unsafe.
@ -11,7 +12,7 @@ import (
// In the future it's possible that compiler optimizations will make these // In the future it's possible that compiler optimizations will make these
// XxxString functions unnecessary by realizing that calls such as // XxxString functions unnecessary by realizing that calls such as
// Sum64([]byte(s)) don't need to copy s. See https://golang.org/issue/2205. // Sum64([]byte(s)) don't need to copy s. See https://go.dev/issue/2205.
// If that happens, even if we keep these functions they can be replaced with // If that happens, even if we keep these functions they can be replaced with
// the trivial safe code. // the trivial safe code.

View File

@ -181,6 +181,6 @@ func List(helper Helper, writer io.Writer) error {
// PrintVersion outputs the current version. // PrintVersion outputs the current version.
func PrintVersion(writer io.Writer) error { func PrintVersion(writer io.Writer) error {
fmt.Fprintln(writer, Version) fmt.Fprintf(writer, "%s (%s) %s\n", Name, Package, Version)
return nil return nil
} }

View File

@ -1,4 +1,16 @@
package credentials package credentials
// Version holds a string describing the current version var (
const Version = "0.6.4" // Name is filled at linking time
Name = ""
// Package is filled at linking time
Package = "github.com/docker/docker-credential-helpers"
// Version holds the complete version number. Filled in at linking time.
Version = "v0.0.0+unknown"
// Revision is filled with the VCS (e.g. git) revision being used to build
// the program at linking time.
Revision = ""
)

View File

@ -24,7 +24,7 @@ info:
title: "Docker Engine API" title: "Docker Engine API"
version: "1.41" version: "1.41"
x-logo: x-logo:
url: "https://docs.docker.com/images/logo-docker-main.png" url: "https://docs.docker.com/assets/images/logo-docker-main.png"
description: | description: |
The Engine API is an HTTP API served by Docker Engine. It is the API the The Engine API is an HTTP API served by Docker Engine. It is the API the
Docker client uses to communicate with the Engine, so everything the Docker Docker client uses to communicate with the Engine, so everything the Docker
@ -1891,23 +1891,52 @@ definitions:
BuildCache: BuildCache:
type: "object" type: "object"
description: |
BuildCache contains information about a build cache record.
properties: properties:
ID: ID:
type: "string" type: "string"
description: |
Unique ID of the build cache record.
example: "ndlpt0hhvkqcdfkputsk4cq9c"
Parent: Parent:
description: |
ID of the parent build cache record.
type: "string" type: "string"
example: "hw53o5aio51xtltp5xjp8v7fx"
Type: Type:
type: "string" type: "string"
description: |
Cache record type.
example: "regular"
# see https://github.com/moby/buildkit/blob/fce4a32258dc9d9664f71a4831d5de10f0670677/client/diskusage.go#L75-L84
enum:
- "internal"
- "frontend"
- "source.local"
- "source.git.checkout"
- "exec.cachemount"
- "regular"
Description: Description:
type: "string" type: "string"
description: |
Description of the build-step that produced the build cache.
example: "mount / from exec /bin/sh -c echo 'Binary::apt::APT::Keep-Downloaded-Packages \"true\";' > /etc/apt/apt.conf.d/keep-cache"
InUse: InUse:
type: "boolean" type: "boolean"
description: |
Indicates if the build cache is in use.
example: false
Shared: Shared:
type: "boolean" type: "boolean"
description: |
Indicates if the build cache is shared.
example: true
Size: Size:
description: | description: |
Amount of disk space used by the build cache (in bytes). Amount of disk space used by the build cache (in bytes).
type: "integer" type: "integer"
example: 51
CreatedAt: CreatedAt:
description: | description: |
Date and time at which the build cache was created in Date and time at which the build cache was created in
@ -1925,6 +1954,7 @@ definitions:
example: "2017-08-09T07:09:37.632105588Z" example: "2017-08-09T07:09:37.632105588Z"
UsageCount: UsageCount:
type: "integer" type: "integer"
example: 26
ImageID: ImageID:
type: "object" type: "object"
@ -5415,6 +5445,28 @@ paths:
`/?[a-zA-Z0-9][a-zA-Z0-9_.-]+`. `/?[a-zA-Z0-9][a-zA-Z0-9_.-]+`.
type: "string" type: "string"
pattern: "^/?[a-zA-Z0-9][a-zA-Z0-9_.-]+$" pattern: "^/?[a-zA-Z0-9][a-zA-Z0-9_.-]+$"
- name: "platform"
in: "query"
description: |
Platform in the format `os[/arch[/variant]]` used for image lookup.
When specified, the daemon checks if the requested image is present
in the local image cache with the given OS and Architecture, and
otherwise returns a `404` status.
If the option is not set, the host's native OS and Architecture are
used to look up the image in the image cache. However, if no platform
is passed and the given image does exist in the local image cache,
but its OS or architecture does not match, the container is created
with the available image, and a warning is added to the `Warnings`
field in the response, for example;
WARNING: The requested image's platform (linux/arm64/v8) does not
match the detected host platform (linux/amd64) and no
specific platform was requested
type: "string"
default: ""
- name: "body" - name: "body"
in: "body" in: "body"
description: "Container to create" description: "Container to create"

View File

@ -135,9 +135,6 @@ func NewClientWithOpts(ops ...Opt) (*Client, error) {
} }
} }
if _, ok := c.client.Transport.(http.RoundTripper); !ok {
return nil, fmt.Errorf("unable to verify TLS configuration, invalid transport %v", c.client.Transport)
}
if c.scheme == "" { if c.scheme == "" {
c.scheme = "http" c.scheme = "http"

View File

@ -150,12 +150,10 @@ func (cli *Client) doRequest(ctx context.Context, req *http.Request) (serverResp
if err.Timeout() { if err.Timeout() {
return serverResp, ErrorConnectionFailed(cli.host) return serverResp, ErrorConnectionFailed(cli.host)
} }
if !err.Temporary() {
if strings.Contains(err.Error(), "connection refused") || strings.Contains(err.Error(), "dial unix") { if strings.Contains(err.Error(), "connection refused") || strings.Contains(err.Error(), "dial unix") {
return serverResp, ErrorConnectionFailed(cli.host) return serverResp, ErrorConnectionFailed(cli.host)
} }
} }
}
// Although there's not a strongly typed error for this in go-winio, // Although there's not a strongly typed error for this in go-winio,
// lots of people are using the default configuration for the docker // lots of people are using the default configuration for the docker
@ -242,11 +240,9 @@ func (cli *Client) addHeaders(req *http.Request, headers headers) *http.Request
req.Header.Set(k, v) req.Header.Set(k, v)
} }
if headers != nil {
for k, v := range headers { for k, v := range headers {
req.Header[k] = v req.Header[k] = v
} }
}
return req return req
} }

View File

@ -2,7 +2,6 @@ package units
import ( import (
"fmt" "fmt"
"regexp"
"strconv" "strconv"
"strings" "strings"
) )
@ -26,16 +25,17 @@ const (
PiB = 1024 * TiB PiB = 1024 * TiB
) )
type unitMap map[string]int64 type unitMap map[byte]int64
var ( var (
decimalMap = unitMap{"k": KB, "m": MB, "g": GB, "t": TB, "p": PB} decimalMap = unitMap{'k': KB, 'm': MB, 'g': GB, 't': TB, 'p': PB}
binaryMap = unitMap{"k": KiB, "m": MiB, "g": GiB, "t": TiB, "p": PiB} binaryMap = unitMap{'k': KiB, 'm': MiB, 'g': GiB, 't': TiB, 'p': PiB}
sizeRegex = regexp.MustCompile(`^(\d+(\.\d+)*) ?([kKmMgGtTpP])?[iI]?[bB]?$`)
) )
var decimapAbbrs = []string{"B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"} var (
var binaryAbbrs = []string{"B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB", "YiB"} decimapAbbrs = []string{"B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"}
binaryAbbrs = []string{"B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB", "YiB"}
)
func getSizeAndUnit(size float64, base float64, _map []string) (float64, string) { func getSizeAndUnit(size float64, base float64, _map []string) (float64, string) {
i := 0 i := 0
@ -89,20 +89,66 @@ func RAMInBytes(size string) (int64, error) {
// Parses the human-readable size string into the amount it represents. // Parses the human-readable size string into the amount it represents.
func parseSize(sizeStr string, uMap unitMap) (int64, error) { func parseSize(sizeStr string, uMap unitMap) (int64, error) {
matches := sizeRegex.FindStringSubmatch(sizeStr) // TODO: rewrite to use strings.Cut if there's a space
if len(matches) != 4 { // once Go < 1.18 is deprecated.
sep := strings.LastIndexAny(sizeStr, "01234567890. ")
if sep == -1 {
// There should be at least a digit.
return -1, fmt.Errorf("invalid size: '%s'", sizeStr) return -1, fmt.Errorf("invalid size: '%s'", sizeStr)
} }
var num, sfx string
if sizeStr[sep] != ' ' {
num = sizeStr[:sep+1]
sfx = sizeStr[sep+1:]
} else {
// Omit the space separator.
num = sizeStr[:sep]
sfx = sizeStr[sep+1:]
}
size, err := strconv.ParseFloat(matches[1], 64) size, err := strconv.ParseFloat(num, 64)
if err != nil { if err != nil {
return -1, err return -1, err
} }
// Backward compatibility: reject negative sizes.
if size < 0 {
return -1, fmt.Errorf("invalid size: '%s'", sizeStr)
}
unitPrefix := strings.ToLower(matches[3]) if len(sfx) == 0 {
if mul, ok := uMap[unitPrefix]; ok { return int64(size), nil
}
// Process the suffix.
if len(sfx) > 3 { // Too long.
goto badSuffix
}
sfx = strings.ToLower(sfx)
// Trivial case: b suffix.
if sfx[0] == 'b' {
if len(sfx) > 1 { // no extra characters allowed after b.
goto badSuffix
}
return int64(size), nil
}
// A suffix from the map.
if mul, ok := uMap[sfx[0]]; ok {
size *= float64(mul) size *= float64(mul)
} else {
goto badSuffix
}
// The suffix may have extra "b" or "ib" (e.g. KiB or MB).
switch {
case len(sfx) == 2 && sfx[1] != 'b':
goto badSuffix
case len(sfx) == 3 && sfx[1:] != "ib":
goto badSuffix
} }
return int64(size), nil return int64(size), nil
badSuffix:
return -1, fmt.Errorf("invalid suffix: '%s'", sfx)
} }

View File

@ -15,15 +15,13 @@
package name package name
import ( import (
_ "crypto/sha256" // Recommended by go-digest.
"strings" "strings"
"github.com/opencontainers/go-digest"
) )
const ( const digestDelim = "@"
// These have the form: sha256:<hex string>
// TODO(dekkagaijin): replace with opencontainers/go-digest or docker/distribution's validation.
digestChars = "sh:0123456789abcdef"
digestDelim = "@"
)
// Digest stores a digest name in a structured form. // Digest stores a digest name in a structured form.
type Digest struct { type Digest struct {
@ -60,10 +58,6 @@ func (d Digest) String() string {
return d.original return d.original
} }
func checkDigest(name string) error {
return checkElement("digest", name, digestChars, 7+64, 7+64)
}
// NewDigest returns a new Digest representing the given name. // NewDigest returns a new Digest representing the given name.
func NewDigest(name string, opts ...Option) (Digest, error) { func NewDigest(name string, opts ...Option) (Digest, error) {
// Split on "@" // Split on "@"
@ -72,10 +66,13 @@ func NewDigest(name string, opts ...Option) (Digest, error) {
return Digest{}, newErrBadName("a digest must contain exactly one '@' separator (e.g. registry/repository@digest) saw: %s", name) return Digest{}, newErrBadName("a digest must contain exactly one '@' separator (e.g. registry/repository@digest) saw: %s", name)
} }
base := parts[0] base := parts[0]
digest := parts[1] dig := parts[1]
prefix := digest.Canonical.String() + ":"
// Always check that the digest is valid. if !strings.HasPrefix(dig, prefix) {
if err := checkDigest(digest); err != nil { return Digest{}, newErrBadName("unsupported digest algorithm: %s", dig)
}
hex := strings.TrimPrefix(dig, prefix)
if err := digest.Canonical.Validate(hex); err != nil {
return Digest{}, err return Digest{}, err
} }
@ -90,7 +87,7 @@ func NewDigest(name string, opts ...Option) (Digest, error) {
} }
return Digest{ return Digest{
Repository: repo, Repository: repo,
digest: digest, digest: dig,
original: name, original: name,
}, nil }, nil
} }

View File

@ -28,8 +28,14 @@ func (e *ErrBadName) Error() string {
return e.info return e.info
} }
// Is reports whether target is an error of type ErrBadName
func (e *ErrBadName) Is(target error) bool {
var berr *ErrBadName
return errors.As(target, &berr)
}
// newErrBadName returns a ErrBadName which returns the given formatted string from Error(). // newErrBadName returns a ErrBadName which returns the given formatted string from Error().
func newErrBadName(fmtStr string, args ...interface{}) *ErrBadName { func newErrBadName(fmtStr string, args ...any) *ErrBadName {
return &ErrBadName{fmt.Sprintf(fmtStr, args...)} return &ErrBadName{fmt.Sprintf(fmtStr, args...)}
} }

View File

@ -1,4 +1,192 @@
Copyright 2014 Alan Shreve Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2022 Alan Shreve (@inconshreveable)
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.

View File

@ -17,6 +17,23 @@ This package provides various compression algorithms.
# changelog # changelog
* Sept 16, 2022 (v1.15.10)
* zstd: Add [WithDecodeAllCapLimit](https://pkg.go.dev/github.com/klauspost/compress@v1.15.10/zstd#WithDecodeAllCapLimit) https://github.com/klauspost/compress/pull/649
* Add Go 1.19 - deprecate Go 1.16 https://github.com/klauspost/compress/pull/651
* flate: Improve level 5+6 compression https://github.com/klauspost/compress/pull/656
* zstd: Improve "better" compresssion https://github.com/klauspost/compress/pull/657
* s2: Improve "best" compression https://github.com/klauspost/compress/pull/658
* s2: Improve "better" compression. https://github.com/klauspost/compress/pull/635
* s2: Slightly faster non-assembly decompression https://github.com/klauspost/compress/pull/646
* Use arrays for constant size copies https://github.com/klauspost/compress/pull/659
* July 21, 2022 (v1.15.9)
* zstd: Fix decoder crash on amd64 (no BMI) on invalid input https://github.com/klauspost/compress/pull/645
* zstd: Disable decoder extended memory copies (amd64) due to possible crashes https://github.com/klauspost/compress/pull/644
* zstd: Allow single segments up to "max decoded size" by @klauspost in https://github.com/klauspost/compress/pull/643
* July 13, 2022 (v1.15.8) * July 13, 2022 (v1.15.8)
* gzip: fix stack exhaustion bug in Reader.Read https://github.com/klauspost/compress/pull/641 * gzip: fix stack exhaustion bug in Reader.Read https://github.com/klauspost/compress/pull/641
@ -91,14 +108,14 @@ This package provides various compression algorithms.
* gzhttp: Add zstd to transport by @klauspost in [#400](https://github.com/klauspost/compress/pull/400) * gzhttp: Add zstd to transport by @klauspost in [#400](https://github.com/klauspost/compress/pull/400)
* gzhttp: Make content-type optional by @klauspost in [#510](https://github.com/klauspost/compress/pull/510) * gzhttp: Make content-type optional by @klauspost in [#510](https://github.com/klauspost/compress/pull/510)
<details>
<summary>See Details</summary>
Both compression and decompression now supports "synchronous" stream operations. This means that whenever "concurrency" is set to 1, they will operate without spawning goroutines. Both compression and decompression now supports "synchronous" stream operations. This means that whenever "concurrency" is set to 1, they will operate without spawning goroutines.
Stream decompression is now faster on asynchronous, since the goroutine allocation much more effectively splits the workload. On typical streams this will typically use 2 cores fully for decompression. When a stream has finished decoding no goroutines will be left over, so decoders can now safely be pooled and still be garbage collected. Stream decompression is now faster on asynchronous, since the goroutine allocation much more effectively splits the workload. On typical streams this will typically use 2 cores fully for decompression. When a stream has finished decoding no goroutines will be left over, so decoders can now safely be pooled and still be garbage collected.
While the release has been extensively tested, it is recommended to testing when upgrading. While the release has been extensively tested, it is recommended to testing when upgrading.
</details>
<details>
<summary>See changes to v1.14.x</summary>
* Feb 22, 2022 (v1.14.4) * Feb 22, 2022 (v1.14.4)
* flate: Fix rare huffman only (-2) corruption. [#503](https://github.com/klauspost/compress/pull/503) * flate: Fix rare huffman only (-2) corruption. [#503](https://github.com/klauspost/compress/pull/503)
@ -125,6 +142,7 @@ While the release has been extensively tested, it is recommended to testing when
* zstd: Performance improvement in [#420]( https://github.com/klauspost/compress/pull/420) [#456](https://github.com/klauspost/compress/pull/456) [#437](https://github.com/klauspost/compress/pull/437) [#467](https://github.com/klauspost/compress/pull/467) [#468](https://github.com/klauspost/compress/pull/468) * zstd: Performance improvement in [#420]( https://github.com/klauspost/compress/pull/420) [#456](https://github.com/klauspost/compress/pull/456) [#437](https://github.com/klauspost/compress/pull/437) [#467](https://github.com/klauspost/compress/pull/467) [#468](https://github.com/klauspost/compress/pull/468)
* zstd: add arm64 xxhash assembly in [#464](https://github.com/klauspost/compress/pull/464) * zstd: add arm64 xxhash assembly in [#464](https://github.com/klauspost/compress/pull/464)
* Add garbled for binaries for s2 in [#445](https://github.com/klauspost/compress/pull/445) * Add garbled for binaries for s2 in [#445](https://github.com/klauspost/compress/pull/445)
</details>
<details> <details>
<summary>See changes to v1.13.x</summary> <summary>See changes to v1.13.x</summary>

View File

@ -131,7 +131,8 @@ func (d *compressor) fillDeflate(b []byte) int {
s := d.state s := d.state
if s.index >= 2*windowSize-(minMatchLength+maxMatchLength) { if s.index >= 2*windowSize-(minMatchLength+maxMatchLength) {
// shift the window by windowSize // shift the window by windowSize
copy(d.window[:], d.window[windowSize:2*windowSize]) //copy(d.window[:], d.window[windowSize:2*windowSize])
*(*[windowSize]byte)(d.window) = *(*[windowSize]byte)(d.window[windowSize:])
s.index -= windowSize s.index -= windowSize
d.windowEnd -= windowSize d.windowEnd -= windowSize
if d.blockStart >= windowSize { if d.blockStart >= windowSize {
@ -373,6 +374,12 @@ func hash4(b []byte) uint32 {
return hash4u(binary.LittleEndian.Uint32(b), hashBits) return hash4u(binary.LittleEndian.Uint32(b), hashBits)
} }
// hash4 returns the hash of u to fit in a hash table with h bits.
// Preferably h should be a constant and should always be <32.
func hash4u(u uint32, h uint8) uint32 {
return (u * prime4bytes) >> (32 - h)
}
// bulkHash4 will compute hashes using the same // bulkHash4 will compute hashes using the same
// algorithm as hash4 // algorithm as hash4
func bulkHash4(b []byte, dst []uint32) { func bulkHash4(b []byte, dst []uint32) {

View File

@ -7,13 +7,13 @@ package flate
// dictDecoder implements the LZ77 sliding dictionary as used in decompression. // dictDecoder implements the LZ77 sliding dictionary as used in decompression.
// LZ77 decompresses data through sequences of two forms of commands: // LZ77 decompresses data through sequences of two forms of commands:
// //
// * Literal insertions: Runs of one or more symbols are inserted into the data // - Literal insertions: Runs of one or more symbols are inserted into the data
// stream as is. This is accomplished through the writeByte method for a // stream as is. This is accomplished through the writeByte method for a
// single symbol, or combinations of writeSlice/writeMark for multiple symbols. // single symbol, or combinations of writeSlice/writeMark for multiple symbols.
// Any valid stream must start with a literal insertion if no preset dictionary // Any valid stream must start with a literal insertion if no preset dictionary
// is used. // is used.
// //
// * Backward copies: Runs of one or more symbols are copied from previously // - Backward copies: Runs of one or more symbols are copied from previously
// emitted data. Backward copies come as the tuple (dist, length) where dist // emitted data. Backward copies come as the tuple (dist, length) where dist
// determines how far back in the stream to copy from and length determines how // determines how far back in the stream to copy from and length determines how
// many bytes to copy. Note that it is valid for the length to be greater than // many bytes to copy. Note that it is valid for the length to be greater than

View File

@ -58,17 +58,6 @@ const (
prime8bytes = 0xcf1bbcdcb7a56463 prime8bytes = 0xcf1bbcdcb7a56463
) )
func load32(b []byte, i int) uint32 {
// Help the compiler eliminate bounds checks on the read so it can be done in a single read.
b = b[i:]
b = b[:4]
return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
}
func load64(b []byte, i int) uint64 {
return binary.LittleEndian.Uint64(b[i:])
}
func load3232(b []byte, i int32) uint32 { func load3232(b []byte, i int32) uint32 {
return binary.LittleEndian.Uint32(b[i:]) return binary.LittleEndian.Uint32(b[i:])
} }
@ -77,10 +66,6 @@ func load6432(b []byte, i int32) uint64 {
return binary.LittleEndian.Uint64(b[i:]) return binary.LittleEndian.Uint64(b[i:])
} }
func hash(u uint32) uint32 {
return (u * 0x1e35a7bd) >> tableShift
}
type tableEntry struct { type tableEntry struct {
offset int32 offset int32
} }
@ -104,7 +89,8 @@ func (e *fastGen) addBlock(src []byte) int32 {
} }
// Move down // Move down
offset := int32(len(e.hist)) - maxMatchOffset offset := int32(len(e.hist)) - maxMatchOffset
copy(e.hist[0:maxMatchOffset], e.hist[offset:]) // copy(e.hist[0:maxMatchOffset], e.hist[offset:])
*(*[maxMatchOffset]byte)(e.hist) = *(*[maxMatchOffset]byte)(e.hist[offset:])
e.cur += offset e.cur += offset
e.hist = e.hist[:maxMatchOffset] e.hist = e.hist[:maxMatchOffset]
} }
@ -114,39 +100,36 @@ func (e *fastGen) addBlock(src []byte) int32 {
return s return s
} }
// hash4 returns the hash of u to fit in a hash table with h bits.
// Preferably h should be a constant and should always be <32.
func hash4u(u uint32, h uint8) uint32 {
return (u * prime4bytes) >> (32 - h)
}
type tableEntryPrev struct { type tableEntryPrev struct {
Cur tableEntry Cur tableEntry
Prev tableEntry Prev tableEntry
} }
// hash4x64 returns the hash of the lowest 4 bytes of u to fit in a hash table with h bits.
// Preferably h should be a constant and should always be <32.
func hash4x64(u uint64, h uint8) uint32 {
return (uint32(u) * prime4bytes) >> ((32 - h) & reg8SizeMask32)
}
// hash7 returns the hash of the lowest 7 bytes of u to fit in a hash table with h bits. // hash7 returns the hash of the lowest 7 bytes of u to fit in a hash table with h bits.
// Preferably h should be a constant and should always be <64. // Preferably h should be a constant and should always be <64.
func hash7(u uint64, h uint8) uint32 { func hash7(u uint64, h uint8) uint32 {
return uint32(((u << (64 - 56)) * prime7bytes) >> ((64 - h) & reg8SizeMask64)) return uint32(((u << (64 - 56)) * prime7bytes) >> ((64 - h) & reg8SizeMask64))
} }
// hash8 returns the hash of u to fit in a hash table with h bits. // hashLen returns a hash of the lowest mls bytes of with length output bits.
// Preferably h should be a constant and should always be <64. // mls must be >=3 and <=8. Any other value will return hash for 4 bytes.
func hash8(u uint64, h uint8) uint32 { // length should always be < 32.
return uint32((u * prime8bytes) >> ((64 - h) & reg8SizeMask64)) // Preferably length and mls should be a constant for inlining.
func hashLen(u uint64, length, mls uint8) uint32 {
switch mls {
case 3:
return (uint32(u<<8) * prime3bytes) >> (32 - length)
case 5:
return uint32(((u << (64 - 40)) * prime5bytes) >> (64 - length))
case 6:
return uint32(((u << (64 - 48)) * prime6bytes) >> (64 - length))
case 7:
return uint32(((u << (64 - 56)) * prime7bytes) >> (64 - length))
case 8:
return uint32((u * prime8bytes) >> (64 - length))
default:
return (uint32(u) * prime4bytes) >> (32 - length)
} }
// hash6 returns the hash of the lowest 6 bytes of u to fit in a hash table with h bits.
// Preferably h should be a constant and should always be <64.
func hash6(u uint64, h uint8) uint32 {
return uint32(((u << (64 - 48)) * prime6bytes) >> ((64 - h) & reg8SizeMask64))
} }
// matchlen will return the match length between offsets and t in src. // matchlen will return the match length between offsets and t in src.

View File

@ -790,9 +790,11 @@ func (w *huffmanBitWriter) fillTokens() {
// and offsetEncoding. // and offsetEncoding.
// The number of literal and offset tokens is returned. // The number of literal and offset tokens is returned.
func (w *huffmanBitWriter) indexTokens(t *tokens, filled bool) (numLiterals, numOffsets int) { func (w *huffmanBitWriter) indexTokens(t *tokens, filled bool) (numLiterals, numOffsets int) {
copy(w.literalFreq[:], t.litHist[:]) //copy(w.literalFreq[:], t.litHist[:])
copy(w.literalFreq[256:], t.extraHist[:]) *(*[256]uint16)(w.literalFreq[:]) = t.litHist
copy(w.offsetFreq[:], t.offHist[:offsetCodeCount]) //copy(w.literalFreq[256:], t.extraHist[:])
*(*[32]uint16)(w.literalFreq[256:]) = t.extraHist
w.offsetFreq = t.offHist
if t.n == 0 { if t.n == 0 {
return return

View File

@ -168,12 +168,17 @@ func (h *huffmanEncoder) canReuseBits(freq []uint16) int {
// The cases of 0, 1, and 2 literals are handled by special case code. // The cases of 0, 1, and 2 literals are handled by special case code.
// //
// list An array of the literals with non-zero frequencies // list An array of the literals with non-zero frequencies
//
// and their associated frequencies. The array is in order of increasing // and their associated frequencies. The array is in order of increasing
// frequency, and has as its last element a special element with frequency // frequency, and has as its last element a special element with frequency
// MaxInt32 // MaxInt32
//
// maxBits The maximum number of bits that should be used to encode any literal. // maxBits The maximum number of bits that should be used to encode any literal.
//
// Must be less than 16. // Must be less than 16.
//
// return An integer array in which array[i] indicates the number of literals // return An integer array in which array[i] indicates the number of literals
//
// that should be encoded in i bits. // that should be encoded in i bits.
func (h *huffmanEncoder) bitCounts(list []literalNode, maxBits int32) []int32 { func (h *huffmanEncoder) bitCounts(list []literalNode, maxBits int32) []int32 {
if maxBits >= maxBitsLimit { if maxBits >= maxBitsLimit {

View File

@ -19,6 +19,7 @@ func (e *fastEncL1) Encode(dst *tokens, src []byte) {
const ( const (
inputMargin = 12 - 1 inputMargin = 12 - 1
minNonLiteralBlockSize = 1 + 1 + inputMargin minNonLiteralBlockSize = 1 + 1 + inputMargin
hashBytes = 5
) )
if debugDeflate && e.cur < 0 { if debugDeflate && e.cur < 0 {
panic(fmt.Sprint("e.cur < 0: ", e.cur)) panic(fmt.Sprint("e.cur < 0: ", e.cur))
@ -68,7 +69,7 @@ func (e *fastEncL1) Encode(dst *tokens, src []byte) {
sLimit := int32(len(src) - inputMargin) sLimit := int32(len(src) - inputMargin)
// nextEmit is where in src the next emitLiteral should start from. // nextEmit is where in src the next emitLiteral should start from.
cv := load3232(src, s) cv := load6432(src, s)
for { for {
const skipLog = 5 const skipLog = 5
@ -77,7 +78,7 @@ func (e *fastEncL1) Encode(dst *tokens, src []byte) {
nextS := s nextS := s
var candidate tableEntry var candidate tableEntry
for { for {
nextHash := hash(cv) nextHash := hashLen(cv, tableBits, hashBytes)
candidate = e.table[nextHash] candidate = e.table[nextHash]
nextS = s + doEvery + (s-nextEmit)>>skipLog nextS = s + doEvery + (s-nextEmit)>>skipLog
if nextS > sLimit { if nextS > sLimit {
@ -86,16 +87,16 @@ func (e *fastEncL1) Encode(dst *tokens, src []byte) {
now := load6432(src, nextS) now := load6432(src, nextS)
e.table[nextHash] = tableEntry{offset: s + e.cur} e.table[nextHash] = tableEntry{offset: s + e.cur}
nextHash = hash(uint32(now)) nextHash = hashLen(now, tableBits, hashBytes)
offset := s - (candidate.offset - e.cur) offset := s - (candidate.offset - e.cur)
if offset < maxMatchOffset && cv == load3232(src, candidate.offset-e.cur) { if offset < maxMatchOffset && uint32(cv) == load3232(src, candidate.offset-e.cur) {
e.table[nextHash] = tableEntry{offset: nextS + e.cur} e.table[nextHash] = tableEntry{offset: nextS + e.cur}
break break
} }
// Do one right away... // Do one right away...
cv = uint32(now) cv = now
s = nextS s = nextS
nextS++ nextS++
candidate = e.table[nextHash] candidate = e.table[nextHash]
@ -103,11 +104,11 @@ func (e *fastEncL1) Encode(dst *tokens, src []byte) {
e.table[nextHash] = tableEntry{offset: s + e.cur} e.table[nextHash] = tableEntry{offset: s + e.cur}
offset = s - (candidate.offset - e.cur) offset = s - (candidate.offset - e.cur)
if offset < maxMatchOffset && cv == load3232(src, candidate.offset-e.cur) { if offset < maxMatchOffset && uint32(cv) == load3232(src, candidate.offset-e.cur) {
e.table[nextHash] = tableEntry{offset: nextS + e.cur} e.table[nextHash] = tableEntry{offset: nextS + e.cur}
break break
} }
cv = uint32(now) cv = now
s = nextS s = nextS
} }
@ -198,9 +199,9 @@ func (e *fastEncL1) Encode(dst *tokens, src []byte) {
} }
if s >= sLimit { if s >= sLimit {
// Index first pair after match end. // Index first pair after match end.
if int(s+l+4) < len(src) { if int(s+l+8) < len(src) {
cv := load3232(src, s) cv := load6432(src, s)
e.table[hash(cv)] = tableEntry{offset: s + e.cur} e.table[hashLen(cv, tableBits, hashBytes)] = tableEntry{offset: s + e.cur}
} }
goto emitRemainder goto emitRemainder
} }
@ -213,16 +214,16 @@ func (e *fastEncL1) Encode(dst *tokens, src []byte) {
// three load32 calls. // three load32 calls.
x := load6432(src, s-2) x := load6432(src, s-2)
o := e.cur + s - 2 o := e.cur + s - 2
prevHash := hash(uint32(x)) prevHash := hashLen(x, tableBits, hashBytes)
e.table[prevHash] = tableEntry{offset: o} e.table[prevHash] = tableEntry{offset: o}
x >>= 16 x >>= 16
currHash := hash(uint32(x)) currHash := hashLen(x, tableBits, hashBytes)
candidate = e.table[currHash] candidate = e.table[currHash]
e.table[currHash] = tableEntry{offset: o + 2} e.table[currHash] = tableEntry{offset: o + 2}
offset := s - (candidate.offset - e.cur) offset := s - (candidate.offset - e.cur)
if offset > maxMatchOffset || uint32(x) != load3232(src, candidate.offset-e.cur) { if offset > maxMatchOffset || uint32(x) != load3232(src, candidate.offset-e.cur) {
cv = uint32(x >> 8) cv = x >> 8
s++ s++
break break
} }

View File

@ -16,6 +16,7 @@ func (e *fastEncL2) Encode(dst *tokens, src []byte) {
const ( const (
inputMargin = 12 - 1 inputMargin = 12 - 1
minNonLiteralBlockSize = 1 + 1 + inputMargin minNonLiteralBlockSize = 1 + 1 + inputMargin
hashBytes = 5
) )
if debugDeflate && e.cur < 0 { if debugDeflate && e.cur < 0 {
@ -66,7 +67,7 @@ func (e *fastEncL2) Encode(dst *tokens, src []byte) {
sLimit := int32(len(src) - inputMargin) sLimit := int32(len(src) - inputMargin)
// nextEmit is where in src the next emitLiteral should start from. // nextEmit is where in src the next emitLiteral should start from.
cv := load3232(src, s) cv := load6432(src, s)
for { for {
// When should we start skipping if we haven't found matches in a long while. // When should we start skipping if we haven't found matches in a long while.
const skipLog = 5 const skipLog = 5
@ -75,7 +76,7 @@ func (e *fastEncL2) Encode(dst *tokens, src []byte) {
nextS := s nextS := s
var candidate tableEntry var candidate tableEntry
for { for {
nextHash := hash4u(cv, bTableBits) nextHash := hashLen(cv, bTableBits, hashBytes)
s = nextS s = nextS
nextS = s + doEvery + (s-nextEmit)>>skipLog nextS = s + doEvery + (s-nextEmit)>>skipLog
if nextS > sLimit { if nextS > sLimit {
@ -84,16 +85,16 @@ func (e *fastEncL2) Encode(dst *tokens, src []byte) {
candidate = e.table[nextHash] candidate = e.table[nextHash]
now := load6432(src, nextS) now := load6432(src, nextS)
e.table[nextHash] = tableEntry{offset: s + e.cur} e.table[nextHash] = tableEntry{offset: s + e.cur}
nextHash = hash4u(uint32(now), bTableBits) nextHash = hashLen(now, bTableBits, hashBytes)
offset := s - (candidate.offset - e.cur) offset := s - (candidate.offset - e.cur)
if offset < maxMatchOffset && cv == load3232(src, candidate.offset-e.cur) { if offset < maxMatchOffset && uint32(cv) == load3232(src, candidate.offset-e.cur) {
e.table[nextHash] = tableEntry{offset: nextS + e.cur} e.table[nextHash] = tableEntry{offset: nextS + e.cur}
break break
} }
// Do one right away... // Do one right away...
cv = uint32(now) cv = now
s = nextS s = nextS
nextS++ nextS++
candidate = e.table[nextHash] candidate = e.table[nextHash]
@ -101,10 +102,10 @@ func (e *fastEncL2) Encode(dst *tokens, src []byte) {
e.table[nextHash] = tableEntry{offset: s + e.cur} e.table[nextHash] = tableEntry{offset: s + e.cur}
offset = s - (candidate.offset - e.cur) offset = s - (candidate.offset - e.cur)
if offset < maxMatchOffset && cv == load3232(src, candidate.offset-e.cur) { if offset < maxMatchOffset && uint32(cv) == load3232(src, candidate.offset-e.cur) {
break break
} }
cv = uint32(now) cv = now
} }
// A 4-byte match has been found. We'll later see if more than 4 bytes // A 4-byte match has been found. We'll later see if more than 4 bytes
@ -154,9 +155,9 @@ func (e *fastEncL2) Encode(dst *tokens, src []byte) {
if s >= sLimit { if s >= sLimit {
// Index first pair after match end. // Index first pair after match end.
if int(s+l+4) < len(src) { if int(s+l+8) < len(src) {
cv := load3232(src, s) cv := load6432(src, s)
e.table[hash4u(cv, bTableBits)] = tableEntry{offset: s + e.cur} e.table[hashLen(cv, bTableBits, hashBytes)] = tableEntry{offset: s + e.cur}
} }
goto emitRemainder goto emitRemainder
} }
@ -164,15 +165,15 @@ func (e *fastEncL2) Encode(dst *tokens, src []byte) {
// Store every second hash in-between, but offset by 1. // Store every second hash in-between, but offset by 1.
for i := s - l + 2; i < s-5; i += 7 { for i := s - l + 2; i < s-5; i += 7 {
x := load6432(src, i) x := load6432(src, i)
nextHash := hash4u(uint32(x), bTableBits) nextHash := hashLen(x, bTableBits, hashBytes)
e.table[nextHash] = tableEntry{offset: e.cur + i} e.table[nextHash] = tableEntry{offset: e.cur + i}
// Skip one // Skip one
x >>= 16 x >>= 16
nextHash = hash4u(uint32(x), bTableBits) nextHash = hashLen(x, bTableBits, hashBytes)
e.table[nextHash] = tableEntry{offset: e.cur + i + 2} e.table[nextHash] = tableEntry{offset: e.cur + i + 2}
// Skip one // Skip one
x >>= 16 x >>= 16
nextHash = hash4u(uint32(x), bTableBits) nextHash = hashLen(x, bTableBits, hashBytes)
e.table[nextHash] = tableEntry{offset: e.cur + i + 4} e.table[nextHash] = tableEntry{offset: e.cur + i + 4}
} }
@ -184,17 +185,17 @@ func (e *fastEncL2) Encode(dst *tokens, src []byte) {
// three load32 calls. // three load32 calls.
x := load6432(src, s-2) x := load6432(src, s-2)
o := e.cur + s - 2 o := e.cur + s - 2
prevHash := hash4u(uint32(x), bTableBits) prevHash := hashLen(x, bTableBits, hashBytes)
prevHash2 := hash4u(uint32(x>>8), bTableBits) prevHash2 := hashLen(x>>8, bTableBits, hashBytes)
e.table[prevHash] = tableEntry{offset: o} e.table[prevHash] = tableEntry{offset: o}
e.table[prevHash2] = tableEntry{offset: o + 1} e.table[prevHash2] = tableEntry{offset: o + 1}
currHash := hash4u(uint32(x>>16), bTableBits) currHash := hashLen(x>>16, bTableBits, hashBytes)
candidate = e.table[currHash] candidate = e.table[currHash]
e.table[currHash] = tableEntry{offset: o + 2} e.table[currHash] = tableEntry{offset: o + 2}
offset := s - (candidate.offset - e.cur) offset := s - (candidate.offset - e.cur)
if offset > maxMatchOffset || uint32(x>>16) != load3232(src, candidate.offset-e.cur) { if offset > maxMatchOffset || uint32(x>>16) != load3232(src, candidate.offset-e.cur) {
cv = uint32(x >> 24) cv = x >> 24
s++ s++
break break
} }

View File

@ -11,10 +11,11 @@ type fastEncL3 struct {
// Encode uses a similar algorithm to level 2, will check up to two candidates. // Encode uses a similar algorithm to level 2, will check up to two candidates.
func (e *fastEncL3) Encode(dst *tokens, src []byte) { func (e *fastEncL3) Encode(dst *tokens, src []byte) {
const ( const (
inputMargin = 8 - 1 inputMargin = 12 - 1
minNonLiteralBlockSize = 1 + 1 + inputMargin minNonLiteralBlockSize = 1 + 1 + inputMargin
tableBits = 16 tableBits = 16
tableSize = 1 << tableBits tableSize = 1 << tableBits
hashBytes = 5
) )
if debugDeflate && e.cur < 0 { if debugDeflate && e.cur < 0 {
@ -69,20 +70,20 @@ func (e *fastEncL3) Encode(dst *tokens, src []byte) {
sLimit := int32(len(src) - inputMargin) sLimit := int32(len(src) - inputMargin)
// nextEmit is where in src the next emitLiteral should start from. // nextEmit is where in src the next emitLiteral should start from.
cv := load3232(src, s) cv := load6432(src, s)
for { for {
const skipLog = 6 const skipLog = 7
nextS := s nextS := s
var candidate tableEntry var candidate tableEntry
for { for {
nextHash := hash4u(cv, tableBits) nextHash := hashLen(cv, tableBits, hashBytes)
s = nextS s = nextS
nextS = s + 1 + (s-nextEmit)>>skipLog nextS = s + 1 + (s-nextEmit)>>skipLog
if nextS > sLimit { if nextS > sLimit {
goto emitRemainder goto emitRemainder
} }
candidates := e.table[nextHash] candidates := e.table[nextHash]
now := load3232(src, nextS) now := load6432(src, nextS)
// Safe offset distance until s + 4... // Safe offset distance until s + 4...
minOffset := e.cur + s - (maxMatchOffset - 4) minOffset := e.cur + s - (maxMatchOffset - 4)
@ -96,8 +97,8 @@ func (e *fastEncL3) Encode(dst *tokens, src []byte) {
continue continue
} }
if cv == load3232(src, candidate.offset-e.cur) { if uint32(cv) == load3232(src, candidate.offset-e.cur) {
if candidates.Prev.offset < minOffset || cv != load3232(src, candidates.Prev.offset-e.cur) { if candidates.Prev.offset < minOffset || uint32(cv) != load3232(src, candidates.Prev.offset-e.cur) {
break break
} }
// Both match and are valid, pick longest. // Both match and are valid, pick longest.
@ -112,7 +113,7 @@ func (e *fastEncL3) Encode(dst *tokens, src []byte) {
// We only check if value mismatches. // We only check if value mismatches.
// Offset will always be invalid in other cases. // Offset will always be invalid in other cases.
candidate = candidates.Prev candidate = candidates.Prev
if candidate.offset > minOffset && cv == load3232(src, candidate.offset-e.cur) { if candidate.offset > minOffset && uint32(cv) == load3232(src, candidate.offset-e.cur) {
break break
} }
} }
@ -164,9 +165,9 @@ func (e *fastEncL3) Encode(dst *tokens, src []byte) {
if s >= sLimit { if s >= sLimit {
t += l t += l
// Index first pair after match end. // Index first pair after match end.
if int(t+4) < len(src) && t > 0 { if int(t+8) < len(src) && t > 0 {
cv := load3232(src, t) cv = load6432(src, t)
nextHash := hash4u(cv, tableBits) nextHash := hashLen(cv, tableBits, hashBytes)
e.table[nextHash] = tableEntryPrev{ e.table[nextHash] = tableEntryPrev{
Prev: e.table[nextHash].Cur, Prev: e.table[nextHash].Cur,
Cur: tableEntry{offset: e.cur + t}, Cur: tableEntry{offset: e.cur + t},
@ -176,8 +177,8 @@ func (e *fastEncL3) Encode(dst *tokens, src []byte) {
} }
// Store every 5th hash in-between. // Store every 5th hash in-between.
for i := s - l + 2; i < s-5; i += 5 { for i := s - l + 2; i < s-5; i += 6 {
nextHash := hash4u(load3232(src, i), tableBits) nextHash := hashLen(load6432(src, i), tableBits, hashBytes)
e.table[nextHash] = tableEntryPrev{ e.table[nextHash] = tableEntryPrev{
Prev: e.table[nextHash].Cur, Prev: e.table[nextHash].Cur,
Cur: tableEntry{offset: e.cur + i}} Cur: tableEntry{offset: e.cur + i}}
@ -185,23 +186,23 @@ func (e *fastEncL3) Encode(dst *tokens, src []byte) {
// We could immediately start working at s now, but to improve // We could immediately start working at s now, but to improve
// compression we first update the hash table at s-2 to s. // compression we first update the hash table at s-2 to s.
x := load6432(src, s-2) x := load6432(src, s-2)
prevHash := hash4u(uint32(x), tableBits) prevHash := hashLen(x, tableBits, hashBytes)
e.table[prevHash] = tableEntryPrev{ e.table[prevHash] = tableEntryPrev{
Prev: e.table[prevHash].Cur, Prev: e.table[prevHash].Cur,
Cur: tableEntry{offset: e.cur + s - 2}, Cur: tableEntry{offset: e.cur + s - 2},
} }
x >>= 8 x >>= 8
prevHash = hash4u(uint32(x), tableBits) prevHash = hashLen(x, tableBits, hashBytes)
e.table[prevHash] = tableEntryPrev{ e.table[prevHash] = tableEntryPrev{
Prev: e.table[prevHash].Cur, Prev: e.table[prevHash].Cur,
Cur: tableEntry{offset: e.cur + s - 1}, Cur: tableEntry{offset: e.cur + s - 1},
} }
x >>= 8 x >>= 8
currHash := hash4u(uint32(x), tableBits) currHash := hashLen(x, tableBits, hashBytes)
candidates := e.table[currHash] candidates := e.table[currHash]
cv = uint32(x) cv = x
e.table[currHash] = tableEntryPrev{ e.table[currHash] = tableEntryPrev{
Prev: candidates.Cur, Prev: candidates.Cur,
Cur: tableEntry{offset: s + e.cur}, Cur: tableEntry{offset: s + e.cur},
@ -212,17 +213,17 @@ func (e *fastEncL3) Encode(dst *tokens, src []byte) {
minOffset := e.cur + s - (maxMatchOffset - 4) minOffset := e.cur + s - (maxMatchOffset - 4)
if candidate.offset > minOffset { if candidate.offset > minOffset {
if cv == load3232(src, candidate.offset-e.cur) { if uint32(cv) == load3232(src, candidate.offset-e.cur) {
// Found a match... // Found a match...
continue continue
} }
candidate = candidates.Prev candidate = candidates.Prev
if candidate.offset > minOffset && cv == load3232(src, candidate.offset-e.cur) { if candidate.offset > minOffset && uint32(cv) == load3232(src, candidate.offset-e.cur) {
// Match at prev... // Match at prev...
continue continue
} }
} }
cv = uint32(x >> 8) cv = x >> 8
s++ s++
break break
} }

View File

@ -12,6 +12,7 @@ func (e *fastEncL4) Encode(dst *tokens, src []byte) {
const ( const (
inputMargin = 12 - 1 inputMargin = 12 - 1
minNonLiteralBlockSize = 1 + 1 + inputMargin minNonLiteralBlockSize = 1 + 1 + inputMargin
hashShortBytes = 4
) )
if debugDeflate && e.cur < 0 { if debugDeflate && e.cur < 0 {
panic(fmt.Sprint("e.cur < 0: ", e.cur)) panic(fmt.Sprint("e.cur < 0: ", e.cur))
@ -80,7 +81,7 @@ func (e *fastEncL4) Encode(dst *tokens, src []byte) {
nextS := s nextS := s
var t int32 var t int32
for { for {
nextHashS := hash4x64(cv, tableBits) nextHashS := hashLen(cv, tableBits, hashShortBytes)
nextHashL := hash7(cv, tableBits) nextHashL := hash7(cv, tableBits)
s = nextS s = nextS
@ -168,7 +169,7 @@ func (e *fastEncL4) Encode(dst *tokens, src []byte) {
// Index first pair after match end. // Index first pair after match end.
if int(s+8) < len(src) { if int(s+8) < len(src) {
cv := load6432(src, s) cv := load6432(src, s)
e.table[hash4x64(cv, tableBits)] = tableEntry{offset: s + e.cur} e.table[hashLen(cv, tableBits, hashShortBytes)] = tableEntry{offset: s + e.cur}
e.bTable[hash7(cv, tableBits)] = tableEntry{offset: s + e.cur} e.bTable[hash7(cv, tableBits)] = tableEntry{offset: s + e.cur}
} }
goto emitRemainder goto emitRemainder
@ -183,7 +184,7 @@ func (e *fastEncL4) Encode(dst *tokens, src []byte) {
t2 := tableEntry{offset: t.offset + 1} t2 := tableEntry{offset: t.offset + 1}
e.bTable[hash7(cv, tableBits)] = t e.bTable[hash7(cv, tableBits)] = t
e.bTable[hash7(cv>>8, tableBits)] = t2 e.bTable[hash7(cv>>8, tableBits)] = t2
e.table[hash4u(uint32(cv>>8), tableBits)] = t2 e.table[hashLen(cv>>8, tableBits, hashShortBytes)] = t2
i += 3 i += 3
for ; i < s-1; i += 3 { for ; i < s-1; i += 3 {
@ -192,7 +193,7 @@ func (e *fastEncL4) Encode(dst *tokens, src []byte) {
t2 := tableEntry{offset: t.offset + 1} t2 := tableEntry{offset: t.offset + 1}
e.bTable[hash7(cv, tableBits)] = t e.bTable[hash7(cv, tableBits)] = t
e.bTable[hash7(cv>>8, tableBits)] = t2 e.bTable[hash7(cv>>8, tableBits)] = t2
e.table[hash4u(uint32(cv>>8), tableBits)] = t2 e.table[hashLen(cv>>8, tableBits, hashShortBytes)] = t2
} }
} }
} }
@ -201,7 +202,7 @@ func (e *fastEncL4) Encode(dst *tokens, src []byte) {
// compression we first update the hash table at s-1 and at s. // compression we first update the hash table at s-1 and at s.
x := load6432(src, s-1) x := load6432(src, s-1)
o := e.cur + s - 1 o := e.cur + s - 1
prevHashS := hash4x64(x, tableBits) prevHashS := hashLen(x, tableBits, hashShortBytes)
prevHashL := hash7(x, tableBits) prevHashL := hash7(x, tableBits)
e.table[prevHashS] = tableEntry{offset: o} e.table[prevHashS] = tableEntry{offset: o}
e.bTable[prevHashL] = tableEntry{offset: o} e.bTable[prevHashL] = tableEntry{offset: o}

View File

@ -12,6 +12,7 @@ func (e *fastEncL5) Encode(dst *tokens, src []byte) {
const ( const (
inputMargin = 12 - 1 inputMargin = 12 - 1
minNonLiteralBlockSize = 1 + 1 + inputMargin minNonLiteralBlockSize = 1 + 1 + inputMargin
hashShortBytes = 4
) )
if debugDeflate && e.cur < 0 { if debugDeflate && e.cur < 0 {
panic(fmt.Sprint("e.cur < 0: ", e.cur)) panic(fmt.Sprint("e.cur < 0: ", e.cur))
@ -88,7 +89,7 @@ func (e *fastEncL5) Encode(dst *tokens, src []byte) {
var l int32 var l int32
var t int32 var t int32
for { for {
nextHashS := hash4x64(cv, tableBits) nextHashS := hashLen(cv, tableBits, hashShortBytes)
nextHashL := hash7(cv, tableBits) nextHashL := hash7(cv, tableBits)
s = nextS s = nextS
@ -105,7 +106,7 @@ func (e *fastEncL5) Encode(dst *tokens, src []byte) {
eLong := &e.bTable[nextHashL] eLong := &e.bTable[nextHashL]
eLong.Cur, eLong.Prev = entry, eLong.Cur eLong.Cur, eLong.Prev = entry, eLong.Cur
nextHashS = hash4x64(next, tableBits) nextHashS = hashLen(next, tableBits, hashShortBytes)
nextHashL = hash7(next, tableBits) nextHashL = hash7(next, tableBits)
t = lCandidate.Cur.offset - e.cur t = lCandidate.Cur.offset - e.cur
@ -191,14 +192,21 @@ func (e *fastEncL5) Encode(dst *tokens, src []byte) {
// Try to locate a better match by checking the end of best match... // Try to locate a better match by checking the end of best match...
if sAt := s + l; l < 30 && sAt < sLimit { if sAt := s + l; l < 30 && sAt < sLimit {
// Allow some bytes at the beginning to mismatch.
// Sweet spot is 2/3 bytes depending on input.
// 3 is only a little better when it is but sometimes a lot worse.
// The skipped bytes are tested in Extend backwards,
// and still picked up as part of the match if they do.
const skipBeginning = 2
eLong := e.bTable[hash7(load6432(src, sAt), tableBits)].Cur.offset eLong := e.bTable[hash7(load6432(src, sAt), tableBits)].Cur.offset
// Test current t2 := eLong - e.cur - l + skipBeginning
t2 := eLong - e.cur - l s2 := s + skipBeginning
off := s - t2 off := s2 - t2
if t2 >= 0 && off < maxMatchOffset && off > 0 { if t2 >= 0 && off < maxMatchOffset && off > 0 {
if l2 := e.matchlenLong(s, t2, src); l2 > l { if l2 := e.matchlenLong(s2, t2, src); l2 > l {
t = t2 t = t2
l = l2 l = l2
s = s2
} }
} }
} }
@ -250,7 +258,7 @@ func (e *fastEncL5) Encode(dst *tokens, src []byte) {
if i < s-1 { if i < s-1 {
cv := load6432(src, i) cv := load6432(src, i)
t := tableEntry{offset: i + e.cur} t := tableEntry{offset: i + e.cur}
e.table[hash4x64(cv, tableBits)] = t e.table[hashLen(cv, tableBits, hashShortBytes)] = t
eLong := &e.bTable[hash7(cv, tableBits)] eLong := &e.bTable[hash7(cv, tableBits)]
eLong.Cur, eLong.Prev = t, eLong.Cur eLong.Cur, eLong.Prev = t, eLong.Cur
@ -263,7 +271,7 @@ func (e *fastEncL5) Encode(dst *tokens, src []byte) {
// We only have enough bits for a short entry at i+2 // We only have enough bits for a short entry at i+2
cv >>= 8 cv >>= 8
t = tableEntry{offset: t.offset + 1} t = tableEntry{offset: t.offset + 1}
e.table[hash4x64(cv, tableBits)] = t e.table[hashLen(cv, tableBits, hashShortBytes)] = t
// Skip one - otherwise we risk hitting 's' // Skip one - otherwise we risk hitting 's'
i += 4 i += 4
@ -273,7 +281,7 @@ func (e *fastEncL5) Encode(dst *tokens, src []byte) {
t2 := tableEntry{offset: t.offset + 1} t2 := tableEntry{offset: t.offset + 1}
eLong := &e.bTable[hash7(cv, tableBits)] eLong := &e.bTable[hash7(cv, tableBits)]
eLong.Cur, eLong.Prev = t, eLong.Cur eLong.Cur, eLong.Prev = t, eLong.Cur
e.table[hash4u(uint32(cv>>8), tableBits)] = t2 e.table[hashLen(cv>>8, tableBits, hashShortBytes)] = t2
} }
} }
} }
@ -282,7 +290,7 @@ func (e *fastEncL5) Encode(dst *tokens, src []byte) {
// compression we first update the hash table at s-1 and at s. // compression we first update the hash table at s-1 and at s.
x := load6432(src, s-1) x := load6432(src, s-1)
o := e.cur + s - 1 o := e.cur + s - 1
prevHashS := hash4x64(x, tableBits) prevHashS := hashLen(x, tableBits, hashShortBytes)
prevHashL := hash7(x, tableBits) prevHashL := hash7(x, tableBits)
e.table[prevHashS] = tableEntry{offset: o} e.table[prevHashS] = tableEntry{offset: o}
eLong := &e.bTable[prevHashL] eLong := &e.bTable[prevHashL]

View File

@ -12,6 +12,7 @@ func (e *fastEncL6) Encode(dst *tokens, src []byte) {
const ( const (
inputMargin = 12 - 1 inputMargin = 12 - 1
minNonLiteralBlockSize = 1 + 1 + inputMargin minNonLiteralBlockSize = 1 + 1 + inputMargin
hashShortBytes = 4
) )
if debugDeflate && e.cur < 0 { if debugDeflate && e.cur < 0 {
panic(fmt.Sprint("e.cur < 0: ", e.cur)) panic(fmt.Sprint("e.cur < 0: ", e.cur))
@ -90,7 +91,7 @@ func (e *fastEncL6) Encode(dst *tokens, src []byte) {
var l int32 var l int32
var t int32 var t int32
for { for {
nextHashS := hash4x64(cv, tableBits) nextHashS := hashLen(cv, tableBits, hashShortBytes)
nextHashL := hash7(cv, tableBits) nextHashL := hash7(cv, tableBits)
s = nextS s = nextS
nextS = s + doEvery + (s-nextEmit)>>skipLog nextS = s + doEvery + (s-nextEmit)>>skipLog
@ -107,7 +108,7 @@ func (e *fastEncL6) Encode(dst *tokens, src []byte) {
eLong.Cur, eLong.Prev = entry, eLong.Cur eLong.Cur, eLong.Prev = entry, eLong.Cur
// Calculate hashes of 'next' // Calculate hashes of 'next'
nextHashS = hash4x64(next, tableBits) nextHashS = hashLen(next, tableBits, hashShortBytes)
nextHashL = hash7(next, tableBits) nextHashL = hash7(next, tableBits)
t = lCandidate.Cur.offset - e.cur t = lCandidate.Cur.offset - e.cur
@ -213,24 +214,33 @@ func (e *fastEncL6) Encode(dst *tokens, src []byte) {
// Try to locate a better match by checking the end-of-match... // Try to locate a better match by checking the end-of-match...
if sAt := s + l; sAt < sLimit { if sAt := s + l; sAt < sLimit {
// Allow some bytes at the beginning to mismatch.
// Sweet spot is 2/3 bytes depending on input.
// 3 is only a little better when it is but sometimes a lot worse.
// The skipped bytes are tested in Extend backwards,
// and still picked up as part of the match if they do.
const skipBeginning = 2
eLong := &e.bTable[hash7(load6432(src, sAt), tableBits)] eLong := &e.bTable[hash7(load6432(src, sAt), tableBits)]
// Test current // Test current
t2 := eLong.Cur.offset - e.cur - l t2 := eLong.Cur.offset - e.cur - l + skipBeginning
off := s - t2 s2 := s + skipBeginning
off := s2 - t2
if off < maxMatchOffset { if off < maxMatchOffset {
if off > 0 && t2 >= 0 { if off > 0 && t2 >= 0 {
if l2 := e.matchlenLong(s, t2, src); l2 > l { if l2 := e.matchlenLong(s2, t2, src); l2 > l {
t = t2 t = t2
l = l2 l = l2
s = s2
} }
} }
// Test next: // Test next:
t2 = eLong.Prev.offset - e.cur - l t2 = eLong.Prev.offset - e.cur - l + skipBeginning
off := s - t2 off := s2 - t2
if off > 0 && off < maxMatchOffset && t2 >= 0 { if off > 0 && off < maxMatchOffset && t2 >= 0 {
if l2 := e.matchlenLong(s, t2, src); l2 > l { if l2 := e.matchlenLong(s2, t2, src); l2 > l {
t = t2 t = t2
l = l2 l = l2
s = s2
} }
} }
} }
@ -277,7 +287,7 @@ func (e *fastEncL6) Encode(dst *tokens, src []byte) {
// Index after match end. // Index after match end.
for i := nextS + 1; i < int32(len(src))-8; i += 2 { for i := nextS + 1; i < int32(len(src))-8; i += 2 {
cv := load6432(src, i) cv := load6432(src, i)
e.table[hash4x64(cv, tableBits)] = tableEntry{offset: i + e.cur} e.table[hashLen(cv, tableBits, hashShortBytes)] = tableEntry{offset: i + e.cur}
eLong := &e.bTable[hash7(cv, tableBits)] eLong := &e.bTable[hash7(cv, tableBits)]
eLong.Cur, eLong.Prev = tableEntry{offset: i + e.cur}, eLong.Cur eLong.Cur, eLong.Prev = tableEntry{offset: i + e.cur}, eLong.Cur
} }
@ -292,7 +302,7 @@ func (e *fastEncL6) Encode(dst *tokens, src []byte) {
t2 := tableEntry{offset: t.offset + 1} t2 := tableEntry{offset: t.offset + 1}
eLong := &e.bTable[hash7(cv, tableBits)] eLong := &e.bTable[hash7(cv, tableBits)]
eLong2 := &e.bTable[hash7(cv>>8, tableBits)] eLong2 := &e.bTable[hash7(cv>>8, tableBits)]
e.table[hash4x64(cv, tableBits)] = t e.table[hashLen(cv, tableBits, hashShortBytes)] = t
eLong.Cur, eLong.Prev = t, eLong.Cur eLong.Cur, eLong.Prev = t, eLong.Cur
eLong2.Cur, eLong2.Prev = t2, eLong2.Cur eLong2.Cur, eLong2.Prev = t2, eLong2.Cur
} }

View File

@ -763,17 +763,20 @@ func (d *Decoder) decompress4X8bit(dst, src []byte) ([]byte, error) {
d.bufs.Put(buf) d.bufs.Put(buf)
return nil, errors.New("corruption detected: stream overrun 1") return nil, errors.New("corruption detected: stream overrun 1")
} }
copy(out, buf[0][:])
copy(out[dstEvery:], buf[1][:])
copy(out[dstEvery*2:], buf[2][:])
copy(out[dstEvery*3:], buf[3][:])
out = out[bufoff:]
decoded += bufoff * 4
// There must at least be 3 buffers left. // There must at least be 3 buffers left.
if len(out) < dstEvery*3 { if len(out)-bufoff < dstEvery*3 {
d.bufs.Put(buf) d.bufs.Put(buf)
return nil, errors.New("corruption detected: stream overrun 2") return nil, errors.New("corruption detected: stream overrun 2")
} }
//copy(out, buf[0][:])
//copy(out[dstEvery:], buf[1][:])
//copy(out[dstEvery*2:], buf[2][:])
*(*[bufoff]byte)(out) = buf[0]
*(*[bufoff]byte)(out[dstEvery:]) = buf[1]
*(*[bufoff]byte)(out[dstEvery*2:]) = buf[2]
*(*[bufoff]byte)(out[dstEvery*3:]) = buf[3]
out = out[bufoff:]
decoded += bufoff * 4
} }
} }
if off > 0 { if off > 0 {
@ -997,17 +1000,22 @@ func (d *Decoder) decompress4X8bitExactly(dst, src []byte) ([]byte, error) {
d.bufs.Put(buf) d.bufs.Put(buf)
return nil, errors.New("corruption detected: stream overrun 1") return nil, errors.New("corruption detected: stream overrun 1")
} }
copy(out, buf[0][:])
copy(out[dstEvery:], buf[1][:])
copy(out[dstEvery*2:], buf[2][:])
copy(out[dstEvery*3:], buf[3][:])
out = out[bufoff:]
decoded += bufoff * 4
// There must at least be 3 buffers left. // There must at least be 3 buffers left.
if len(out) < dstEvery*3 { if len(out)-bufoff < dstEvery*3 {
d.bufs.Put(buf) d.bufs.Put(buf)
return nil, errors.New("corruption detected: stream overrun 2") return nil, errors.New("corruption detected: stream overrun 2")
} }
//copy(out, buf[0][:])
//copy(out[dstEvery:], buf[1][:])
//copy(out[dstEvery*2:], buf[2][:])
// copy(out[dstEvery*3:], buf[3][:])
*(*[bufoff]byte)(out) = buf[0]
*(*[bufoff]byte)(out[dstEvery:]) = buf[1]
*(*[bufoff]byte)(out[dstEvery*2:]) = buf[2]
*(*[bufoff]byte)(out[dstEvery*3:]) = buf[3]
out = out[bufoff:]
decoded += bufoff * 4
} }
} }
if off > 0 { if off > 0 {

View File

@ -14,12 +14,14 @@ import (
// decompress4x_main_loop_x86 is an x86 assembler implementation // decompress4x_main_loop_x86 is an x86 assembler implementation
// of Decompress4X when tablelog > 8. // of Decompress4X when tablelog > 8.
//
//go:noescape //go:noescape
func decompress4x_main_loop_amd64(ctx *decompress4xContext) func decompress4x_main_loop_amd64(ctx *decompress4xContext)
// decompress4x_8b_loop_x86 is an x86 assembler implementation // decompress4x_8b_loop_x86 is an x86 assembler implementation
// of Decompress4X when tablelog <= 8 which decodes 4 entries // of Decompress4X when tablelog <= 8 which decodes 4 entries
// per loop. // per loop.
//
//go:noescape //go:noescape
func decompress4x_8b_main_loop_amd64(ctx *decompress4xContext) func decompress4x_8b_main_loop_amd64(ctx *decompress4xContext)
@ -145,11 +147,13 @@ func (d *Decoder) Decompress4X(dst, src []byte) ([]byte, error) {
// decompress4x_main_loop_x86 is an x86 assembler implementation // decompress4x_main_loop_x86 is an x86 assembler implementation
// of Decompress1X when tablelog > 8. // of Decompress1X when tablelog > 8.
//
//go:noescape //go:noescape
func decompress1x_main_loop_amd64(ctx *decompress1xContext) func decompress1x_main_loop_amd64(ctx *decompress1xContext)
// decompress4x_main_loop_x86 is an x86 with BMI2 assembler implementation // decompress4x_main_loop_x86 is an x86 with BMI2 assembler implementation
// of Decompress1X when tablelog > 8. // of Decompress1X when tablelog > 8.
//
//go:noescape //go:noescape
func decompress1x_main_loop_bmi2(ctx *decompress1xContext) func decompress1x_main_loop_bmi2(ctx *decompress1xContext)

View File

@ -1,7 +1,6 @@
// Code generated by command: go run gen.go -out ../decompress_amd64.s -pkg=huff0. DO NOT EDIT. // Code generated by command: go run gen.go -out ../decompress_amd64.s -pkg=huff0. DO NOT EDIT.
//go:build amd64 && !appengine && !noasm && gc //go:build amd64 && !appengine && !noasm && gc
// +build amd64,!appengine,!noasm,gc
// func decompress4x_main_loop_amd64(ctx *decompress4xContext) // func decompress4x_main_loop_amd64(ctx *decompress4xContext)
TEXT ·decompress4x_main_loop_amd64(SB), $0-8 TEXT ·decompress4x_main_loop_amd64(SB), $0-8

View File

@ -122,17 +122,21 @@ func (d *Decoder) Decompress4X(dst, src []byte) ([]byte, error) {
d.bufs.Put(buf) d.bufs.Put(buf)
return nil, errors.New("corruption detected: stream overrun 1") return nil, errors.New("corruption detected: stream overrun 1")
} }
copy(out, buf[0][:])
copy(out[dstEvery:], buf[1][:])
copy(out[dstEvery*2:], buf[2][:])
copy(out[dstEvery*3:], buf[3][:])
out = out[bufoff:]
decoded += bufoff * 4
// There must at least be 3 buffers left. // There must at least be 3 buffers left.
if len(out) < dstEvery*3 { if len(out)-bufoff < dstEvery*3 {
d.bufs.Put(buf) d.bufs.Put(buf)
return nil, errors.New("corruption detected: stream overrun 2") return nil, errors.New("corruption detected: stream overrun 2")
} }
//copy(out, buf[0][:])
//copy(out[dstEvery:], buf[1][:])
//copy(out[dstEvery*2:], buf[2][:])
//copy(out[dstEvery*3:], buf[3][:])
*(*[bufoff]byte)(out) = buf[0]
*(*[bufoff]byte)(out[dstEvery:]) = buf[1]
*(*[bufoff]byte)(out[dstEvery*2:]) = buf[2]
*(*[bufoff]byte)(out[dstEvery*3:]) = buf[3]
out = out[bufoff:]
decoded += bufoff * 4
} }
} }
if off > 0 { if off > 0 {

View File

@ -18,6 +18,7 @@ func load64(b []byte, i int) uint64 {
// emitLiteral writes a literal chunk and returns the number of bytes written. // emitLiteral writes a literal chunk and returns the number of bytes written.
// //
// It assumes that: // It assumes that:
//
// dst is long enough to hold the encoded bytes // dst is long enough to hold the encoded bytes
// 1 <= len(lit) && len(lit) <= 65536 // 1 <= len(lit) && len(lit) <= 65536
func emitLiteral(dst, lit []byte) int { func emitLiteral(dst, lit []byte) int {
@ -42,6 +43,7 @@ func emitLiteral(dst, lit []byte) int {
// emitCopy writes a copy chunk and returns the number of bytes written. // emitCopy writes a copy chunk and returns the number of bytes written.
// //
// It assumes that: // It assumes that:
//
// dst is long enough to hold the encoded bytes // dst is long enough to hold the encoded bytes
// 1 <= offset && offset <= 65535 // 1 <= offset && offset <= 65535
// 4 <= length && length <= 65535 // 4 <= length && length <= 65535
@ -89,6 +91,7 @@ func emitCopy(dst []byte, offset, length int) int {
// src[i:i+k-j] and src[j:k] have the same contents. // src[i:i+k-j] and src[j:k] have the same contents.
// //
// It assumes that: // It assumes that:
//
// 0 <= i && i < j && j <= len(src) // 0 <= i && i < j && j <= len(src)
func extendMatch(src []byte, i, j int) int { func extendMatch(src []byte, i, j int) int {
for ; j < len(src) && src[i] == src[j]; i, j = i+1, j+1 { for ; j < len(src) && src[i] == src[j]; i, j = i+1, j+1 {
@ -105,6 +108,7 @@ func hash(u, shift uint32) uint32 {
// been written. // been written.
// //
// It also assumes that: // It also assumes that:
//
// len(dst) >= MaxEncodedLen(len(src)) && // len(dst) >= MaxEncodedLen(len(src)) &&
// minNonLiteralBlockSize <= len(src) && len(src) <= maxBlockSize // minNonLiteralBlockSize <= len(src) && len(src) <= maxBlockSize
func encodeBlock(dst, src []byte) (d int) { func encodeBlock(dst, src []byte) (d int) {

View File

@ -12,6 +12,8 @@ The `zstd` package is provided as open source software using a Go standard licen
Currently the package is heavily optimized for 64 bit processors and will be significantly slower on 32 bit processors. Currently the package is heavily optimized for 64 bit processors and will be significantly slower on 32 bit processors.
For seekable zstd streams, see [this excellent package](https://github.com/SaveTheRbtz/zstd-seekable-format-go).
## Installation ## Installation
Install using `go get -u github.com/klauspost/compress`. The package is located in `github.com/klauspost/compress/zstd`. Install using `go get -u github.com/klauspost/compress`. The package is located in `github.com/klauspost/compress/zstd`.

View File

@ -10,7 +10,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
"sync" "sync"
@ -651,7 +650,7 @@ func (b *blockDec) prepareSequences(in []byte, hist *history) (err error) {
fatalErr(binary.Write(&buf, binary.LittleEndian, hist.decoders.matchLengths.fse)) fatalErr(binary.Write(&buf, binary.LittleEndian, hist.decoders.matchLengths.fse))
fatalErr(binary.Write(&buf, binary.LittleEndian, hist.decoders.offsets.fse)) fatalErr(binary.Write(&buf, binary.LittleEndian, hist.decoders.offsets.fse))
buf.Write(in) buf.Write(in)
ioutil.WriteFile(filepath.Join("testdata", "seqs", fn), buf.Bytes(), os.ModePerm) os.WriteFile(filepath.Join("testdata", "seqs", fn), buf.Bytes(), os.ModePerm)
} }
return nil return nil

View File

@ -7,7 +7,6 @@ package zstd
import ( import (
"fmt" "fmt"
"io" "io"
"io/ioutil"
) )
type byteBuffer interface { type byteBuffer interface {
@ -124,7 +123,7 @@ func (r *readerWrapper) readByte() (byte, error) {
} }
func (r *readerWrapper) skipN(n int64) error { func (r *readerWrapper) skipN(n int64) error {
n2, err := io.CopyN(ioutil.Discard, r.r, n) n2, err := io.CopyN(io.Discard, r.r, n)
if n2 != n { if n2 != n {
err = io.ErrUnexpectedEOF err = io.ErrUnexpectedEOF
} }

View File

@ -35,6 +35,7 @@ type Decoder struct {
br readerWrapper br readerWrapper
enabled bool enabled bool
inFrame bool inFrame bool
dstBuf []byte
} }
frame *frameDec frame *frameDec
@ -187,21 +188,23 @@ func (d *Decoder) Reset(r io.Reader) error {
} }
// If bytes buffer and < 5MB, do sync decoding anyway. // If bytes buffer and < 5MB, do sync decoding anyway.
if bb, ok := r.(byter); ok && bb.Len() < 5<<20 { if bb, ok := r.(byter); ok && bb.Len() < d.o.decodeBufsBelow && !d.o.limitToCap {
bb2 := bb bb2 := bb
if debugDecoder { if debugDecoder {
println("*bytes.Buffer detected, doing sync decode, len:", bb.Len()) println("*bytes.Buffer detected, doing sync decode, len:", bb.Len())
} }
b := bb2.Bytes() b := bb2.Bytes()
var dst []byte var dst []byte
if cap(d.current.b) > 0 { if cap(d.syncStream.dstBuf) > 0 {
dst = d.current.b dst = d.syncStream.dstBuf[:0]
} }
dst, err := d.DecodeAll(b, dst[:0]) dst, err := d.DecodeAll(b, dst)
if err == nil { if err == nil {
err = io.EOF err = io.EOF
} }
// Save output buffer
d.syncStream.dstBuf = dst
d.current.b = dst d.current.b = dst
d.current.err = err d.current.err = err
d.current.flushed = true d.current.flushed = true
@ -216,6 +219,7 @@ func (d *Decoder) Reset(r io.Reader) error {
d.current.err = nil d.current.err = nil
d.current.flushed = false d.current.flushed = false
d.current.d = nil d.current.d = nil
d.syncStream.dstBuf = nil
// Ensure no-one else is still running... // Ensure no-one else is still running...
d.streamWg.Wait() d.streamWg.Wait()
@ -312,6 +316,7 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) {
// Grab a block decoder and frame decoder. // Grab a block decoder and frame decoder.
block := <-d.decoders block := <-d.decoders
frame := block.localFrame frame := block.localFrame
initialSize := len(dst)
defer func() { defer func() {
if debugDecoder { if debugDecoder {
printf("re-adding decoder: %p", block) printf("re-adding decoder: %p", block)
@ -354,7 +359,16 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) {
return dst, ErrWindowSizeExceeded return dst, ErrWindowSizeExceeded
} }
if frame.FrameContentSize != fcsUnknown { if frame.FrameContentSize != fcsUnknown {
if frame.FrameContentSize > d.o.maxDecodedSize-uint64(len(dst)) { if frame.FrameContentSize > d.o.maxDecodedSize-uint64(len(dst)-initialSize) {
if debugDecoder {
println("decoder size exceeded; fcs:", frame.FrameContentSize, "> mcs:", d.o.maxDecodedSize-uint64(len(dst)-initialSize), "len:", len(dst))
}
return dst, ErrDecoderSizeExceeded
}
if d.o.limitToCap && frame.FrameContentSize > uint64(cap(dst)-len(dst)) {
if debugDecoder {
println("decoder size exceeded; fcs:", frame.FrameContentSize, "> (cap-len)", cap(dst)-len(dst))
}
return dst, ErrDecoderSizeExceeded return dst, ErrDecoderSizeExceeded
} }
if cap(dst)-len(dst) < int(frame.FrameContentSize) { if cap(dst)-len(dst) < int(frame.FrameContentSize) {
@ -364,7 +378,7 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) {
} }
} }
if cap(dst) == 0 { if cap(dst) == 0 && !d.o.limitToCap {
// Allocate len(input) * 2 by default if nothing is provided // Allocate len(input) * 2 by default if nothing is provided
// and we didn't get frame content size. // and we didn't get frame content size.
size := len(input) * 2 size := len(input) * 2
@ -382,6 +396,9 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) {
if err != nil { if err != nil {
return dst, err return dst, err
} }
if uint64(len(dst)-initialSize) > d.o.maxDecodedSize {
return dst, ErrDecoderSizeExceeded
}
if len(frame.bBuf) == 0 { if len(frame.bBuf) == 0 {
if debugDecoder { if debugDecoder {
println("frame dbuf empty") println("frame dbuf empty")
@ -667,6 +684,7 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch
if debugDecoder { if debugDecoder {
println("Async 1: new history, recent:", block.async.newHist.recentOffsets) println("Async 1: new history, recent:", block.async.newHist.recentOffsets)
} }
hist.reset()
hist.decoders = block.async.newHist.decoders hist.decoders = block.async.newHist.decoders
hist.recentOffsets = block.async.newHist.recentOffsets hist.recentOffsets = block.async.newHist.recentOffsets
hist.windowSize = block.async.newHist.windowSize hist.windowSize = block.async.newHist.windowSize
@ -698,6 +716,7 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch
seqExecute <- block seqExecute <- block
} }
close(seqExecute) close(seqExecute)
hist.reset()
}() }()
var wg sync.WaitGroup var wg sync.WaitGroup
@ -721,6 +740,7 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch
if debugDecoder { if debugDecoder {
println("Async 2: new history") println("Async 2: new history")
} }
hist.reset()
hist.windowSize = block.async.newHist.windowSize hist.windowSize = block.async.newHist.windowSize
hist.allocFrameBuffer = block.async.newHist.allocFrameBuffer hist.allocFrameBuffer = block.async.newHist.allocFrameBuffer
if block.async.newHist.dict != nil { if block.async.newHist.dict != nil {
@ -802,13 +822,14 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch
if debugDecoder { if debugDecoder {
println("decoder goroutines finished") println("decoder goroutines finished")
} }
hist.reset()
}() }()
var hist history
decodeStream: decodeStream:
for { for {
var hist history
var hasErr bool var hasErr bool
hist.reset()
decodeBlock := func(block *blockDec) { decodeBlock := func(block *blockDec) {
if hasErr { if hasErr {
if block != nil { if block != nil {
@ -852,6 +873,10 @@ decodeStream:
} }
} }
if err == nil && d.frame.WindowSize > d.o.maxWindowSize { if err == nil && d.frame.WindowSize > d.o.maxWindowSize {
if debugDecoder {
println("decoder size exceeded, fws:", d.frame.WindowSize, "> mws:", d.o.maxWindowSize)
}
err = ErrDecoderSizeExceeded err = ErrDecoderSizeExceeded
} }
if err != nil { if err != nil {
@ -920,5 +945,6 @@ decodeStream:
} }
close(seqDecode) close(seqDecode)
wg.Wait() wg.Wait()
hist.reset()
d.frame.history.b = frameHistCache d.frame.history.b = frameHistCache
} }

View File

@ -20,6 +20,8 @@ type decoderOptions struct {
maxWindowSize uint64 maxWindowSize uint64
dicts []dict dicts []dict
ignoreChecksum bool ignoreChecksum bool
limitToCap bool
decodeBufsBelow int
} }
func (o *decoderOptions) setDefault() { func (o *decoderOptions) setDefault() {
@ -28,6 +30,7 @@ func (o *decoderOptions) setDefault() {
lowMem: true, lowMem: true,
concurrent: runtime.GOMAXPROCS(0), concurrent: runtime.GOMAXPROCS(0),
maxWindowSize: MaxWindowSize, maxWindowSize: MaxWindowSize,
decodeBufsBelow: 128 << 10,
} }
if o.concurrent > 4 { if o.concurrent > 4 {
o.concurrent = 4 o.concurrent = 4
@ -114,6 +117,29 @@ func WithDecoderMaxWindow(size uint64) DOption {
} }
} }
// WithDecodeAllCapLimit will limit DecodeAll to decoding cap(dst)-len(dst) bytes,
// or any size set in WithDecoderMaxMemory.
// This can be used to limit decoding to a specific maximum output size.
// Disabled by default.
func WithDecodeAllCapLimit(b bool) DOption {
return func(o *decoderOptions) error {
o.limitToCap = b
return nil
}
}
// WithDecodeBuffersBelow will fully decode readers that have a
// `Bytes() []byte` and `Len() int` interface similar to bytes.Buffer.
// This typically uses less allocations but will have the full decompressed object in memory.
// Note that DecodeAllCapLimit will disable this, as well as giving a size of 0 or less.
// Default is 128KiB.
func WithDecodeBuffersBelow(size int) DOption {
return func(o *decoderOptions) error {
o.decodeBufsBelow = size
return nil
}
}
// IgnoreChecksum allows to forcibly ignore checksum checking. // IgnoreChecksum allows to forcibly ignore checksum checking.
func IgnoreChecksum(b bool) DOption { func IgnoreChecksum(b bool) DOption {
return func(o *decoderOptions) error { return func(o *decoderOptions) error {

View File

@ -32,6 +32,7 @@ type match struct {
length int32 length int32
rep int32 rep int32
est int32 est int32
_ [12]byte // Aligned size to cache line: 4+4+4+4+4 bytes + 12 bytes padding = 32 bytes
} }
const highScore = 25000 const highScore = 25000

View File

@ -416,15 +416,23 @@ encodeLoop:
// Try to find a better match by searching for a long match at the end of the current best match // Try to find a better match by searching for a long match at the end of the current best match
if s+matched < sLimit { if s+matched < sLimit {
// Allow some bytes at the beginning to mismatch.
// Sweet spot is around 3 bytes, but depends on input.
// The skipped bytes are tested in Extend backwards,
// and still picked up as part of the match if they do.
const skipBeginning = 3
nextHashL := hashLen(load6432(src, s+matched), betterLongTableBits, betterLongLen) nextHashL := hashLen(load6432(src, s+matched), betterLongTableBits, betterLongLen)
cv := load3232(src, s) s2 := s + skipBeginning
cv := load3232(src, s2)
candidateL := e.longTable[nextHashL] candidateL := e.longTable[nextHashL]
coffsetL := candidateL.offset - e.cur - matched coffsetL := candidateL.offset - e.cur - matched + skipBeginning
if coffsetL >= 0 && coffsetL < s && s-coffsetL < e.maxMatchOff && cv == load3232(src, coffsetL) { if coffsetL >= 0 && coffsetL < s2 && s2-coffsetL < e.maxMatchOff && cv == load3232(src, coffsetL) {
// Found a long match, at least 4 bytes. // Found a long match, at least 4 bytes.
matchedNext := e.matchlen(s+4, coffsetL+4, src) + 4 matchedNext := e.matchlen(s2+4, coffsetL+4, src) + 4
if matchedNext > matched { if matchedNext > matched {
t = coffsetL t = coffsetL
s = s2
matched = matchedNext matched = matchedNext
if debugMatches { if debugMatches {
println("long match at end-of-match") println("long match at end-of-match")
@ -434,12 +442,13 @@ encodeLoop:
// Check prev long... // Check prev long...
if true { if true {
coffsetL = candidateL.prev - e.cur - matched coffsetL = candidateL.prev - e.cur - matched + skipBeginning
if coffsetL >= 0 && coffsetL < s && s-coffsetL < e.maxMatchOff && cv == load3232(src, coffsetL) { if coffsetL >= 0 && coffsetL < s2 && s2-coffsetL < e.maxMatchOff && cv == load3232(src, coffsetL) {
// Found a long match, at least 4 bytes. // Found a long match, at least 4 bytes.
matchedNext := e.matchlen(s+4, coffsetL+4, src) + 4 matchedNext := e.matchlen(s2+4, coffsetL+4, src) + 4
if matchedNext > matched { if matchedNext > matched {
t = coffsetL t = coffsetL
s = s2
matched = matchedNext matched = matchedNext
if debugMatches { if debugMatches {
println("prev long match at end-of-match") println("prev long match at end-of-match")

View File

@ -1103,7 +1103,8 @@ func (e *doubleFastEncoderDict) Reset(d *dict, singleBlock bool) {
} }
if allDirty || dirtyShardCnt > dLongTableShardCnt/2 { if allDirty || dirtyShardCnt > dLongTableShardCnt/2 {
copy(e.longTable[:], e.dictLongTable) //copy(e.longTable[:], e.dictLongTable)
e.longTable = *(*[dFastLongTableSize]tableEntry)(e.dictLongTable)
for i := range e.longTableShardDirty { for i := range e.longTableShardDirty {
e.longTableShardDirty[i] = false e.longTableShardDirty[i] = false
} }
@ -1114,7 +1115,9 @@ func (e *doubleFastEncoderDict) Reset(d *dict, singleBlock bool) {
continue continue
} }
copy(e.longTable[i*dLongTableShardSize:(i+1)*dLongTableShardSize], e.dictLongTable[i*dLongTableShardSize:(i+1)*dLongTableShardSize]) // copy(e.longTable[i*dLongTableShardSize:(i+1)*dLongTableShardSize], e.dictLongTable[i*dLongTableShardSize:(i+1)*dLongTableShardSize])
*(*[dLongTableShardSize]tableEntry)(e.longTable[i*dLongTableShardSize:]) = *(*[dLongTableShardSize]tableEntry)(e.dictLongTable[i*dLongTableShardSize:])
e.longTableShardDirty[i] = false e.longTableShardDirty[i] = false
} }
} }

View File

@ -871,7 +871,8 @@ func (e *fastEncoderDict) Reset(d *dict, singleBlock bool) {
const shardCnt = tableShardCnt const shardCnt = tableShardCnt
const shardSize = tableShardSize const shardSize = tableShardSize
if e.allDirty || dirtyShardCnt > shardCnt*4/6 { if e.allDirty || dirtyShardCnt > shardCnt*4/6 {
copy(e.table[:], e.dictTable) //copy(e.table[:], e.dictTable)
e.table = *(*[tableSize]tableEntry)(e.dictTable)
for i := range e.tableShardDirty { for i := range e.tableShardDirty {
e.tableShardDirty[i] = false e.tableShardDirty[i] = false
} }
@ -883,7 +884,8 @@ func (e *fastEncoderDict) Reset(d *dict, singleBlock bool) {
continue continue
} }
copy(e.table[i*shardSize:(i+1)*shardSize], e.dictTable[i*shardSize:(i+1)*shardSize]) //copy(e.table[i*shardSize:(i+1)*shardSize], e.dictTable[i*shardSize:(i+1)*shardSize])
*(*[shardSize]tableEntry)(e.table[i*shardSize:]) = *(*[shardSize]tableEntry)(e.dictTable[i*shardSize:])
e.tableShardDirty[i] = false e.tableShardDirty[i] = false
} }
e.allDirty = false e.allDirty = false

View File

@ -343,7 +343,7 @@ func (d *frameDec) consumeCRC() error {
return nil return nil
} }
// runDecoder will create a sync decoder that will decode a block of data. // runDecoder will run the decoder for the remainder of the frame.
func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) { func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) {
saved := d.history.b saved := d.history.b
@ -353,12 +353,23 @@ func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) {
// Store input length, so we only check new data. // Store input length, so we only check new data.
crcStart := len(dst) crcStart := len(dst)
d.history.decoders.maxSyncLen = 0 d.history.decoders.maxSyncLen = 0
if d.o.limitToCap {
d.history.decoders.maxSyncLen = uint64(cap(dst) - len(dst))
}
if d.FrameContentSize != fcsUnknown { if d.FrameContentSize != fcsUnknown {
if !d.o.limitToCap || d.FrameContentSize+uint64(len(dst)) < d.history.decoders.maxSyncLen {
d.history.decoders.maxSyncLen = d.FrameContentSize + uint64(len(dst)) d.history.decoders.maxSyncLen = d.FrameContentSize + uint64(len(dst))
}
if d.history.decoders.maxSyncLen > d.o.maxDecodedSize { if d.history.decoders.maxSyncLen > d.o.maxDecodedSize {
if debugDecoder {
println("maxSyncLen:", d.history.decoders.maxSyncLen, "> maxDecodedSize:", d.o.maxDecodedSize)
}
return dst, ErrDecoderSizeExceeded return dst, ErrDecoderSizeExceeded
} }
if uint64(cap(dst)) < d.history.decoders.maxSyncLen { if debugDecoder {
println("maxSyncLen:", d.history.decoders.maxSyncLen)
}
if !d.o.limitToCap && uint64(cap(dst)) < d.history.decoders.maxSyncLen {
// Alloc for output // Alloc for output
dst2 := make([]byte, len(dst), d.history.decoders.maxSyncLen+compressedBlockOverAlloc) dst2 := make([]byte, len(dst), d.history.decoders.maxSyncLen+compressedBlockOverAlloc)
copy(dst2, dst) copy(dst2, dst)
@ -378,7 +389,13 @@ func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) {
if err != nil { if err != nil {
break break
} }
if uint64(len(d.history.b)) > d.o.maxDecodedSize { if uint64(len(d.history.b)-crcStart) > d.o.maxDecodedSize {
println("runDecoder: maxDecodedSize exceeded", uint64(len(d.history.b)-crcStart), ">", d.o.maxDecodedSize)
err = ErrDecoderSizeExceeded
break
}
if d.o.limitToCap && len(d.history.b) > cap(dst) {
println("runDecoder: cap exceeded", uint64(len(d.history.b)), ">", cap(dst))
err = ErrDecoderSizeExceeded err = ErrDecoderSizeExceeded
break break
} }

View File

@ -21,6 +21,7 @@ type buildDtableAsmContext struct {
// buildDtable_asm is an x86 assembly implementation of fseDecoder.buildDtable. // buildDtable_asm is an x86 assembly implementation of fseDecoder.buildDtable.
// Function returns non-zero exit code on error. // Function returns non-zero exit code on error.
//
//go:noescape //go:noescape
func buildDtable_asm(s *fseDecoder, ctx *buildDtableAsmContext) int func buildDtable_asm(s *fseDecoder, ctx *buildDtableAsmContext) int

View File

@ -1,7 +1,6 @@
// Code generated by command: go run gen_fse.go -out ../fse_decoder_amd64.s -pkg=zstd. DO NOT EDIT. // Code generated by command: go run gen_fse.go -out ../fse_decoder_amd64.s -pkg=zstd. DO NOT EDIT.
//go:build !appengine && !noasm && gc && !noasm //go:build !appengine && !noasm && gc && !noasm
// +build !appengine,!noasm,gc,!noasm
// func buildDtable_asm(s *fseDecoder, ctx *buildDtableAsmContext) int // func buildDtable_asm(s *fseDecoder, ctx *buildDtableAsmContext) int
TEXT ·buildDtable_asm(SB), $0-24 TEXT ·buildDtable_asm(SB), $0-24

View File

@ -37,26 +37,23 @@ func (h *history) reset() {
h.ignoreBuffer = 0 h.ignoreBuffer = 0
h.error = false h.error = false
h.recentOffsets = [3]int{1, 4, 8} h.recentOffsets = [3]int{1, 4, 8}
if f := h.decoders.litLengths.fse; f != nil && !f.preDefined { h.decoders.freeDecoders()
fseDecoderPool.Put(f)
}
if f := h.decoders.offsets.fse; f != nil && !f.preDefined {
fseDecoderPool.Put(f)
}
if f := h.decoders.matchLengths.fse; f != nil && !f.preDefined {
fseDecoderPool.Put(f)
}
h.decoders = sequenceDecs{br: h.decoders.br} h.decoders = sequenceDecs{br: h.decoders.br}
if h.huffTree != nil { h.freeHuffDecoder()
if h.dict == nil || h.dict.litEnc != h.huffTree {
huffDecoderPool.Put(h.huffTree)
}
}
h.huffTree = nil h.huffTree = nil
h.dict = nil h.dict = nil
//printf("history created: %+v (l: %d, c: %d)", *h, len(h.b), cap(h.b)) //printf("history created: %+v (l: %d, c: %d)", *h, len(h.b), cap(h.b))
} }
func (h *history) freeHuffDecoder() {
if h.huffTree != nil {
if h.dict == nil || h.dict.litEnc != h.huffTree {
huffDecoderPool.Put(h.huffTree)
h.huffTree = nil
}
}
}
func (h *history) setDict(dict *dict) { func (h *history) setDict(dict *dict) {
if dict == nil { if dict == nil {
return return

View File

@ -99,6 +99,21 @@ func (s *sequenceDecs) initialize(br *bitReader, hist *history, out []byte) erro
return nil return nil
} }
func (s *sequenceDecs) freeDecoders() {
if f := s.litLengths.fse; f != nil && !f.preDefined {
fseDecoderPool.Put(f)
s.litLengths.fse = nil
}
if f := s.offsets.fse; f != nil && !f.preDefined {
fseDecoderPool.Put(f)
s.offsets.fse = nil
}
if f := s.matchLengths.fse; f != nil && !f.preDefined {
fseDecoderPool.Put(f)
s.matchLengths.fse = nil
}
}
// execute will execute the decoded sequence with the provided history. // execute will execute the decoded sequence with the provided history.
// The sequence must be evaluated before being sent. // The sequence must be evaluated before being sent.
func (s *sequenceDecs) execute(seqs []seqVals, hist []byte) error { func (s *sequenceDecs) execute(seqs []seqVals, hist []byte) error {
@ -299,7 +314,10 @@ func (s *sequenceDecs) decodeSync(hist []byte) error {
} }
size := ll + ml + len(out) size := ll + ml + len(out)
if size-startSize > maxBlockSize { if size-startSize > maxBlockSize {
return fmt.Errorf("output (%d) bigger than max block size (%d)", size-startSize, maxBlockSize) if size-startSize == 424242 {
panic("here")
}
return fmt.Errorf("output bigger than max block size (%d)", maxBlockSize)
} }
if size > cap(out) { if size > cap(out) {
// Not enough size, which can happen under high volume block streaming conditions // Not enough size, which can happen under high volume block streaming conditions
@ -411,7 +429,7 @@ func (s *sequenceDecs) decodeSync(hist []byte) error {
// Check if space for literals // Check if space for literals
if size := len(s.literals) + len(s.out) - startSize; size > maxBlockSize { if size := len(s.literals) + len(s.out) - startSize; size > maxBlockSize {
return fmt.Errorf("output (%d) bigger than max block size (%d)", size, maxBlockSize) return fmt.Errorf("output bigger than max block size (%d)", maxBlockSize)
} }
// Add final literals // Add final literals

View File

@ -32,18 +32,22 @@ type decodeSyncAsmContext struct {
// sequenceDecs_decodeSync_amd64 implements the main loop of sequenceDecs.decodeSync in x86 asm. // sequenceDecs_decodeSync_amd64 implements the main loop of sequenceDecs.decodeSync in x86 asm.
// //
// Please refer to seqdec_generic.go for the reference implementation. // Please refer to seqdec_generic.go for the reference implementation.
//
//go:noescape //go:noescape
func sequenceDecs_decodeSync_amd64(s *sequenceDecs, br *bitReader, ctx *decodeSyncAsmContext) int func sequenceDecs_decodeSync_amd64(s *sequenceDecs, br *bitReader, ctx *decodeSyncAsmContext) int
// sequenceDecs_decodeSync_bmi2 implements the main loop of sequenceDecs.decodeSync in x86 asm with BMI2 extensions. // sequenceDecs_decodeSync_bmi2 implements the main loop of sequenceDecs.decodeSync in x86 asm with BMI2 extensions.
//
//go:noescape //go:noescape
func sequenceDecs_decodeSync_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeSyncAsmContext) int func sequenceDecs_decodeSync_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeSyncAsmContext) int
// sequenceDecs_decodeSync_safe_amd64 does the same as above, but does not write more than output buffer. // sequenceDecs_decodeSync_safe_amd64 does the same as above, but does not write more than output buffer.
//
//go:noescape //go:noescape
func sequenceDecs_decodeSync_safe_amd64(s *sequenceDecs, br *bitReader, ctx *decodeSyncAsmContext) int func sequenceDecs_decodeSync_safe_amd64(s *sequenceDecs, br *bitReader, ctx *decodeSyncAsmContext) int
// sequenceDecs_decodeSync_safe_bmi2 does the same as above, but does not write more than output buffer. // sequenceDecs_decodeSync_safe_bmi2 does the same as above, but does not write more than output buffer.
//
//go:noescape //go:noescape
func sequenceDecs_decodeSync_safe_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeSyncAsmContext) int func sequenceDecs_decodeSync_safe_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeSyncAsmContext) int
@ -135,7 +139,7 @@ func (s *sequenceDecs) decodeSyncSimple(hist []byte) (bool, error) {
if debugDecoder { if debugDecoder {
println("msl:", s.maxSyncLen, "cap", cap(s.out), "bef:", startSize, "sz:", size-startSize, "mbs:", maxBlockSize, "outsz:", cap(s.out)-startSize) println("msl:", s.maxSyncLen, "cap", cap(s.out), "bef:", startSize, "sz:", size-startSize, "mbs:", maxBlockSize, "outsz:", cap(s.out)-startSize)
} }
return true, fmt.Errorf("output (%d) bigger than max block size (%d)", size-startSize, maxBlockSize) return true, fmt.Errorf("output bigger than max block size (%d)", maxBlockSize)
default: default:
return true, fmt.Errorf("sequenceDecs_decode returned erronous code %d", errCode) return true, fmt.Errorf("sequenceDecs_decode returned erronous code %d", errCode)
@ -143,7 +147,8 @@ func (s *sequenceDecs) decodeSyncSimple(hist []byte) (bool, error) {
s.seqSize += ctx.litRemain s.seqSize += ctx.litRemain
if s.seqSize > maxBlockSize { if s.seqSize > maxBlockSize {
return true, fmt.Errorf("output (%d) bigger than max block size (%d)", s.seqSize, maxBlockSize) return true, fmt.Errorf("output bigger than max block size (%d)", maxBlockSize)
} }
err := br.close() err := br.close()
if err != nil { if err != nil {
@ -201,20 +206,24 @@ const errorNotEnoughSpace = 5
// sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm. // sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm.
// //
// Please refer to seqdec_generic.go for the reference implementation. // Please refer to seqdec_generic.go for the reference implementation.
//
//go:noescape //go:noescape
func sequenceDecs_decode_amd64(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int func sequenceDecs_decode_amd64(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int
// sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm. // sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm.
// //
// Please refer to seqdec_generic.go for the reference implementation. // Please refer to seqdec_generic.go for the reference implementation.
//
//go:noescape //go:noescape
func sequenceDecs_decode_56_amd64(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int func sequenceDecs_decode_56_amd64(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int
// sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm with BMI2 extensions. // sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm with BMI2 extensions.
//
//go:noescape //go:noescape
func sequenceDecs_decode_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int func sequenceDecs_decode_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int
// sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm with BMI2 extensions. // sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm with BMI2 extensions.
//
//go:noescape //go:noescape
func sequenceDecs_decode_56_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int func sequenceDecs_decode_56_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int
@ -281,7 +290,7 @@ func (s *sequenceDecs) decode(seqs []seqVals) error {
s.seqSize += ctx.litRemain s.seqSize += ctx.litRemain
if s.seqSize > maxBlockSize { if s.seqSize > maxBlockSize {
return fmt.Errorf("output (%d) bigger than max block size (%d)", s.seqSize, maxBlockSize) return fmt.Errorf("output bigger than max block size (%d)", maxBlockSize)
} }
err := br.close() err := br.close()
if err != nil { if err != nil {
@ -308,10 +317,12 @@ type executeAsmContext struct {
// Returns false if a match offset is too big. // Returns false if a match offset is too big.
// //
// Please refer to seqdec_generic.go for the reference implementation. // Please refer to seqdec_generic.go for the reference implementation.
//
//go:noescape //go:noescape
func sequenceDecs_executeSimple_amd64(ctx *executeAsmContext) bool func sequenceDecs_executeSimple_amd64(ctx *executeAsmContext) bool
// Same as above, but with safe memcopies // Same as above, but with safe memcopies
//
//go:noescape //go:noescape
func sequenceDecs_executeSimple_safe_amd64(ctx *executeAsmContext) bool func sequenceDecs_executeSimple_safe_amd64(ctx *executeAsmContext) bool

View File

@ -1,7 +1,6 @@
// Code generated by command: go run gen.go -out ../seqdec_amd64.s -pkg=zstd. DO NOT EDIT. // Code generated by command: go run gen.go -out ../seqdec_amd64.s -pkg=zstd. DO NOT EDIT.
//go:build !appengine && !noasm && gc && !noasm //go:build !appengine && !noasm && gc && !noasm
// +build !appengine,!noasm,gc,!noasm
// func sequenceDecs_decode_amd64(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int // func sequenceDecs_decode_amd64(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int
// Requires: CMOV // Requires: CMOV

View File

@ -111,7 +111,7 @@ func (s *sequenceDecs) decode(seqs []seqVals) error {
} }
s.seqSize += ll + ml s.seqSize += ll + ml
if s.seqSize > maxBlockSize { if s.seqSize > maxBlockSize {
return fmt.Errorf("output (%d) bigger than max block size (%d)", s.seqSize, maxBlockSize) return fmt.Errorf("output bigger than max block size (%d)", maxBlockSize)
} }
litRemain -= ll litRemain -= ll
if litRemain < 0 { if litRemain < 0 {
@ -149,7 +149,7 @@ func (s *sequenceDecs) decode(seqs []seqVals) error {
} }
s.seqSize += litRemain s.seqSize += litRemain
if s.seqSize > maxBlockSize { if s.seqSize > maxBlockSize {
return fmt.Errorf("output (%d) bigger than max block size (%d)", s.seqSize, maxBlockSize) return fmt.Errorf("output bigger than max block size (%d)", maxBlockSize)
} }
err := br.close() err := br.close()
if err != nil { if err != nil {

View File

@ -1,5 +1,7 @@
package core package core
import "fmt"
func newChallenge(challengeType AcmeChallenge, token string) Challenge { func newChallenge(challengeType AcmeChallenge, token string) Challenge {
return Challenge{ return Challenge{
Type: challengeType, Type: challengeType,
@ -25,3 +27,19 @@ func DNSChallenge01(token string) Challenge {
func TLSALPNChallenge01(token string) Challenge { func TLSALPNChallenge01(token string) Challenge {
return newChallenge(ChallengeTypeTLSALPN01, token) return newChallenge(ChallengeTypeTLSALPN01, token)
} }
// NewChallenge constructs a random challenge of the given kind. It returns an
// error if the challenge type is unrecognized. If token is empty a random token
// will be generated, otherwise the provided token is used.
func NewChallenge(kind AcmeChallenge, token string) (Challenge, error) {
switch kind {
case ChallengeTypeHTTP01:
return HTTPChallenge01(token), nil
case ChallengeTypeDNS01:
return DNSChallenge01(token), nil
case ChallengeTypeTLSALPN01:
return TLSALPNChallenge01(token), nil
default:
return Challenge{}, fmt.Errorf("unrecognized challenge type %q", kind)
}
}

View File

@ -7,8 +7,8 @@ import (
// PolicyAuthority defines the public interface for the Boulder PA // PolicyAuthority defines the public interface for the Boulder PA
// TODO(#5891): Move this interface to a more appropriate location. // TODO(#5891): Move this interface to a more appropriate location.
type PolicyAuthority interface { type PolicyAuthority interface {
WillingToIssue(domain identifier.ACMEIdentifier) error WillingToIssueWildcards([]identifier.ACMEIdentifier) error
WillingToIssueWildcards(identifiers []identifier.ACMEIdentifier) error ChallengesFor(identifier.ACMEIdentifier) ([]Challenge, error)
ChallengesFor(domain identifier.ACMEIdentifier) ([]Challenge, error) ChallengeTypeEnabled(AcmeChallenge) bool
ChallengeTypeEnabled(t AcmeChallenge) bool CheckAuthz(*Authorization) error
} }

View File

@ -2,7 +2,6 @@ package core
import ( import (
"crypto" "crypto"
"crypto/x509"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -11,7 +10,8 @@ import (
"strings" "strings"
"time" "time"
"gopkg.in/square/go-jose.v2" "golang.org/x/crypto/ocsp"
"gopkg.in/go-jose/go-jose.v2"
"github.com/letsencrypt/boulder/identifier" "github.com/letsencrypt/boulder/identifier"
"github.com/letsencrypt/boulder/probs" "github.com/letsencrypt/boulder/probs"
@ -52,7 +52,6 @@ const (
type AcmeChallenge string type AcmeChallenge string
// These types are the available challenges // These types are the available challenges
// TODO(#5009): Make this a custom type as well.
const ( const (
ChallengeTypeHTTP01 = AcmeChallenge("http-01") ChallengeTypeHTTP01 = AcmeChallenge("http-01")
ChallengeTypeDNS01 = AcmeChallenge("dns-01") ChallengeTypeDNS01 = AcmeChallenge("dns-01")
@ -78,47 +77,18 @@ const (
OCSPStatusRevoked = OCSPStatus("revoked") OCSPStatusRevoked = OCSPStatus("revoked")
) )
var OCSPStatusToInt = map[OCSPStatus]int{
OCSPStatusGood: ocsp.Good,
OCSPStatusRevoked: ocsp.Revoked,
}
// DNSPrefix is attached to DNS names in DNS challenges // DNSPrefix is attached to DNS names in DNS challenges
const DNSPrefix = "_acme-challenge" const DNSPrefix = "_acme-challenge"
// CertificateRequest is just a CSR
//
// This data is unmarshalled from JSON by way of RawCertificateRequest, which
// represents the actual structure received from the client.
type CertificateRequest struct {
CSR *x509.CertificateRequest // The CSR
Bytes []byte // The original bytes of the CSR, for logging.
}
type RawCertificateRequest struct { type RawCertificateRequest struct {
CSR JSONBuffer `json:"csr"` // The encoded CSR CSR JSONBuffer `json:"csr"` // The encoded CSR
} }
// UnmarshalJSON provides an implementation for decoding CertificateRequest objects.
func (cr *CertificateRequest) UnmarshalJSON(data []byte) error {
var raw RawCertificateRequest
err := json.Unmarshal(data, &raw)
if err != nil {
return err
}
csr, err := x509.ParseCertificateRequest(raw.CSR)
if err != nil {
return err
}
cr.CSR = csr
cr.Bytes = raw.CSR
return nil
}
// MarshalJSON provides an implementation for encoding CertificateRequest objects.
func (cr CertificateRequest) MarshalJSON() ([]byte, error) {
return json.Marshal(RawCertificateRequest{
CSR: cr.CSR.Raw,
})
}
// Registration objects represent non-public metadata attached // Registration objects represent non-public metadata attached
// to account keys. // to account keys.
type Registration struct { type Registration struct {
@ -169,11 +139,6 @@ type ValidationRecord struct {
// ... // ...
// } // }
AddressesTried []net.IP `json:"addressesTried,omitempty"` AddressesTried []net.IP `json:"addressesTried,omitempty"`
// OldTLS is true if any request in the validation chain used HTTPS and negotiated
// a TLS version lower than 1.2.
// TODO(#6011): Remove once TLS 1.0 and 1.1 support is gone.
OldTLS bool `json:"oldTLS,omitempty"`
} }
func looksLikeKeyAuthorization(str string) error { func looksLikeKeyAuthorization(str string) error {
@ -398,38 +363,25 @@ func (authz *Authorization) FindChallengeByStringID(id string) int {
// SolvedBy will look through the Authorizations challenges, returning the type // SolvedBy will look through the Authorizations challenges, returning the type
// of the *first* challenge it finds with Status: valid, or an error if no // of the *first* challenge it finds with Status: valid, or an error if no
// challenge is valid. // challenge is valid.
func (authz *Authorization) SolvedBy() (*AcmeChallenge, error) { func (authz *Authorization) SolvedBy() (AcmeChallenge, error) {
if len(authz.Challenges) == 0 { if len(authz.Challenges) == 0 {
return nil, fmt.Errorf("Authorization has no challenges") return "", fmt.Errorf("Authorization has no challenges")
} }
for _, chal := range authz.Challenges { for _, chal := range authz.Challenges {
if chal.Status == StatusValid { if chal.Status == StatusValid {
return &chal.Type, nil return chal.Type, nil
} }
} }
return nil, fmt.Errorf("Authorization not solved by any challenge") return "", fmt.Errorf("Authorization not solved by any challenge")
} }
// JSONBuffer fields get encoded and decoded JOSE-style, in base64url encoding // JSONBuffer fields get encoded and decoded JOSE-style, in base64url encoding
// with stripped padding. // with stripped padding.
type JSONBuffer []byte type JSONBuffer []byte
// URL-safe base64 encode that strips padding
func base64URLEncode(data []byte) string {
var result = base64.URLEncoding.EncodeToString(data)
return strings.TrimRight(result, "=")
}
// URL-safe base64 decoder that adds padding
func base64URLDecode(data string) ([]byte, error) {
var missing = (4 - len(data)%4) % 4
data += strings.Repeat("=", missing)
return base64.URLEncoding.DecodeString(data)
}
// MarshalJSON encodes a JSONBuffer for transmission. // MarshalJSON encodes a JSONBuffer for transmission.
func (jb JSONBuffer) MarshalJSON() (result []byte, err error) { func (jb JSONBuffer) MarshalJSON() (result []byte, err error) {
return json.Marshal(base64URLEncode(jb)) return json.Marshal(base64.RawURLEncoding.EncodeToString(jb))
} }
// UnmarshalJSON decodes a JSONBuffer to an object. // UnmarshalJSON decodes a JSONBuffer to an object.
@ -439,7 +391,7 @@ func (jb *JSONBuffer) UnmarshalJSON(data []byte) (err error) {
if err != nil { if err != nil {
return err return err
} }
*jb, err = base64URLDecode(str) *jb, err = base64.RawURLEncoding.DecodeString(strings.TrimRight(str, "="))
return return
} }
@ -534,3 +486,34 @@ type SuggestedWindow struct {
type RenewalInfo struct { type RenewalInfo struct {
SuggestedWindow SuggestedWindow `json:"suggestedWindow"` SuggestedWindow SuggestedWindow `json:"suggestedWindow"`
} }
// RenewalInfoSimple constructs a `RenewalInfo` object and suggested window
// using a very simple renewal calculation: calculate a point 2/3rds of the way
// through the validity period, then give a 2-day window around that. Both the
// `issued` and `expires` timestamps are expected to be UTC.
func RenewalInfoSimple(issued time.Time, expires time.Time) RenewalInfo {
validity := expires.Add(time.Second).Sub(issued)
renewalOffset := validity / time.Duration(3)
idealRenewal := expires.Add(-renewalOffset)
return RenewalInfo{
SuggestedWindow: SuggestedWindow{
Start: idealRenewal.Add(-24 * time.Hour),
End: idealRenewal.Add(24 * time.Hour),
},
}
}
// RenewalInfoImmediate constructs a `RenewalInfo` object with a suggested
// window in the past. Per the draft-ietf-acme-ari-00 spec, clients should
// attempt to renew immediately if the suggested window is in the past. The
// passed `now` is assumed to be a timestamp representing the current moment in
// time.
func RenewalInfoImmediate(now time.Time) RenewalInfo {
oneHourAgo := now.Add(-1 * time.Hour)
return RenewalInfo{
SuggestedWindow: SuggestedWindow{
Start: oneHourAgo,
End: oneHourAgo.Add(time.Minute * 30),
},
}
}

View File

@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// versions: // versions:
// protoc-gen-go v1.26.0 // protoc-gen-go v1.28.0
// protoc v3.15.6 // protoc v3.20.1
// source: core.proto // source: core.proto
package proto package proto
@ -807,6 +807,69 @@ func (x *Order) GetV2Authorizations() []int64 {
return nil return nil
} }
type CRLEntry struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Serial string `protobuf:"bytes,1,opt,name=serial,proto3" json:"serial,omitempty"`
Reason int32 `protobuf:"varint,2,opt,name=reason,proto3" json:"reason,omitempty"`
RevokedAt int64 `protobuf:"varint,3,opt,name=revokedAt,proto3" json:"revokedAt,omitempty"` // Unix timestamp (nanoseconds)
}
func (x *CRLEntry) Reset() {
*x = CRLEntry{}
if protoimpl.UnsafeEnabled {
mi := &file_core_proto_msgTypes[8]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *CRLEntry) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*CRLEntry) ProtoMessage() {}
func (x *CRLEntry) ProtoReflect() protoreflect.Message {
mi := &file_core_proto_msgTypes[8]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use CRLEntry.ProtoReflect.Descriptor instead.
func (*CRLEntry) Descriptor() ([]byte, []int) {
return file_core_proto_rawDescGZIP(), []int{8}
}
func (x *CRLEntry) GetSerial() string {
if x != nil {
return x.Serial
}
return ""
}
func (x *CRLEntry) GetReason() int32 {
if x != nil {
return x.Reason
}
return 0
}
func (x *CRLEntry) GetRevokedAt() int64 {
if x != nil {
return x.RevokedAt
}
return 0
}
var File_core_proto protoreflect.FileDescriptor var File_core_proto protoreflect.FileDescriptor
var file_core_proto_rawDesc = []byte{ var file_core_proto_rawDesc = []byte{
@ -935,10 +998,16 @@ var file_core_proto_rawDesc = []byte{
0x64, 0x12, 0x2a, 0x0a, 0x10, 0x76, 0x32, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x64, 0x12, 0x2a, 0x0a, 0x10, 0x76, 0x32, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61,
0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x0b, 0x20, 0x03, 0x28, 0x03, 0x52, 0x10, 0x76, 0x32, 0x41, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x0b, 0x20, 0x03, 0x28, 0x03, 0x52, 0x10, 0x76, 0x32, 0x41,
0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x4a, 0x04, 0x08, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x4a, 0x04, 0x08,
0x06, 0x10, 0x07, 0x42, 0x2b, 0x5a, 0x29, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x06, 0x10, 0x07, 0x22, 0x58, 0x0a, 0x08, 0x43, 0x52, 0x4c, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12,
0x6d, 0x2f, 0x6c, 0x65, 0x74, 0x73, 0x65, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x2f, 0x62, 0x6f, 0x16, 0x0a, 0x06, 0x73, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52,
0x75, 0x6c, 0x64, 0x65, 0x72, 0x2f, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x06, 0x73, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x65, 0x61, 0x73, 0x6f,
0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x06, 0x72, 0x65, 0x61, 0x73, 0x6f, 0x6e, 0x12,
0x1c, 0x0a, 0x09, 0x72, 0x65, 0x76, 0x6f, 0x6b, 0x65, 0x64, 0x41, 0x74, 0x18, 0x03, 0x20, 0x01,
0x28, 0x03, 0x52, 0x09, 0x72, 0x65, 0x76, 0x6f, 0x6b, 0x65, 0x64, 0x41, 0x74, 0x42, 0x2b, 0x5a,
0x29, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6c, 0x65, 0x74, 0x73,
0x65, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x2f, 0x62, 0x6f, 0x75, 0x6c, 0x64, 0x65, 0x72, 0x2f,
0x63, 0x6f, 0x72, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x33,
} }
var ( var (
@ -953,7 +1022,7 @@ func file_core_proto_rawDescGZIP() []byte {
return file_core_proto_rawDescData return file_core_proto_rawDescData
} }
var file_core_proto_msgTypes = make([]protoimpl.MessageInfo, 8) var file_core_proto_msgTypes = make([]protoimpl.MessageInfo, 9)
var file_core_proto_goTypes = []interface{}{ var file_core_proto_goTypes = []interface{}{
(*Challenge)(nil), // 0: core.Challenge (*Challenge)(nil), // 0: core.Challenge
(*ValidationRecord)(nil), // 1: core.ValidationRecord (*ValidationRecord)(nil), // 1: core.ValidationRecord
@ -963,6 +1032,7 @@ var file_core_proto_goTypes = []interface{}{
(*Registration)(nil), // 5: core.Registration (*Registration)(nil), // 5: core.Registration
(*Authorization)(nil), // 6: core.Authorization (*Authorization)(nil), // 6: core.Authorization
(*Order)(nil), // 7: core.Order (*Order)(nil), // 7: core.Order
(*CRLEntry)(nil), // 8: core.CRLEntry
} }
var file_core_proto_depIdxs = []int32{ var file_core_proto_depIdxs = []int32{
1, // 0: core.Challenge.validationrecords:type_name -> core.ValidationRecord 1, // 0: core.Challenge.validationrecords:type_name -> core.ValidationRecord
@ -1078,6 +1148,18 @@ func file_core_proto_init() {
return nil return nil
} }
} }
file_core_proto_msgTypes[8].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*CRLEntry); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
} }
type x struct{} type x struct{}
out := protoimpl.TypeBuilder{ out := protoimpl.TypeBuilder{
@ -1085,7 +1167,7 @@ func file_core_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(), GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_core_proto_rawDesc, RawDescriptor: file_core_proto_rawDesc,
NumEnums: 0, NumEnums: 0,
NumMessages: 8, NumMessages: 9,
NumExtensions: 0, NumExtensions: 0,
NumServices: 0, NumServices: 0,
}, },

View File

@ -93,3 +93,9 @@ message Order {
int64 created = 10; int64 created = 10;
repeated int64 v2Authorizations = 11; repeated int64 v2Authorizations = 11;
} }
message CRLEntry {
string serial = 1;
int32 reason = 2;
int64 revokedAt = 3; // Unix timestamp (nanoseconds)
}

View File

@ -13,9 +13,9 @@ import (
"expvar" "expvar"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"math/big" "math/big"
mrand "math/rand" mrand "math/rand"
"os"
"reflect" "reflect"
"regexp" "regexp"
"sort" "sort"
@ -23,9 +23,11 @@ import (
"time" "time"
"unicode" "unicode"
jose "gopkg.in/square/go-jose.v2" jose "gopkg.in/go-jose/go-jose.v2"
) )
const Unspecified = "Unspecified"
// Package Variables Variables // Package Variables Variables
// BuildID is set by the compiler (using -ldflags "-X core.BuildID $(git rev-parse --short HEAD)") // BuildID is set by the compiler (using -ldflags "-X core.BuildID $(git rev-parse --short HEAD)")
@ -182,7 +184,7 @@ func ValidSerial(serial string) bool {
func GetBuildID() (retID string) { func GetBuildID() (retID string) {
retID = BuildID retID = BuildID
if retID == "" { if retID == "" {
retID = "Unspecified" retID = Unspecified
} }
return return
} }
@ -191,7 +193,7 @@ func GetBuildID() (retID string) {
func GetBuildTime() (retID string) { func GetBuildTime() (retID string) {
retID = BuildTime retID = BuildTime
if retID == "" { if retID == "" {
retID = "Unspecified" retID = Unspecified
} }
return return
} }
@ -200,7 +202,7 @@ func GetBuildTime() (retID string) {
func GetBuildHost() (retID string) { func GetBuildHost() (retID string) {
retID = BuildHost retID = BuildHost
if retID == "" { if retID == "" {
retID = "Unspecified" retID = Unspecified
} }
return return
} }
@ -245,7 +247,7 @@ func UniqueLowerNames(names []string) (unique []string) {
// LoadCert loads a PEM certificate specified by filename or returns an error // LoadCert loads a PEM certificate specified by filename or returns an error
func LoadCert(filename string) (*x509.Certificate, error) { func LoadCert(filename string) (*x509.Certificate, error) {
certPEM, err := ioutil.ReadFile(filename) certPEM, err := os.ReadFile(filename)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -1,7 +1,18 @@
// Package errors provides internal-facing error types for use in Boulder. Many
// of these are transformed directly into Problem Details documents by the WFE.
// Some, like NotFound, may be handled internally. We avoid using Problem
// Details documents as part of our internal error system to avoid layering
// confusions.
//
// These errors are specifically for use in errors that cross RPC boundaries.
// An error type that does not need to be passed through an RPC can use a plain
// Go type locally. Our gRPC code is aware of these error types and will
// serialize and deserialize them automatically.
package errors package errors
import ( import (
"fmt" "fmt"
"time"
"github.com/letsencrypt/boulder/identifier" "github.com/letsencrypt/boulder/identifier"
) )
@ -12,7 +23,10 @@ import (
// BoulderError wrapping one of these types. // BoulderError wrapping one of these types.
type ErrorType int type ErrorType int
// These numeric constants are used when sending berrors through gRPC.
const ( const (
// InternalServer is deprecated. Instead, pass a plain Go error. That will get
// turned into a probs.InternalServerError by the WFE.
InternalServer ErrorType = iota InternalServer ErrorType = iota
_ _
Malformed Malformed
@ -43,6 +57,10 @@ type BoulderError struct {
Type ErrorType Type ErrorType
Detail string Detail string
SubErrors []SubBoulderError SubErrors []SubBoulderError
// RetryAfter the duration a client should wait before retrying the request
// which resulted in this error.
RetryAfter time.Duration
} }
// SubBoulderError represents sub-errors specific to an identifier that are // SubBoulderError represents sub-errors specific to an identifier that are
@ -67,6 +85,7 @@ func (be *BoulderError) WithSubErrors(subErrs []SubBoulderError) *BoulderError {
Type: be.Type, Type: be.Type,
Detail: be.Detail, Detail: be.Detail,
SubErrors: append(be.SubErrors, subErrs...), SubErrors: append(be.SubErrors, subErrs...),
RetryAfter: be.RetryAfter,
} }
} }
@ -94,10 +113,35 @@ func NotFoundError(msg string, args ...interface{}) error {
return New(NotFound, msg, args...) return New(NotFound, msg, args...)
} }
func RateLimitError(msg string, args ...interface{}) error { func RateLimitError(retryAfter time.Duration, msg string, args ...interface{}) error {
return &BoulderError{ return &BoulderError{
Type: RateLimit, Type: RateLimit,
Detail: fmt.Sprintf(msg+": see https://letsencrypt.org/docs/rate-limits/", args...), Detail: fmt.Sprintf(msg+": see https://letsencrypt.org/docs/rate-limits/", args...),
RetryAfter: retryAfter,
}
}
func DuplicateCertificateError(retryAfter time.Duration, msg string, args ...interface{}) error {
return &BoulderError{
Type: RateLimit,
Detail: fmt.Sprintf(msg+": see https://letsencrypt.org/docs/duplicate-certificate-limit/", args...),
RetryAfter: retryAfter,
}
}
func FailedValidationError(retryAfter time.Duration, msg string, args ...interface{}) error {
return &BoulderError{
Type: RateLimit,
Detail: fmt.Sprintf(msg+": see https://letsencrypt.org/docs/failed-validation-limit/", args...),
RetryAfter: retryAfter,
}
}
func RegistrationsPerIPError(retryAfter time.Duration, msg string, args ...interface{}) error {
return &BoulderError{
Type: RateLimit,
Detail: fmt.Sprintf(msg+": see https://letsencrypt.org/docs/too-many-registrations-for-this-ip/", args...),
RetryAfter: retryAfter,
} }
} }

View File

@ -1,45 +0,0 @@
// Code generated by "stringer -type=FeatureFlag"; DO NOT EDIT.
package features
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[unused-0]
_ = x[PrecertificateRevocation-1]
_ = x[StripDefaultSchemePort-2]
_ = x[NonCFSSLSigner-3]
_ = x[StoreIssuerInfo-4]
_ = x[StreamlineOrderAndAuthzs-5]
_ = x[V1DisableNewValidations-6]
_ = x[CAAValidationMethods-7]
_ = x[CAAAccountURI-8]
_ = x[EnforceMultiVA-9]
_ = x[MultiVAFullResults-10]
_ = x[MandatoryPOSTAsGET-11]
_ = x[AllowV1Registration-12]
_ = x[StoreRevokerInfo-13]
_ = x[RestrictRSAKeySizes-14]
_ = x[FasterNewOrdersRateLimit-15]
_ = x[ECDSAForAll-16]
_ = x[ServeRenewalInfo-17]
_ = x[GetAuthzReadOnly-18]
_ = x[GetAuthzUseIndex-19]
_ = x[CheckFailedAuthorizationsFirst-20]
_ = x[AllowReRevocation-21]
_ = x[MozRevocationReasons-22]
}
const _FeatureFlag_name = "unusedPrecertificateRevocationStripDefaultSchemePortNonCFSSLSignerStoreIssuerInfoStreamlineOrderAndAuthzsV1DisableNewValidationsCAAValidationMethodsCAAAccountURIEnforceMultiVAMultiVAFullResultsMandatoryPOSTAsGETAllowV1RegistrationStoreRevokerInfoRestrictRSAKeySizesFasterNewOrdersRateLimitECDSAForAllServeRenewalInfoGetAuthzReadOnlyGetAuthzUseIndexCheckFailedAuthorizationsFirstAllowReRevocationMozRevocationReasons"
var _FeatureFlag_index = [...]uint16{0, 6, 30, 52, 66, 81, 105, 128, 148, 161, 175, 193, 211, 230, 246, 265, 289, 300, 316, 332, 348, 378, 395, 415}
func (i FeatureFlag) String() string {
if i < 0 || i >= FeatureFlag(len(_FeatureFlag_index)-1) {
return "FeatureFlag(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _FeatureFlag_name[_FeatureFlag_index[i]:_FeatureFlag_index[i+1]]
}

View File

@ -1,158 +0,0 @@
//go:generate stringer -type=FeatureFlag
package features
import (
"fmt"
"sync"
)
type FeatureFlag int
const (
unused FeatureFlag = iota // unused is used for testing
// Deprecated features, these can be removed once stripped from production configs
PrecertificateRevocation
StripDefaultSchemePort
NonCFSSLSigner
StoreIssuerInfo
StreamlineOrderAndAuthzs
V1DisableNewValidations
// Currently in-use features
// Check CAA and respect validationmethods parameter.
CAAValidationMethods
// Check CAA and respect accounturi parameter.
CAAAccountURI
// EnforceMultiVA causes the VA to block on remote VA PerformValidation
// requests in order to make a valid/invalid decision with the results.
EnforceMultiVA
// MultiVAFullResults will cause the main VA to wait for all of the remote VA
// results, not just the threshold required to make a decision.
MultiVAFullResults
// MandatoryPOSTAsGET forbids legacy unauthenticated GET requests for ACME
// resources.
MandatoryPOSTAsGET
// Allow creation of new registrations in ACMEv1.
AllowV1Registration
// StoreRevokerInfo enables storage of the revoker and a bool indicating if the row
// was checked for extant unrevoked certificates in the blockedKeys table.
StoreRevokerInfo
// RestrictRSAKeySizes enables restriction of acceptable RSA public key moduli to
// the common sizes (2048, 3072, and 4096 bits).
RestrictRSAKeySizes
// FasterNewOrdersRateLimit enables use of a separate table for counting the
// new orders rate limit.
FasterNewOrdersRateLimit
// ECDSAForAll enables all accounts, regardless of their presence in the CA's
// ecdsaAllowedAccounts config value, to get issuance from ECDSA issuers.
ECDSAForAll
// ServeRenewalInfo exposes the renewalInfo endpoint in the directory and for
// GET requests. WARNING: This feature is a draft and highly unstable.
ServeRenewalInfo
// GetAuthzReadOnly causes the SA to use its read-only database connection
// (which is generally pointed at a replica rather than the primary db) when
// querying the authz2 table.
GetAuthzReadOnly
// GetAuthzUseIndex causes the SA to use to add a USE INDEX hint when it
// queries the authz2 table.
GetAuthzUseIndex
// Check the failed authorization limit before doing authz reuse.
CheckFailedAuthorizationsFirst
// AllowReRevocation causes the RA to allow the revocation reason of an
// already-revoked certificate to be updated to `keyCompromise` from any
// other reason if that compromise is demonstrated by making the second
// revocation request signed by the certificate keypair.
AllowReRevocation
// MozRevocationReasons causes the RA to enforce the following upcoming
// Mozilla policies regarding revocation:
// - A subscriber can request that their certificate be revoked with reason
// keyCompromise, even without demonstrating that compromise at the time.
// However, the cert's pubkey will not be added to the blocked keys list.
// - When an applicant other than the original subscriber requests that a
// certificate be revoked (by demonstrating control over all names in it),
// the cert will be revoked with reason cessationOfOperation, regardless of
// what revocation reason they request.
// - When anyone requests that a certificate be revoked by signing the request
// with the certificate's keypair, the cert will be revoked with reason
// keyCompromise, regardless of what revocation reason they request.
MozRevocationReasons
)
// List of features and their default value, protected by fMu
var features = map[FeatureFlag]bool{
unused: false,
CAAValidationMethods: false,
CAAAccountURI: false,
EnforceMultiVA: false,
MultiVAFullResults: false,
MandatoryPOSTAsGET: false,
AllowV1Registration: true,
V1DisableNewValidations: false,
PrecertificateRevocation: false,
StripDefaultSchemePort: false,
StoreIssuerInfo: false,
StoreRevokerInfo: false,
RestrictRSAKeySizes: false,
FasterNewOrdersRateLimit: false,
NonCFSSLSigner: false,
ECDSAForAll: false,
StreamlineOrderAndAuthzs: false,
ServeRenewalInfo: false,
GetAuthzReadOnly: false,
GetAuthzUseIndex: false,
CheckFailedAuthorizationsFirst: false,
AllowReRevocation: false,
MozRevocationReasons: false,
}
var fMu = new(sync.RWMutex)
var initial = map[FeatureFlag]bool{}
var nameToFeature = make(map[string]FeatureFlag, len(features))
func init() {
for f, v := range features {
nameToFeature[f.String()] = f
initial[f] = v
}
}
// Set accepts a list of features and whether they should
// be enabled or disabled, it will return a error if passed
// a feature name that it doesn't know
func Set(featureSet map[string]bool) error {
fMu.Lock()
defer fMu.Unlock()
for n, v := range featureSet {
f, present := nameToFeature[n]
if !present {
return fmt.Errorf("feature '%s' doesn't exist", n)
}
features[f] = v
}
return nil
}
// Enabled returns true if the feature is enabled or false
// if it isn't, it will panic if passed a feature that it
// doesn't know.
func Enabled(n FeatureFlag) bool {
fMu.RLock()
defer fMu.RUnlock()
v, present := features[n]
if !present {
panic(fmt.Sprintf("feature '%s' doesn't exist", n.String()))
}
return v
}
// Reset resets the features to their initial state
func Reset() {
fMu.Lock()
defer fMu.Unlock()
for k, v := range initial {
features[k] = v
}
}

View File

@ -6,11 +6,11 @@ import (
"encoding/base64" "encoding/base64"
"encoding/hex" "encoding/hex"
"errors" "errors"
"io/ioutil" "os"
"github.com/letsencrypt/boulder/core" "github.com/letsencrypt/boulder/core"
yaml "gopkg.in/yaml.v2" yaml "gopkg.in/yaml.v3"
) )
// blockedKeys is a type for maintaining a map of SHA256 hashes // blockedKeys is a type for maintaining a map of SHA256 hashes
@ -42,16 +42,14 @@ func (b blockedKeys) blocked(key crypto.PublicKey) (bool, error) {
// SHA256 hashes of SubjectPublicKeyInfo's in the input YAML file // SHA256 hashes of SubjectPublicKeyInfo's in the input YAML file
// with the expected format: // with the expected format:
// //
// ```
// blocked: // blocked:
// - cuwGhNNI6nfob5aqY90e7BleU6l7rfxku4X3UTJ3Z7M= // - cuwGhNNI6nfob5aqY90e7BleU6l7rfxku4X3UTJ3Z7M=
// <snipped> // <snipped>
// - Qebc1V3SkX3izkYRGNJilm9Bcuvf0oox4U2Rn+b4JOE= // - Qebc1V3SkX3izkYRGNJilm9Bcuvf0oox4U2Rn+b4JOE=
// ```
// //
// If no hashes are found in the input YAML an error is returned. // If no hashes are found in the input YAML an error is returned.
func loadBlockedKeysList(filename string) (*blockedKeys, error) { func loadBlockedKeysList(filename string) (*blockedKeys, error) {
yamlBytes, err := ioutil.ReadFile(filename) yamlBytes, err := os.ReadFile(filename)
if err != nil { if err != nil {
return nil, err return nil, err
} }

Some files were not shown because too many files have changed in this diff Show More