diff --git a/.gitignore b/.gitignore index 251b391c1..cc1c664a6 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,4 @@ Makefile.main .ci/ _example/main _example/*.exe +test-server diff --git a/.travis.yml b/.travis.yml index 11541d621..e06f94818 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,21 +1,108 @@ language: go +go_import_path: github.com/src-d/go-mysql-server -go: - - 1.9.x - - 1.10.x - - tip +env: + global: + - LD_LIBRARY_PATH="/usr/local/lib":${LD_LIBRARY_PATH} + - GO111MODULE=on + - GOPROXY=https://proxy.golang.org -go_import_path: gopkg.in/src-d/go-mysql-server.v0 +addons: + apt: + packages: + - libmysqlclient-dev matrix: fast_finish: true - allow_failures: - - go: tip -sudo: required - -install: - - make dependencies +before_script: + - sudo service mysql stop script: - - make test-coverage + - make ci-script + +jobs: + include: + - go: 1.11.x + name: 'Unit tests Go 1.11' + - go: 1.12.x + name: 'Unit tests Go 1.12' + + # Integration test builds for 3rd party clients + - go: 1.12.x + name: 'Integration test go' + script: + - make TEST=go integration + + - language: python + python: '3.6' + before_install: + - eval "$(gimme 1.12.4)" + name: 'Integration test python-pymysql' + script: + - make TEST=python-pymysql integration + + - language: python + python: '3.6' + before_install: + - eval "$(gimme 1.12.4)" + name: 'Integration test python-mysql' + script: + - make TEST=python-mysql integration + + - language: python + python: '3.6' + before_install: + - eval "$(gimme 1.12.4)" + name: 'Integration test python-sqlalchemy' + script: + - make TEST=python-sqlalchemy integration + + - language: php + php: '7.1' + before_install: + - eval "$(gimme 1.12.4)" + name: 'Integration test php' + script: + - make TEST=php integration + + - language: ruby + ruby: '2.3' + before_install: + - eval "$(gimme 1.12.4)" + name: 'Integration test ruby' + script: + - make TEST=ruby integration + + - language: java + jdk: openjdk8 + before_install: + - eval "$(gimme 1.12.4)" + name: 'Integration test jdbc-mariadb' + script: + - make TEST=jdbc-mariadb integration + + - language: node_js + node_js: '12' + before_install: + - eval "$(gimme 1.12.4)" + name: 'Integration test javascript' + script: + - make TEST=javascript integration + + - language: csharp + mono: none + dotnet: '2.1' + before_install: + - eval "$(gimme 1.12.4)" + name: 'Integration test dotnet' + script: + - make TEST=dotnet integration + + - language: c + compiler: clang + before_install: + - eval "$(gimme 1.12.4)" + name: 'Integration test c' + script: + - make TEST=c integration diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md new file mode 100644 index 000000000..24f3aebe8 --- /dev/null +++ b/ARCHITECTURE.md @@ -0,0 +1,142 @@ +# Architecture overview + +This document provides an overview of all parts and pieces of the project as well as how they fit together. It is meant to help new contributors understand where things may be, and how changes in some components may interact with other components of the system. + +## Root package (`sqle`) + +This is where the engine lives. The engine is the piece that coordinates and makes all other pieces work together as well as the main API users of the system will use to create and configure an engine and perform queries. + +Because this is the point where all components fit together, it is also where integration tests are. Those integration tests can be found in `engine_test.go`. +A test should be added here, plus in any specific place where the feature/issue belonged, if needed. + +**How to add integration tests** + +The test executing all integration test is `TestQueries`, which you can run with the following command: + +``` +go test -run=TestQueries +``` + +This test is just executing all the queries in a loop. New test cases should be added to the `queries` package variable at the top of `engine_test.go`. +Simply add a new element to the slice with the query and the expected result. + +## `sql` + +This package is probably the most important of the project. It has several main roles: +- Defines the main interfaces used in the rest of the packages `Node`, `Expression`, ... +- Provides implementations of components used in the rest of the packages `Row`, `Context`, `ProcessList`, `Catalog`, ... +- Defines the `information_schema` table, which is a special table available in all databases and contains some data about the schemas of other tables. + +### `sql/analyzer` + +The analyzer is the more complex component of the project. It contains a main component, which is the `Analyzer`, in charge of executing its registered rules on execution trees for resolving some parts, removing redundant data, optimizing things for performance, etc. + +There are several phases on the analyzer, because some rules need to be run before others, some need to be executed several times, other just once, etc. +Inside `rules.go` are all the default rules and the phases in which they're executed. + +On top of that, all available rules are defined in this package. Each rule has a specific role in the analyzer. Rules should be as small and atomic as possible and try to do only one job and always produce a tree that is as resolved as the one it received or more. + +### `sql/expression` + +This package includes the implementation of all the SQL expressions available in go-mysql-server, except functions. Arithmetic operators, logic operators, conversions, etc are implemented here. + +Inside `registry.go` there is a registry of all the default functions, even if they're not defined here. + +`Inspect` and `Walk` utility functions are provided to inspect expressions. + +### `sql/expression/function` + +Implementation of all the functions available in go-mysql-server. + +### `sql/expression/function/aggregation` + +Implementation of all the aggregation functions available in go-mysql-server. + +### `sql/index` + +Contains the index driver configuration file implementation and other utilities for dealing with index drivers. + +### `sql/index/pilosa` + +Actual implementation of an index driver. Underneath, it's using a bitmap database called pilosa (hence the name) to implement bitmap indexes. + +### `sql/parse` + +This package exposes the `Parse` function, which parses a SQL query and translates it into an execution plan. + +Parsing is done using `vitess` parser, but sometimes there are queries vitess cannot parse. In this case, custom parsers are used. Otherwise, vitess is used to parse the query and then converted to a go-mysql-server execution plan. + +### `sql/plan` + +All the different nodes of the execution plan (except for very specific nodes used in some optimisation rules) are defined here. + +For example, `SELECT foo FROM bar` is translated into the following plan: + +``` +Project(foo) + |- Table(bar) +``` + +Which means, the execution plan is a `Project` node projecting `foo` and has a `ResolvedTable`, which is `bar` as its children. + +Each node inside this package implements at least the `sql.Node` interface, but it can implement more. `sql.Expressioner`, for example. + +Along with the nodes, `Inspect` and `Walk` functions are provided as utilities to inspect an execution tree. + +## `server` + +Contains all the code to turn an engine into a runnable server that can communicate using the MySQL wire protocol. + +## `auth` + +This package contains all the code related to the audit log, authentication and permission management in go-mysql-server. + +There are two authentication methods: +- **None:** no authentication needed. +- **Native:** authentication performed with user and password. Read, write or all permissions can be specified for those users. It can also be configured using a JSON file. + +## `internal/similartext` + +Contains a function to `Find` the most similar name from an +array to a given one using the Levenshtein distance algorithm. Used for suggestions on errors. + +## `internal/regex` + +go-mysql-server has multiple regular expression engines, such as oniguruma and the standard Go regexp engine. In this package, a common interface for regular expression engines is defined. +This means, Go standard library `regexp` package should not be used in any user-facing feature, instead this package should be used. + +The default engine is oniguruma, but the Go standard library engine can be used using the `mysql_go_regex` build tag. + +## `test` + +Test contains pieces that are only used for tests, such as an opentracing tracer that stores spans in memory to be inspected later in the tests. + +## `_integration` + +To ensure compatibility with some clients, there is a small example connecting and querying a go-mysql-server server from those clients. Each folder corresponds to a different client. + +For more info about supported clients see [SUPPORTED_CLIENTS.md](/SUPPORTED_CLIENTS.md). + +These integrations tests can be run using this command: + +``` +make TEST=${CLIENT FOLDER NAME} integration +``` + +It will take care of setting up the test server and shutting it down. + +## `_example` + +A small example of how to use go-mysql-server to create a server and run it. + +# Connecting the dots + +`server` uses the engine defined in `sql`. + +Engine uses audit logs and authentication defined in `auth`, parses using `sql/parse` to convert a query into an execution plan, with nodes defined in `sql/plan` and expressions defined in `sql/expression`, `sql/expression/function` and `sql/expression/function/aggregation`. + +After parsing, the obtained execution plan is analyzed using the analyzer defined in `sql/analyzer` and its rules to resolve tables, fields, databases, apply optimisation rules, etc. + +If indexes can be used, the analyzer will transform the query so it uses indexes reading from the drivers in `sql/index` (in this case `sql/index/pilosa` because there is only one driver). + +Once the plan is analyzed, it will be executed recursively from the top of the tree to the bottom to obtain the results and they will be sent back to the client using the MySQL wire protocol. diff --git a/CNAME b/CNAME deleted file mode 100644 index 8b0cae80b..000000000 --- a/CNAME +++ /dev/null @@ -1 +0,0 @@ -sqle.io \ No newline at end of file diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 000000000..f48cffb3a --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1 @@ +* @src-d/data-processing diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 000000000..71a98af05 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,62 @@ +# source{d} Contributing Guidelines + +source{d} projects accept contributions via GitHub pull requests. +This document outlines some of the +conventions on development workflow, commit message formatting, contact points, +and other resources to make it easier to get your contribution accepted. + +## Certificate of Origin + +By contributing to this project you agree to the [Developer Certificate of +Origin (DCO)](DCO). This document was created by the Linux Kernel community and is a +simple statement that you, as a contributor, have the legal right to make the +contribution. + +In order to show your agreement with the DCO you should include at the end of commit message, +the following line: `Signed-off-by: John Doe `, using your real name. + +This can be done easily using the [`-s`](https://github.com/git/git/blob/b2c150d3aa82f6583b9aadfecc5f8fa1c74aca09/Documentation/git-commit.txt#L154-L161) flag on the `git commit`. + +## Support Channels + +The official support channels, for both users and contributors, are: + +- GitHub issues: each repository has its own list of issues. +- Slack: join the [source{d} Slack](https://join.slack.com/t/sourced-community/shared_invite/enQtMjc4Njk5MzEyNzM2LTFjNzY4NjEwZGEwMzRiNTM4MzRlMzQ4MmIzZjkwZmZlM2NjODUxZmJjNDI1OTcxNDAyMmZlNmFjODZlNTg0YWM) community. + +*Before opening a new issue or submitting a new pull request, it's helpful to +search the project - it's likely that another user has already reported the +issue you're facing, or it's a known issue that we're already aware of. + +## How to Contribute + +Pull Requests (PRs) are the main and exclusive way to contribute code to source{d} projects. +In order for a PR to be accepted it needs to pass a list of requirements: + +- The contribution must be correctly explained with natural language and providing a minimum working example that reproduces it. +- All PRs must be written idiomatically: + - for Go: formatted according to [gofmt](https://golang.org/cmd/gofmt/), and without any warnings from [go lint](https://github.com/golang/lint) nor [go vet](https://golang.org/cmd/vet/) + - for other languages, similar constraints apply. +- They should in general include tests, and those shall pass. + - If the PR is a bug fix, it has to include a new unit test that fails before the patch is merged. + - If the PR is a new feature, it has to come with a suite of unit tests, that tests the new functionality. + - In any case, all the PRs have to pass the personal evaluation of at least one of the [maintainers](MAINTAINERS) of the project. + +### Getting started + +If you are a new contributor to the project, reading [ARCHITECTURE.md](/ARCHITECTURE.md) is highly recommended, as it contains all the details about the architecture of go-mysql-server and its components. + + +### Format of the commit message + +Every commit message should describe what was changed, under which context and, if applicable, the GitHub issue it relates to: + +``` +plumbing: packp, Skip argument validations for unknown capabilities. Fixes #623 +``` + +The format can be described more formally as follows: + +``` +: , . [Fixes #] +``` diff --git a/DCO b/DCO new file mode 100644 index 000000000..716561d5d --- /dev/null +++ b/DCO @@ -0,0 +1,36 @@ +Developer Certificate of Origin +Version 1.1 + +Copyright (C) 2004, 2006 The Linux Foundation and its contributors. +660 York Street, Suite 102, +San Francisco, CA 94110 USA + +Everyone is permitted to copy and distribute verbatim copies of this +license document, but changing it is not allowed. + + +Developer's Certificate of Origin 1.1 + +By making a contribution to this project, I certify that: + +(a) The contribution was created in whole or in part by me and I + have the right to submit it under the open source license + indicated in the file; or + +(b) The contribution is based upon previous work that, to the best + of my knowledge, is covered under an appropriate open source + license and I have the right under that license to submit that + work with modifications, whether created in whole or in part + by me, under the same open source license (unless I am + permitted to submit under a different license), as indicated + in the file; or + +(c) The contribution was provided directly to me by some other + person who certified (a), (b) or (c) and I have not modified + it. + +(d) I understand and agree that this project and the contribution + are public and that a record of the contribution (including all + personal information I submit with it, including my sign-off) is + maintained indefinitely and may be redistributed consistent with + this project or the open source license(s) involved. diff --git a/LICENSE b/LICENSE index 8dada3eda..261eeb9e9 100644 --- a/LICENSE +++ b/LICENSE @@ -178,7 +178,7 @@ APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "{}" + boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a @@ -186,7 +186,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright {yyyy} {name of copyright owner} + Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/MAINTAINERS b/MAINTAINERS index dc6817245..8d8de2618 100644 --- a/MAINTAINERS +++ b/MAINTAINERS @@ -1 +1,2 @@ -Antonio Navarro Perez (@ajnavarro) +Miguel Molina (@erizocosmico) +Juanjo Álvarez Martinez (@juanjux) diff --git a/Makefile b/Makefile index a6e658291..81c0e52e0 100644 --- a/Makefile +++ b/Makefile @@ -1,14 +1,27 @@ # Package configuration PROJECT = go-mysql-server COMMANDS = +UNAME_S := $(shell uname -s) # Including ci Makefile -MAKEFILE = Makefile.main -CI_REPOSITORY = https://github.com/src-d/ci.git -CI_FOLDER = .ci - +CI_REPOSITORY ?= https://github.com/src-d/ci.git +CI_BRANCH ?= v1 +CI_PATH ?= .ci +MAKEFILE := $(CI_PATH)/Makefile.main $(MAKEFILE): - @git clone --quiet $(CI_REPOSITORY) $(CI_FOLDER); \ - cp $(CI_FOLDER)/$(MAKEFILE) .; - + git clone --quiet --depth 1 -b $(CI_BRANCH) $(CI_REPOSITORY) $(CI_PATH); -include $(MAKEFILE) + +integration: + ./_integration/run ${TEST} + +oniguruma: +ifeq ($(UNAME_S),Linux) + $(shell apt-get install libonig-dev) +endif + +ifeq ($(UNAME_S),Darwin) + $(shell brew install oniguruma) +endif + +.PHONY: integration \ No newline at end of file diff --git a/README.md b/README.md index cd3b5feac..1ab7e9ee5 100644 --- a/README.md +++ b/README.md @@ -1,19 +1,46 @@ -# go-mysql-server +_**Notice: This repository is no longer actively maintained, and no further updates will be done, nor issues/PRs will be answered or attended. An alternative actively maintained can be found at https://github.com/dolthub/go-mysql-server repository.**_ -Build Status +# go-mysql-server +[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) +Build Status codecov GoDoc +[![Issues](http://img.shields.io/github/issues/src-d/go-mysql-server.svg)](https://github.com/src-d/go-mysql-server/issues) + +**go-mysql-server** is a SQL engine which parses standard SQL (based on MySQL syntax), resolves and optimizes queries. +It provides simple interfaces to allow custom tabular data source implementations. + +**go-mysql-server** also provides a server implementation compatible with the MySQL wire protocol. +That means it is compatible with MySQL ODBC, JDBC, or the default MySQL client shell interface. + +## Scope of this project + +These are the goals of **go-mysql-server**: + +- Be a generic extensible SQL engine that performs queries on your data sources. +- Provide interfaces so you can implement your own custom data sources without providing any (except for the `mem` data source that is used for testing purposes). +- Have a runnable server you can use on your specific implementation. +- Parse and optimize queries while still allow specific implementations to add their own analysis steps and optimizations. +- Provide some common index driver implementations so the user does not have to bring their own index implementation, and still be able to do so if they need to. -**go-mysql-server** is an extensible MySQL server implementation in Go. +What are not the goals of **go-mysql-server**: + +- Be a drop-in MySQL database replacement. +- Be an application/server you can use directly. +- Provide any kind of backend implementation (other than the `mem` one used for testing) such as json, csv, yaml, ... That's for clients to implement and use. + +What's the use case of **go-mysql-server**? + +Having data in another format that you want as tabular data to query using SQL, such as git. As an example of this, we have [gitbase](https://github.com/src-d/gitbase). ## Installation -The import path for the package is `gopkg.in/src-d/go-mysql-server.v0`. +The import path for the package is `github.com/src-d/go-mysql-server`. To install it, run: ``` -go get gopkg.in/src-d/go-mysql-server.v0 +go get github.com/src-d/go-mysql-server ``` ## Documentation @@ -23,80 +50,170 @@ go get gopkg.in/src-d/go-mysql-server.v0 ## SQL syntax -We are continuously adding more functionality to go-mysql-server. We support a subset of what is supported in MySQL, currently including: - -| | Supported | -|:----------------------:|:---------------------------------------------------------------------------------:| -| Comparison expressions | !=, ==, >, <, >=,<=, BETWEEN, REGEXP, IN, NOT IN | -| Null check expressions | IS NULL, IS NOT NULL | -| Grouping expressions | COUNT, MIN, MAX ,AVG | -| Standard expressions | ALIAS, LITERAL, STAR (*) | -| Statements | CROSS JOIN, INNER JOIN, DESCRIBE, FILTER (WHERE), GROUP BY, LIMIT/OFFSET, SELECT, SHOW TABLES, SORT, DISTINCT, CREATE TABLE, INSERT | -| Functions | SUBSTRING, ARRAY_LENGTH | -| Time functions | YEAR, MONTH, DAY, HOUR, MINUTE, SECOND, DAYOFYEAR | - -## Custom functions +We are continuously adding more functionality to go-mysql-server. We support a subset of what is supported in MySQL, to see what is currently included check the [SUPPORTED](./SUPPORTED.md) file. + +## Third-party clients + +We support and actively test against certain third-party clients to ensure compatibility between them and go-mysql-server. You can check out the list of supported third party clients in the [SUPPORTED_CLIENTS](./SUPPORTED_CLIENTS.md) file along with some examples on how to connect to go-mysql-server using them. + +## Available functions + + +| Name | Description | +|:-------------|:-------------------------------------------------------------------------------------------------------------------------------| +|`ARRAY_LENGTH(json)`|if the json representation is an array, this function returns its size.| +|`AVG(expr)`| returns the average value of expr in all rows.| +|`CEIL(number)`| returns the smallest integer value that is greater than or equal to `number`.| +|`CEILING(number)`| returns the smallest integer value that is greater than or equal to `number`.| +|`CHAR_LENGTH(str)`| returns the length of the string in characters.| +|`COALESCE(...)`| returns the first non-null value in a list.| +|`CONCAT(...)`| concatenates any group of fields into a single string.| +|`CONCAT_WS(sep, ...)`| concatenates any group of fields into a single string. The first argument is the separator for the rest of the arguments. The separator is added between the strings to be concatenated. The separator can be a string, as can the rest of the arguments. If the separator is NULL, the result is NULL.| +|`CONNECTION_ID()`| returns the current connection ID.| +|`COUNT(expr)`| returns a count of the number of non-NULL values of expr in the rows retrieved by a SELECT statement.| +|`DATE_ADD(date, interval)`| adds the interval to the given `date`.| +|`DATE_SUB(date, interval)`| subtracts the interval from the given `date`.| +|`DAY(date)`| is a synonym for DAYOFMONTH().| +|`DATE(date)`| returns the date part of the given `date`.| +|`DAYOFMONTH(date)`| returns the day of the month (0-31).| +|`DAYOFWEEK(date)`| returns the day of the week of the given `date`.| +|`DAYOFYEAR(date)`| returns the day of the year of the given `date`.| +|`FIRST(expr)`| returns the first value in a sequence of elements of an aggregation.| +|`FLOOR(number)`| returns the largest integer value that is less than or equal to `number`.| +|`FROM_BASE64(str)`| decodes the base64-encoded string `str`.| +|`GREATEST(...)`| returns the greatest numeric or string value.| +|`HOUR(date)`| returns the hours of the given `date`.| +|`IFNULL(expr1, expr2)`| if `expr1` is not NULL, it returns `expr1`; otherwise it returns `expr2`.| +|`IS_BINARY(blob)`| returns whether a `blob` is a binary file or not.| +|`JSON_EXTRACT(json_doc, path, ...)`| extracts data from a json document using json paths. Extracting a string will result in that string being quoted. To avoid this, use `JSON_UNQUOTE(JSON_EXTRACT(json_doc, path, ...))`.| +|`JSON_UNQUOTE(json)`| unquotes JSON value and returns the result as a utf8mb4 string.| +|`LAST(expr)`| returns the last value in a sequence of elements of an aggregation.| +|`LEAST(...)`| returns the smaller numeric or string value.| +|`LENGTH(str)`| returns the length of the string in bytes.| +|`LN(X)`| returns the natural logarithm of `X`.| +|`LOG(X), LOG(B, X)`| if called with one parameter, this function returns the natural logarithm of `X`. If called with two parameters, this function returns the logarithm of `X` to the base `B`. If `X` is less than or equal to 0, or if `B` is less than or equal to 1, then NULL is returned.| +|`LOG10(X)`| returns the base-10 logarithm of `X`.| +|`LOG2(X)`| returns the base-2 logarithm of `X`.| +|`LOWER(str)`| returns the string `str` with all characters in lower case.| +|`LPAD(str, len, padstr)`| returns the string `str`, left-padded with the string `padstr` to a length of `len` characters.| +|`LTRIM(str)`| returns the string `str` with leading space characters removed.| +|`MAX(expr)`| returns the maximum value of `expr` in all rows.| +|`MID(str, pos, [len])`| returns a substring from the provided string starting at `pos` with a length of `len` characters. If no `len` is provided, all characters from `pos` until the end will be taken.| +|`MIN(expr)`| returns the minimum value of `expr` in all rows.| +|`MINUTE(date)`| returns the minutes of the given `date`.| +|`MONTH(date)`| returns the month of the given `date`.| +|`NOW()`| returns the current timestamp.| +|`NULLIF(expr1, expr2)`| returns NULL if `expr1 = expr2` is true, otherwise returns `expr1`.| +|`POW(X, Y)`| returns the value of `X` raised to the power of `Y`.| +|`REGEXP_MATCHES(text, pattern, [flags])`| returns an array with the matches of the `pattern` in the given `text`. Flags can be given to control certain behaviours of the regular expression. Currently, only the `i` flag is supported, to make the comparison case insensitive.| +|`REPEAT(str, count)`| returns a string consisting of the string `str` repeated `count` times.| +|`REPLACE(str,from_str,to_str)`| returns the string `str` with all occurrences of the string `from_str` replaced by the string `to_str`.| +|`REVERSE(str)`| returns the string `str` with the order of the characters reversed.| +|`ROUND(number, decimals)`| rounds the `number` to `decimals` decimal places.| +|`RPAD(str, len, padstr)`| returns the string `str`, right-padded with the string `padstr` to a length of `len` characters.| +|`RTRIM(str)`| returns the string `str` with trailing space characters removed.| +|`SECOND(date)`| returns the seconds of the given `date`.| +|`SLEEP(seconds)`| waits for the specified number of seconds (can be fractional).| +|`SOUNDEX(str)`| returns the soundex of a string.| +|`SPLIT(str,sep)`| returns the parts of the string `str` split by the separator `sep` as a JSON array of strings.| +|`SQRT(X)`| returns the square root of a nonnegative number `X`.| +|`SUBSTR(str, pos, [len])`| returns a substring from the string `str` starting at `pos` with a length of `len` characters. If no `len` is provided, all characters from `pos` until the end will be taken.| +|`SUBSTRING(str, pos, [len])`| returns a substring from the string `str` starting at `pos` with a length of `len` characters. If no `len` is provided, all characters from `pos` until the end will be taken.| +|`SUM(expr)`| returns the sum of `expr` in all rows.| +|`TO_BASE64(str)`| encodes the string `str` in base64 format.| +|`TRIM(str)`| returns the string `str` with all spaces removed.| +|`UPPER(str)`| returns the string `str` with all characters in upper case.| +|`WEEKDAY(date)`| returns the weekday of the given `date`.| +|`YEAR(date)`| returns the year of the given `date`.| +|`YEARWEEK(date, mode)`| returns year and week for a date. The year in the result may be different from the year in the date argument for the first and the last week of the year.| + + +## Configuration + +The behaviour of certain parts of go-mysql-server can be configured using either environment variables or session variables. + +Session variables are set using the following SQL queries: + +```sql +SET = +``` -- `IS_BINARY(blob)`: returns whether a BLOB is a binary file or not + +| Name | Type | Description | +|:-----|:-----|:------------| +|`INMEMORY_JOINS`|environment|If set it will perform all joins in memory. Default is off.| +|`inmemory_joins`|session|If set it will perform all joins in memory. Default is off. This has precedence over `INMEMORY_JOINS`.| +|`MAX_MEMORY`|environment|The maximum number of memory, in megabytes, that can be consumed by go-mysql-server. Any in-memory caches or computations will no longer try to use memory when the limit is reached. Note that this may cause certain queries to fail if there is not enough memory available, such as queries using DISTINCT, ORDER BY or GROUP BY with groupings.| +|`DEBUG_ANALYZER`|environment|If set, the analyzer will print debug messages. Default is off.| +|`PILOSA_INDEX_THREADS`|environment|Number of threads used in index creation. Default is the number of cores available in the machine.| +|`pilosa_index_threads`|environment|Number of threads used in index creation. Default is the number of cores available in the machine. This has precedence over `PILOSA_INDEX_THREADS`.| + ## Example -`go-mysql-server` has a sql engine and a server implementation, so to start a server you must instantiate the engine and give it your `sql.Database` implementation that will be in charge to handle all the logic about retrieving the data from your source : +`go-mysql-server` contains a SQL engine and server implementation. So, if you want to start a server, first instantiate the engine and pass your `sql.Database` implementation. + +It will be in charge of handling all the logic to retrieve the data from your source. +Here you can see an example using the in-memory database implementation: ```go ... func main() { - driver := sqle.New() - driver.AddDatabase(createTestDatabase()) - - auth := mysql.NewAuthServerStatic() - auth.Entries["user"] = []*mysql.AuthServerStaticEntry{{ - Password: "pass", - }} - - config := server.Config{ - Protocol: "tcp", - Address: "localhost:3306", - Auth: auth, - } + driver := sqle.NewDefault() + driver.AddDatabase(createTestDatabase()) - s, err := server.NewDefaultServer(config, driver) - if err != nil { - panic(err) - } + config := server.Config{ + Protocol: "tcp", + Address: "localhost:3306", + Auth: auth.NewNativeSingle("user", "pass", auth.AllPermissions), + } - s.Start() + s, err := server.NewDefaultServer(config, driver) + if err != nil { + panic(err) + } + + s.Start() } -func createTestDatabase() *mem.Database { - const ( - dbName = "test" - tableName = "mytable" - ) - - db := mem.NewDatabase(dbName).(*mem.Database) - table := mem.NewTable(tableName, sql.Schema{ - {Name: "name", Type: sql.Text, Nullable: false, Source: tableName}, - {Name: "email", Type: sql.Text, Nullable: false, Source: tableName}, - {Name: "phone_numbers", Type: sql.JSON, Nullable: false, Source: tableName}, - {Name: "created_at", Type: sql.Timestamp, Nullable: false, Source: tableName}, - }) +func createTestDatabase() *memory.Database { + const ( + dbName = "test" + tableName = "mytable" + ) + + db := memory.NewDatabase(dbName) + table := memory.NewTable(tableName, sql.Schema{ + {Name: "name", Type: sql.Text, Nullable: false, Source: tableName}, + {Name: "email", Type: sql.Text, Nullable: false, Source: tableName}, + {Name: "phone_numbers", Type: sql.JSON, Nullable: false, Source: tableName}, + {Name: "created_at", Type: sql.Timestamp, Nullable: false, Source: tableName}, + }) + + db.AddTable(tableName, table) + ctx := sql.NewEmptyContext() + + rows := []sql.Row{ + sql.NewRow("John Doe", "john@doe.com", []string{"555-555-555"}, time.Now()), + sql.NewRow("John Doe", "johnalt@doe.com", []string{}, time.Now()), + sql.NewRow("Jane Doe", "jane@doe.com", []string{}, time.Now()), + sql.NewRow("Evil Bob", "evilbob@gmail.com", []string{"555-666-555", "666-666-666"}, time.Now()), + } + + for _, row := range rows { + table.Insert(ctx, row) + } - db.AddTable(tableName, table) - table.Insert(sql.NewRow("John Doe", "john@doe.com", []string{"555-555-555"}, time.Now())) - table.Insert(sql.NewRow("John Doe", "johnalt@doe.com", []string{}, time.Now())) - table.Insert(sql.NewRow("Jane Doe", "jane@doe.com", []string{}, time.Now())) - table.Insert(sql.NewRow("Evil Bob", "evilbob@gmail.com", []string{"555-666-555", "666-666-666"}, time.Now())) - return db + return db } ... ``` -Then, you can connect to the server with any mysql client: +Then, you can connect to the server with any MySQL client: ```bash -> mysql --host=127.0.0.1 --port=3306 -u user -ppass db -e "SELECT * FROM mytable" +> mysql --host=127.0.0.1 --port=3306 -u user -ppass test -e "SELECT * FROM mytable" +----------+-------------------+-------------------------------+---------------------+ | name | email | phone_numbers | created_at | +----------+-------------------+-------------------------------+---------------------+ @@ -137,10 +254,191 @@ SELECT email FROM mytable WHERE name = 'Evil Bob' +-------------------+ ``` +## Custom data source implementation + +To be able to create your own data source implementation you need to implement the following interfaces: + +- `sql.Database` interface. This interface will provide tables from your data source. + - If your database implementation supports adding more tables, you might want to add support for `sql.Alterable` interface + +- `sql.Table` interface. It will be in charge of transforming any kind of data into an iterator of Rows. Depending on how much you want to optimize the queries, you also can implement other interfaces on your tables: + - `sql.PushdownProjectionTable` interface will provide a way to get only the columns needed for the executed query. + - `sql.PushdownProjectionAndFiltersTable` interface will provide the same functionality described before, but also will push down the filters used in the executed query. It allows to filter data in advance, and speed up queries. + - `sql.Indexable` add index capabilities to your table. By implementing this interface you can create and use indexes on this table. + - `sql.Inserter` can be implemented if your data source tables allow insertions. + +- If you need some custom tree modifications, you can also implement your own `analyzer.Rules`. + +You can see a really simple data source implementation on our `mem` package. + +## Indexes + +`go-mysql-server` exposes a series of interfaces to allow you to implement your own indexes so you can speedup your queries. + +Taking a look at the main [index interface](https://github.com/src-d/go-mysql-server/blob/master/sql/index.go#L35), you must note a couple of constraints: + +- This abstraction lets you create an index for multiple columns (one or more) or for **only one** expression (e.g. function applied on multiple columns). +- If you want to index an expression that is not a column you will only be able to index **one and only one** expression at a time. + +## Custom index driver implementation + +Index drivers provide different backends for storing and querying indexes. To implement a custom index driver you need to implement a few things: + +- `sql.IndexDriver` interface, which will be the driver itself. Not that your driver must return an unique ID in the `ID` method. This ID is unique for your driver and should not clash with any other registered driver. It's the driver's responsibility to be fault tolerant and be able to automatically detect and recover from corruption in indexes. +- `sql.Index` interface, returned by your driver when an index is loaded or created. + - Your `sql.Index` may optionally implement the `sql.AscendIndex` and/or `sql.DescendIndex` interfaces, if you want to support more comparison operators like `>`, `<`, `>=`, `<=` or `BETWEEN`. +- `sql.IndexLookup` interface, returned by your index in any of the implemented operations to get a subset of the indexed values. + - Your `sql.IndexLookup` may optionally implement the `sql.Mergeable` and `sql.SetOperations` interfaces if you want to support set operations to merge your index lookups. +- `sql.IndexValueIter` interface, which will be returned by your `sql.IndexLookup` and should return the values of the index. +- Don't forget to register the index driver in your `sql.Catalog` using `catalog.RegisterIndexDriver(mydriver)` to be able to use it. + +To create indexes using your custom index driver you need to use `USING driverid` on the index creation query. For example: + +```sql +CREATE INDEX foo ON table USING driverid (col1, col2) +``` + +You can see an example of a driver implementation inside the `sql/index/pilosa` package, where the pilosa driver is implemented. + +Index creation is synchronous by default, to make it asynchronous, use `WITH (async = true)`, for example: + +```sql +CREATE INDEX foo ON table USING driverid (col1, col2) WITH (async = true) +``` + +### Old `pilosalib` driver + +`pilosalib` driver was renamed to `pilosa` and now `pilosa` does not require an external pilosa server. `pilosa` is not supported on Windows. + +### Metrics + +`go-mysql-server` utilizes `github.com/go-kit/kit/metrics` module to expose metrics (counters, gauges, histograms) for certain packages (so far for `engine`, `analyzer`, `regex`, `pilosa`). If you already have metrics server (prometheus, statsd/statsite, influxdb, etc.) and you want to gather metrics also from `go-mysql-server` components, you will need to initialize some global variables by particular implementations to satisfy following interfaces: +```go +// Counter describes a metric that accumulates values monotonically. +type Counter interface { + With(labelValues ...string) Counter + Add(delta float64) +} + +// Gauge describes a metric that takes specific values over time. +type Gauge interface { + With(labelValues ...string) Gauge + Set(value float64) + Add(delta float64) +} + +// Histogram describes a metric that takes repeated observations of the same +// kind of thing, and produces a statistical summary of those observations, +// typically expressed as quantiles or buckets. +type Histogram interface { + With(labelValues ...string) Histogram + Observe(value float64) +} +``` + +You can use one of `go-kit` implementations or try your own. +For instance, we want to expose metrics for _prometheus_ server. So, before we start _mysql engine_, we have to set up the following variables: +```go + +import( + "github.com/go-kit/kit/metrics/prometheus" + promopts "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" +) + +//.... + +// engine metrics +sqle.QueryCounter = prometheus.NewCounterFrom(promopts.CounterOpts{ + Namespace: "go_mysql_server", + Subsystem: "engine", + Name: "query_counter", + }, []string{ + "query", + }) +sqle.QueryErrorCounter = prometheus.NewCounterFrom(promopts.CounterOpts{ + Namespace: "go_mysql_server", + Subsystem: "engine", + Name: "query_error_counter", +}, []string{ + "query", + "error", +}) +sqle.QueryHistogram = prometheus.NewHistogramFrom(promopts.HistogramOpts{ + Namespace: "go_mysql_server", + Subsystem: "engine", + Name: "query_histogram", +}, []string{ + "query", + "duration", +}) + +// analyzer metrics +analyzer.ParallelQueryCounter = prometheus.NewCounterFrom(promopts.CounterOpts{ + Namespace: "go_mysql_server", + Subsystem: "analyzer", + Name: "parallel_query_counter", +}, []string{ + "parallelism", +}) + +// regex metrics +regex.CompileHistogram = prometheus.NewHistogramFrom(promopts.HistogramOpts{ + Namespace: "go_mysql_server", + Subsystem: "regex", + Name: "compile_histogram", +}, []string{ + "regex", + "duration", +}) +regex.MatchHistogram = prometheus.NewHistogramFrom(promopts.HistogramOpts{ + Namespace: "go_mysql_server", + Subsystem: "regex", + Name: "match_histogram", +}, []string{ + "string", + "duration", +}) + +// pilosa index driver metrics +pilosa.RowsGauge = prometheus.NewGaugeFrom(promopts.GaugeOpts{ + Namespace: "go_mysql_server", + Subsystem: "index", + Name: "indexed_rows_gauge", +}, []string{ + "driver", +}) +pilosa.TotalHistogram = prometheus.NewHistogramFrom(promopts.HistogramOpts{ + Namespace: "go_mysql_server", + Subsystem: "index", + Name: "index_created_total_histogram", +}, []string{ + "driver", + "duration", +}) +pilosa.MappingHistogram = prometheus.NewHistogramFrom(promopts.HistogramOpts{ + Namespace: "go_mysql_server", + Subsystem: "index", + Name: "index_created_mapping_histogram", +}, []string{ + "driver", + "duration", +}) +pilosa.BitmapHistogram = prometheus.NewHistogramFrom(promopts.HistogramOpts{ + Namespace: "go_mysql_server", + Subsystem: "index", + Name: "index_created_bitmap_histogram", +}, []string{ + "driver", + "duration", +}) +``` +One _important note_ - internally we set some _labels_ for metrics, that's why have to pass those keys like "duration", "query", "driver", ... when we register metrics in `prometheus`. Other systems may have different requirements. ## Powered by go-mysql-server -* [gitquery](https://github.com/src-d/gitquery) +* [gitbase](https://github.com/src-d/gitbase) +* [dolt](https://github.com/liquidata-inc/dolt) ## License diff --git a/SUPPORTED.md b/SUPPORTED.md new file mode 100644 index 000000000..87841ac72 --- /dev/null +++ b/SUPPORTED.md @@ -0,0 +1,134 @@ +# Supported SQL Syntax + +## Comparisson expressions +- != +- == +- \> +- < +- \>= +- <= +- BETWEEN +- IN +- NOT IN +- REGEXP + +## Null check expressions +- IS NOT NULL +- IS NULL + +## Grouping expressions +- AVG +- COUNT and COUNT(DISTINCT) +- MAX +- MIN +- SUM (always returns DOUBLE) + +## Standard expressions +- ALIAS (AS) +- CAST/CONVERT +- CREATE TABLE +- DESCRIBE/DESC/EXPLAIN FORMAT=TREE [query] +- DISTINCT +- FILTER (WHERE) +- GROUP BY +- INSERT INTO +- LIMIT/OFFSET +- LITERAL +- ORDER BY +- SELECT +- SHOW TABLES +- SORT +- STAR (*) +- SHOW PROCESSLIST +- SHOW TABLE STATUS +- SHOW VARIABLES +- SHOW CREATE DATABASE +- SHOW CREATE TABLE +- SHOW FIELDS FROM +- LOCK/UNLOCK +- USE +- SHOW DATABASES +- SHOW WARNINGS +- INTERVALS + +## Index expressions +- CREATE INDEX (an index can be created using either column names or a single arbitrary expression). +- DROP INDEX +- SHOW {INDEXES | INDEX | KEYS} {FROM | IN} [table name] + +## Join expressions +- CROSS JOIN +- INNER JOIN +- NATURAL JOIN + +## Logical expressions +- AND +- NOT +- OR + +## Arithmetic expressions +- \+ (including between dates and intervals) +- \- (including between dates and intervals) +- \* +- \\ +- << +- \>> +- & +- \| +- ^ +- div +- % + +## Functions +- ARRAY_LENGTH +- CEIL +- CEILING +- COALESCE +- CONCAT +- CONCAT_WS +- CONNECTION_ID +- DATABASE +- FLOOR +- FROM_BASE64 +- GREATEST +- IS_BINARY +- IS_BINARY +- JSON_EXTRACT +- JSON_UNQUOTE +- LEAST +- LN +- LOG10 +- LOG2 +- LOWER +- LPAD +- POW +- POWER +- ROUND +- RPAD +- SLEEP +- SOUNDEX +- SPLIT +- SQRT +- SUBSTRING +- TO_BASE64 +- UPPER + +## Time functions +- DATE +- DATE_ADD +- DATE_SUB +- DAY +- DAYOFMONTH +- DAYOFWEEK +- DAYOFYEAR +- HOUR +- MINUTE +- MONTH +- NOW +- SECOND +- WEEKDAY +- YEAR +- YEARWEEK + +## Subqueries +Supported both as a table and as expressions but they can't access the parent query scope. diff --git a/SUPPORTED_CLIENTS.md b/SUPPORTED_CLIENTS.md new file mode 100644 index 000000000..f0f4f8adb --- /dev/null +++ b/SUPPORTED_CLIENTS.md @@ -0,0 +1,274 @@ +# Supported clients + +These are the clients we actively test against to check that they are compatible with go-mysql-server. Other clients may also work, but we don't check on every build if we remain compatible with them. + +- Python + - [pymysql](#pymysql) + - [mysql-connector](#python-mysql-connector) + - [sqlalchemy](#python-sqlalchemy) +- Ruby + - [ruby-mysql](#ruby-mysql) +- [PHP](#php) +- Node.js + - [mysqljs/mysql](#mysqljs) +- .NET Core + - [MysqlConnector](#mysqlconnector) +- Java/JVM + - [mariadb-java-client](#mariadb-java-client) +- Go + - [go-mysql-driver/mysql](#go-sql-drivermysql) +- C + - [mysql-connector-c](#mysql-connector-c) +- Grafana +- Tableau Desktop + +## Example client usage + +### pymysql + +```python +import pymysql.cursors + +connection = pymysql.connect(host='127.0.0.1', + user='root', + password='', + db='mydb', + cursorclass=pymysql.cursors.DictCursor) + +try: + with connection.cursor() as cursor: + sql = "SELECT * FROM mytable LIMIT 1" + cursor.execute(sql) + rows = cursor.fetchall() + + # use rows +finally: + connection.close() +``` + +### Python mysql-connector + +```python +import mysql.connector + +connection = mysql.connector.connect(host='127.0.0.1', + user='root', + passwd='', + port=3306, + database='mydb') + +try: + cursor = connection.cursor() + sql = "SELECT * FROM mytable LIMIT 1" + cursor.execute(sql) + rows = cursor.fetchall() + + # use rows +finally: + connection.close() +``` + +### Python sqlalchemy + +```python +import pandas as pd +import sqlalchemy + +engine = sqlalchemy.create_engine('mysql+pymysql://root:@127.0.0.1:3306/mydb') +with engine.connect() as conn: + repo_df = pd.read_sql_table("mytable", con=conn) + for table_name in repo_df.to_dict(): + print(table_name) +``` + +### ruby-mysql + +```ruby +require "mysql" + +conn = Mysql::new("127.0.0.1", "root", "", "mydb") +resp = conn.query "SELECT * FROM mytable LIMIT 1" + +# use resp + +conn.close() +``` + +### php + +```php +try { + $conn = new PDO("mysql:host=127.0.0.1:3306;dbname=mydb", "root", ""); + $conn->setAttribute(PDO::ATTR_ERRMODE, PDO::ERRMODE_EXCEPTION); + + $stmt = $conn->query('SELECT * FROM mytable LIMIT 1'); + $result = $stmt->fetchAll(PDO::FETCH_ASSOC); + + // use result +} catch (PDOException $e) { + // handle error +} +``` + +### mysqljs + +```js +import mysql from 'mysql'; + +const connection = mysql.createConnection({ + host: '127.0.0.1', + port: 3306, + user: 'root', + password: '', + database: 'mydb' +}); +connection.connect(); + +const query = 'SELECT * FROM mytable LIMIT 1'; +connection.query(query, function (error, results, _) { + if (error) throw error; + + // use results +}); + +connection.end(); +``` + +### MysqlConnector + +```csharp +using MySql.Data.MySqlClient; +using System.Threading.Tasks; + +namespace something +{ + public class Something + { + public async Task DoQuery() + { + var connectionString = "server=127.0.0.1;user id=root;password=;port=3306;database=mydb;"; + + using (var conn = new MySqlConnection(connectionString)) + { + await conn.OpenAsync(); + + var sql = "SELECT * FROM mytable LIMIT 1"; + + using (var cmd = new MySqlCommand(sql, conn)) + using (var reader = await cmd.ExecuteReaderAsync()) + while (await reader.ReadAsync()) { + // use reader + } + } + } + } +} +``` + +### mariadb-java-client + +```java +package org.testing.mariadbjavaclient; + +import java.sql.*; + +class Main { + public static void main(String[] args) { + String dbUrl = "jdbc:mariadb://127.0.0.1:3306/mydb?user=root&password="; + String query = "SELECT * FROM mytable LIMIT 1"; + + try (Connection connection = DriverManager.getConnection(dbUrl)) { + try (PreparedStatement stmt = connection.prepareStatement(query)) { + try (ResultSet rs = stmt.executeQuery()) { + while (rs.next()) { + // use rs + } + } + } + } catch (SQLException e) { + // handle failure + } + } +} +``` + +### go-sql-driver/mysql + +```go +package main + +import ( + "database/sql" + + _ "github.com/go-sql-driver/mysql" +) + +func main() { + db, err := sql.Open("mysql", "root:@tcp(127.0.0.1:3306)/mydb") + if err != nil { + // handle error + } + + rows, err := db.Query("SELECT * FROM mytable LIMIT 1") + if err != nil { + // handle error + } + + // use rows +} +``` + +### mysql-connector-c + +```c +#include +#include + +void finish_with_error(MYSQL *con) +{ + fprintf(stderr, "%s\n", mysql_error(con)); + mysql_close(con); + exit(1); +} + +int main(int argc, char **argv) +{ + MYSQL *con = NULL; + MYSQL_RES *result = NULL; + int num_fields = 0; + MYSQL_ROW row; + + printf("MySQL client version: %s\n", mysql_get_client_info()); + + con = mysql_init(NULL); + if (con == NULL) { + finish_with_error(con); + } + + if (mysql_real_connect(con, "127.0.0.1", "root", "", "mydb", 3306, NULL, 0) == NULL) { + finish_with_error(con); + } + + if (mysql_query(con, "SELECT name, email, phone_numbers FROM mytable")) { + finish_with_error(con); + } + + result = mysql_store_result(con); + if (result == NULL) { + finish_with_error(con); + } + + num_fields = mysql_num_fields(result); + while ((row = mysql_fetch_row(result))) { + for(int i = 0; i < num_fields; i++) { + printf("%s ", row[i] ? row[i] : "NULL"); + } + printf("\n"); + } + + mysql_free_result(result); + mysql_close(con); + + return 0; +} +``` diff --git a/_example/main.go b/_example/main.go index 7c51481da..55a007247 100644 --- a/_example/main.go +++ b/_example/main.go @@ -3,17 +3,17 @@ package main import ( "time" - "gopkg.in/src-d/go-mysql-server.v0" - "gopkg.in/src-d/go-mysql-server.v0/mem" - "gopkg.in/src-d/go-mysql-server.v0/server" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-vitess.v0/mysql" + sqle "github.com/src-d/go-mysql-server" + "github.com/src-d/go-mysql-server/auth" + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/server" + "github.com/src-d/go-mysql-server/sql" ) // Example of how to implement a MySQL server based on a Engine: // // ``` -// > mysql --host=127.0.0.1 --port=5123 -u user1 -ppassword1 db -e "SELECT * FROM mytable" +// > mysql --host=127.0.0.1 --port=5123 -u user -ppass db -e "SELECT * FROM mytable" // +----------+-------------------+-------------------------------+---------------------+ // | name | email | phone_numbers | created_at | // +----------+-------------------+-------------------------------+---------------------+ @@ -24,21 +24,17 @@ import ( // +----------+-------------------+-------------------------------+---------------------+ // ``` func main() { - driver := sqle.New() - driver.AddDatabase(createTestDatabase()) - - auth := mysql.NewAuthServerStatic() - auth.Entries["user"] = []*mysql.AuthServerStaticEntry{{ - Password: "pass", - }} + engine := sqle.NewDefault() + engine.AddDatabase(createTestDatabase()) + engine.AddDatabase(sql.NewInformationSchemaDatabase(engine.Catalog)) config := server.Config{ Protocol: "tcp", Address: "localhost:3306", - Auth: auth, + Auth: auth.NewNativeSingle("root", "", auth.AllPermissions), } - s, err := server.NewDefaultServer(config, driver) + s, err := server.NewDefaultServer(config, engine) if err != nil { panic(err) } @@ -46,14 +42,14 @@ func main() { s.Start() } -func createTestDatabase() *mem.Database { +func createTestDatabase() *memory.Database { const ( - dbName = "test" + dbName = "mydb" tableName = "mytable" ) - db := mem.NewDatabase(dbName).(*mem.Database) - table := mem.NewTable(tableName, sql.Schema{ + db := memory.NewDatabase(dbName) + table := memory.NewTable(tableName, sql.Schema{ {Name: "name", Type: sql.Text, Nullable: false, Source: tableName}, {Name: "email", Type: sql.Text, Nullable: false, Source: tableName}, {Name: "phone_numbers", Type: sql.JSON, Nullable: false, Source: tableName}, @@ -61,9 +57,10 @@ func createTestDatabase() *mem.Database { }) db.AddTable(tableName, table) - table.Insert(sql.NewRow("John Doe", "john@doe.com", []string{"555-555-555"}, time.Now())) - table.Insert(sql.NewRow("John Doe", "johnalt@doe.com", []string{}, time.Now())) - table.Insert(sql.NewRow("Jane Doe", "jane@doe.com", []string{}, time.Now())) - table.Insert(sql.NewRow("Evil Bob", "evilbob@gmail.com", []string{"555-666-555", "666-666-666"}, time.Now())) + ctx := sql.NewEmptyContext() + table.Insert(ctx, sql.NewRow("John Doe", "john@doe.com", []string{"555-555-555"}, time.Now())) + table.Insert(ctx, sql.NewRow("John Doe", "johnalt@doe.com", []string{}, time.Now())) + table.Insert(ctx, sql.NewRow("Jane Doe", "jane@doe.com", []string{}, time.Now())) + table.Insert(ctx, sql.NewRow("Evil Bob", "evilbob@gmail.com", []string{"555-666-555", "666-666-666"}, time.Now())) return db } diff --git a/_integration/c/Makefile b/_integration/c/Makefile new file mode 100644 index 000000000..24a769011 --- /dev/null +++ b/_integration/c/Makefile @@ -0,0 +1,21 @@ +# +# Darwin: brew install mysql-connector-c +# Linux: apt-get install libmysqlclient-dev +# +CFLAGS=-Wall `mysql_config --cflags --libs` +UNAME_S := $(shell uname -s) +ifeq ($(UNAME_S),Darwin) + CFLAGS += mysqlclient +endif + +%.c: + @echo CFLAGS: $(CFLAGS) + $(CC) *.c $(CFLAGS) + +test: %.c + ./a.out + +clean: + @rm -f *.o a.out + +.PHONY: test clean \ No newline at end of file diff --git a/_integration/c/test.c b/_integration/c/test.c new file mode 100644 index 000000000..b6e26ad7c --- /dev/null +++ b/_integration/c/test.c @@ -0,0 +1,68 @@ +#include +#include + +#include +#include + +#define TEST(s1, s2) do { printf("'%s' =?= '%s'\n", s1, s2); assert(0 == strcmp(s1, s2)); } while(0) + +static void finish_with_error(MYSQL *con) +{ + fprintf(stderr, "%s\n", mysql_error(con)); + mysql_close(con); + exit(1); +} + +int main(int argc, char **argv) +{ + MYSQL *con = NULL; + MYSQL_RES *result = NULL; + MYSQL_ROW row; + + int n = 0; + const int expected_num_records = 4; + const char *expected_name[expected_num_records] = { + "John Doe\0", + "John Doe\0", + "Jane Doe\0", + "Evil Bob\0" + }; + const char *expected_email[expected_num_records] = { + "john@doe.com\0", + "johnalt@doe.com\0", + "jane@doe.com\0", + "evilbob@gmail.com\0" + }; + + printf("MySQL client version: %s\n", mysql_get_client_info()); + + con = mysql_init(NULL); + if (con == NULL) { + finish_with_error(con); + } + + if (mysql_real_connect(con, "127.0.0.1", "root", "", "mydb", 3306, NULL, 0) == NULL) { + finish_with_error(con); + } + + if (mysql_query(con, "SELECT name, email FROM mytable")) { + finish_with_error(con); + } + + result = mysql_store_result(con); + if (result == NULL) { + finish_with_error(con); + } + + while ((row = mysql_fetch_row(result))) { + TEST(expected_name[n], row[0]); + TEST(expected_email[n], row[1]); + ++n; + } + assert(expected_num_records == n); + + mysql_free_result(result); + mysql_close(con); + + return 0; +} \ No newline at end of file diff --git a/_integration/dotnet/.gitignore b/_integration/dotnet/.gitignore new file mode 100644 index 000000000..8d4a6c08a --- /dev/null +++ b/_integration/dotnet/.gitignore @@ -0,0 +1,2 @@ +bin +obj \ No newline at end of file diff --git a/_integration/dotnet/Makefile b/_integration/dotnet/Makefile new file mode 100644 index 000000000..14ea2e0e3 --- /dev/null +++ b/_integration/dotnet/Makefile @@ -0,0 +1,4 @@ +test: + dotnet test + +.PHONY: test \ No newline at end of file diff --git a/_integration/dotnet/MySQLTest.cs b/_integration/dotnet/MySQLTest.cs new file mode 100644 index 000000000..1a0b18b60 --- /dev/null +++ b/_integration/dotnet/MySQLTest.cs @@ -0,0 +1,42 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using MySql.Data.MySqlClient; +using System.Threading.Tasks; + +namespace dotnet +{ + [TestClass] + public class MySQLTest + { + [TestMethod] + public async Task TestCanConnect() + { + var connectionString = "server=127.0.0.1;user id=root;password=;port=3306;database=mydb;"; + var expected = new string[][]{ + new string[]{"Evil Bob", "evilbob@gmail.com"}, + new string[]{"Jane Doe", "jane@doe.com"}, + new string[]{"John Doe", "john@doe.com"}, + new string[]{"John Doe", "johnalt@doe.com"}, + }; + + using (var conn = new MySqlConnection(connectionString)) + { + await conn.OpenAsync(); + + var sql = "SELECT name, email FROM mytable ORDER BY name, email"; + var i = 0; + + using (var cmd = new MySqlCommand(sql, conn)) + using (var reader = await cmd.ExecuteReaderAsync()) + while (await reader.ReadAsync()) { + if (i >= expected.Length) { + Assert.Fail("more rows than expected"); + } + + Assert.AreEqual(expected[i][0], reader.GetString(0)); + Assert.AreEqual(expected[i][1], reader.GetString(1)); + i++; + } + } + } + } +} diff --git a/_integration/dotnet/dotnet.csproj b/_integration/dotnet/dotnet.csproj new file mode 100644 index 000000000..648bba047 --- /dev/null +++ b/_integration/dotnet/dotnet.csproj @@ -0,0 +1,16 @@ + + + + netcoreapp2.1 + + false + + + + + + + + + + diff --git a/_integration/go/.gitignore b/_integration/go/.gitignore new file mode 100644 index 000000000..d1dc5643c --- /dev/null +++ b/_integration/go/.gitignore @@ -0,0 +1,2 @@ +!vendor +!go.sum \ No newline at end of file diff --git a/_integration/go/Makefile b/_integration/go/Makefile new file mode 100644 index 000000000..df72709e1 --- /dev/null +++ b/_integration/go/Makefile @@ -0,0 +1,4 @@ +test: + go test . -v + +.PHONY: test \ No newline at end of file diff --git a/_integration/go/go.mod b/_integration/go/go.mod new file mode 100644 index 000000000..01e7159d1 --- /dev/null +++ b/_integration/go/go.mod @@ -0,0 +1,6 @@ +module github.com/src-d/go-mysql-server/integration/go + +require ( + github.com/go-sql-driver/mysql v1.4.0 + google.golang.org/appengine v1.2.0 // indirect +) diff --git a/_integration/go/go.sum b/_integration/go/go.sum new file mode 100644 index 000000000..33d775ef1 --- /dev/null +++ b/_integration/go/go.sum @@ -0,0 +1,7 @@ +github.com/go-sql-driver/mysql v1.4.0 h1:7LxgVwFb2hIQtMm87NdgAVfXjnt4OePseqT1tKx+opk= +github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +google.golang.org/appengine v1.2.0 h1:S0iUepdCWODXRvtE+gcRDd15L+k+k1AiHlMiMjefH24= +google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= diff --git a/_integration/go/mysql_test.go b/_integration/go/mysql_test.go new file mode 100644 index 000000000..926477b06 --- /dev/null +++ b/_integration/go/mysql_test.go @@ -0,0 +1,148 @@ +package testmysql + +import ( + "database/sql" + "reflect" + "testing" + + _ "github.com/go-sql-driver/mysql" +) + +const connectionString = "root:@tcp(127.0.0.1:3306)/mydb" + +func TestMySQL(t *testing.T) { + db, err := sql.Open("mysql", connectionString) + if err != nil { + t.Fatalf("can't connect to mysql: %s", err) + } + + rs, err := db.Query("SELECT name, email FROM mytable ORDER BY name, email") + if err != nil { + t.Fatalf("unable to get rows: %s", err) + } + + var rows [][2]string + for rs.Next() { + var row [2]string + if err := rs.Scan(&row[0], &row[1]); err != nil { + t.Errorf("got error scanning row: %s", err) + } + + rows = append(rows, row) + } + + if err := rs.Err(); err != nil { + t.Errorf("got unexpected error: %s", err) + } + + expected := [][2]string{ + {"Evil Bob", "evilbob@gmail.com"}, + {"Jane Doe", "jane@doe.com"}, + {"John Doe", "john@doe.com"}, + {"John Doe", "johnalt@doe.com"}, + } + + if len(expected) != len(rows) { + t.Errorf("got %d rows, expecting %d", len(rows), len(expected)) + } + + for i := range rows { + if rows[i][0] != expected[i][0] || rows[i][1] != expected[i][1] { + t.Errorf( + "incorrect row %d, got: {%s, %s}, expected: {%s, %s}", + i, + rows[i][0], rows[i][1], + expected[i][0], expected[i][1], + ) + } + } +} + +func TestGrafana(t *testing.T) { + db, err := sql.Open("mysql", connectionString) + if err != nil { + t.Fatalf("can't connect to mysql: %s", err) + } + + tests := []struct { + query string + expected [][]string + }{ + { + `SELECT 1`, + [][]string{{"1"}}, + }, + { + `select @@version_comment limit 1`, + [][]string{{""}}, + }, + { + `describe table mytable`, + [][]string{ + {"name", "TEXT"}, + {"email", "TEXT"}, + {"phone_numbers", "JSON"}, + {"created_at", "TIMESTAMP"}, + }, + }, + { + `select count(*) from mytable where created_at ` + + `between '2000-01-01T00:00:00Z' and '2999-01-01T00:00:00Z'`, + [][]string{{"4"}}, + }, + } + + for _, c := range tests { + rs, err := db.Query(c.query) + if err != nil { + t.Fatalf("unable to execute query: %s", err) + } + + result := getResult(t, rs) + + if !reflect.DeepEqual(result, c.expected) { + t.Fatalf("rows do not match, expected: %v, got: %v", c.expected, result) + } + } +} + +func getResult(t *testing.T, rs *sql.Rows) [][]string { + t.Helper() + + columns, err := rs.Columns() + if err != nil { + t.Fatalf("unable to get columns: %s", err) + } + + var result [][]string + p := make([]interface{}, len(columns)) + + for rs.Next() { + row := make([]interface{}, len(columns)) + for i := range row { + p[i] = &row[i] + } + + err = rs.Scan(p...) + if err != nil { + t.Fatalf("could not retrieve row: %s", err) + } + + result = append(result, getStringSlice(row)) + } + + return result +} + +func getStringSlice(row []interface{}) []string { + rowStrings := make([]string, len(row)) + for i, r := range row { + if r == nil { + rowStrings[i] = "NULL" + } else { + rowStrings[i] = string(r.([]uint8)) + } + } + + return rowStrings +} diff --git a/_integration/go/vendor/github.com/go-sql-driver/mysql/.gitignore b/_integration/go/vendor/github.com/go-sql-driver/mysql/.gitignore new file mode 100644 index 000000000..2de28da16 --- /dev/null +++ b/_integration/go/vendor/github.com/go-sql-driver/mysql/.gitignore @@ -0,0 +1,9 @@ +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +Icon? +ehthumbs.db +Thumbs.db +.idea diff --git a/_integration/go/vendor/github.com/go-sql-driver/mysql/.travis.yml b/_integration/go/vendor/github.com/go-sql-driver/mysql/.travis.yml new file mode 100644 index 000000000..cc1268c36 --- /dev/null +++ b/_integration/go/vendor/github.com/go-sql-driver/mysql/.travis.yml @@ -0,0 +1,107 @@ +sudo: false +language: go +go: + - 1.7.x + - 1.8.x + - 1.9.x + - 1.10.x + - master + +before_install: + - go get golang.org/x/tools/cmd/cover + - go get github.com/mattn/goveralls + +before_script: + - echo -e "[server]\ninnodb_log_file_size=256MB\ninnodb_buffer_pool_size=512MB\nmax_allowed_packet=16MB" | sudo tee -a /etc/mysql/my.cnf + - sudo service mysql restart + - .travis/wait_mysql.sh + - mysql -e 'create database gotest;' + +matrix: + include: + - env: DB=MYSQL8 + sudo: required + dist: trusty + go: 1.10.x + services: + - docker + before_install: + - go get golang.org/x/tools/cmd/cover + - go get github.com/mattn/goveralls + - docker pull mysql:8.0 + - docker run -d -p 127.0.0.1:3307:3306 --name mysqld -e MYSQL_DATABASE=gotest -e MYSQL_USER=gotest -e MYSQL_PASSWORD=secret -e MYSQL_ROOT_PASSWORD=verysecret + mysql:8.0 --innodb_log_file_size=256MB --innodb_buffer_pool_size=512MB --max_allowed_packet=16MB --local-infile=1 + - cp .travis/docker.cnf ~/.my.cnf + - .travis/wait_mysql.sh + before_script: + - export MYSQL_TEST_USER=gotest + - export MYSQL_TEST_PASS=secret + - export MYSQL_TEST_ADDR=127.0.0.1:3307 + - export MYSQL_TEST_CONCURRENT=1 + + - env: DB=MYSQL57 + sudo: required + dist: trusty + go: 1.10.x + services: + - docker + before_install: + - go get golang.org/x/tools/cmd/cover + - go get github.com/mattn/goveralls + - docker pull mysql:5.7 + - docker run -d -p 127.0.0.1:3307:3306 --name mysqld -e MYSQL_DATABASE=gotest -e MYSQL_USER=gotest -e MYSQL_PASSWORD=secret -e MYSQL_ROOT_PASSWORD=verysecret + mysql:5.7 --innodb_log_file_size=256MB --innodb_buffer_pool_size=512MB --max_allowed_packet=16MB --local-infile=1 + - cp .travis/docker.cnf ~/.my.cnf + - .travis/wait_mysql.sh + before_script: + - export MYSQL_TEST_USER=gotest + - export MYSQL_TEST_PASS=secret + - export MYSQL_TEST_ADDR=127.0.0.1:3307 + - export MYSQL_TEST_CONCURRENT=1 + + - env: DB=MARIA55 + sudo: required + dist: trusty + go: 1.10.x + services: + - docker + before_install: + - go get golang.org/x/tools/cmd/cover + - go get github.com/mattn/goveralls + - docker pull mariadb:5.5 + - docker run -d -p 127.0.0.1:3307:3306 --name mysqld -e MYSQL_DATABASE=gotest -e MYSQL_USER=gotest -e MYSQL_PASSWORD=secret -e MYSQL_ROOT_PASSWORD=verysecret + mariadb:5.5 --innodb_log_file_size=256MB --innodb_buffer_pool_size=512MB --max_allowed_packet=16MB --local-infile=1 + - cp .travis/docker.cnf ~/.my.cnf + - .travis/wait_mysql.sh + before_script: + - export MYSQL_TEST_USER=gotest + - export MYSQL_TEST_PASS=secret + - export MYSQL_TEST_ADDR=127.0.0.1:3307 + - export MYSQL_TEST_CONCURRENT=1 + + - env: DB=MARIA10_1 + sudo: required + dist: trusty + go: 1.10.x + services: + - docker + before_install: + - go get golang.org/x/tools/cmd/cover + - go get github.com/mattn/goveralls + - docker pull mariadb:10.1 + - docker run -d -p 127.0.0.1:3307:3306 --name mysqld -e MYSQL_DATABASE=gotest -e MYSQL_USER=gotest -e MYSQL_PASSWORD=secret -e MYSQL_ROOT_PASSWORD=verysecret + mariadb:10.1 --innodb_log_file_size=256MB --innodb_buffer_pool_size=512MB --max_allowed_packet=16MB --local-infile=1 + - cp .travis/docker.cnf ~/.my.cnf + - .travis/wait_mysql.sh + before_script: + - export MYSQL_TEST_USER=gotest + - export MYSQL_TEST_PASS=secret + - export MYSQL_TEST_ADDR=127.0.0.1:3307 + - export MYSQL_TEST_CONCURRENT=1 + +script: + - go test -v -covermode=count -coverprofile=coverage.out + - go vet ./... + - .travis/gofmt.sh +after_script: + - $HOME/gopath/bin/goveralls -coverprofile=coverage.out -service=travis-ci diff --git a/_integration/go/vendor/github.com/go-sql-driver/mysql/AUTHORS b/_integration/go/vendor/github.com/go-sql-driver/mysql/AUTHORS new file mode 100644 index 000000000..73ff68fbc --- /dev/null +++ b/_integration/go/vendor/github.com/go-sql-driver/mysql/AUTHORS @@ -0,0 +1,89 @@ +# This is the official list of Go-MySQL-Driver authors for copyright purposes. + +# If you are submitting a patch, please add your name or the name of the +# organization which holds the copyright to this list in alphabetical order. + +# Names should be added to this file as +# Name +# The email address is not required for organizations. +# Please keep the list sorted. + + +# Individual Persons + +Aaron Hopkins +Achille Roussel +Alexey Palazhchenko +Andrew Reid +Arne Hormann +Asta Xie +Bulat Gaifullin +Carlos Nieto +Chris Moos +Craig Wilson +Daniel Montoya +Daniel Nichter +Daniël van Eeden +Dave Protasowski +DisposaBoy +Egor Smolyakov +Evan Shaw +Frederick Mayle +Gustavo Kristic +Hajime Nakagami +Hanno Braun +Henri Yandell +Hirotaka Yamamoto +ICHINOSE Shogo +INADA Naoki +Jacek Szwec +James Harr +Jeff Hodges +Jeffrey Charles +Jian Zhen +Joshua Prunier +Julien Lefevre +Julien Schmidt +Justin Li +Justin Nuß +Kamil Dziedzic +Kevin Malachowski +Kieron Woodhouse +Lennart Rudolph +Leonardo YongUk Kim +Linh Tran Tuan +Lion Yang +Luca Looz +Lucas Liu +Luke Scott +Maciej Zimnoch +Michael Woolnough +Nicola Peduzzi +Olivier Mengué +oscarzhao +Paul Bonser +Peter Schultz +Rebecca Chin +Reed Allman +Richard Wilkes +Robert Russell +Runrioter Wung +Shuode Li +Soroush Pour +Stan Putrya +Stanley Gunawan +Xiangyu Hu +Xiaobing Jiang +Xiuming Chen +Zhenye Xie + +# Organizations + +Barracuda Networks, Inc. +Counting Ltd. +Google Inc. +InfoSum Ltd. +Keybase Inc. +Percona LLC +Pivotal Inc. +Stripe Inc. diff --git a/_integration/go/vendor/github.com/go-sql-driver/mysql/CHANGELOG.md b/_integration/go/vendor/github.com/go-sql-driver/mysql/CHANGELOG.md new file mode 100644 index 000000000..2d87d74c9 --- /dev/null +++ b/_integration/go/vendor/github.com/go-sql-driver/mysql/CHANGELOG.md @@ -0,0 +1,167 @@ +## Version 1.4 (2018-06-03) + +Changes: + + - Documentation fixes (#530, #535, #567) + - Refactoring (#575, #579, #580, #581, #603, #615, #704) + - Cache column names (#444) + - Sort the DSN parameters in DSNs generated from a config (#637) + - Allow native password authentication by default (#644) + - Use the default port if it is missing in the DSN (#668) + - Removed the `strict` mode (#676) + - Do not query `max_allowed_packet` by default (#680) + - Dropped support Go 1.6 and lower (#696) + - Updated `ConvertValue()` to match the database/sql/driver implementation (#760) + - Document the usage of `0000-00-00T00:00:00` as the time.Time zero value (#783) + - Improved the compatibility of the authentication system (#807) + +New Features: + + - Multi-Results support (#537) + - `rejectReadOnly` DSN option (#604) + - `context.Context` support (#608, #612, #627, #761) + - Transaction isolation level support (#619, #744) + - Read-Only transactions support (#618, #634) + - `NewConfig` function which initializes a config with default values (#679) + - Implemented the `ColumnType` interfaces (#667, #724) + - Support for custom string types in `ConvertValue` (#623) + - Implemented `NamedValueChecker`, improving support for uint64 with high bit set (#690, #709, #710) + - `caching_sha2_password` authentication plugin support (#794, #800, #801, #802) + - Implemented `driver.SessionResetter` (#779) + - `sha256_password` authentication plugin support (#808) + +Bugfixes: + + - Use the DSN hostname as TLS default ServerName if `tls=true` (#564, #718) + - Fixed LOAD LOCAL DATA INFILE for empty files (#590) + - Removed columns definition cache since it sometimes cached invalid data (#592) + - Don't mutate registered TLS configs (#600) + - Make RegisterTLSConfig concurrency-safe (#613) + - Handle missing auth data in the handshake packet correctly (#646) + - Do not retry queries when data was written to avoid data corruption (#302, #736) + - Cache the connection pointer for error handling before invalidating it (#678) + - Fixed imports for appengine/cloudsql (#700) + - Fix sending STMT_LONG_DATA for 0 byte data (#734) + - Set correct capacity for []bytes read from length-encoded strings (#766) + - Make RegisterDial concurrency-safe (#773) + + +## Version 1.3 (2016-12-01) + +Changes: + + - Go 1.1 is no longer supported + - Use decimals fields in MySQL to format time types (#249) + - Buffer optimizations (#269) + - TLS ServerName defaults to the host (#283) + - Refactoring (#400, #410, #437) + - Adjusted documentation for second generation CloudSQL (#485) + - Documented DSN system var quoting rules (#502) + - Made statement.Close() calls idempotent to avoid errors in Go 1.6+ (#512) + +New Features: + + - Enable microsecond resolution on TIME, DATETIME and TIMESTAMP (#249) + - Support for returning table alias on Columns() (#289, #359, #382) + - Placeholder interpolation, can be actived with the DSN parameter `interpolateParams=true` (#309, #318, #490) + - Support for uint64 parameters with high bit set (#332, #345) + - Cleartext authentication plugin support (#327) + - Exported ParseDSN function and the Config struct (#403, #419, #429) + - Read / Write timeouts (#401) + - Support for JSON field type (#414) + - Support for multi-statements and multi-results (#411, #431) + - DSN parameter to set the driver-side max_allowed_packet value manually (#489) + - Native password authentication plugin support (#494, #524) + +Bugfixes: + + - Fixed handling of queries without columns and rows (#255) + - Fixed a panic when SetKeepAlive() failed (#298) + - Handle ERR packets while reading rows (#321) + - Fixed reading NULL length-encoded integers in MySQL 5.6+ (#349) + - Fixed absolute paths support in LOAD LOCAL DATA INFILE (#356) + - Actually zero out bytes in handshake response (#378) + - Fixed race condition in registering LOAD DATA INFILE handler (#383) + - Fixed tests with MySQL 5.7.9+ (#380) + - QueryUnescape TLS config names (#397) + - Fixed "broken pipe" error by writing to closed socket (#390) + - Fixed LOAD LOCAL DATA INFILE buffering (#424) + - Fixed parsing of floats into float64 when placeholders are used (#434) + - Fixed DSN tests with Go 1.7+ (#459) + - Handle ERR packets while waiting for EOF (#473) + - Invalidate connection on error while discarding additional results (#513) + - Allow terminating packets of length 0 (#516) + + +## Version 1.2 (2014-06-03) + +Changes: + + - We switched back to a "rolling release". `go get` installs the current master branch again + - Version v1 of the driver will not be maintained anymore. Go 1.0 is no longer supported by this driver + - Exported errors to allow easy checking from application code + - Enabled TCP Keepalives on TCP connections + - Optimized INFILE handling (better buffer size calculation, lazy init, ...) + - The DSN parser also checks for a missing separating slash + - Faster binary date / datetime to string formatting + - Also exported the MySQLWarning type + - mysqlConn.Close returns the first error encountered instead of ignoring all errors + - writePacket() automatically writes the packet size to the header + - readPacket() uses an iterative approach instead of the recursive approach to merge splitted packets + +New Features: + + - `RegisterDial` allows the usage of a custom dial function to establish the network connection + - Setting the connection collation is possible with the `collation` DSN parameter. This parameter should be preferred over the `charset` parameter + - Logging of critical errors is configurable with `SetLogger` + - Google CloudSQL support + +Bugfixes: + + - Allow more than 32 parameters in prepared statements + - Various old_password fixes + - Fixed TestConcurrent test to pass Go's race detection + - Fixed appendLengthEncodedInteger for large numbers + - Renamed readLengthEnodedString to readLengthEncodedString and skipLengthEnodedString to skipLengthEncodedString (fixed typo) + + +## Version 1.1 (2013-11-02) + +Changes: + + - Go-MySQL-Driver now requires Go 1.1 + - Connections now use the collation `utf8_general_ci` by default. Adding `&charset=UTF8` to the DSN should not be necessary anymore + - Made closing rows and connections error tolerant. This allows for example deferring rows.Close() without checking for errors + - `[]byte(nil)` is now treated as a NULL value. Before, it was treated like an empty string / `[]byte("")` + - DSN parameter values must now be url.QueryEscape'ed. This allows text values to contain special characters, such as '&'. + - Use the IO buffer also for writing. This results in zero allocations (by the driver) for most queries + - Optimized the buffer for reading + - stmt.Query now caches column metadata + - New Logo + - Changed the copyright header to include all contributors + - Improved the LOAD INFILE documentation + - The driver struct is now exported to make the driver directly accessible + - Refactored the driver tests + - Added more benchmarks and moved all to a separate file + - Other small refactoring + +New Features: + + - Added *old_passwords* support: Required in some cases, but must be enabled by adding `allowOldPasswords=true` to the DSN since it is insecure + - Added a `clientFoundRows` parameter: Return the number of matching rows instead of the number of rows changed on UPDATEs + - Added TLS/SSL support: Use a TLS/SSL encrypted connection to the server. Custom TLS configs can be registered and used + +Bugfixes: + + - Fixed MySQL 4.1 support: MySQL 4.1 sends packets with lengths which differ from the specification + - Convert to DB timezone when inserting `time.Time` + - Splitted packets (more than 16MB) are now merged correctly + - Fixed false positive `io.EOF` errors when the data was fully read + - Avoid panics on reuse of closed connections + - Fixed empty string producing false nil values + - Fixed sign byte for positive TIME fields + + +## Version 1.0 (2013-05-14) + +Initial Release diff --git a/_integration/go/vendor/github.com/go-sql-driver/mysql/CONTRIBUTING.md b/_integration/go/vendor/github.com/go-sql-driver/mysql/CONTRIBUTING.md new file mode 100644 index 000000000..8fe16bcb4 --- /dev/null +++ b/_integration/go/vendor/github.com/go-sql-driver/mysql/CONTRIBUTING.md @@ -0,0 +1,23 @@ +# Contributing Guidelines + +## Reporting Issues + +Before creating a new Issue, please check first if a similar Issue [already exists](https://github.com/go-sql-driver/mysql/issues?state=open) or was [recently closed](https://github.com/go-sql-driver/mysql/issues?direction=desc&page=1&sort=updated&state=closed). + +## Contributing Code + +By contributing to this project, you share your code under the Mozilla Public License 2, as specified in the LICENSE file. +Don't forget to add yourself to the AUTHORS file. + +### Code Review + +Everyone is invited to review and comment on pull requests. +If it looks fine to you, comment with "LGTM" (Looks good to me). + +If changes are required, notice the reviewers with "PTAL" (Please take another look) after committing the fixes. + +Before merging the Pull Request, at least one [team member](https://github.com/go-sql-driver?tab=members) must have commented with "LGTM". + +## Development Ideas + +If you are looking for ideas for code contributions, please check our [Development Ideas](https://github.com/go-sql-driver/mysql/wiki/Development-Ideas) Wiki page. diff --git a/_integration/go/vendor/github.com/go-sql-driver/mysql/LICENSE b/_integration/go/vendor/github.com/go-sql-driver/mysql/LICENSE new file mode 100644 index 000000000..14e2f777f --- /dev/null +++ b/_integration/go/vendor/github.com/go-sql-driver/mysql/LICENSE @@ -0,0 +1,373 @@ +Mozilla Public License Version 2.0 +================================== + +1. Definitions +-------------- + +1.1. "Contributor" + means each individual or legal entity that creates, contributes to + the creation of, or owns Covered Software. + +1.2. "Contributor Version" + means the combination of the Contributions of others (if any) used + by a Contributor and that particular Contributor's Contribution. + +1.3. "Contribution" + means Covered Software of a particular Contributor. + +1.4. "Covered Software" + means Source Code Form to which the initial Contributor has attached + the notice in Exhibit A, the Executable Form of such Source Code + Form, and Modifications of such Source Code Form, in each case + including portions thereof. + +1.5. "Incompatible With Secondary Licenses" + means + + (a) that the initial Contributor has attached the notice described + in Exhibit B to the Covered Software; or + + (b) that the Covered Software was made available under the terms of + version 1.1 or earlier of the License, but not also under the + terms of a Secondary License. + +1.6. "Executable Form" + means any form of the work other than Source Code Form. + +1.7. "Larger Work" + means a work that combines Covered Software with other material, in + a separate file or files, that is not Covered Software. + +1.8. "License" + means this document. + +1.9. "Licensable" + means having the right to grant, to the maximum extent possible, + whether at the time of the initial grant or subsequently, any and + all of the rights conveyed by this License. + +1.10. "Modifications" + means any of the following: + + (a) any file in Source Code Form that results from an addition to, + deletion from, or modification of the contents of Covered + Software; or + + (b) any new file in Source Code Form that contains any Covered + Software. + +1.11. "Patent Claims" of a Contributor + means any patent claim(s), including without limitation, method, + process, and apparatus claims, in any patent Licensable by such + Contributor that would be infringed, but for the grant of the + License, by the making, using, selling, offering for sale, having + made, import, or transfer of either its Contributions or its + Contributor Version. + +1.12. "Secondary License" + means either the GNU General Public License, Version 2.0, the GNU + Lesser General Public License, Version 2.1, the GNU Affero General + Public License, Version 3.0, or any later versions of those + licenses. + +1.13. "Source Code Form" + means the form of the work preferred for making modifications. + +1.14. "You" (or "Your") + means an individual or a legal entity exercising rights under this + License. For legal entities, "You" includes any entity that + controls, is controlled by, or is under common control with You. For + purposes of this definition, "control" means (a) the power, direct + or indirect, to cause the direction or management of such entity, + whether by contract or otherwise, or (b) ownership of more than + fifty percent (50%) of the outstanding shares or beneficial + ownership of such entity. + +2. License Grants and Conditions +-------------------------------- + +2.1. Grants + +Each Contributor hereby grants You a world-wide, royalty-free, +non-exclusive license: + +(a) under intellectual property rights (other than patent or trademark) + Licensable by such Contributor to use, reproduce, make available, + modify, display, perform, distribute, and otherwise exploit its + Contributions, either on an unmodified basis, with Modifications, or + as part of a Larger Work; and + +(b) under Patent Claims of such Contributor to make, use, sell, offer + for sale, have made, import, and otherwise transfer either its + Contributions or its Contributor Version. + +2.2. Effective Date + +The licenses granted in Section 2.1 with respect to any Contribution +become effective for each Contribution on the date the Contributor first +distributes such Contribution. + +2.3. Limitations on Grant Scope + +The licenses granted in this Section 2 are the only rights granted under +this License. No additional rights or licenses will be implied from the +distribution or licensing of Covered Software under this License. +Notwithstanding Section 2.1(b) above, no patent license is granted by a +Contributor: + +(a) for any code that a Contributor has removed from Covered Software; + or + +(b) for infringements caused by: (i) Your and any other third party's + modifications of Covered Software, or (ii) the combination of its + Contributions with other software (except as part of its Contributor + Version); or + +(c) under Patent Claims infringed by Covered Software in the absence of + its Contributions. + +This License does not grant any rights in the trademarks, service marks, +or logos of any Contributor (except as may be necessary to comply with +the notice requirements in Section 3.4). + +2.4. Subsequent Licenses + +No Contributor makes additional grants as a result of Your choice to +distribute the Covered Software under a subsequent version of this +License (see Section 10.2) or under the terms of a Secondary License (if +permitted under the terms of Section 3.3). + +2.5. Representation + +Each Contributor represents that the Contributor believes its +Contributions are its original creation(s) or it has sufficient rights +to grant the rights to its Contributions conveyed by this License. + +2.6. Fair Use + +This License is not intended to limit any rights You have under +applicable copyright doctrines of fair use, fair dealing, or other +equivalents. + +2.7. Conditions + +Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted +in Section 2.1. + +3. Responsibilities +------------------- + +3.1. Distribution of Source Form + +All distribution of Covered Software in Source Code Form, including any +Modifications that You create or to which You contribute, must be under +the terms of this License. You must inform recipients that the Source +Code Form of the Covered Software is governed by the terms of this +License, and how they can obtain a copy of this License. You may not +attempt to alter or restrict the recipients' rights in the Source Code +Form. + +3.2. Distribution of Executable Form + +If You distribute Covered Software in Executable Form then: + +(a) such Covered Software must also be made available in Source Code + Form, as described in Section 3.1, and You must inform recipients of + the Executable Form how they can obtain a copy of such Source Code + Form by reasonable means in a timely manner, at a charge no more + than the cost of distribution to the recipient; and + +(b) You may distribute such Executable Form under the terms of this + License, or sublicense it under different terms, provided that the + license for the Executable Form does not attempt to limit or alter + the recipients' rights in the Source Code Form under this License. + +3.3. Distribution of a Larger Work + +You may create and distribute a Larger Work under terms of Your choice, +provided that You also comply with the requirements of this License for +the Covered Software. If the Larger Work is a combination of Covered +Software with a work governed by one or more Secondary Licenses, and the +Covered Software is not Incompatible With Secondary Licenses, this +License permits You to additionally distribute such Covered Software +under the terms of such Secondary License(s), so that the recipient of +the Larger Work may, at their option, further distribute the Covered +Software under the terms of either this License or such Secondary +License(s). + +3.4. Notices + +You may not remove or alter the substance of any license notices +(including copyright notices, patent notices, disclaimers of warranty, +or limitations of liability) contained within the Source Code Form of +the Covered Software, except that You may alter any license notices to +the extent required to remedy known factual inaccuracies. + +3.5. Application of Additional Terms + +You may choose to offer, and to charge a fee for, warranty, support, +indemnity or liability obligations to one or more recipients of Covered +Software. However, You may do so only on Your own behalf, and not on +behalf of any Contributor. You must make it absolutely clear that any +such warranty, support, indemnity, or liability obligation is offered by +You alone, and You hereby agree to indemnify every Contributor for any +liability incurred by such Contributor as a result of warranty, support, +indemnity or liability terms You offer. You may include additional +disclaimers of warranty and limitations of liability specific to any +jurisdiction. + +4. Inability to Comply Due to Statute or Regulation +--------------------------------------------------- + +If it is impossible for You to comply with any of the terms of this +License with respect to some or all of the Covered Software due to +statute, judicial order, or regulation then You must: (a) comply with +the terms of this License to the maximum extent possible; and (b) +describe the limitations and the code they affect. Such description must +be placed in a text file included with all distributions of the Covered +Software under this License. Except to the extent prohibited by statute +or regulation, such description must be sufficiently detailed for a +recipient of ordinary skill to be able to understand it. + +5. Termination +-------------- + +5.1. The rights granted under this License will terminate automatically +if You fail to comply with any of its terms. However, if You become +compliant, then the rights granted under this License from a particular +Contributor are reinstated (a) provisionally, unless and until such +Contributor explicitly and finally terminates Your grants, and (b) on an +ongoing basis, if such Contributor fails to notify You of the +non-compliance by some reasonable means prior to 60 days after You have +come back into compliance. Moreover, Your grants from a particular +Contributor are reinstated on an ongoing basis if such Contributor +notifies You of the non-compliance by some reasonable means, this is the +first time You have received notice of non-compliance with this License +from such Contributor, and You become compliant prior to 30 days after +Your receipt of the notice. + +5.2. If You initiate litigation against any entity by asserting a patent +infringement claim (excluding declaratory judgment actions, +counter-claims, and cross-claims) alleging that a Contributor Version +directly or indirectly infringes any patent, then the rights granted to +You by any and all Contributors for the Covered Software under Section +2.1 of this License shall terminate. + +5.3. In the event of termination under Sections 5.1 or 5.2 above, all +end user license agreements (excluding distributors and resellers) which +have been validly granted by You or Your distributors under this License +prior to termination shall survive termination. + +************************************************************************ +* * +* 6. Disclaimer of Warranty * +* ------------------------- * +* * +* Covered Software is provided under this License on an "as is" * +* basis, without warranty of any kind, either expressed, implied, or * +* statutory, including, without limitation, warranties that the * +* Covered Software is free of defects, merchantable, fit for a * +* particular purpose or non-infringing. The entire risk as to the * +* quality and performance of the Covered Software is with You. * +* Should any Covered Software prove defective in any respect, You * +* (not any Contributor) assume the cost of any necessary servicing, * +* repair, or correction. This disclaimer of warranty constitutes an * +* essential part of this License. No use of any Covered Software is * +* authorized under this License except under this disclaimer. * +* * +************************************************************************ + +************************************************************************ +* * +* 7. Limitation of Liability * +* -------------------------- * +* * +* Under no circumstances and under no legal theory, whether tort * +* (including negligence), contract, or otherwise, shall any * +* Contributor, or anyone who distributes Covered Software as * +* permitted above, be liable to You for any direct, indirect, * +* special, incidental, or consequential damages of any character * +* including, without limitation, damages for lost profits, loss of * +* goodwill, work stoppage, computer failure or malfunction, or any * +* and all other commercial damages or losses, even if such party * +* shall have been informed of the possibility of such damages. This * +* limitation of liability shall not apply to liability for death or * +* personal injury resulting from such party's negligence to the * +* extent applicable law prohibits such limitation. Some * +* jurisdictions do not allow the exclusion or limitation of * +* incidental or consequential damages, so this exclusion and * +* limitation may not apply to You. * +* * +************************************************************************ + +8. Litigation +------------- + +Any litigation relating to this License may be brought only in the +courts of a jurisdiction where the defendant maintains its principal +place of business and such litigation shall be governed by laws of that +jurisdiction, without reference to its conflict-of-law provisions. +Nothing in this Section shall prevent a party's ability to bring +cross-claims or counter-claims. + +9. Miscellaneous +---------------- + +This License represents the complete agreement concerning the subject +matter hereof. If any provision of this License is held to be +unenforceable, such provision shall be reformed only to the extent +necessary to make it enforceable. Any law or regulation which provides +that the language of a contract shall be construed against the drafter +shall not be used to construe this License against a Contributor. + +10. Versions of the License +--------------------------- + +10.1. New Versions + +Mozilla Foundation is the license steward. Except as provided in Section +10.3, no one other than the license steward has the right to modify or +publish new versions of this License. Each version will be given a +distinguishing version number. + +10.2. Effect of New Versions + +You may distribute the Covered Software under the terms of the version +of the License under which You originally received the Covered Software, +or under the terms of any subsequent version published by the license +steward. + +10.3. Modified Versions + +If you create software not governed by this License, and you want to +create a new license for such software, you may create and use a +modified version of this License if you rename the license and remove +any references to the name of the license steward (except to note that +such modified license differs from this License). + +10.4. Distributing Source Code Form that is Incompatible With Secondary +Licenses + +If You choose to distribute Source Code Form that is Incompatible With +Secondary Licenses under the terms of this version of the License, the +notice described in Exhibit B of this License must be attached. + +Exhibit A - Source Code Form License Notice +------------------------------------------- + + This Source Code Form is subject to the terms of the Mozilla Public + License, v. 2.0. If a copy of the MPL was not distributed with this + file, You can obtain one at http://mozilla.org/MPL/2.0/. + +If it is not possible or desirable to put the notice in a particular +file, then You may include the notice in a location (such as a LICENSE +file in a relevant directory) where a recipient would be likely to look +for such a notice. + +You may add additional accurate notices of copyright ownership. + +Exhibit B - "Incompatible With Secondary Licenses" Notice +--------------------------------------------------------- + + This Source Code Form is "Incompatible With Secondary Licenses", as + defined by the Mozilla Public License, v. 2.0. diff --git a/_integration/go/vendor/github.com/go-sql-driver/mysql/README.md b/_integration/go/vendor/github.com/go-sql-driver/mysql/README.md new file mode 100644 index 000000000..2e9b07eeb --- /dev/null +++ b/_integration/go/vendor/github.com/go-sql-driver/mysql/README.md @@ -0,0 +1,490 @@ +# Go-MySQL-Driver + +A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) package + +![Go-MySQL-Driver logo](https://raw.github.com/wiki/go-sql-driver/mysql/gomysql_m.png "Golang Gopher holding the MySQL Dolphin") + +--------------------------------------- + * [Features](#features) + * [Requirements](#requirements) + * [Installation](#installation) + * [Usage](#usage) + * [DSN (Data Source Name)](#dsn-data-source-name) + * [Password](#password) + * [Protocol](#protocol) + * [Address](#address) + * [Parameters](#parameters) + * [Examples](#examples) + * [Connection pool and timeouts](#connection-pool-and-timeouts) + * [context.Context Support](#contextcontext-support) + * [ColumnType Support](#columntype-support) + * [LOAD DATA LOCAL INFILE support](#load-data-local-infile-support) + * [time.Time support](#timetime-support) + * [Unicode support](#unicode-support) + * [Testing / Development](#testing--development) + * [License](#license) + +--------------------------------------- + +## Features + * Lightweight and [fast](https://github.com/go-sql-driver/sql-benchmark "golang MySQL-Driver performance") + * Native Go implementation. No C-bindings, just pure Go + * Connections over TCP/IPv4, TCP/IPv6, Unix domain sockets or [custom protocols](https://godoc.org/github.com/go-sql-driver/mysql#DialFunc) + * Automatic handling of broken connections + * Automatic Connection Pooling *(by database/sql package)* + * Supports queries larger than 16MB + * Full [`sql.RawBytes`](https://golang.org/pkg/database/sql/#RawBytes) support. + * Intelligent `LONG DATA` handling in prepared statements + * Secure `LOAD DATA LOCAL INFILE` support with file Whitelisting and `io.Reader` support + * Optional `time.Time` parsing + * Optional placeholder interpolation + +## Requirements + * Go 1.7 or higher. We aim to support the 3 latest versions of Go. + * MySQL (4.1+), MariaDB, Percona Server, Google CloudSQL or Sphinx (2.2.3+) + +--------------------------------------- + +## Installation +Simple install the package to your [$GOPATH](https://github.com/golang/go/wiki/GOPATH "GOPATH") with the [go tool](https://golang.org/cmd/go/ "go command") from shell: +```bash +$ go get -u github.com/go-sql-driver/mysql +``` +Make sure [Git is installed](https://git-scm.com/downloads) on your machine and in your system's `PATH`. + +## Usage +_Go MySQL Driver_ is an implementation of Go's `database/sql/driver` interface. You only need to import the driver and can use the full [`database/sql`](https://golang.org/pkg/database/sql/) API then. + +Use `mysql` as `driverName` and a valid [DSN](#dsn-data-source-name) as `dataSourceName`: +```go +import "database/sql" +import _ "github.com/go-sql-driver/mysql" + +db, err := sql.Open("mysql", "user:password@/dbname") +``` + +[Examples are available in our Wiki](https://github.com/go-sql-driver/mysql/wiki/Examples "Go-MySQL-Driver Examples"). + + +### DSN (Data Source Name) + +The Data Source Name has a common format, like e.g. [PEAR DB](http://pear.php.net/manual/en/package.database.db.intro-dsn.php) uses it, but without type-prefix (optional parts marked by squared brackets): +``` +[username[:password]@][protocol[(address)]]/dbname[?param1=value1&...¶mN=valueN] +``` + +A DSN in its fullest form: +``` +username:password@protocol(address)/dbname?param=value +``` + +Except for the databasename, all values are optional. So the minimal DSN is: +``` +/dbname +``` + +If you do not want to preselect a database, leave `dbname` empty: +``` +/ +``` +This has the same effect as an empty DSN string: +``` + +``` + +Alternatively, [Config.FormatDSN](https://godoc.org/github.com/go-sql-driver/mysql#Config.FormatDSN) can be used to create a DSN string by filling a struct. + +#### Password +Passwords can consist of any character. Escaping is **not** necessary. + +#### Protocol +See [net.Dial](https://golang.org/pkg/net/#Dial) for more information which networks are available. +In general you should use an Unix domain socket if available and TCP otherwise for best performance. + +#### Address +For TCP and UDP networks, addresses have the form `host[:port]`. +If `port` is omitted, the default port will be used. +If `host` is a literal IPv6 address, it must be enclosed in square brackets. +The functions [net.JoinHostPort](https://golang.org/pkg/net/#JoinHostPort) and [net.SplitHostPort](https://golang.org/pkg/net/#SplitHostPort) manipulate addresses in this form. + +For Unix domain sockets the address is the absolute path to the MySQL-Server-socket, e.g. `/var/run/mysqld/mysqld.sock` or `/tmp/mysql.sock`. + +#### Parameters +*Parameters are case-sensitive!* + +Notice that any of `true`, `TRUE`, `True` or `1` is accepted to stand for a true boolean value. Not surprisingly, false can be specified as any of: `false`, `FALSE`, `False` or `0`. + +##### `allowAllFiles` + +``` +Type: bool +Valid Values: true, false +Default: false +``` + +`allowAllFiles=true` disables the file Whitelist for `LOAD DATA LOCAL INFILE` and allows *all* files. +[*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html) + +##### `allowCleartextPasswords` + +``` +Type: bool +Valid Values: true, false +Default: false +``` + +`allowCleartextPasswords=true` allows using the [cleartext client side plugin](http://dev.mysql.com/doc/en/cleartext-authentication-plugin.html) if required by an account, such as one defined with the [PAM authentication plugin](http://dev.mysql.com/doc/en/pam-authentication-plugin.html). Sending passwords in clear text may be a security problem in some configurations. To avoid problems if there is any possibility that the password would be intercepted, clients should connect to MySQL Server using a method that protects the password. Possibilities include [TLS / SSL](#tls), IPsec, or a private network. + +##### `allowNativePasswords` + +``` +Type: bool +Valid Values: true, false +Default: true +``` +`allowNativePasswords=false` disallows the usage of MySQL native password method. + +##### `allowOldPasswords` + +``` +Type: bool +Valid Values: true, false +Default: false +``` +`allowOldPasswords=true` allows the usage of the insecure old password method. This should be avoided, but is necessary in some cases. See also [the old_passwords wiki page](https://github.com/go-sql-driver/mysql/wiki/old_passwords). + +##### `charset` + +``` +Type: string +Valid Values: +Default: none +``` + +Sets the charset used for client-server interaction (`"SET NAMES "`). If multiple charsets are set (separated by a comma), the following charset is used if setting the charset failes. This enables for example support for `utf8mb4` ([introduced in MySQL 5.5.3](http://dev.mysql.com/doc/refman/5.5/en/charset-unicode-utf8mb4.html)) with fallback to `utf8` for older servers (`charset=utf8mb4,utf8`). + +Usage of the `charset` parameter is discouraged because it issues additional queries to the server. +Unless you need the fallback behavior, please use `collation` instead. + +##### `collation` + +``` +Type: string +Valid Values: +Default: utf8_general_ci +``` + +Sets the collation used for client-server interaction on connection. In contrast to `charset`, `collation` does not issue additional queries. If the specified collation is unavailable on the target server, the connection will fail. + +A list of valid charsets for a server is retrievable with `SHOW COLLATION`. + +##### `clientFoundRows` + +``` +Type: bool +Valid Values: true, false +Default: false +``` + +`clientFoundRows=true` causes an UPDATE to return the number of matching rows instead of the number of rows changed. + +##### `columnsWithAlias` + +``` +Type: bool +Valid Values: true, false +Default: false +``` + +When `columnsWithAlias` is true, calls to `sql.Rows.Columns()` will return the table alias and the column name separated by a dot. For example: + +``` +SELECT u.id FROM users as u +``` + +will return `u.id` instead of just `id` if `columnsWithAlias=true`. + +##### `interpolateParams` + +``` +Type: bool +Valid Values: true, false +Default: false +``` + +If `interpolateParams` is true, placeholders (`?`) in calls to `db.Query()` and `db.Exec()` are interpolated into a single query string with given parameters. This reduces the number of roundtrips, since the driver has to prepare a statement, execute it with given parameters and close the statement again with `interpolateParams=false`. + +*This can not be used together with the multibyte encodings BIG5, CP932, GB2312, GBK or SJIS. These are blacklisted as they may [introduce a SQL injection vulnerability](http://stackoverflow.com/a/12118602/3430118)!* + +##### `loc` + +``` +Type: string +Valid Values: +Default: UTC +``` + +Sets the location for time.Time values (when using `parseTime=true`). *"Local"* sets the system's location. See [time.LoadLocation](https://golang.org/pkg/time/#LoadLocation) for details. + +Note that this sets the location for time.Time values but does not change MySQL's [time_zone setting](https://dev.mysql.com/doc/refman/5.5/en/time-zone-support.html). For that see the [time_zone system variable](#system-variables), which can also be set as a DSN parameter. + +Please keep in mind, that param values must be [url.QueryEscape](https://golang.org/pkg/net/url/#QueryEscape)'ed. Alternatively you can manually replace the `/` with `%2F`. For example `US/Pacific` would be `loc=US%2FPacific`. + +##### `maxAllowedPacket` +``` +Type: decimal number +Default: 4194304 +``` + +Max packet size allowed in bytes. The default value is 4 MiB and should be adjusted to match the server settings. `maxAllowedPacket=0` can be used to automatically fetch the `max_allowed_packet` variable from server *on every connection*. + +##### `multiStatements` + +``` +Type: bool +Valid Values: true, false +Default: false +``` + +Allow multiple statements in one query. While this allows batch queries, it also greatly increases the risk of SQL injections. Only the result of the first query is returned, all other results are silently discarded. + +When `multiStatements` is used, `?` parameters must only be used in the first statement. + +##### `parseTime` + +``` +Type: bool +Valid Values: true, false +Default: false +``` + +`parseTime=true` changes the output type of `DATE` and `DATETIME` values to `time.Time` instead of `[]byte` / `string` +The date or datetime like `0000-00-00 00:00:00` is converted into zero value of `time.Time`. + + +##### `readTimeout` + +``` +Type: duration +Default: 0 +``` + +I/O read timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*. + +##### `rejectReadOnly` + +``` +Type: bool +Valid Values: true, false +Default: false +``` + + +`rejectReadOnly=true` causes the driver to reject read-only connections. This +is for a possible race condition during an automatic failover, where the mysql +client gets connected to a read-only replica after the failover. + +Note that this should be a fairly rare case, as an automatic failover normally +happens when the primary is down, and the race condition shouldn't happen +unless it comes back up online as soon as the failover is kicked off. On the +other hand, when this happens, a MySQL application can get stuck on a +read-only connection until restarted. It is however fairly easy to reproduce, +for example, using a manual failover on AWS Aurora's MySQL-compatible cluster. + +If you are not relying on read-only transactions to reject writes that aren't +supposed to happen, setting this on some MySQL providers (such as AWS Aurora) +is safer for failovers. + +Note that ERROR 1290 can be returned for a `read-only` server and this option will +cause a retry for that error. However the same error number is used for some +other cases. You should ensure your application will never cause an ERROR 1290 +except for `read-only` mode when enabling this option. + + +##### `serverPubKey` + +``` +Type: string +Valid Values: +Default: none +``` + +Server public keys can be registered with [`mysql.RegisterServerPubKey`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterServerPubKey), which can then be used by the assigned name in the DSN. +Public keys are used to transmit encrypted data, e.g. for authentication. +If the server's public key is known, it should be set manually to avoid expensive and potentially insecure transmissions of the public key from the server to the client each time it is required. + + +##### `timeout` + +``` +Type: duration +Default: OS default +``` + +Timeout for establishing connections, aka dial timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*. + + +##### `tls` + +``` +Type: bool / string +Valid Values: true, false, skip-verify, +Default: false +``` + +`tls=true` enables TLS / SSL encrypted connection to the server. Use `skip-verify` if you want to use a self-signed or invalid certificate (server side). Use a custom value registered with [`mysql.RegisterTLSConfig`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterTLSConfig). + + +##### `writeTimeout` + +``` +Type: duration +Default: 0 +``` + +I/O write timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*. + + +##### System Variables + +Any other parameters are interpreted as system variables: + * `=`: `SET =` + * `=`: `SET =` + * `=%27%27`: `SET =''` + +Rules: +* The values for string variables must be quoted with `'`. +* The values must also be [url.QueryEscape](http://golang.org/pkg/net/url/#QueryEscape)'ed! + (which implies values of string variables must be wrapped with `%27`). + +Examples: + * `autocommit=1`: `SET autocommit=1` + * [`time_zone=%27Europe%2FParis%27`](https://dev.mysql.com/doc/refman/5.5/en/time-zone-support.html): `SET time_zone='Europe/Paris'` + * [`tx_isolation=%27REPEATABLE-READ%27`](https://dev.mysql.com/doc/refman/5.5/en/server-system-variables.html#sysvar_tx_isolation): `SET tx_isolation='REPEATABLE-READ'` + + +#### Examples +``` +user@unix(/path/to/socket)/dbname +``` + +``` +root:pw@unix(/tmp/mysql.sock)/myDatabase?loc=Local +``` + +``` +user:password@tcp(localhost:5555)/dbname?tls=skip-verify&autocommit=true +``` + +Treat warnings as errors by setting the system variable [`sql_mode`](https://dev.mysql.com/doc/refman/5.7/en/sql-mode.html): +``` +user:password@/dbname?sql_mode=TRADITIONAL +``` + +TCP via IPv6: +``` +user:password@tcp([de:ad:be:ef::ca:fe]:80)/dbname?timeout=90s&collation=utf8mb4_unicode_ci +``` + +TCP on a remote host, e.g. Amazon RDS: +``` +id:password@tcp(your-amazonaws-uri.com:3306)/dbname +``` + +Google Cloud SQL on App Engine (First Generation MySQL Server): +``` +user@cloudsql(project-id:instance-name)/dbname +``` + +Google Cloud SQL on App Engine (Second Generation MySQL Server): +``` +user@cloudsql(project-id:regionname:instance-name)/dbname +``` + +TCP using default port (3306) on localhost: +``` +user:password@tcp/dbname?charset=utf8mb4,utf8&sys_var=esc%40ped +``` + +Use the default protocol (tcp) and host (localhost:3306): +``` +user:password@/dbname +``` + +No Database preselected: +``` +user:password@/ +``` + + +### Connection pool and timeouts +The connection pool is managed by Go's database/sql package. For details on how to configure the size of the pool and how long connections stay in the pool see `*DB.SetMaxOpenConns`, `*DB.SetMaxIdleConns`, and `*DB.SetConnMaxLifetime` in the [database/sql documentation](https://golang.org/pkg/database/sql/). The read, write, and dial timeouts for each individual connection are configured with the DSN parameters [`readTimeout`](#readtimeout), [`writeTimeout`](#writetimeout), and [`timeout`](#timeout), respectively. + +## `ColumnType` Support +This driver supports the [`ColumnType` interface](https://golang.org/pkg/database/sql/#ColumnType) introduced in Go 1.8, with the exception of [`ColumnType.Length()`](https://golang.org/pkg/database/sql/#ColumnType.Length), which is currently not supported. + +## `context.Context` Support +Go 1.8 added `database/sql` support for `context.Context`. This driver supports query timeouts and cancellation via contexts. +See [context support in the database/sql package](https://golang.org/doc/go1.8#database_sql) for more details. + + +### `LOAD DATA LOCAL INFILE` support +For this feature you need direct access to the package. Therefore you must change the import path (no `_`): +```go +import "github.com/go-sql-driver/mysql" +``` + +Files must be whitelisted by registering them with `mysql.RegisterLocalFile(filepath)` (recommended) or the Whitelist check must be deactivated by using the DSN parameter `allowAllFiles=true` ([*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html)). + +To use a `io.Reader` a handler function must be registered with `mysql.RegisterReaderHandler(name, handler)` which returns a `io.Reader` or `io.ReadCloser`. The Reader is available with the filepath `Reader::` then. Choose different names for different handlers and `DeregisterReaderHandler` when you don't need it anymore. + +See the [godoc of Go-MySQL-Driver](https://godoc.org/github.com/go-sql-driver/mysql "golang mysql driver documentation") for details. + + +### `time.Time` support +The default internal output type of MySQL `DATE` and `DATETIME` values is `[]byte` which allows you to scan the value into a `[]byte`, `string` or `sql.RawBytes` variable in your program. + +However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` variables, which is the logical opposite in Go to `DATE` and `DATETIME` in MySQL. You can do that by changing the internal output type from `[]byte` to `time.Time` with the DSN parameter `parseTime=true`. You can set the default [`time.Time` location](https://golang.org/pkg/time/#Location) with the `loc` DSN parameter. + +**Caution:** As of Go 1.1, this makes `time.Time` the only variable type you can scan `DATE` and `DATETIME` values into. This breaks for example [`sql.RawBytes` support](https://github.com/go-sql-driver/mysql/wiki/Examples#rawbytes). + +Alternatively you can use the [`NullTime`](https://godoc.org/github.com/go-sql-driver/mysql#NullTime) type as the scan destination, which works with both `time.Time` and `string` / `[]byte`. + + +### Unicode support +Since version 1.1 Go-MySQL-Driver automatically uses the collation `utf8_general_ci` by default. + +Other collations / charsets can be set using the [`collation`](#collation) DSN parameter. + +Version 1.0 of the driver recommended adding `&charset=utf8` (alias for `SET NAMES utf8`) to the DSN to enable proper UTF-8 support. This is not necessary anymore. The [`collation`](#collation) parameter should be preferred to set another collation / charset than the default. + +See http://dev.mysql.com/doc/refman/5.7/en/charset-unicode.html for more details on MySQL's Unicode support. + +## Testing / Development +To run the driver tests you may need to adjust the configuration. See the [Testing Wiki-Page](https://github.com/go-sql-driver/mysql/wiki/Testing "Testing") for details. + +Go-MySQL-Driver is not feature-complete yet. Your help is very appreciated. +If you want to contribute, you can work on an [open issue](https://github.com/go-sql-driver/mysql/issues?state=open) or review a [pull request](https://github.com/go-sql-driver/mysql/pulls). + +See the [Contribution Guidelines](https://github.com/go-sql-driver/mysql/blob/master/CONTRIBUTING.md) for details. + +--------------------------------------- + +## License +Go-MySQL-Driver is licensed under the [Mozilla Public License Version 2.0](https://raw.github.com/go-sql-driver/mysql/master/LICENSE) + +Mozilla summarizes the license scope as follows: +> MPL: The copyleft applies to any files containing MPLed code. + + +That means: + * You can **use** the **unchanged** source code both in private and commercially. + * When distributing, you **must publish** the source code of any **changed files** licensed under the MPL 2.0 under a) the MPL 2.0 itself or b) a compatible license (e.g. GPL 3.0 or Apache License 2.0). + * You **needn't publish** the source code of your library as long as the files licensed under the MPL 2.0 are **unchanged**. + +Please read the [MPL 2.0 FAQ](https://www.mozilla.org/en-US/MPL/2.0/FAQ/) if you have further questions regarding the license. + +You can read the full terms here: [LICENSE](https://raw.github.com/go-sql-driver/mysql/master/LICENSE). + +![Go Gopher and MySQL Dolphin](https://raw.github.com/wiki/go-sql-driver/mysql/go-mysql-driver_m.jpg "Golang Gopher transporting the MySQL Dolphin in a wheelbarrow") + diff --git a/_integration/go/vendor/github.com/go-sql-driver/mysql/appengine.go b/_integration/go/vendor/github.com/go-sql-driver/mysql/appengine.go new file mode 100644 index 000000000..be41f2ee6 --- /dev/null +++ b/_integration/go/vendor/github.com/go-sql-driver/mysql/appengine.go @@ -0,0 +1,19 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build appengine + +package mysql + +import ( + "google.golang.org/appengine/cloudsql" +) + +func init() { + RegisterDial("cloudsql", cloudsql.Dial) +} diff --git a/_integration/go/vendor/github.com/go-sql-driver/mysql/auth.go b/_integration/go/vendor/github.com/go-sql-driver/mysql/auth.go new file mode 100644 index 000000000..0b59f52ee --- /dev/null +++ b/_integration/go/vendor/github.com/go-sql-driver/mysql/auth.go @@ -0,0 +1,420 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/sha1" + "crypto/sha256" + "crypto/x509" + "encoding/pem" + "sync" +) + +// server pub keys registry +var ( + serverPubKeyLock sync.RWMutex + serverPubKeyRegistry map[string]*rsa.PublicKey +) + +// RegisterServerPubKey registers a server RSA public key which can be used to +// send data in a secure manner to the server without receiving the public key +// in a potentially insecure way from the server first. +// Registered keys can afterwards be used adding serverPubKey= to the DSN. +// +// Note: The provided rsa.PublicKey instance is exclusively owned by the driver +// after registering it and may not be modified. +// +// data, err := ioutil.ReadFile("mykey.pem") +// if err != nil { +// log.Fatal(err) +// } +// +// block, _ := pem.Decode(data) +// if block == nil || block.Type != "PUBLIC KEY" { +// log.Fatal("failed to decode PEM block containing public key") +// } +// +// pub, err := x509.ParsePKIXPublicKey(block.Bytes) +// if err != nil { +// log.Fatal(err) +// } +// +// if rsaPubKey, ok := pub.(*rsa.PublicKey); ok { +// mysql.RegisterServerPubKey("mykey", rsaPubKey) +// } else { +// log.Fatal("not a RSA public key") +// } +// +func RegisterServerPubKey(name string, pubKey *rsa.PublicKey) { + serverPubKeyLock.Lock() + if serverPubKeyRegistry == nil { + serverPubKeyRegistry = make(map[string]*rsa.PublicKey) + } + + serverPubKeyRegistry[name] = pubKey + serverPubKeyLock.Unlock() +} + +// DeregisterServerPubKey removes the public key registered with the given name. +func DeregisterServerPubKey(name string) { + serverPubKeyLock.Lock() + if serverPubKeyRegistry != nil { + delete(serverPubKeyRegistry, name) + } + serverPubKeyLock.Unlock() +} + +func getServerPubKey(name string) (pubKey *rsa.PublicKey) { + serverPubKeyLock.RLock() + if v, ok := serverPubKeyRegistry[name]; ok { + pubKey = v + } + serverPubKeyLock.RUnlock() + return +} + +// Hash password using pre 4.1 (old password) method +// https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c +type myRnd struct { + seed1, seed2 uint32 +} + +const myRndMaxVal = 0x3FFFFFFF + +// Pseudo random number generator +func newMyRnd(seed1, seed2 uint32) *myRnd { + return &myRnd{ + seed1: seed1 % myRndMaxVal, + seed2: seed2 % myRndMaxVal, + } +} + +// Tested to be equivalent to MariaDB's floating point variant +// http://play.golang.org/p/QHvhd4qved +// http://play.golang.org/p/RG0q4ElWDx +func (r *myRnd) NextByte() byte { + r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal + r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal + + return byte(uint64(r.seed1) * 31 / myRndMaxVal) +} + +// Generate binary hash from byte string using insecure pre 4.1 method +func pwHash(password []byte) (result [2]uint32) { + var add uint32 = 7 + var tmp uint32 + + result[0] = 1345345333 + result[1] = 0x12345671 + + for _, c := range password { + // skip spaces and tabs in password + if c == ' ' || c == '\t' { + continue + } + + tmp = uint32(c) + result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8) + result[1] += (result[1] << 8) ^ result[0] + add += tmp + } + + // Remove sign bit (1<<31)-1) + result[0] &= 0x7FFFFFFF + result[1] &= 0x7FFFFFFF + + return +} + +// Hash password using insecure pre 4.1 method +func scrambleOldPassword(scramble []byte, password string) []byte { + if len(password) == 0 { + return nil + } + + scramble = scramble[:8] + + hashPw := pwHash([]byte(password)) + hashSc := pwHash(scramble) + + r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1]) + + var out [8]byte + for i := range out { + out[i] = r.NextByte() + 64 + } + + mask := r.NextByte() + for i := range out { + out[i] ^= mask + } + + return out[:] +} + +// Hash password using 4.1+ method (SHA1) +func scramblePassword(scramble []byte, password string) []byte { + if len(password) == 0 { + return nil + } + + // stage1Hash = SHA1(password) + crypt := sha1.New() + crypt.Write([]byte(password)) + stage1 := crypt.Sum(nil) + + // scrambleHash = SHA1(scramble + SHA1(stage1Hash)) + // inner Hash + crypt.Reset() + crypt.Write(stage1) + hash := crypt.Sum(nil) + + // outer Hash + crypt.Reset() + crypt.Write(scramble) + crypt.Write(hash) + scramble = crypt.Sum(nil) + + // token = scrambleHash XOR stage1Hash + for i := range scramble { + scramble[i] ^= stage1[i] + } + return scramble +} + +// Hash password using MySQL 8+ method (SHA256) +func scrambleSHA256Password(scramble []byte, password string) []byte { + if len(password) == 0 { + return nil + } + + // XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble)) + + crypt := sha256.New() + crypt.Write([]byte(password)) + message1 := crypt.Sum(nil) + + crypt.Reset() + crypt.Write(message1) + message1Hash := crypt.Sum(nil) + + crypt.Reset() + crypt.Write(message1Hash) + crypt.Write(scramble) + message2 := crypt.Sum(nil) + + for i := range message1 { + message1[i] ^= message2[i] + } + + return message1 +} + +func encryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte, error) { + plain := make([]byte, len(password)+1) + copy(plain, password) + for i := range plain { + j := i % len(seed) + plain[i] ^= seed[j] + } + sha1 := sha1.New() + return rsa.EncryptOAEP(sha1, rand.Reader, pub, plain, nil) +} + +func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) error { + enc, err := encryptPassword(mc.cfg.Passwd, seed, pub) + if err != nil { + return err + } + return mc.writeAuthSwitchPacket(enc, false) +} + +func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, bool, error) { + switch plugin { + case "caching_sha2_password": + authResp := scrambleSHA256Password(authData, mc.cfg.Passwd) + return authResp, (authResp == nil), nil + + case "mysql_old_password": + if !mc.cfg.AllowOldPasswords { + return nil, false, ErrOldPassword + } + // Note: there are edge cases where this should work but doesn't; + // this is currently "wontfix": + // https://github.com/go-sql-driver/mysql/issues/184 + authResp := scrambleOldPassword(authData[:8], mc.cfg.Passwd) + return authResp, true, nil + + case "mysql_clear_password": + if !mc.cfg.AllowCleartextPasswords { + return nil, false, ErrCleartextPassword + } + // http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html + // http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html + return []byte(mc.cfg.Passwd), true, nil + + case "mysql_native_password": + if !mc.cfg.AllowNativePasswords { + return nil, false, ErrNativePassword + } + // https://dev.mysql.com/doc/internals/en/secure-password-authentication.html + // Native password authentication only need and will need 20-byte challenge. + authResp := scramblePassword(authData[:20], mc.cfg.Passwd) + return authResp, false, nil + + case "sha256_password": + if len(mc.cfg.Passwd) == 0 { + return nil, true, nil + } + if mc.cfg.tls != nil || mc.cfg.Net == "unix" { + // write cleartext auth packet + return []byte(mc.cfg.Passwd), true, nil + } + + pubKey := mc.cfg.pubKey + if pubKey == nil { + // request public key from server + return []byte{1}, false, nil + } + + // encrypted password + enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey) + return enc, false, err + + default: + errLog.Print("unknown auth plugin:", plugin) + return nil, false, ErrUnknownPlugin + } +} + +func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { + // Read Result Packet + authData, newPlugin, err := mc.readAuthResult() + if err != nil { + return err + } + + // handle auth plugin switch, if requested + if newPlugin != "" { + // If CLIENT_PLUGIN_AUTH capability is not supported, no new cipher is + // sent and we have to keep using the cipher sent in the init packet. + if authData == nil { + authData = oldAuthData + } else { + // copy data from read buffer to owned slice + copy(oldAuthData, authData) + } + + plugin = newPlugin + + authResp, addNUL, err := mc.auth(authData, plugin) + if err != nil { + return err + } + if err = mc.writeAuthSwitchPacket(authResp, addNUL); err != nil { + return err + } + + // Read Result Packet + authData, newPlugin, err = mc.readAuthResult() + if err != nil { + return err + } + + // Do not allow to change the auth plugin more than once + if newPlugin != "" { + return ErrMalformPkt + } + } + + switch plugin { + + // https://insidemysql.com/preparing-your-community-connector-for-mysql-8-part-2-sha256/ + case "caching_sha2_password": + switch len(authData) { + case 0: + return nil // auth successful + case 1: + switch authData[0] { + case cachingSha2PasswordFastAuthSuccess: + if err = mc.readResultOK(); err == nil { + return nil // auth successful + } + + case cachingSha2PasswordPerformFullAuthentication: + if mc.cfg.tls != nil || mc.cfg.Net == "unix" { + // write cleartext auth packet + err = mc.writeAuthSwitchPacket([]byte(mc.cfg.Passwd), true) + if err != nil { + return err + } + } else { + pubKey := mc.cfg.pubKey + if pubKey == nil { + // request public key from server + data := mc.buf.takeSmallBuffer(4 + 1) + data[4] = cachingSha2PasswordRequestPublicKey + mc.writePacket(data) + + // parse public key + data, err := mc.readPacket() + if err != nil { + return err + } + + block, _ := pem.Decode(data[1:]) + pkix, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return err + } + pubKey = pkix.(*rsa.PublicKey) + } + + // send encrypted password + err = mc.sendEncryptedPassword(oldAuthData, pubKey) + if err != nil { + return err + } + } + return mc.readResultOK() + + default: + return ErrMalformPkt + } + default: + return ErrMalformPkt + } + + case "sha256_password": + switch len(authData) { + case 0: + return nil // auth successful + default: + block, _ := pem.Decode(authData) + pub, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return err + } + + // send encrypted password + err = mc.sendEncryptedPassword(oldAuthData, pub.(*rsa.PublicKey)) + if err != nil { + return err + } + return mc.readResultOK() + } + + default: + return nil // auth successful + } + + return err +} diff --git a/_integration/go/vendor/github.com/go-sql-driver/mysql/buffer.go b/_integration/go/vendor/github.com/go-sql-driver/mysql/buffer.go new file mode 100644 index 000000000..eb4748bf4 --- /dev/null +++ b/_integration/go/vendor/github.com/go-sql-driver/mysql/buffer.go @@ -0,0 +1,147 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "io" + "net" + "time" +) + +const defaultBufSize = 4096 + +// A buffer which is used for both reading and writing. +// This is possible since communication on each connection is synchronous. +// In other words, we can't write and read simultaneously on the same connection. +// The buffer is similar to bufio.Reader / Writer but zero-copy-ish +// Also highly optimized for this particular use case. +type buffer struct { + buf []byte + nc net.Conn + idx int + length int + timeout time.Duration +} + +func newBuffer(nc net.Conn) buffer { + var b [defaultBufSize]byte + return buffer{ + buf: b[:], + nc: nc, + } +} + +// fill reads into the buffer until at least _need_ bytes are in it +func (b *buffer) fill(need int) error { + n := b.length + + // move existing data to the beginning + if n > 0 && b.idx > 0 { + copy(b.buf[0:n], b.buf[b.idx:]) + } + + // grow buffer if necessary + // TODO: let the buffer shrink again at some point + // Maybe keep the org buf slice and swap back? + if need > len(b.buf) { + // Round up to the next multiple of the default size + newBuf := make([]byte, ((need/defaultBufSize)+1)*defaultBufSize) + copy(newBuf, b.buf) + b.buf = newBuf + } + + b.idx = 0 + + for { + if b.timeout > 0 { + if err := b.nc.SetReadDeadline(time.Now().Add(b.timeout)); err != nil { + return err + } + } + + nn, err := b.nc.Read(b.buf[n:]) + n += nn + + switch err { + case nil: + if n < need { + continue + } + b.length = n + return nil + + case io.EOF: + if n >= need { + b.length = n + return nil + } + return io.ErrUnexpectedEOF + + default: + return err + } + } +} + +// returns next N bytes from buffer. +// The returned slice is only guaranteed to be valid until the next read +func (b *buffer) readNext(need int) ([]byte, error) { + if b.length < need { + // refill + if err := b.fill(need); err != nil { + return nil, err + } + } + + offset := b.idx + b.idx += need + b.length -= need + return b.buf[offset:b.idx], nil +} + +// returns a buffer with the requested size. +// If possible, a slice from the existing buffer is returned. +// Otherwise a bigger buffer is made. +// Only one buffer (total) can be used at a time. +func (b *buffer) takeBuffer(length int) []byte { + if b.length > 0 { + return nil + } + + // test (cheap) general case first + if length <= defaultBufSize || length <= cap(b.buf) { + return b.buf[:length] + } + + if length < maxPacketSize { + b.buf = make([]byte, length) + return b.buf + } + return make([]byte, length) +} + +// shortcut which can be used if the requested buffer is guaranteed to be +// smaller than defaultBufSize +// Only one buffer (total) can be used at a time. +func (b *buffer) takeSmallBuffer(length int) []byte { + if b.length > 0 { + return nil + } + return b.buf[:length] +} + +// takeCompleteBuffer returns the complete existing buffer. +// This can be used if the necessary buffer size is unknown. +// Only one buffer (total) can be used at a time. +func (b *buffer) takeCompleteBuffer() []byte { + if b.length > 0 { + return nil + } + return b.buf +} diff --git a/_integration/go/vendor/github.com/go-sql-driver/mysql/collations.go b/_integration/go/vendor/github.com/go-sql-driver/mysql/collations.go new file mode 100644 index 000000000..136c9e4d1 --- /dev/null +++ b/_integration/go/vendor/github.com/go-sql-driver/mysql/collations.go @@ -0,0 +1,251 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2014 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +const defaultCollation = "utf8_general_ci" +const binaryCollation = "binary" + +// A list of available collations mapped to the internal ID. +// To update this map use the following MySQL query: +// SELECT COLLATION_NAME, ID FROM information_schema.COLLATIONS +var collations = map[string]byte{ + "big5_chinese_ci": 1, + "latin2_czech_cs": 2, + "dec8_swedish_ci": 3, + "cp850_general_ci": 4, + "latin1_german1_ci": 5, + "hp8_english_ci": 6, + "koi8r_general_ci": 7, + "latin1_swedish_ci": 8, + "latin2_general_ci": 9, + "swe7_swedish_ci": 10, + "ascii_general_ci": 11, + "ujis_japanese_ci": 12, + "sjis_japanese_ci": 13, + "cp1251_bulgarian_ci": 14, + "latin1_danish_ci": 15, + "hebrew_general_ci": 16, + "tis620_thai_ci": 18, + "euckr_korean_ci": 19, + "latin7_estonian_cs": 20, + "latin2_hungarian_ci": 21, + "koi8u_general_ci": 22, + "cp1251_ukrainian_ci": 23, + "gb2312_chinese_ci": 24, + "greek_general_ci": 25, + "cp1250_general_ci": 26, + "latin2_croatian_ci": 27, + "gbk_chinese_ci": 28, + "cp1257_lithuanian_ci": 29, + "latin5_turkish_ci": 30, + "latin1_german2_ci": 31, + "armscii8_general_ci": 32, + "utf8_general_ci": 33, + "cp1250_czech_cs": 34, + "ucs2_general_ci": 35, + "cp866_general_ci": 36, + "keybcs2_general_ci": 37, + "macce_general_ci": 38, + "macroman_general_ci": 39, + "cp852_general_ci": 40, + "latin7_general_ci": 41, + "latin7_general_cs": 42, + "macce_bin": 43, + "cp1250_croatian_ci": 44, + "utf8mb4_general_ci": 45, + "utf8mb4_bin": 46, + "latin1_bin": 47, + "latin1_general_ci": 48, + "latin1_general_cs": 49, + "cp1251_bin": 50, + "cp1251_general_ci": 51, + "cp1251_general_cs": 52, + "macroman_bin": 53, + "utf16_general_ci": 54, + "utf16_bin": 55, + "utf16le_general_ci": 56, + "cp1256_general_ci": 57, + "cp1257_bin": 58, + "cp1257_general_ci": 59, + "utf32_general_ci": 60, + "utf32_bin": 61, + "utf16le_bin": 62, + "binary": 63, + "armscii8_bin": 64, + "ascii_bin": 65, + "cp1250_bin": 66, + "cp1256_bin": 67, + "cp866_bin": 68, + "dec8_bin": 69, + "greek_bin": 70, + "hebrew_bin": 71, + "hp8_bin": 72, + "keybcs2_bin": 73, + "koi8r_bin": 74, + "koi8u_bin": 75, + "latin2_bin": 77, + "latin5_bin": 78, + "latin7_bin": 79, + "cp850_bin": 80, + "cp852_bin": 81, + "swe7_bin": 82, + "utf8_bin": 83, + "big5_bin": 84, + "euckr_bin": 85, + "gb2312_bin": 86, + "gbk_bin": 87, + "sjis_bin": 88, + "tis620_bin": 89, + "ucs2_bin": 90, + "ujis_bin": 91, + "geostd8_general_ci": 92, + "geostd8_bin": 93, + "latin1_spanish_ci": 94, + "cp932_japanese_ci": 95, + "cp932_bin": 96, + "eucjpms_japanese_ci": 97, + "eucjpms_bin": 98, + "cp1250_polish_ci": 99, + "utf16_unicode_ci": 101, + "utf16_icelandic_ci": 102, + "utf16_latvian_ci": 103, + "utf16_romanian_ci": 104, + "utf16_slovenian_ci": 105, + "utf16_polish_ci": 106, + "utf16_estonian_ci": 107, + "utf16_spanish_ci": 108, + "utf16_swedish_ci": 109, + "utf16_turkish_ci": 110, + "utf16_czech_ci": 111, + "utf16_danish_ci": 112, + "utf16_lithuanian_ci": 113, + "utf16_slovak_ci": 114, + "utf16_spanish2_ci": 115, + "utf16_roman_ci": 116, + "utf16_persian_ci": 117, + "utf16_esperanto_ci": 118, + "utf16_hungarian_ci": 119, + "utf16_sinhala_ci": 120, + "utf16_german2_ci": 121, + "utf16_croatian_ci": 122, + "utf16_unicode_520_ci": 123, + "utf16_vietnamese_ci": 124, + "ucs2_unicode_ci": 128, + "ucs2_icelandic_ci": 129, + "ucs2_latvian_ci": 130, + "ucs2_romanian_ci": 131, + "ucs2_slovenian_ci": 132, + "ucs2_polish_ci": 133, + "ucs2_estonian_ci": 134, + "ucs2_spanish_ci": 135, + "ucs2_swedish_ci": 136, + "ucs2_turkish_ci": 137, + "ucs2_czech_ci": 138, + "ucs2_danish_ci": 139, + "ucs2_lithuanian_ci": 140, + "ucs2_slovak_ci": 141, + "ucs2_spanish2_ci": 142, + "ucs2_roman_ci": 143, + "ucs2_persian_ci": 144, + "ucs2_esperanto_ci": 145, + "ucs2_hungarian_ci": 146, + "ucs2_sinhala_ci": 147, + "ucs2_german2_ci": 148, + "ucs2_croatian_ci": 149, + "ucs2_unicode_520_ci": 150, + "ucs2_vietnamese_ci": 151, + "ucs2_general_mysql500_ci": 159, + "utf32_unicode_ci": 160, + "utf32_icelandic_ci": 161, + "utf32_latvian_ci": 162, + "utf32_romanian_ci": 163, + "utf32_slovenian_ci": 164, + "utf32_polish_ci": 165, + "utf32_estonian_ci": 166, + "utf32_spanish_ci": 167, + "utf32_swedish_ci": 168, + "utf32_turkish_ci": 169, + "utf32_czech_ci": 170, + "utf32_danish_ci": 171, + "utf32_lithuanian_ci": 172, + "utf32_slovak_ci": 173, + "utf32_spanish2_ci": 174, + "utf32_roman_ci": 175, + "utf32_persian_ci": 176, + "utf32_esperanto_ci": 177, + "utf32_hungarian_ci": 178, + "utf32_sinhala_ci": 179, + "utf32_german2_ci": 180, + "utf32_croatian_ci": 181, + "utf32_unicode_520_ci": 182, + "utf32_vietnamese_ci": 183, + "utf8_unicode_ci": 192, + "utf8_icelandic_ci": 193, + "utf8_latvian_ci": 194, + "utf8_romanian_ci": 195, + "utf8_slovenian_ci": 196, + "utf8_polish_ci": 197, + "utf8_estonian_ci": 198, + "utf8_spanish_ci": 199, + "utf8_swedish_ci": 200, + "utf8_turkish_ci": 201, + "utf8_czech_ci": 202, + "utf8_danish_ci": 203, + "utf8_lithuanian_ci": 204, + "utf8_slovak_ci": 205, + "utf8_spanish2_ci": 206, + "utf8_roman_ci": 207, + "utf8_persian_ci": 208, + "utf8_esperanto_ci": 209, + "utf8_hungarian_ci": 210, + "utf8_sinhala_ci": 211, + "utf8_german2_ci": 212, + "utf8_croatian_ci": 213, + "utf8_unicode_520_ci": 214, + "utf8_vietnamese_ci": 215, + "utf8_general_mysql500_ci": 223, + "utf8mb4_unicode_ci": 224, + "utf8mb4_icelandic_ci": 225, + "utf8mb4_latvian_ci": 226, + "utf8mb4_romanian_ci": 227, + "utf8mb4_slovenian_ci": 228, + "utf8mb4_polish_ci": 229, + "utf8mb4_estonian_ci": 230, + "utf8mb4_spanish_ci": 231, + "utf8mb4_swedish_ci": 232, + "utf8mb4_turkish_ci": 233, + "utf8mb4_czech_ci": 234, + "utf8mb4_danish_ci": 235, + "utf8mb4_lithuanian_ci": 236, + "utf8mb4_slovak_ci": 237, + "utf8mb4_spanish2_ci": 238, + "utf8mb4_roman_ci": 239, + "utf8mb4_persian_ci": 240, + "utf8mb4_esperanto_ci": 241, + "utf8mb4_hungarian_ci": 242, + "utf8mb4_sinhala_ci": 243, + "utf8mb4_german2_ci": 244, + "utf8mb4_croatian_ci": 245, + "utf8mb4_unicode_520_ci": 246, + "utf8mb4_vietnamese_ci": 247, +} + +// A blacklist of collations which is unsafe to interpolate parameters. +// These multibyte encodings may contains 0x5c (`\`) in their trailing bytes. +var unsafeCollations = map[string]bool{ + "big5_chinese_ci": true, + "sjis_japanese_ci": true, + "gbk_chinese_ci": true, + "big5_bin": true, + "gb2312_bin": true, + "gbk_bin": true, + "sjis_bin": true, + "cp932_japanese_ci": true, + "cp932_bin": true, +} diff --git a/_integration/go/vendor/github.com/go-sql-driver/mysql/connection.go b/_integration/go/vendor/github.com/go-sql-driver/mysql/connection.go new file mode 100644 index 000000000..e57061412 --- /dev/null +++ b/_integration/go/vendor/github.com/go-sql-driver/mysql/connection.go @@ -0,0 +1,461 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "database/sql/driver" + "io" + "net" + "strconv" + "strings" + "time" +) + +// a copy of context.Context for Go 1.7 and earlier +type mysqlContext interface { + Done() <-chan struct{} + Err() error + + // defined in context.Context, but not used in this driver: + // Deadline() (deadline time.Time, ok bool) + // Value(key interface{}) interface{} +} + +type mysqlConn struct { + buf buffer + netConn net.Conn + affectedRows uint64 + insertId uint64 + cfg *Config + maxAllowedPacket int + maxWriteSize int + writeTimeout time.Duration + flags clientFlag + status statusFlag + sequence uint8 + parseTime bool + + // for context support (Go 1.8+) + watching bool + watcher chan<- mysqlContext + closech chan struct{} + finished chan<- struct{} + canceled atomicError // set non-nil if conn is canceled + closed atomicBool // set when conn is closed, before closech is closed +} + +// Handles parameters set in DSN after the connection is established +func (mc *mysqlConn) handleParams() (err error) { + for param, val := range mc.cfg.Params { + switch param { + // Charset + case "charset": + charsets := strings.Split(val, ",") + for i := range charsets { + // ignore errors here - a charset may not exist + err = mc.exec("SET NAMES " + charsets[i]) + if err == nil { + break + } + } + if err != nil { + return + } + + // System Vars + default: + err = mc.exec("SET " + param + "=" + val + "") + if err != nil { + return + } + } + } + + return +} + +func (mc *mysqlConn) markBadConn(err error) error { + if mc == nil { + return err + } + if err != errBadConnNoWrite { + return err + } + return driver.ErrBadConn +} + +func (mc *mysqlConn) Begin() (driver.Tx, error) { + return mc.begin(false) +} + +func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) { + if mc.closed.IsSet() { + errLog.Print(ErrInvalidConn) + return nil, driver.ErrBadConn + } + var q string + if readOnly { + q = "START TRANSACTION READ ONLY" + } else { + q = "START TRANSACTION" + } + err := mc.exec(q) + if err == nil { + return &mysqlTx{mc}, err + } + return nil, mc.markBadConn(err) +} + +func (mc *mysqlConn) Close() (err error) { + // Makes Close idempotent + if !mc.closed.IsSet() { + err = mc.writeCommandPacket(comQuit) + } + + mc.cleanup() + + return +} + +// Closes the network connection and unsets internal variables. Do not call this +// function after successfully authentication, call Close instead. This function +// is called before auth or on auth failure because MySQL will have already +// closed the network connection. +func (mc *mysqlConn) cleanup() { + if !mc.closed.TrySet(true) { + return + } + + // Makes cleanup idempotent + close(mc.closech) + if mc.netConn == nil { + return + } + if err := mc.netConn.Close(); err != nil { + errLog.Print(err) + } +} + +func (mc *mysqlConn) error() error { + if mc.closed.IsSet() { + if err := mc.canceled.Value(); err != nil { + return err + } + return ErrInvalidConn + } + return nil +} + +func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { + if mc.closed.IsSet() { + errLog.Print(ErrInvalidConn) + return nil, driver.ErrBadConn + } + // Send command + err := mc.writeCommandPacketStr(comStmtPrepare, query) + if err != nil { + return nil, mc.markBadConn(err) + } + + stmt := &mysqlStmt{ + mc: mc, + } + + // Read Result + columnCount, err := stmt.readPrepareResultPacket() + if err == nil { + if stmt.paramCount > 0 { + if err = mc.readUntilEOF(); err != nil { + return nil, err + } + } + + if columnCount > 0 { + err = mc.readUntilEOF() + } + } + + return stmt, err +} + +func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) { + // Number of ? should be same to len(args) + if strings.Count(query, "?") != len(args) { + return "", driver.ErrSkip + } + + buf := mc.buf.takeCompleteBuffer() + if buf == nil { + // can not take the buffer. Something must be wrong with the connection + errLog.Print(ErrBusyBuffer) + return "", ErrInvalidConn + } + buf = buf[:0] + argPos := 0 + + for i := 0; i < len(query); i++ { + q := strings.IndexByte(query[i:], '?') + if q == -1 { + buf = append(buf, query[i:]...) + break + } + buf = append(buf, query[i:i+q]...) + i += q + + arg := args[argPos] + argPos++ + + if arg == nil { + buf = append(buf, "NULL"...) + continue + } + + switch v := arg.(type) { + case int64: + buf = strconv.AppendInt(buf, v, 10) + case float64: + buf = strconv.AppendFloat(buf, v, 'g', -1, 64) + case bool: + if v { + buf = append(buf, '1') + } else { + buf = append(buf, '0') + } + case time.Time: + if v.IsZero() { + buf = append(buf, "'0000-00-00'"...) + } else { + v := v.In(mc.cfg.Loc) + v = v.Add(time.Nanosecond * 500) // To round under microsecond + year := v.Year() + year100 := year / 100 + year1 := year % 100 + month := v.Month() + day := v.Day() + hour := v.Hour() + minute := v.Minute() + second := v.Second() + micro := v.Nanosecond() / 1000 + + buf = append(buf, []byte{ + '\'', + digits10[year100], digits01[year100], + digits10[year1], digits01[year1], + '-', + digits10[month], digits01[month], + '-', + digits10[day], digits01[day], + ' ', + digits10[hour], digits01[hour], + ':', + digits10[minute], digits01[minute], + ':', + digits10[second], digits01[second], + }...) + + if micro != 0 { + micro10000 := micro / 10000 + micro100 := micro / 100 % 100 + micro1 := micro % 100 + buf = append(buf, []byte{ + '.', + digits10[micro10000], digits01[micro10000], + digits10[micro100], digits01[micro100], + digits10[micro1], digits01[micro1], + }...) + } + buf = append(buf, '\'') + } + case []byte: + if v == nil { + buf = append(buf, "NULL"...) + } else { + buf = append(buf, "_binary'"...) + if mc.status&statusNoBackslashEscapes == 0 { + buf = escapeBytesBackslash(buf, v) + } else { + buf = escapeBytesQuotes(buf, v) + } + buf = append(buf, '\'') + } + case string: + buf = append(buf, '\'') + if mc.status&statusNoBackslashEscapes == 0 { + buf = escapeStringBackslash(buf, v) + } else { + buf = escapeStringQuotes(buf, v) + } + buf = append(buf, '\'') + default: + return "", driver.ErrSkip + } + + if len(buf)+4 > mc.maxAllowedPacket { + return "", driver.ErrSkip + } + } + if argPos != len(args) { + return "", driver.ErrSkip + } + return string(buf), nil +} + +func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { + if mc.closed.IsSet() { + errLog.Print(ErrInvalidConn) + return nil, driver.ErrBadConn + } + if len(args) != 0 { + if !mc.cfg.InterpolateParams { + return nil, driver.ErrSkip + } + // try to interpolate the parameters to save extra roundtrips for preparing and closing a statement + prepared, err := mc.interpolateParams(query, args) + if err != nil { + return nil, err + } + query = prepared + } + mc.affectedRows = 0 + mc.insertId = 0 + + err := mc.exec(query) + if err == nil { + return &mysqlResult{ + affectedRows: int64(mc.affectedRows), + insertId: int64(mc.insertId), + }, err + } + return nil, mc.markBadConn(err) +} + +// Internal function to execute commands +func (mc *mysqlConn) exec(query string) error { + // Send command + if err := mc.writeCommandPacketStr(comQuery, query); err != nil { + return mc.markBadConn(err) + } + + // Read Result + resLen, err := mc.readResultSetHeaderPacket() + if err != nil { + return err + } + + if resLen > 0 { + // columns + if err := mc.readUntilEOF(); err != nil { + return err + } + + // rows + if err := mc.readUntilEOF(); err != nil { + return err + } + } + + return mc.discardResults() +} + +func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) { + return mc.query(query, args) +} + +func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) { + if mc.closed.IsSet() { + errLog.Print(ErrInvalidConn) + return nil, driver.ErrBadConn + } + if len(args) != 0 { + if !mc.cfg.InterpolateParams { + return nil, driver.ErrSkip + } + // try client-side prepare to reduce roundtrip + prepared, err := mc.interpolateParams(query, args) + if err != nil { + return nil, err + } + query = prepared + } + // Send command + err := mc.writeCommandPacketStr(comQuery, query) + if err == nil { + // Read Result + var resLen int + resLen, err = mc.readResultSetHeaderPacket() + if err == nil { + rows := new(textRows) + rows.mc = mc + + if resLen == 0 { + rows.rs.done = true + + switch err := rows.NextResultSet(); err { + case nil, io.EOF: + return rows, nil + default: + return nil, err + } + } + + // Columns + rows.rs.columns, err = mc.readColumns(resLen) + return rows, err + } + } + return nil, mc.markBadConn(err) +} + +// Gets the value of the given MySQL System Variable +// The returned byte slice is only valid until the next read +func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { + // Send command + if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil { + return nil, err + } + + // Read Result + resLen, err := mc.readResultSetHeaderPacket() + if err == nil { + rows := new(textRows) + rows.mc = mc + rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}} + + if resLen > 0 { + // Columns + if err := mc.readUntilEOF(); err != nil { + return nil, err + } + } + + dest := make([]driver.Value, resLen) + if err = rows.readRow(dest); err == nil { + return dest[0].([]byte), mc.readUntilEOF() + } + } + return nil, err +} + +// finish is called when the query has canceled. +func (mc *mysqlConn) cancel(err error) { + mc.canceled.Set(err) + mc.cleanup() +} + +// finish is called when the query has succeeded. +func (mc *mysqlConn) finish() { + if !mc.watching || mc.finished == nil { + return + } + select { + case mc.finished <- struct{}{}: + mc.watching = false + case <-mc.closech: + } +} diff --git a/_integration/go/vendor/github.com/go-sql-driver/mysql/connection_go18.go b/_integration/go/vendor/github.com/go-sql-driver/mysql/connection_go18.go new file mode 100644 index 000000000..62796bfce --- /dev/null +++ b/_integration/go/vendor/github.com/go-sql-driver/mysql/connection_go18.go @@ -0,0 +1,208 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build go1.8 + +package mysql + +import ( + "context" + "database/sql" + "database/sql/driver" +) + +// Ping implements driver.Pinger interface +func (mc *mysqlConn) Ping(ctx context.Context) (err error) { + if mc.closed.IsSet() { + errLog.Print(ErrInvalidConn) + return driver.ErrBadConn + } + + if err = mc.watchCancel(ctx); err != nil { + return + } + defer mc.finish() + + if err = mc.writeCommandPacket(comPing); err != nil { + return + } + + return mc.readResultOK() +} + +// BeginTx implements driver.ConnBeginTx interface +func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + defer mc.finish() + + if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault { + level, err := mapIsolationLevel(opts.Isolation) + if err != nil { + return nil, err + } + err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level) + if err != nil { + return nil, err + } + } + + return mc.begin(opts.ReadOnly) +} + +func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + + rows, err := mc.query(query, dargs) + if err != nil { + mc.finish() + return nil, err + } + rows.finish = mc.finish + return rows, err +} + +func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + defer mc.finish() + + return mc.Exec(query, dargs) +} + +func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + + stmt, err := mc.Prepare(query) + mc.finish() + if err != nil { + return nil, err + } + + select { + default: + case <-ctx.Done(): + stmt.Close() + return nil, ctx.Err() + } + return stmt, nil +} + +func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + if err := stmt.mc.watchCancel(ctx); err != nil { + return nil, err + } + + rows, err := stmt.query(dargs) + if err != nil { + stmt.mc.finish() + return nil, err + } + rows.finish = stmt.mc.finish + return rows, err +} + +func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + if err := stmt.mc.watchCancel(ctx); err != nil { + return nil, err + } + defer stmt.mc.finish() + + return stmt.Exec(dargs) +} + +func (mc *mysqlConn) watchCancel(ctx context.Context) error { + if mc.watching { + // Reach here if canceled, + // so the connection is already invalid + mc.cleanup() + return nil + } + if ctx.Done() == nil { + return nil + } + + mc.watching = true + select { + default: + case <-ctx.Done(): + return ctx.Err() + } + if mc.watcher == nil { + return nil + } + + mc.watcher <- ctx + + return nil +} + +func (mc *mysqlConn) startWatcher() { + watcher := make(chan mysqlContext, 1) + mc.watcher = watcher + finished := make(chan struct{}) + mc.finished = finished + go func() { + for { + var ctx mysqlContext + select { + case ctx = <-watcher: + case <-mc.closech: + return + } + + select { + case <-ctx.Done(): + mc.cancel(ctx.Err()) + case <-finished: + case <-mc.closech: + return + } + } + }() +} + +func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) { + nv.Value, err = converter{}.ConvertValue(nv.Value) + return +} + +// ResetSession implements driver.SessionResetter. +// (From Go 1.10) +func (mc *mysqlConn) ResetSession(ctx context.Context) error { + if mc.closed.IsSet() { + return driver.ErrBadConn + } + return nil +} diff --git a/_integration/go/vendor/github.com/go-sql-driver/mysql/const.go b/_integration/go/vendor/github.com/go-sql-driver/mysql/const.go new file mode 100644 index 000000000..b1e6b85ef --- /dev/null +++ b/_integration/go/vendor/github.com/go-sql-driver/mysql/const.go @@ -0,0 +1,174 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +const ( + defaultAuthPlugin = "mysql_native_password" + defaultMaxAllowedPacket = 4 << 20 // 4 MiB + minProtocolVersion = 10 + maxPacketSize = 1<<24 - 1 + timeFormat = "2006-01-02 15:04:05.999999" +) + +// MySQL constants documentation: +// http://dev.mysql.com/doc/internals/en/client-server-protocol.html + +const ( + iOK byte = 0x00 + iAuthMoreData byte = 0x01 + iLocalInFile byte = 0xfb + iEOF byte = 0xfe + iERR byte = 0xff +) + +// https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags +type clientFlag uint32 + +const ( + clientLongPassword clientFlag = 1 << iota + clientFoundRows + clientLongFlag + clientConnectWithDB + clientNoSchema + clientCompress + clientODBC + clientLocalFiles + clientIgnoreSpace + clientProtocol41 + clientInteractive + clientSSL + clientIgnoreSIGPIPE + clientTransactions + clientReserved + clientSecureConn + clientMultiStatements + clientMultiResults + clientPSMultiResults + clientPluginAuth + clientConnectAttrs + clientPluginAuthLenEncClientData + clientCanHandleExpiredPasswords + clientSessionTrack + clientDeprecateEOF +) + +const ( + comQuit byte = iota + 1 + comInitDB + comQuery + comFieldList + comCreateDB + comDropDB + comRefresh + comShutdown + comStatistics + comProcessInfo + comConnect + comProcessKill + comDebug + comPing + comTime + comDelayedInsert + comChangeUser + comBinlogDump + comTableDump + comConnectOut + comRegisterSlave + comStmtPrepare + comStmtExecute + comStmtSendLongData + comStmtClose + comStmtReset + comSetOption + comStmtFetch +) + +// https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnType +type fieldType byte + +const ( + fieldTypeDecimal fieldType = iota + fieldTypeTiny + fieldTypeShort + fieldTypeLong + fieldTypeFloat + fieldTypeDouble + fieldTypeNULL + fieldTypeTimestamp + fieldTypeLongLong + fieldTypeInt24 + fieldTypeDate + fieldTypeTime + fieldTypeDateTime + fieldTypeYear + fieldTypeNewDate + fieldTypeVarChar + fieldTypeBit +) +const ( + fieldTypeJSON fieldType = iota + 0xf5 + fieldTypeNewDecimal + fieldTypeEnum + fieldTypeSet + fieldTypeTinyBLOB + fieldTypeMediumBLOB + fieldTypeLongBLOB + fieldTypeBLOB + fieldTypeVarString + fieldTypeString + fieldTypeGeometry +) + +type fieldFlag uint16 + +const ( + flagNotNULL fieldFlag = 1 << iota + flagPriKey + flagUniqueKey + flagMultipleKey + flagBLOB + flagUnsigned + flagZeroFill + flagBinary + flagEnum + flagAutoIncrement + flagTimestamp + flagSet + flagUnknown1 + flagUnknown2 + flagUnknown3 + flagUnknown4 +) + +// http://dev.mysql.com/doc/internals/en/status-flags.html +type statusFlag uint16 + +const ( + statusInTrans statusFlag = 1 << iota + statusInAutocommit + statusReserved // Not in documentation + statusMoreResultsExists + statusNoGoodIndexUsed + statusNoIndexUsed + statusCursorExists + statusLastRowSent + statusDbDropped + statusNoBackslashEscapes + statusMetadataChanged + statusQueryWasSlow + statusPsOutParams + statusInTransReadonly + statusSessionStateChanged +) + +const ( + cachingSha2PasswordRequestPublicKey = 2 + cachingSha2PasswordFastAuthSuccess = 3 + cachingSha2PasswordPerformFullAuthentication = 4 +) diff --git a/_integration/go/vendor/github.com/go-sql-driver/mysql/driver.go b/_integration/go/vendor/github.com/go-sql-driver/mysql/driver.go new file mode 100644 index 000000000..1a75a16ec --- /dev/null +++ b/_integration/go/vendor/github.com/go-sql-driver/mysql/driver.go @@ -0,0 +1,169 @@ +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// Package mysql provides a MySQL driver for Go's database/sql package. +// +// The driver should be used via the database/sql package: +// +// import "database/sql" +// import _ "github.com/go-sql-driver/mysql" +// +// db, err := sql.Open("mysql", "user:password@/dbname") +// +// See https://github.com/go-sql-driver/mysql#usage for details +package mysql + +import ( + "database/sql" + "database/sql/driver" + "net" + "sync" +) + +// watcher interface is used for context support (From Go 1.8) +type watcher interface { + startWatcher() +} + +// MySQLDriver is exported to make the driver directly accessible. +// In general the driver is used via the database/sql package. +type MySQLDriver struct{} + +// DialFunc is a function which can be used to establish the network connection. +// Custom dial functions must be registered with RegisterDial +type DialFunc func(addr string) (net.Conn, error) + +var ( + dialsLock sync.RWMutex + dials map[string]DialFunc +) + +// RegisterDial registers a custom dial function. It can then be used by the +// network address mynet(addr), where mynet is the registered new network. +// addr is passed as a parameter to the dial function. +func RegisterDial(net string, dial DialFunc) { + dialsLock.Lock() + defer dialsLock.Unlock() + if dials == nil { + dials = make(map[string]DialFunc) + } + dials[net] = dial +} + +// Open new Connection. +// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how +// the DSN string is formated +func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { + var err error + + // New mysqlConn + mc := &mysqlConn{ + maxAllowedPacket: maxPacketSize, + maxWriteSize: maxPacketSize - 1, + closech: make(chan struct{}), + } + mc.cfg, err = ParseDSN(dsn) + if err != nil { + return nil, err + } + mc.parseTime = mc.cfg.ParseTime + + // Connect to Server + dialsLock.RLock() + dial, ok := dials[mc.cfg.Net] + dialsLock.RUnlock() + if ok { + mc.netConn, err = dial(mc.cfg.Addr) + } else { + nd := net.Dialer{Timeout: mc.cfg.Timeout} + mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr) + } + if err != nil { + return nil, err + } + + // Enable TCP Keepalives on TCP connections + if tc, ok := mc.netConn.(*net.TCPConn); ok { + if err := tc.SetKeepAlive(true); err != nil { + // Don't send COM_QUIT before handshake. + mc.netConn.Close() + mc.netConn = nil + return nil, err + } + } + + // Call startWatcher for context support (From Go 1.8) + if s, ok := interface{}(mc).(watcher); ok { + s.startWatcher() + } + + mc.buf = newBuffer(mc.netConn) + + // Set I/O timeouts + mc.buf.timeout = mc.cfg.ReadTimeout + mc.writeTimeout = mc.cfg.WriteTimeout + + // Reading Handshake Initialization Packet + authData, plugin, err := mc.readHandshakePacket() + if err != nil { + mc.cleanup() + return nil, err + } + + // Send Client Authentication Packet + authResp, addNUL, err := mc.auth(authData, plugin) + if err != nil { + // try the default auth plugin, if using the requested plugin failed + errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error()) + plugin = defaultAuthPlugin + authResp, addNUL, err = mc.auth(authData, plugin) + if err != nil { + mc.cleanup() + return nil, err + } + } + if err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin); err != nil { + mc.cleanup() + return nil, err + } + + // Handle response to auth packet, switch methods if possible + if err = mc.handleAuthResult(authData, plugin); err != nil { + // Authentication failed and MySQL has already closed the connection + // (https://dev.mysql.com/doc/internals/en/authentication-fails.html). + // Do not send COM_QUIT, just cleanup and return the error. + mc.cleanup() + return nil, err + } + + if mc.cfg.MaxAllowedPacket > 0 { + mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket + } else { + // Get max allowed packet size + maxap, err := mc.getSystemVar("max_allowed_packet") + if err != nil { + mc.Close() + return nil, err + } + mc.maxAllowedPacket = stringToInt(maxap) - 1 + } + if mc.maxAllowedPacket < maxPacketSize { + mc.maxWriteSize = mc.maxAllowedPacket + } + + // Handle DSN Params + err = mc.handleParams() + if err != nil { + mc.Close() + return nil, err + } + + return mc, nil +} + +func init() { + sql.Register("mysql", &MySQLDriver{}) +} diff --git a/_integration/go/vendor/github.com/go-sql-driver/mysql/dsn.go b/_integration/go/vendor/github.com/go-sql-driver/mysql/dsn.go new file mode 100644 index 000000000..be014babe --- /dev/null +++ b/_integration/go/vendor/github.com/go-sql-driver/mysql/dsn.go @@ -0,0 +1,611 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "bytes" + "crypto/rsa" + "crypto/tls" + "errors" + "fmt" + "net" + "net/url" + "sort" + "strconv" + "strings" + "time" +) + +var ( + errInvalidDSNUnescaped = errors.New("invalid DSN: did you forget to escape a param value?") + errInvalidDSNAddr = errors.New("invalid DSN: network address not terminated (missing closing brace)") + errInvalidDSNNoSlash = errors.New("invalid DSN: missing the slash separating the database name") + errInvalidDSNUnsafeCollation = errors.New("invalid DSN: interpolateParams can not be used with unsafe collations") +) + +// Config is a configuration parsed from a DSN string. +// If a new Config is created instead of being parsed from a DSN string, +// the NewConfig function should be used, which sets default values. +type Config struct { + User string // Username + Passwd string // Password (requires User) + Net string // Network type + Addr string // Network address (requires Net) + DBName string // Database name + Params map[string]string // Connection parameters + Collation string // Connection collation + Loc *time.Location // Location for time.Time values + MaxAllowedPacket int // Max packet size allowed + ServerPubKey string // Server public key name + pubKey *rsa.PublicKey // Server public key + TLSConfig string // TLS configuration name + tls *tls.Config // TLS configuration + Timeout time.Duration // Dial timeout + ReadTimeout time.Duration // I/O read timeout + WriteTimeout time.Duration // I/O write timeout + + AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE + AllowCleartextPasswords bool // Allows the cleartext client side plugin + AllowNativePasswords bool // Allows the native password authentication method + AllowOldPasswords bool // Allows the old insecure password method + ClientFoundRows bool // Return number of matching rows instead of rows changed + ColumnsWithAlias bool // Prepend table alias to column names + InterpolateParams bool // Interpolate placeholders into query string + MultiStatements bool // Allow multiple statements in one query + ParseTime bool // Parse time values to time.Time + RejectReadOnly bool // Reject read-only connections +} + +// NewConfig creates a new Config and sets default values. +func NewConfig() *Config { + return &Config{ + Collation: defaultCollation, + Loc: time.UTC, + MaxAllowedPacket: defaultMaxAllowedPacket, + AllowNativePasswords: true, + } +} + +func (cfg *Config) normalize() error { + if cfg.InterpolateParams && unsafeCollations[cfg.Collation] { + return errInvalidDSNUnsafeCollation + } + + // Set default network if empty + if cfg.Net == "" { + cfg.Net = "tcp" + } + + // Set default address if empty + if cfg.Addr == "" { + switch cfg.Net { + case "tcp": + cfg.Addr = "127.0.0.1:3306" + case "unix": + cfg.Addr = "/tmp/mysql.sock" + default: + return errors.New("default addr for network '" + cfg.Net + "' unknown") + } + + } else if cfg.Net == "tcp" { + cfg.Addr = ensureHavePort(cfg.Addr) + } + + if cfg.tls != nil { + if cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify { + host, _, err := net.SplitHostPort(cfg.Addr) + if err == nil { + cfg.tls.ServerName = host + } + } + } + + return nil +} + +// FormatDSN formats the given Config into a DSN string which can be passed to +// the driver. +func (cfg *Config) FormatDSN() string { + var buf bytes.Buffer + + // [username[:password]@] + if len(cfg.User) > 0 { + buf.WriteString(cfg.User) + if len(cfg.Passwd) > 0 { + buf.WriteByte(':') + buf.WriteString(cfg.Passwd) + } + buf.WriteByte('@') + } + + // [protocol[(address)]] + if len(cfg.Net) > 0 { + buf.WriteString(cfg.Net) + if len(cfg.Addr) > 0 { + buf.WriteByte('(') + buf.WriteString(cfg.Addr) + buf.WriteByte(')') + } + } + + // /dbname + buf.WriteByte('/') + buf.WriteString(cfg.DBName) + + // [?param1=value1&...¶mN=valueN] + hasParam := false + + if cfg.AllowAllFiles { + hasParam = true + buf.WriteString("?allowAllFiles=true") + } + + if cfg.AllowCleartextPasswords { + if hasParam { + buf.WriteString("&allowCleartextPasswords=true") + } else { + hasParam = true + buf.WriteString("?allowCleartextPasswords=true") + } + } + + if !cfg.AllowNativePasswords { + if hasParam { + buf.WriteString("&allowNativePasswords=false") + } else { + hasParam = true + buf.WriteString("?allowNativePasswords=false") + } + } + + if cfg.AllowOldPasswords { + if hasParam { + buf.WriteString("&allowOldPasswords=true") + } else { + hasParam = true + buf.WriteString("?allowOldPasswords=true") + } + } + + if cfg.ClientFoundRows { + if hasParam { + buf.WriteString("&clientFoundRows=true") + } else { + hasParam = true + buf.WriteString("?clientFoundRows=true") + } + } + + if col := cfg.Collation; col != defaultCollation && len(col) > 0 { + if hasParam { + buf.WriteString("&collation=") + } else { + hasParam = true + buf.WriteString("?collation=") + } + buf.WriteString(col) + } + + if cfg.ColumnsWithAlias { + if hasParam { + buf.WriteString("&columnsWithAlias=true") + } else { + hasParam = true + buf.WriteString("?columnsWithAlias=true") + } + } + + if cfg.InterpolateParams { + if hasParam { + buf.WriteString("&interpolateParams=true") + } else { + hasParam = true + buf.WriteString("?interpolateParams=true") + } + } + + if cfg.Loc != time.UTC && cfg.Loc != nil { + if hasParam { + buf.WriteString("&loc=") + } else { + hasParam = true + buf.WriteString("?loc=") + } + buf.WriteString(url.QueryEscape(cfg.Loc.String())) + } + + if cfg.MultiStatements { + if hasParam { + buf.WriteString("&multiStatements=true") + } else { + hasParam = true + buf.WriteString("?multiStatements=true") + } + } + + if cfg.ParseTime { + if hasParam { + buf.WriteString("&parseTime=true") + } else { + hasParam = true + buf.WriteString("?parseTime=true") + } + } + + if cfg.ReadTimeout > 0 { + if hasParam { + buf.WriteString("&readTimeout=") + } else { + hasParam = true + buf.WriteString("?readTimeout=") + } + buf.WriteString(cfg.ReadTimeout.String()) + } + + if cfg.RejectReadOnly { + if hasParam { + buf.WriteString("&rejectReadOnly=true") + } else { + hasParam = true + buf.WriteString("?rejectReadOnly=true") + } + } + + if len(cfg.ServerPubKey) > 0 { + if hasParam { + buf.WriteString("&serverPubKey=") + } else { + hasParam = true + buf.WriteString("?serverPubKey=") + } + buf.WriteString(url.QueryEscape(cfg.ServerPubKey)) + } + + if cfg.Timeout > 0 { + if hasParam { + buf.WriteString("&timeout=") + } else { + hasParam = true + buf.WriteString("?timeout=") + } + buf.WriteString(cfg.Timeout.String()) + } + + if len(cfg.TLSConfig) > 0 { + if hasParam { + buf.WriteString("&tls=") + } else { + hasParam = true + buf.WriteString("?tls=") + } + buf.WriteString(url.QueryEscape(cfg.TLSConfig)) + } + + if cfg.WriteTimeout > 0 { + if hasParam { + buf.WriteString("&writeTimeout=") + } else { + hasParam = true + buf.WriteString("?writeTimeout=") + } + buf.WriteString(cfg.WriteTimeout.String()) + } + + if cfg.MaxAllowedPacket != defaultMaxAllowedPacket { + if hasParam { + buf.WriteString("&maxAllowedPacket=") + } else { + hasParam = true + buf.WriteString("?maxAllowedPacket=") + } + buf.WriteString(strconv.Itoa(cfg.MaxAllowedPacket)) + + } + + // other params + if cfg.Params != nil { + var params []string + for param := range cfg.Params { + params = append(params, param) + } + sort.Strings(params) + for _, param := range params { + if hasParam { + buf.WriteByte('&') + } else { + hasParam = true + buf.WriteByte('?') + } + + buf.WriteString(param) + buf.WriteByte('=') + buf.WriteString(url.QueryEscape(cfg.Params[param])) + } + } + + return buf.String() +} + +// ParseDSN parses the DSN string to a Config +func ParseDSN(dsn string) (cfg *Config, err error) { + // New config with some default values + cfg = NewConfig() + + // [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN] + // Find the last '/' (since the password or the net addr might contain a '/') + foundSlash := false + for i := len(dsn) - 1; i >= 0; i-- { + if dsn[i] == '/' { + foundSlash = true + var j, k int + + // left part is empty if i <= 0 + if i > 0 { + // [username[:password]@][protocol[(address)]] + // Find the last '@' in dsn[:i] + for j = i; j >= 0; j-- { + if dsn[j] == '@' { + // username[:password] + // Find the first ':' in dsn[:j] + for k = 0; k < j; k++ { + if dsn[k] == ':' { + cfg.Passwd = dsn[k+1 : j] + break + } + } + cfg.User = dsn[:k] + + break + } + } + + // [protocol[(address)]] + // Find the first '(' in dsn[j+1:i] + for k = j + 1; k < i; k++ { + if dsn[k] == '(' { + // dsn[i-1] must be == ')' if an address is specified + if dsn[i-1] != ')' { + if strings.ContainsRune(dsn[k+1:i], ')') { + return nil, errInvalidDSNUnescaped + } + return nil, errInvalidDSNAddr + } + cfg.Addr = dsn[k+1 : i-1] + break + } + } + cfg.Net = dsn[j+1 : k] + } + + // dbname[?param1=value1&...¶mN=valueN] + // Find the first '?' in dsn[i+1:] + for j = i + 1; j < len(dsn); j++ { + if dsn[j] == '?' { + if err = parseDSNParams(cfg, dsn[j+1:]); err != nil { + return + } + break + } + } + cfg.DBName = dsn[i+1 : j] + + break + } + } + + if !foundSlash && len(dsn) > 0 { + return nil, errInvalidDSNNoSlash + } + + if err = cfg.normalize(); err != nil { + return nil, err + } + return +} + +// parseDSNParams parses the DSN "query string" +// Values must be url.QueryEscape'ed +func parseDSNParams(cfg *Config, params string) (err error) { + for _, v := range strings.Split(params, "&") { + param := strings.SplitN(v, "=", 2) + if len(param) != 2 { + continue + } + + // cfg params + switch value := param[1]; param[0] { + // Disable INFILE whitelist / enable all files + case "allowAllFiles": + var isBool bool + cfg.AllowAllFiles, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // Use cleartext authentication mode (MySQL 5.5.10+) + case "allowCleartextPasswords": + var isBool bool + cfg.AllowCleartextPasswords, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // Use native password authentication + case "allowNativePasswords": + var isBool bool + cfg.AllowNativePasswords, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // Use old authentication mode (pre MySQL 4.1) + case "allowOldPasswords": + var isBool bool + cfg.AllowOldPasswords, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // Switch "rowsAffected" mode + case "clientFoundRows": + var isBool bool + cfg.ClientFoundRows, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // Collation + case "collation": + cfg.Collation = value + break + + case "columnsWithAlias": + var isBool bool + cfg.ColumnsWithAlias, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // Compression + case "compress": + return errors.New("compression not implemented yet") + + // Enable client side placeholder substitution + case "interpolateParams": + var isBool bool + cfg.InterpolateParams, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // Time Location + case "loc": + if value, err = url.QueryUnescape(value); err != nil { + return + } + cfg.Loc, err = time.LoadLocation(value) + if err != nil { + return + } + + // multiple statements in one query + case "multiStatements": + var isBool bool + cfg.MultiStatements, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // time.Time parsing + case "parseTime": + var isBool bool + cfg.ParseTime, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // I/O read Timeout + case "readTimeout": + cfg.ReadTimeout, err = time.ParseDuration(value) + if err != nil { + return + } + + // Reject read-only connections + case "rejectReadOnly": + var isBool bool + cfg.RejectReadOnly, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // Server public key + case "serverPubKey": + name, err := url.QueryUnescape(value) + if err != nil { + return fmt.Errorf("invalid value for server pub key name: %v", err) + } + + if pubKey := getServerPubKey(name); pubKey != nil { + cfg.ServerPubKey = name + cfg.pubKey = pubKey + } else { + return errors.New("invalid value / unknown server pub key name: " + name) + } + + // Strict mode + case "strict": + panic("strict mode has been removed. See https://github.com/go-sql-driver/mysql/wiki/strict-mode") + + // Dial Timeout + case "timeout": + cfg.Timeout, err = time.ParseDuration(value) + if err != nil { + return + } + + // TLS-Encryption + case "tls": + boolValue, isBool := readBool(value) + if isBool { + if boolValue { + cfg.TLSConfig = "true" + cfg.tls = &tls.Config{} + } else { + cfg.TLSConfig = "false" + } + } else if vl := strings.ToLower(value); vl == "skip-verify" { + cfg.TLSConfig = vl + cfg.tls = &tls.Config{InsecureSkipVerify: true} + } else { + name, err := url.QueryUnescape(value) + if err != nil { + return fmt.Errorf("invalid value for TLS config name: %v", err) + } + + if tlsConfig := getTLSConfigClone(name); tlsConfig != nil { + cfg.TLSConfig = name + cfg.tls = tlsConfig + } else { + return errors.New("invalid value / unknown config name: " + name) + } + } + + // I/O write Timeout + case "writeTimeout": + cfg.WriteTimeout, err = time.ParseDuration(value) + if err != nil { + return + } + case "maxAllowedPacket": + cfg.MaxAllowedPacket, err = strconv.Atoi(value) + if err != nil { + return + } + default: + // lazy init + if cfg.Params == nil { + cfg.Params = make(map[string]string) + } + + if cfg.Params[param[0]], err = url.QueryUnescape(value); err != nil { + return + } + } + } + + return +} + +func ensureHavePort(addr string) string { + if _, _, err := net.SplitHostPort(addr); err != nil { + return net.JoinHostPort(addr, "3306") + } + return addr +} diff --git a/_integration/go/vendor/github.com/go-sql-driver/mysql/errors.go b/_integration/go/vendor/github.com/go-sql-driver/mysql/errors.go new file mode 100644 index 000000000..760782ff2 --- /dev/null +++ b/_integration/go/vendor/github.com/go-sql-driver/mysql/errors.go @@ -0,0 +1,65 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "errors" + "fmt" + "log" + "os" +) + +// Various errors the driver might return. Can change between driver versions. +var ( + ErrInvalidConn = errors.New("invalid connection") + ErrMalformPkt = errors.New("malformed packet") + ErrNoTLS = errors.New("TLS requested but server does not support TLS") + ErrCleartextPassword = errors.New("this user requires clear text authentication. If you still want to use it, please add 'allowCleartextPasswords=1' to your DSN") + ErrNativePassword = errors.New("this user requires mysql native password authentication.") + ErrOldPassword = errors.New("this user requires old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords") + ErrUnknownPlugin = errors.New("this authentication plugin is not supported") + ErrOldProtocol = errors.New("MySQL server does not support required protocol 41+") + ErrPktSync = errors.New("commands out of sync. You can't run this command now") + ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?") + ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server") + ErrBusyBuffer = errors.New("busy buffer") + + // errBadConnNoWrite is used for connection errors where nothing was sent to the database yet. + // If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn + // to trigger a resend. + // See https://github.com/go-sql-driver/mysql/pull/302 + errBadConnNoWrite = errors.New("bad connection") +) + +var errLog = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile)) + +// Logger is used to log critical error messages. +type Logger interface { + Print(v ...interface{}) +} + +// SetLogger is used to set the logger for critical errors. +// The initial logger is os.Stderr. +func SetLogger(logger Logger) error { + if logger == nil { + return errors.New("logger is nil") + } + errLog = logger + return nil +} + +// MySQLError is an error type which represents a single MySQL error +type MySQLError struct { + Number uint16 + Message string +} + +func (me *MySQLError) Error() string { + return fmt.Sprintf("Error %d: %s", me.Number, me.Message) +} diff --git a/_integration/go/vendor/github.com/go-sql-driver/mysql/fields.go b/_integration/go/vendor/github.com/go-sql-driver/mysql/fields.go new file mode 100644 index 000000000..e1e2ece4b --- /dev/null +++ b/_integration/go/vendor/github.com/go-sql-driver/mysql/fields.go @@ -0,0 +1,194 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "database/sql" + "reflect" +) + +func (mf *mysqlField) typeDatabaseName() string { + switch mf.fieldType { + case fieldTypeBit: + return "BIT" + case fieldTypeBLOB: + if mf.charSet != collations[binaryCollation] { + return "TEXT" + } + return "BLOB" + case fieldTypeDate: + return "DATE" + case fieldTypeDateTime: + return "DATETIME" + case fieldTypeDecimal: + return "DECIMAL" + case fieldTypeDouble: + return "DOUBLE" + case fieldTypeEnum: + return "ENUM" + case fieldTypeFloat: + return "FLOAT" + case fieldTypeGeometry: + return "GEOMETRY" + case fieldTypeInt24: + return "MEDIUMINT" + case fieldTypeJSON: + return "JSON" + case fieldTypeLong: + return "INT" + case fieldTypeLongBLOB: + if mf.charSet != collations[binaryCollation] { + return "LONGTEXT" + } + return "LONGBLOB" + case fieldTypeLongLong: + return "BIGINT" + case fieldTypeMediumBLOB: + if mf.charSet != collations[binaryCollation] { + return "MEDIUMTEXT" + } + return "MEDIUMBLOB" + case fieldTypeNewDate: + return "DATE" + case fieldTypeNewDecimal: + return "DECIMAL" + case fieldTypeNULL: + return "NULL" + case fieldTypeSet: + return "SET" + case fieldTypeShort: + return "SMALLINT" + case fieldTypeString: + if mf.charSet == collations[binaryCollation] { + return "BINARY" + } + return "CHAR" + case fieldTypeTime: + return "TIME" + case fieldTypeTimestamp: + return "TIMESTAMP" + case fieldTypeTiny: + return "TINYINT" + case fieldTypeTinyBLOB: + if mf.charSet != collations[binaryCollation] { + return "TINYTEXT" + } + return "TINYBLOB" + case fieldTypeVarChar: + if mf.charSet == collations[binaryCollation] { + return "VARBINARY" + } + return "VARCHAR" + case fieldTypeVarString: + if mf.charSet == collations[binaryCollation] { + return "VARBINARY" + } + return "VARCHAR" + case fieldTypeYear: + return "YEAR" + default: + return "" + } +} + +var ( + scanTypeFloat32 = reflect.TypeOf(float32(0)) + scanTypeFloat64 = reflect.TypeOf(float64(0)) + scanTypeInt8 = reflect.TypeOf(int8(0)) + scanTypeInt16 = reflect.TypeOf(int16(0)) + scanTypeInt32 = reflect.TypeOf(int32(0)) + scanTypeInt64 = reflect.TypeOf(int64(0)) + scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{}) + scanTypeNullInt = reflect.TypeOf(sql.NullInt64{}) + scanTypeNullTime = reflect.TypeOf(NullTime{}) + scanTypeUint8 = reflect.TypeOf(uint8(0)) + scanTypeUint16 = reflect.TypeOf(uint16(0)) + scanTypeUint32 = reflect.TypeOf(uint32(0)) + scanTypeUint64 = reflect.TypeOf(uint64(0)) + scanTypeRawBytes = reflect.TypeOf(sql.RawBytes{}) + scanTypeUnknown = reflect.TypeOf(new(interface{})) +) + +type mysqlField struct { + tableName string + name string + length uint32 + flags fieldFlag + fieldType fieldType + decimals byte + charSet uint8 +} + +func (mf *mysqlField) scanType() reflect.Type { + switch mf.fieldType { + case fieldTypeTiny: + if mf.flags&flagNotNULL != 0 { + if mf.flags&flagUnsigned != 0 { + return scanTypeUint8 + } + return scanTypeInt8 + } + return scanTypeNullInt + + case fieldTypeShort, fieldTypeYear: + if mf.flags&flagNotNULL != 0 { + if mf.flags&flagUnsigned != 0 { + return scanTypeUint16 + } + return scanTypeInt16 + } + return scanTypeNullInt + + case fieldTypeInt24, fieldTypeLong: + if mf.flags&flagNotNULL != 0 { + if mf.flags&flagUnsigned != 0 { + return scanTypeUint32 + } + return scanTypeInt32 + } + return scanTypeNullInt + + case fieldTypeLongLong: + if mf.flags&flagNotNULL != 0 { + if mf.flags&flagUnsigned != 0 { + return scanTypeUint64 + } + return scanTypeInt64 + } + return scanTypeNullInt + + case fieldTypeFloat: + if mf.flags&flagNotNULL != 0 { + return scanTypeFloat32 + } + return scanTypeNullFloat + + case fieldTypeDouble: + if mf.flags&flagNotNULL != 0 { + return scanTypeFloat64 + } + return scanTypeNullFloat + + case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar, + fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB, + fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB, + fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON, + fieldTypeTime: + return scanTypeRawBytes + + case fieldTypeDate, fieldTypeNewDate, + fieldTypeTimestamp, fieldTypeDateTime: + // NullTime is always returned for more consistent behavior as it can + // handle both cases of parseTime regardless if the field is nullable. + return scanTypeNullTime + + default: + return scanTypeUnknown + } +} diff --git a/_integration/go/vendor/github.com/go-sql-driver/mysql/infile.go b/_integration/go/vendor/github.com/go-sql-driver/mysql/infile.go new file mode 100644 index 000000000..273cb0ba5 --- /dev/null +++ b/_integration/go/vendor/github.com/go-sql-driver/mysql/infile.go @@ -0,0 +1,182 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "fmt" + "io" + "os" + "strings" + "sync" +) + +var ( + fileRegister map[string]bool + fileRegisterLock sync.RWMutex + readerRegister map[string]func() io.Reader + readerRegisterLock sync.RWMutex +) + +// RegisterLocalFile adds the given file to the file whitelist, +// so that it can be used by "LOAD DATA LOCAL INFILE ". +// Alternatively you can allow the use of all local files with +// the DSN parameter 'allowAllFiles=true' +// +// filePath := "/home/gopher/data.csv" +// mysql.RegisterLocalFile(filePath) +// err := db.Exec("LOAD DATA LOCAL INFILE '" + filePath + "' INTO TABLE foo") +// if err != nil { +// ... +// +func RegisterLocalFile(filePath string) { + fileRegisterLock.Lock() + // lazy map init + if fileRegister == nil { + fileRegister = make(map[string]bool) + } + + fileRegister[strings.Trim(filePath, `"`)] = true + fileRegisterLock.Unlock() +} + +// DeregisterLocalFile removes the given filepath from the whitelist. +func DeregisterLocalFile(filePath string) { + fileRegisterLock.Lock() + delete(fileRegister, strings.Trim(filePath, `"`)) + fileRegisterLock.Unlock() +} + +// RegisterReaderHandler registers a handler function which is used +// to receive a io.Reader. +// The Reader can be used by "LOAD DATA LOCAL INFILE Reader::". +// If the handler returns a io.ReadCloser Close() is called when the +// request is finished. +// +// mysql.RegisterReaderHandler("data", func() io.Reader { +// var csvReader io.Reader // Some Reader that returns CSV data +// ... // Open Reader here +// return csvReader +// }) +// err := db.Exec("LOAD DATA LOCAL INFILE 'Reader::data' INTO TABLE foo") +// if err != nil { +// ... +// +func RegisterReaderHandler(name string, handler func() io.Reader) { + readerRegisterLock.Lock() + // lazy map init + if readerRegister == nil { + readerRegister = make(map[string]func() io.Reader) + } + + readerRegister[name] = handler + readerRegisterLock.Unlock() +} + +// DeregisterReaderHandler removes the ReaderHandler function with +// the given name from the registry. +func DeregisterReaderHandler(name string) { + readerRegisterLock.Lock() + delete(readerRegister, name) + readerRegisterLock.Unlock() +} + +func deferredClose(err *error, closer io.Closer) { + closeErr := closer.Close() + if *err == nil { + *err = closeErr + } +} + +func (mc *mysqlConn) handleInFileRequest(name string) (err error) { + var rdr io.Reader + var data []byte + packetSize := 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP + if mc.maxWriteSize < packetSize { + packetSize = mc.maxWriteSize + } + + if idx := strings.Index(name, "Reader::"); idx == 0 || (idx > 0 && name[idx-1] == '/') { // io.Reader + // The server might return an an absolute path. See issue #355. + name = name[idx+8:] + + readerRegisterLock.RLock() + handler, inMap := readerRegister[name] + readerRegisterLock.RUnlock() + + if inMap { + rdr = handler() + if rdr != nil { + if cl, ok := rdr.(io.Closer); ok { + defer deferredClose(&err, cl) + } + } else { + err = fmt.Errorf("Reader '%s' is ", name) + } + } else { + err = fmt.Errorf("Reader '%s' is not registered", name) + } + } else { // File + name = strings.Trim(name, `"`) + fileRegisterLock.RLock() + fr := fileRegister[name] + fileRegisterLock.RUnlock() + if mc.cfg.AllowAllFiles || fr { + var file *os.File + var fi os.FileInfo + + if file, err = os.Open(name); err == nil { + defer deferredClose(&err, file) + + // get file size + if fi, err = file.Stat(); err == nil { + rdr = file + if fileSize := int(fi.Size()); fileSize < packetSize { + packetSize = fileSize + } + } + } + } else { + err = fmt.Errorf("local file '%s' is not registered", name) + } + } + + // send content packets + // if packetSize == 0, the Reader contains no data + if err == nil && packetSize > 0 { + data := make([]byte, 4+packetSize) + var n int + for err == nil { + n, err = rdr.Read(data[4:]) + if n > 0 { + if ioErr := mc.writePacket(data[:4+n]); ioErr != nil { + return ioErr + } + } + } + if err == io.EOF { + err = nil + } + } + + // send empty packet (termination) + if data == nil { + data = make([]byte, 4) + } + if ioErr := mc.writePacket(data[:4]); ioErr != nil { + return ioErr + } + + // read OK packet + if err == nil { + return mc.readResultOK() + } + + mc.readPacket() + return err +} diff --git a/_integration/go/vendor/github.com/go-sql-driver/mysql/packets.go b/_integration/go/vendor/github.com/go-sql-driver/mysql/packets.go new file mode 100644 index 000000000..d873a97b2 --- /dev/null +++ b/_integration/go/vendor/github.com/go-sql-driver/mysql/packets.go @@ -0,0 +1,1301 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "bytes" + "crypto/tls" + "database/sql/driver" + "encoding/binary" + "errors" + "fmt" + "io" + "math" + "time" +) + +// Packets documentation: +// http://dev.mysql.com/doc/internals/en/client-server-protocol.html + +// Read packet to buffer 'data' +func (mc *mysqlConn) readPacket() ([]byte, error) { + var prevData []byte + for { + // read packet header + data, err := mc.buf.readNext(4) + if err != nil { + if cerr := mc.canceled.Value(); cerr != nil { + return nil, cerr + } + errLog.Print(err) + mc.Close() + return nil, ErrInvalidConn + } + + // packet length [24 bit] + pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16) + + // check packet sync [8 bit] + if data[3] != mc.sequence { + if data[3] > mc.sequence { + return nil, ErrPktSyncMul + } + return nil, ErrPktSync + } + mc.sequence++ + + // packets with length 0 terminate a previous packet which is a + // multiple of (2^24)−1 bytes long + if pktLen == 0 { + // there was no previous packet + if prevData == nil { + errLog.Print(ErrMalformPkt) + mc.Close() + return nil, ErrInvalidConn + } + + return prevData, nil + } + + // read packet body [pktLen bytes] + data, err = mc.buf.readNext(pktLen) + if err != nil { + if cerr := mc.canceled.Value(); cerr != nil { + return nil, cerr + } + errLog.Print(err) + mc.Close() + return nil, ErrInvalidConn + } + + // return data if this was the last packet + if pktLen < maxPacketSize { + // zero allocations for non-split packets + if prevData == nil { + return data, nil + } + + return append(prevData, data...), nil + } + + prevData = append(prevData, data...) + } +} + +// Write packet buffer 'data' +func (mc *mysqlConn) writePacket(data []byte) error { + pktLen := len(data) - 4 + + if pktLen > mc.maxAllowedPacket { + return ErrPktTooLarge + } + + for { + var size int + if pktLen >= maxPacketSize { + data[0] = 0xff + data[1] = 0xff + data[2] = 0xff + size = maxPacketSize + } else { + data[0] = byte(pktLen) + data[1] = byte(pktLen >> 8) + data[2] = byte(pktLen >> 16) + size = pktLen + } + data[3] = mc.sequence + + // Write packet + if mc.writeTimeout > 0 { + if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil { + return err + } + } + + n, err := mc.netConn.Write(data[:4+size]) + if err == nil && n == 4+size { + mc.sequence++ + if size != maxPacketSize { + return nil + } + pktLen -= size + data = data[size:] + continue + } + + // Handle error + if err == nil { // n != len(data) + mc.cleanup() + errLog.Print(ErrMalformPkt) + } else { + if cerr := mc.canceled.Value(); cerr != nil { + return cerr + } + if n == 0 && pktLen == len(data)-4 { + // only for the first loop iteration when nothing was written yet + return errBadConnNoWrite + } + mc.cleanup() + errLog.Print(err) + } + return ErrInvalidConn + } +} + +/****************************************************************************** +* Initialization Process * +******************************************************************************/ + +// Handshake Initialization Packet +// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake +func (mc *mysqlConn) readHandshakePacket() ([]byte, string, error) { + data, err := mc.readPacket() + if err != nil { + // for init we can rewrite this to ErrBadConn for sql.Driver to retry, since + // in connection initialization we don't risk retrying non-idempotent actions. + if err == ErrInvalidConn { + return nil, "", driver.ErrBadConn + } + return nil, "", err + } + + if data[0] == iERR { + return nil, "", mc.handleErrorPacket(data) + } + + // protocol version [1 byte] + if data[0] < minProtocolVersion { + return nil, "", fmt.Errorf( + "unsupported protocol version %d. Version %d or higher is required", + data[0], + minProtocolVersion, + ) + } + + // server version [null terminated string] + // connection id [4 bytes] + pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4 + + // first part of the password cipher [8 bytes] + authData := data[pos : pos+8] + + // (filler) always 0x00 [1 byte] + pos += 8 + 1 + + // capability flags (lower 2 bytes) [2 bytes] + mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) + if mc.flags&clientProtocol41 == 0 { + return nil, "", ErrOldProtocol + } + if mc.flags&clientSSL == 0 && mc.cfg.tls != nil { + return nil, "", ErrNoTLS + } + pos += 2 + + plugin := "" + if len(data) > pos { + // character set [1 byte] + // status flags [2 bytes] + // capability flags (upper 2 bytes) [2 bytes] + // length of auth-plugin-data [1 byte] + // reserved (all [00]) [10 bytes] + pos += 1 + 2 + 2 + 1 + 10 + + // second part of the password cipher [mininum 13 bytes], + // where len=MAX(13, length of auth-plugin-data - 8) + // + // The web documentation is ambiguous about the length. However, + // according to mysql-5.7/sql/auth/sql_authentication.cc line 538, + // the 13th byte is "\0 byte, terminating the second part of + // a scramble". So the second part of the password cipher is + // a NULL terminated string that's at least 13 bytes with the + // last byte being NULL. + // + // The official Python library uses the fixed length 12 + // which seems to work but technically could have a hidden bug. + authData = append(authData, data[pos:pos+12]...) + pos += 13 + + // EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2) + // \NUL otherwise + if end := bytes.IndexByte(data[pos:], 0x00); end != -1 { + plugin = string(data[pos : pos+end]) + } else { + plugin = string(data[pos:]) + } + + // make a memory safe copy of the cipher slice + var b [20]byte + copy(b[:], authData) + return b[:], plugin, nil + } + + plugin = defaultAuthPlugin + + // make a memory safe copy of the cipher slice + var b [8]byte + copy(b[:], authData) + return b[:], plugin, nil +} + +// Client Authentication Packet +// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse +func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, plugin string) error { + // Adjust client flags based on server support + clientFlags := clientProtocol41 | + clientSecureConn | + clientLongPassword | + clientTransactions | + clientLocalFiles | + clientPluginAuth | + clientMultiResults | + mc.flags&clientLongFlag + + if mc.cfg.ClientFoundRows { + clientFlags |= clientFoundRows + } + + // To enable TLS / SSL + if mc.cfg.tls != nil { + clientFlags |= clientSSL + } + + if mc.cfg.MultiStatements { + clientFlags |= clientMultiStatements + } + + // encode length of the auth plugin data + var authRespLEIBuf [9]byte + authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(len(authResp))) + if len(authRespLEI) > 1 { + // if the length can not be written in 1 byte, it must be written as a + // length encoded integer + clientFlags |= clientPluginAuthLenEncClientData + } + + pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1 + if addNUL { + pktLen++ + } + + // To specify a db name + if n := len(mc.cfg.DBName); n > 0 { + clientFlags |= clientConnectWithDB + pktLen += n + 1 + } + + // Calculate packet length and get buffer with that size + data := mc.buf.takeSmallBuffer(pktLen + 4) + if data == nil { + // cannot take the buffer. Something must be wrong with the connection + errLog.Print(ErrBusyBuffer) + return errBadConnNoWrite + } + + // ClientFlags [32 bit] + data[4] = byte(clientFlags) + data[5] = byte(clientFlags >> 8) + data[6] = byte(clientFlags >> 16) + data[7] = byte(clientFlags >> 24) + + // MaxPacketSize [32 bit] (none) + data[8] = 0x00 + data[9] = 0x00 + data[10] = 0x00 + data[11] = 0x00 + + // Charset [1 byte] + var found bool + data[12], found = collations[mc.cfg.Collation] + if !found { + // Note possibility for false negatives: + // could be triggered although the collation is valid if the + // collations map does not contain entries the server supports. + return errors.New("unknown collation") + } + + // SSL Connection Request Packet + // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest + if mc.cfg.tls != nil { + // Send TLS / SSL request packet + if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil { + return err + } + + // Switch to TLS + tlsConn := tls.Client(mc.netConn, mc.cfg.tls) + if err := tlsConn.Handshake(); err != nil { + return err + } + mc.netConn = tlsConn + mc.buf.nc = tlsConn + } + + // Filler [23 bytes] (all 0x00) + pos := 13 + for ; pos < 13+23; pos++ { + data[pos] = 0 + } + + // User [null terminated string] + if len(mc.cfg.User) > 0 { + pos += copy(data[pos:], mc.cfg.User) + } + data[pos] = 0x00 + pos++ + + // Auth Data [length encoded integer] + pos += copy(data[pos:], authRespLEI) + pos += copy(data[pos:], authResp) + if addNUL { + data[pos] = 0x00 + pos++ + } + + // Databasename [null terminated string] + if len(mc.cfg.DBName) > 0 { + pos += copy(data[pos:], mc.cfg.DBName) + data[pos] = 0x00 + pos++ + } + + pos += copy(data[pos:], plugin) + data[pos] = 0x00 + + // Send Auth packet + return mc.writePacket(data) +} + +// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse +func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte, addNUL bool) error { + pktLen := 4 + len(authData) + if addNUL { + pktLen++ + } + data := mc.buf.takeSmallBuffer(pktLen) + if data == nil { + // cannot take the buffer. Something must be wrong with the connection + errLog.Print(ErrBusyBuffer) + return errBadConnNoWrite + } + + // Add the auth data [EOF] + copy(data[4:], authData) + if addNUL { + data[pktLen-1] = 0x00 + } + + return mc.writePacket(data) +} + +/****************************************************************************** +* Command Packets * +******************************************************************************/ + +func (mc *mysqlConn) writeCommandPacket(command byte) error { + // Reset Packet Sequence + mc.sequence = 0 + + data := mc.buf.takeSmallBuffer(4 + 1) + if data == nil { + // cannot take the buffer. Something must be wrong with the connection + errLog.Print(ErrBusyBuffer) + return errBadConnNoWrite + } + + // Add command byte + data[4] = command + + // Send CMD packet + return mc.writePacket(data) +} + +func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { + // Reset Packet Sequence + mc.sequence = 0 + + pktLen := 1 + len(arg) + data := mc.buf.takeBuffer(pktLen + 4) + if data == nil { + // cannot take the buffer. Something must be wrong with the connection + errLog.Print(ErrBusyBuffer) + return errBadConnNoWrite + } + + // Add command byte + data[4] = command + + // Add arg + copy(data[5:], arg) + + // Send CMD packet + return mc.writePacket(data) +} + +func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { + // Reset Packet Sequence + mc.sequence = 0 + + data := mc.buf.takeSmallBuffer(4 + 1 + 4) + if data == nil { + // cannot take the buffer. Something must be wrong with the connection + errLog.Print(ErrBusyBuffer) + return errBadConnNoWrite + } + + // Add command byte + data[4] = command + + // Add arg [32 bit] + data[5] = byte(arg) + data[6] = byte(arg >> 8) + data[7] = byte(arg >> 16) + data[8] = byte(arg >> 24) + + // Send CMD packet + return mc.writePacket(data) +} + +/****************************************************************************** +* Result Packets * +******************************************************************************/ + +func (mc *mysqlConn) readAuthResult() ([]byte, string, error) { + data, err := mc.readPacket() + if err != nil { + return nil, "", err + } + + // packet indicator + switch data[0] { + + case iOK: + return nil, "", mc.handleOkPacket(data) + + case iAuthMoreData: + return data[1:], "", err + + case iEOF: + if len(data) < 1 { + // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest + return nil, "mysql_old_password", nil + } + pluginEndIndex := bytes.IndexByte(data, 0x00) + if pluginEndIndex < 0 { + return nil, "", ErrMalformPkt + } + plugin := string(data[1:pluginEndIndex]) + authData := data[pluginEndIndex+1:] + return authData, plugin, nil + + default: // Error otherwise + return nil, "", mc.handleErrorPacket(data) + } +} + +// Returns error if Packet is not an 'Result OK'-Packet +func (mc *mysqlConn) readResultOK() error { + data, err := mc.readPacket() + if err != nil { + return err + } + + if data[0] == iOK { + return mc.handleOkPacket(data) + } + return mc.handleErrorPacket(data) +} + +// Result Set Header Packet +// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset +func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) { + data, err := mc.readPacket() + if err == nil { + switch data[0] { + + case iOK: + return 0, mc.handleOkPacket(data) + + case iERR: + return 0, mc.handleErrorPacket(data) + + case iLocalInFile: + return 0, mc.handleInFileRequest(string(data[1:])) + } + + // column count + num, _, n := readLengthEncodedInteger(data) + if n-len(data) == 0 { + return int(num), nil + } + + return 0, ErrMalformPkt + } + return 0, err +} + +// Error Packet +// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-ERR_Packet +func (mc *mysqlConn) handleErrorPacket(data []byte) error { + if data[0] != iERR { + return ErrMalformPkt + } + + // 0xff [1 byte] + + // Error Number [16 bit uint] + errno := binary.LittleEndian.Uint16(data[1:3]) + + // 1792: ER_CANT_EXECUTE_IN_READ_ONLY_TRANSACTION + // 1290: ER_OPTION_PREVENTS_STATEMENT (returned by Aurora during failover) + if (errno == 1792 || errno == 1290) && mc.cfg.RejectReadOnly { + // Oops; we are connected to a read-only connection, and won't be able + // to issue any write statements. Since RejectReadOnly is configured, + // we throw away this connection hoping this one would have write + // permission. This is specifically for a possible race condition + // during failover (e.g. on AWS Aurora). See README.md for more. + // + // We explicitly close the connection before returning + // driver.ErrBadConn to ensure that `database/sql` purges this + // connection and initiates a new one for next statement next time. + mc.Close() + return driver.ErrBadConn + } + + pos := 3 + + // SQL State [optional: # + 5bytes string] + if data[3] == 0x23 { + //sqlstate := string(data[4 : 4+5]) + pos = 9 + } + + // Error Message [string] + return &MySQLError{ + Number: errno, + Message: string(data[pos:]), + } +} + +func readStatus(b []byte) statusFlag { + return statusFlag(b[0]) | statusFlag(b[1])<<8 +} + +// Ok Packet +// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet +func (mc *mysqlConn) handleOkPacket(data []byte) error { + var n, m int + + // 0x00 [1 byte] + + // Affected rows [Length Coded Binary] + mc.affectedRows, _, n = readLengthEncodedInteger(data[1:]) + + // Insert id [Length Coded Binary] + mc.insertId, _, m = readLengthEncodedInteger(data[1+n:]) + + // server_status [2 bytes] + mc.status = readStatus(data[1+n+m : 1+n+m+2]) + if mc.status&statusMoreResultsExists != 0 { + return nil + } + + // warning count [2 bytes] + + return nil +} + +// Read Packets as Field Packets until EOF-Packet or an Error appears +// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41 +func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { + columns := make([]mysqlField, count) + + for i := 0; ; i++ { + data, err := mc.readPacket() + if err != nil { + return nil, err + } + + // EOF Packet + if data[0] == iEOF && (len(data) == 5 || len(data) == 1) { + if i == count { + return columns, nil + } + return nil, fmt.Errorf("column count mismatch n:%d len:%d", count, len(columns)) + } + + // Catalog + pos, err := skipLengthEncodedString(data) + if err != nil { + return nil, err + } + + // Database [len coded string] + n, err := skipLengthEncodedString(data[pos:]) + if err != nil { + return nil, err + } + pos += n + + // Table [len coded string] + if mc.cfg.ColumnsWithAlias { + tableName, _, n, err := readLengthEncodedString(data[pos:]) + if err != nil { + return nil, err + } + pos += n + columns[i].tableName = string(tableName) + } else { + n, err = skipLengthEncodedString(data[pos:]) + if err != nil { + return nil, err + } + pos += n + } + + // Original table [len coded string] + n, err = skipLengthEncodedString(data[pos:]) + if err != nil { + return nil, err + } + pos += n + + // Name [len coded string] + name, _, n, err := readLengthEncodedString(data[pos:]) + if err != nil { + return nil, err + } + columns[i].name = string(name) + pos += n + + // Original name [len coded string] + n, err = skipLengthEncodedString(data[pos:]) + if err != nil { + return nil, err + } + pos += n + + // Filler [uint8] + pos++ + + // Charset [charset, collation uint8] + columns[i].charSet = data[pos] + pos += 2 + + // Length [uint32] + columns[i].length = binary.LittleEndian.Uint32(data[pos : pos+4]) + pos += 4 + + // Field type [uint8] + columns[i].fieldType = fieldType(data[pos]) + pos++ + + // Flags [uint16] + columns[i].flags = fieldFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) + pos += 2 + + // Decimals [uint8] + columns[i].decimals = data[pos] + //pos++ + + // Default value [len coded binary] + //if pos < len(data) { + // defaultVal, _, err = bytesToLengthCodedBinary(data[pos:]) + //} + } +} + +// Read Packets as Field Packets until EOF-Packet or an Error appears +// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow +func (rows *textRows) readRow(dest []driver.Value) error { + mc := rows.mc + + if rows.rs.done { + return io.EOF + } + + data, err := mc.readPacket() + if err != nil { + return err + } + + // EOF Packet + if data[0] == iEOF && len(data) == 5 { + // server_status [2 bytes] + rows.mc.status = readStatus(data[3:]) + rows.rs.done = true + if !rows.HasNextResultSet() { + rows.mc = nil + } + return io.EOF + } + if data[0] == iERR { + rows.mc = nil + return mc.handleErrorPacket(data) + } + + // RowSet Packet + var n int + var isNull bool + pos := 0 + + for i := range dest { + // Read bytes and convert to string + dest[i], isNull, n, err = readLengthEncodedString(data[pos:]) + pos += n + if err == nil { + if !isNull { + if !mc.parseTime { + continue + } else { + switch rows.rs.columns[i].fieldType { + case fieldTypeTimestamp, fieldTypeDateTime, + fieldTypeDate, fieldTypeNewDate: + dest[i], err = parseDateTime( + string(dest[i].([]byte)), + mc.cfg.Loc, + ) + if err == nil { + continue + } + default: + continue + } + } + + } else { + dest[i] = nil + continue + } + } + return err // err != nil + } + + return nil +} + +// Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read +func (mc *mysqlConn) readUntilEOF() error { + for { + data, err := mc.readPacket() + if err != nil { + return err + } + + switch data[0] { + case iERR: + return mc.handleErrorPacket(data) + case iEOF: + if len(data) == 5 { + mc.status = readStatus(data[3:]) + } + return nil + } + } +} + +/****************************************************************************** +* Prepared Statements * +******************************************************************************/ + +// Prepare Result Packets +// http://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html +func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { + data, err := stmt.mc.readPacket() + if err == nil { + // packet indicator [1 byte] + if data[0] != iOK { + return 0, stmt.mc.handleErrorPacket(data) + } + + // statement id [4 bytes] + stmt.id = binary.LittleEndian.Uint32(data[1:5]) + + // Column count [16 bit uint] + columnCount := binary.LittleEndian.Uint16(data[5:7]) + + // Param count [16 bit uint] + stmt.paramCount = int(binary.LittleEndian.Uint16(data[7:9])) + + // Reserved [8 bit] + + // Warning count [16 bit uint] + + return columnCount, nil + } + return 0, err +} + +// http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html +func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { + maxLen := stmt.mc.maxAllowedPacket - 1 + pktLen := maxLen + + // After the header (bytes 0-3) follows before the data: + // 1 byte command + // 4 bytes stmtID + // 2 bytes paramID + const dataOffset = 1 + 4 + 2 + + // Cannot use the write buffer since + // a) the buffer is too small + // b) it is in use + data := make([]byte, 4+1+4+2+len(arg)) + + copy(data[4+dataOffset:], arg) + + for argLen := len(arg); argLen > 0; argLen -= pktLen - dataOffset { + if dataOffset+argLen < maxLen { + pktLen = dataOffset + argLen + } + + stmt.mc.sequence = 0 + // Add command byte [1 byte] + data[4] = comStmtSendLongData + + // Add stmtID [32 bit] + data[5] = byte(stmt.id) + data[6] = byte(stmt.id >> 8) + data[7] = byte(stmt.id >> 16) + data[8] = byte(stmt.id >> 24) + + // Add paramID [16 bit] + data[9] = byte(paramID) + data[10] = byte(paramID >> 8) + + // Send CMD packet + err := stmt.mc.writePacket(data[:4+pktLen]) + if err == nil { + data = data[pktLen-dataOffset:] + continue + } + return err + + } + + // Reset Packet Sequence + stmt.mc.sequence = 0 + return nil +} + +// Execute Prepared Statement +// http://dev.mysql.com/doc/internals/en/com-stmt-execute.html +func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { + if len(args) != stmt.paramCount { + return fmt.Errorf( + "argument count mismatch (got: %d; has: %d)", + len(args), + stmt.paramCount, + ) + } + + const minPktLen = 4 + 1 + 4 + 1 + 4 + mc := stmt.mc + + // Determine threshould dynamically to avoid packet size shortage. + longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1) + if longDataSize < 64 { + longDataSize = 64 + } + + // Reset packet-sequence + mc.sequence = 0 + + var data []byte + + if len(args) == 0 { + data = mc.buf.takeBuffer(minPktLen) + } else { + data = mc.buf.takeCompleteBuffer() + } + if data == nil { + // cannot take the buffer. Something must be wrong with the connection + errLog.Print(ErrBusyBuffer) + return errBadConnNoWrite + } + + // command [1 byte] + data[4] = comStmtExecute + + // statement_id [4 bytes] + data[5] = byte(stmt.id) + data[6] = byte(stmt.id >> 8) + data[7] = byte(stmt.id >> 16) + data[8] = byte(stmt.id >> 24) + + // flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte] + data[9] = 0x00 + + // iteration_count (uint32(1)) [4 bytes] + data[10] = 0x01 + data[11] = 0x00 + data[12] = 0x00 + data[13] = 0x00 + + if len(args) > 0 { + pos := minPktLen + + var nullMask []byte + if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= len(data) { + // buffer has to be extended but we don't know by how much so + // we depend on append after all data with known sizes fit. + // We stop at that because we deal with a lot of columns here + // which makes the required allocation size hard to guess. + tmp := make([]byte, pos+maskLen+typesLen) + copy(tmp[:pos], data[:pos]) + data = tmp + nullMask = data[pos : pos+maskLen] + pos += maskLen + } else { + nullMask = data[pos : pos+maskLen] + for i := 0; i < maskLen; i++ { + nullMask[i] = 0 + } + pos += maskLen + } + + // newParameterBoundFlag 1 [1 byte] + data[pos] = 0x01 + pos++ + + // type of each parameter [len(args)*2 bytes] + paramTypes := data[pos:] + pos += len(args) * 2 + + // value of each parameter [n bytes] + paramValues := data[pos:pos] + valuesCap := cap(paramValues) + + for i, arg := range args { + // build NULL-bitmap + if arg == nil { + nullMask[i/8] |= 1 << (uint(i) & 7) + paramTypes[i+i] = byte(fieldTypeNULL) + paramTypes[i+i+1] = 0x00 + continue + } + + // cache types and values + switch v := arg.(type) { + case int64: + paramTypes[i+i] = byte(fieldTypeLongLong) + paramTypes[i+i+1] = 0x00 + + if cap(paramValues)-len(paramValues)-8 >= 0 { + paramValues = paramValues[:len(paramValues)+8] + binary.LittleEndian.PutUint64( + paramValues[len(paramValues)-8:], + uint64(v), + ) + } else { + paramValues = append(paramValues, + uint64ToBytes(uint64(v))..., + ) + } + + case float64: + paramTypes[i+i] = byte(fieldTypeDouble) + paramTypes[i+i+1] = 0x00 + + if cap(paramValues)-len(paramValues)-8 >= 0 { + paramValues = paramValues[:len(paramValues)+8] + binary.LittleEndian.PutUint64( + paramValues[len(paramValues)-8:], + math.Float64bits(v), + ) + } else { + paramValues = append(paramValues, + uint64ToBytes(math.Float64bits(v))..., + ) + } + + case bool: + paramTypes[i+i] = byte(fieldTypeTiny) + paramTypes[i+i+1] = 0x00 + + if v { + paramValues = append(paramValues, 0x01) + } else { + paramValues = append(paramValues, 0x00) + } + + case []byte: + // Common case (non-nil value) first + if v != nil { + paramTypes[i+i] = byte(fieldTypeString) + paramTypes[i+i+1] = 0x00 + + if len(v) < longDataSize { + paramValues = appendLengthEncodedInteger(paramValues, + uint64(len(v)), + ) + paramValues = append(paramValues, v...) + } else { + if err := stmt.writeCommandLongData(i, v); err != nil { + return err + } + } + continue + } + + // Handle []byte(nil) as a NULL value + nullMask[i/8] |= 1 << (uint(i) & 7) + paramTypes[i+i] = byte(fieldTypeNULL) + paramTypes[i+i+1] = 0x00 + + case string: + paramTypes[i+i] = byte(fieldTypeString) + paramTypes[i+i+1] = 0x00 + + if len(v) < longDataSize { + paramValues = appendLengthEncodedInteger(paramValues, + uint64(len(v)), + ) + paramValues = append(paramValues, v...) + } else { + if err := stmt.writeCommandLongData(i, []byte(v)); err != nil { + return err + } + } + + case time.Time: + paramTypes[i+i] = byte(fieldTypeString) + paramTypes[i+i+1] = 0x00 + + var a [64]byte + var b = a[:0] + + if v.IsZero() { + b = append(b, "0000-00-00"...) + } else { + b = v.In(mc.cfg.Loc).AppendFormat(b, timeFormat) + } + + paramValues = appendLengthEncodedInteger(paramValues, + uint64(len(b)), + ) + paramValues = append(paramValues, b...) + + default: + return fmt.Errorf("cannot convert type: %T", arg) + } + } + + // Check if param values exceeded the available buffer + // In that case we must build the data packet with the new values buffer + if valuesCap != cap(paramValues) { + data = append(data[:pos], paramValues...) + mc.buf.buf = data + } + + pos += len(paramValues) + data = data[:pos] + } + + return mc.writePacket(data) +} + +func (mc *mysqlConn) discardResults() error { + for mc.status&statusMoreResultsExists != 0 { + resLen, err := mc.readResultSetHeaderPacket() + if err != nil { + return err + } + if resLen > 0 { + // columns + if err := mc.readUntilEOF(); err != nil { + return err + } + // rows + if err := mc.readUntilEOF(); err != nil { + return err + } + } + } + return nil +} + +// http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html +func (rows *binaryRows) readRow(dest []driver.Value) error { + data, err := rows.mc.readPacket() + if err != nil { + return err + } + + // packet indicator [1 byte] + if data[0] != iOK { + // EOF Packet + if data[0] == iEOF && len(data) == 5 { + rows.mc.status = readStatus(data[3:]) + rows.rs.done = true + if !rows.HasNextResultSet() { + rows.mc = nil + } + return io.EOF + } + mc := rows.mc + rows.mc = nil + + // Error otherwise + return mc.handleErrorPacket(data) + } + + // NULL-bitmap, [(column-count + 7 + 2) / 8 bytes] + pos := 1 + (len(dest)+7+2)>>3 + nullMask := data[1:pos] + + for i := range dest { + // Field is NULL + // (byte >> bit-pos) % 2 == 1 + if ((nullMask[(i+2)>>3] >> uint((i+2)&7)) & 1) == 1 { + dest[i] = nil + continue + } + + // Convert to byte-coded string + switch rows.rs.columns[i].fieldType { + case fieldTypeNULL: + dest[i] = nil + continue + + // Numeric Types + case fieldTypeTiny: + if rows.rs.columns[i].flags&flagUnsigned != 0 { + dest[i] = int64(data[pos]) + } else { + dest[i] = int64(int8(data[pos])) + } + pos++ + continue + + case fieldTypeShort, fieldTypeYear: + if rows.rs.columns[i].flags&flagUnsigned != 0 { + dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2])) + } else { + dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2]))) + } + pos += 2 + continue + + case fieldTypeInt24, fieldTypeLong: + if rows.rs.columns[i].flags&flagUnsigned != 0 { + dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4])) + } else { + dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4]))) + } + pos += 4 + continue + + case fieldTypeLongLong: + if rows.rs.columns[i].flags&flagUnsigned != 0 { + val := binary.LittleEndian.Uint64(data[pos : pos+8]) + if val > math.MaxInt64 { + dest[i] = uint64ToString(val) + } else { + dest[i] = int64(val) + } + } else { + dest[i] = int64(binary.LittleEndian.Uint64(data[pos : pos+8])) + } + pos += 8 + continue + + case fieldTypeFloat: + dest[i] = math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4])) + pos += 4 + continue + + case fieldTypeDouble: + dest[i] = math.Float64frombits(binary.LittleEndian.Uint64(data[pos : pos+8])) + pos += 8 + continue + + // Length coded Binary Strings + case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar, + fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB, + fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB, + fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON: + var isNull bool + var n int + dest[i], isNull, n, err = readLengthEncodedString(data[pos:]) + pos += n + if err == nil { + if !isNull { + continue + } else { + dest[i] = nil + continue + } + } + return err + + case + fieldTypeDate, fieldTypeNewDate, // Date YYYY-MM-DD + fieldTypeTime, // Time [-][H]HH:MM:SS[.fractal] + fieldTypeTimestamp, fieldTypeDateTime: // Timestamp YYYY-MM-DD HH:MM:SS[.fractal] + + num, isNull, n := readLengthEncodedInteger(data[pos:]) + pos += n + + switch { + case isNull: + dest[i] = nil + continue + case rows.rs.columns[i].fieldType == fieldTypeTime: + // database/sql does not support an equivalent to TIME, return a string + var dstlen uint8 + switch decimals := rows.rs.columns[i].decimals; decimals { + case 0x00, 0x1f: + dstlen = 8 + case 1, 2, 3, 4, 5, 6: + dstlen = 8 + 1 + decimals + default: + return fmt.Errorf( + "protocol error, illegal decimals value %d", + rows.rs.columns[i].decimals, + ) + } + dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true) + case rows.mc.parseTime: + dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc) + default: + var dstlen uint8 + if rows.rs.columns[i].fieldType == fieldTypeDate { + dstlen = 10 + } else { + switch decimals := rows.rs.columns[i].decimals; decimals { + case 0x00, 0x1f: + dstlen = 19 + case 1, 2, 3, 4, 5, 6: + dstlen = 19 + 1 + decimals + default: + return fmt.Errorf( + "protocol error, illegal decimals value %d", + rows.rs.columns[i].decimals, + ) + } + } + dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, false) + } + + if err == nil { + pos += int(num) + continue + } else { + return err + } + + // Please report if this happens! + default: + return fmt.Errorf("unknown field type %d", rows.rs.columns[i].fieldType) + } + } + + return nil +} diff --git a/_integration/go/vendor/github.com/go-sql-driver/mysql/result.go b/_integration/go/vendor/github.com/go-sql-driver/mysql/result.go new file mode 100644 index 000000000..c6438d034 --- /dev/null +++ b/_integration/go/vendor/github.com/go-sql-driver/mysql/result.go @@ -0,0 +1,22 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +type mysqlResult struct { + affectedRows int64 + insertId int64 +} + +func (res *mysqlResult) LastInsertId() (int64, error) { + return res.insertId, nil +} + +func (res *mysqlResult) RowsAffected() (int64, error) { + return res.affectedRows, nil +} diff --git a/_integration/go/vendor/github.com/go-sql-driver/mysql/rows.go b/_integration/go/vendor/github.com/go-sql-driver/mysql/rows.go new file mode 100644 index 000000000..d3b1e2822 --- /dev/null +++ b/_integration/go/vendor/github.com/go-sql-driver/mysql/rows.go @@ -0,0 +1,216 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "database/sql/driver" + "io" + "math" + "reflect" +) + +type resultSet struct { + columns []mysqlField + columnNames []string + done bool +} + +type mysqlRows struct { + mc *mysqlConn + rs resultSet + finish func() +} + +type binaryRows struct { + mysqlRows +} + +type textRows struct { + mysqlRows +} + +func (rows *mysqlRows) Columns() []string { + if rows.rs.columnNames != nil { + return rows.rs.columnNames + } + + columns := make([]string, len(rows.rs.columns)) + if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias { + for i := range columns { + if tableName := rows.rs.columns[i].tableName; len(tableName) > 0 { + columns[i] = tableName + "." + rows.rs.columns[i].name + } else { + columns[i] = rows.rs.columns[i].name + } + } + } else { + for i := range columns { + columns[i] = rows.rs.columns[i].name + } + } + + rows.rs.columnNames = columns + return columns +} + +func (rows *mysqlRows) ColumnTypeDatabaseTypeName(i int) string { + return rows.rs.columns[i].typeDatabaseName() +} + +// func (rows *mysqlRows) ColumnTypeLength(i int) (length int64, ok bool) { +// return int64(rows.rs.columns[i].length), true +// } + +func (rows *mysqlRows) ColumnTypeNullable(i int) (nullable, ok bool) { + return rows.rs.columns[i].flags&flagNotNULL == 0, true +} + +func (rows *mysqlRows) ColumnTypePrecisionScale(i int) (int64, int64, bool) { + column := rows.rs.columns[i] + decimals := int64(column.decimals) + + switch column.fieldType { + case fieldTypeDecimal, fieldTypeNewDecimal: + if decimals > 0 { + return int64(column.length) - 2, decimals, true + } + return int64(column.length) - 1, decimals, true + case fieldTypeTimestamp, fieldTypeDateTime, fieldTypeTime: + return decimals, decimals, true + case fieldTypeFloat, fieldTypeDouble: + if decimals == 0x1f { + return math.MaxInt64, math.MaxInt64, true + } + return math.MaxInt64, decimals, true + } + + return 0, 0, false +} + +func (rows *mysqlRows) ColumnTypeScanType(i int) reflect.Type { + return rows.rs.columns[i].scanType() +} + +func (rows *mysqlRows) Close() (err error) { + if f := rows.finish; f != nil { + f() + rows.finish = nil + } + + mc := rows.mc + if mc == nil { + return nil + } + if err := mc.error(); err != nil { + return err + } + + // Remove unread packets from stream + if !rows.rs.done { + err = mc.readUntilEOF() + } + if err == nil { + if err = mc.discardResults(); err != nil { + return err + } + } + + rows.mc = nil + return err +} + +func (rows *mysqlRows) HasNextResultSet() (b bool) { + if rows.mc == nil { + return false + } + return rows.mc.status&statusMoreResultsExists != 0 +} + +func (rows *mysqlRows) nextResultSet() (int, error) { + if rows.mc == nil { + return 0, io.EOF + } + if err := rows.mc.error(); err != nil { + return 0, err + } + + // Remove unread packets from stream + if !rows.rs.done { + if err := rows.mc.readUntilEOF(); err != nil { + return 0, err + } + rows.rs.done = true + } + + if !rows.HasNextResultSet() { + rows.mc = nil + return 0, io.EOF + } + rows.rs = resultSet{} + return rows.mc.readResultSetHeaderPacket() +} + +func (rows *mysqlRows) nextNotEmptyResultSet() (int, error) { + for { + resLen, err := rows.nextResultSet() + if err != nil { + return 0, err + } + + if resLen > 0 { + return resLen, nil + } + + rows.rs.done = true + } +} + +func (rows *binaryRows) NextResultSet() error { + resLen, err := rows.nextNotEmptyResultSet() + if err != nil { + return err + } + + rows.rs.columns, err = rows.mc.readColumns(resLen) + return err +} + +func (rows *binaryRows) Next(dest []driver.Value) error { + if mc := rows.mc; mc != nil { + if err := mc.error(); err != nil { + return err + } + + // Fetch next row from stream + return rows.readRow(dest) + } + return io.EOF +} + +func (rows *textRows) NextResultSet() (err error) { + resLen, err := rows.nextNotEmptyResultSet() + if err != nil { + return err + } + + rows.rs.columns, err = rows.mc.readColumns(resLen) + return err +} + +func (rows *textRows) Next(dest []driver.Value) error { + if mc := rows.mc; mc != nil { + if err := mc.error(); err != nil { + return err + } + + // Fetch next row from stream + return rows.readRow(dest) + } + return io.EOF +} diff --git a/_integration/go/vendor/github.com/go-sql-driver/mysql/statement.go b/_integration/go/vendor/github.com/go-sql-driver/mysql/statement.go new file mode 100644 index 000000000..ce7fe4cd0 --- /dev/null +++ b/_integration/go/vendor/github.com/go-sql-driver/mysql/statement.go @@ -0,0 +1,211 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "database/sql/driver" + "fmt" + "io" + "reflect" + "strconv" +) + +type mysqlStmt struct { + mc *mysqlConn + id uint32 + paramCount int +} + +func (stmt *mysqlStmt) Close() error { + if stmt.mc == nil || stmt.mc.closed.IsSet() { + // driver.Stmt.Close can be called more than once, thus this function + // has to be idempotent. + // See also Issue #450 and golang/go#16019. + //errLog.Print(ErrInvalidConn) + return driver.ErrBadConn + } + + err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id) + stmt.mc = nil + return err +} + +func (stmt *mysqlStmt) NumInput() int { + return stmt.paramCount +} + +func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter { + return converter{} +} + +func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { + if stmt.mc.closed.IsSet() { + errLog.Print(ErrInvalidConn) + return nil, driver.ErrBadConn + } + // Send command + err := stmt.writeExecutePacket(args) + if err != nil { + return nil, stmt.mc.markBadConn(err) + } + + mc := stmt.mc + + mc.affectedRows = 0 + mc.insertId = 0 + + // Read Result + resLen, err := mc.readResultSetHeaderPacket() + if err != nil { + return nil, err + } + + if resLen > 0 { + // Columns + if err = mc.readUntilEOF(); err != nil { + return nil, err + } + + // Rows + if err := mc.readUntilEOF(); err != nil { + return nil, err + } + } + + if err := mc.discardResults(); err != nil { + return nil, err + } + + return &mysqlResult{ + affectedRows: int64(mc.affectedRows), + insertId: int64(mc.insertId), + }, nil +} + +func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { + return stmt.query(args) +} + +func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { + if stmt.mc.closed.IsSet() { + errLog.Print(ErrInvalidConn) + return nil, driver.ErrBadConn + } + // Send command + err := stmt.writeExecutePacket(args) + if err != nil { + return nil, stmt.mc.markBadConn(err) + } + + mc := stmt.mc + + // Read Result + resLen, err := mc.readResultSetHeaderPacket() + if err != nil { + return nil, err + } + + rows := new(binaryRows) + + if resLen > 0 { + rows.mc = mc + rows.rs.columns, err = mc.readColumns(resLen) + } else { + rows.rs.done = true + + switch err := rows.NextResultSet(); err { + case nil, io.EOF: + return rows, nil + default: + return nil, err + } + } + + return rows, err +} + +type converter struct{} + +// ConvertValue mirrors the reference/default converter in database/sql/driver +// with _one_ exception. We support uint64 with their high bit and the default +// implementation does not. This function should be kept in sync with +// database/sql/driver defaultConverter.ConvertValue() except for that +// deliberate difference. +func (c converter) ConvertValue(v interface{}) (driver.Value, error) { + if driver.IsValue(v) { + return v, nil + } + + if vr, ok := v.(driver.Valuer); ok { + sv, err := callValuerValue(vr) + if err != nil { + return nil, err + } + if !driver.IsValue(sv) { + return nil, fmt.Errorf("non-Value type %T returned from Value", sv) + } + return sv, nil + } + + rv := reflect.ValueOf(v) + switch rv.Kind() { + case reflect.Ptr: + // indirect pointers + if rv.IsNil() { + return nil, nil + } else { + return c.ConvertValue(rv.Elem().Interface()) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return rv.Int(), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: + return int64(rv.Uint()), nil + case reflect.Uint64: + u64 := rv.Uint() + if u64 >= 1<<63 { + return strconv.FormatUint(u64, 10), nil + } + return int64(u64), nil + case reflect.Float32, reflect.Float64: + return rv.Float(), nil + case reflect.Bool: + return rv.Bool(), nil + case reflect.Slice: + ek := rv.Type().Elem().Kind() + if ek == reflect.Uint8 { + return rv.Bytes(), nil + } + return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek) + case reflect.String: + return rv.String(), nil + } + return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind()) +} + +var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() + +// callValuerValue returns vr.Value(), with one exception: +// If vr.Value is an auto-generated method on a pointer type and the +// pointer is nil, it would panic at runtime in the panicwrap +// method. Treat it like nil instead. +// +// This is so people can implement driver.Value on value types and +// still use nil pointers to those types to mean nil/NULL, just like +// string/*string. +// +// This is an exact copy of the same-named unexported function from the +// database/sql package. +func callValuerValue(vr driver.Valuer) (v driver.Value, err error) { + if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr && + rv.IsNil() && + rv.Type().Elem().Implements(valuerReflectType) { + return nil, nil + } + return vr.Value() +} diff --git a/_integration/go/vendor/github.com/go-sql-driver/mysql/transaction.go b/_integration/go/vendor/github.com/go-sql-driver/mysql/transaction.go new file mode 100644 index 000000000..417d72793 --- /dev/null +++ b/_integration/go/vendor/github.com/go-sql-driver/mysql/transaction.go @@ -0,0 +1,31 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +type mysqlTx struct { + mc *mysqlConn +} + +func (tx *mysqlTx) Commit() (err error) { + if tx.mc == nil || tx.mc.closed.IsSet() { + return ErrInvalidConn + } + err = tx.mc.exec("COMMIT") + tx.mc = nil + return +} + +func (tx *mysqlTx) Rollback() (err error) { + if tx.mc == nil || tx.mc.closed.IsSet() { + return ErrInvalidConn + } + err = tx.mc.exec("ROLLBACK") + tx.mc = nil + return +} diff --git a/_integration/go/vendor/github.com/go-sql-driver/mysql/utils.go b/_integration/go/vendor/github.com/go-sql-driver/mysql/utils.go new file mode 100644 index 000000000..84d595b6b --- /dev/null +++ b/_integration/go/vendor/github.com/go-sql-driver/mysql/utils.go @@ -0,0 +1,710 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "crypto/tls" + "database/sql/driver" + "encoding/binary" + "fmt" + "io" + "strings" + "sync" + "sync/atomic" + "time" +) + +// Registry for custom tls.Configs +var ( + tlsConfigLock sync.RWMutex + tlsConfigRegistry map[string]*tls.Config +) + +// RegisterTLSConfig registers a custom tls.Config to be used with sql.Open. +// Use the key as a value in the DSN where tls=value. +// +// Note: The provided tls.Config is exclusively owned by the driver after +// registering it. +// +// rootCertPool := x509.NewCertPool() +// pem, err := ioutil.ReadFile("/path/ca-cert.pem") +// if err != nil { +// log.Fatal(err) +// } +// if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { +// log.Fatal("Failed to append PEM.") +// } +// clientCert := make([]tls.Certificate, 0, 1) +// certs, err := tls.LoadX509KeyPair("/path/client-cert.pem", "/path/client-key.pem") +// if err != nil { +// log.Fatal(err) +// } +// clientCert = append(clientCert, certs) +// mysql.RegisterTLSConfig("custom", &tls.Config{ +// RootCAs: rootCertPool, +// Certificates: clientCert, +// }) +// db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom") +// +func RegisterTLSConfig(key string, config *tls.Config) error { + if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" { + return fmt.Errorf("key '%s' is reserved", key) + } + + tlsConfigLock.Lock() + if tlsConfigRegistry == nil { + tlsConfigRegistry = make(map[string]*tls.Config) + } + + tlsConfigRegistry[key] = config + tlsConfigLock.Unlock() + return nil +} + +// DeregisterTLSConfig removes the tls.Config associated with key. +func DeregisterTLSConfig(key string) { + tlsConfigLock.Lock() + if tlsConfigRegistry != nil { + delete(tlsConfigRegistry, key) + } + tlsConfigLock.Unlock() +} + +func getTLSConfigClone(key string) (config *tls.Config) { + tlsConfigLock.RLock() + if v, ok := tlsConfigRegistry[key]; ok { + config = cloneTLSConfig(v) + } + tlsConfigLock.RUnlock() + return +} + +// Returns the bool value of the input. +// The 2nd return value indicates if the input was a valid bool value +func readBool(input string) (value bool, valid bool) { + switch input { + case "1", "true", "TRUE", "True": + return true, true + case "0", "false", "FALSE", "False": + return false, true + } + + // Not a valid bool value + return +} + +/****************************************************************************** +* Time related utils * +******************************************************************************/ + +// NullTime represents a time.Time that may be NULL. +// NullTime implements the Scanner interface so +// it can be used as a scan destination: +// +// var nt NullTime +// err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt) +// ... +// if nt.Valid { +// // use nt.Time +// } else { +// // NULL value +// } +// +// This NullTime implementation is not driver-specific +type NullTime struct { + Time time.Time + Valid bool // Valid is true if Time is not NULL +} + +// Scan implements the Scanner interface. +// The value type must be time.Time or string / []byte (formatted time-string), +// otherwise Scan fails. +func (nt *NullTime) Scan(value interface{}) (err error) { + if value == nil { + nt.Time, nt.Valid = time.Time{}, false + return + } + + switch v := value.(type) { + case time.Time: + nt.Time, nt.Valid = v, true + return + case []byte: + nt.Time, err = parseDateTime(string(v), time.UTC) + nt.Valid = (err == nil) + return + case string: + nt.Time, err = parseDateTime(v, time.UTC) + nt.Valid = (err == nil) + return + } + + nt.Valid = false + return fmt.Errorf("Can't convert %T to time.Time", value) +} + +// Value implements the driver Valuer interface. +func (nt NullTime) Value() (driver.Value, error) { + if !nt.Valid { + return nil, nil + } + return nt.Time, nil +} + +func parseDateTime(str string, loc *time.Location) (t time.Time, err error) { + base := "0000-00-00 00:00:00.0000000" + switch len(str) { + case 10, 19, 21, 22, 23, 24, 25, 26: // up to "YYYY-MM-DD HH:MM:SS.MMMMMM" + if str == base[:len(str)] { + return + } + t, err = time.Parse(timeFormat[:len(str)], str) + default: + err = fmt.Errorf("invalid time string: %s", str) + return + } + + // Adjust location + if err == nil && loc != time.UTC { + y, mo, d := t.Date() + h, mi, s := t.Clock() + t, err = time.Date(y, mo, d, h, mi, s, t.Nanosecond(), loc), nil + } + + return +} + +func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Value, error) { + switch num { + case 0: + return time.Time{}, nil + case 4: + return time.Date( + int(binary.LittleEndian.Uint16(data[:2])), // year + time.Month(data[2]), // month + int(data[3]), // day + 0, 0, 0, 0, + loc, + ), nil + case 7: + return time.Date( + int(binary.LittleEndian.Uint16(data[:2])), // year + time.Month(data[2]), // month + int(data[3]), // day + int(data[4]), // hour + int(data[5]), // minutes + int(data[6]), // seconds + 0, + loc, + ), nil + case 11: + return time.Date( + int(binary.LittleEndian.Uint16(data[:2])), // year + time.Month(data[2]), // month + int(data[3]), // day + int(data[4]), // hour + int(data[5]), // minutes + int(data[6]), // seconds + int(binary.LittleEndian.Uint32(data[7:11]))*1000, // nanoseconds + loc, + ), nil + } + return nil, fmt.Errorf("invalid DATETIME packet length %d", num) +} + +// zeroDateTime is used in formatBinaryDateTime to avoid an allocation +// if the DATE or DATETIME has the zero value. +// It must never be changed. +// The current behavior depends on database/sql copying the result. +var zeroDateTime = []byte("0000-00-00 00:00:00.000000") + +const digits01 = "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" +const digits10 = "0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999" + +func formatBinaryDateTime(src []byte, length uint8, justTime bool) (driver.Value, error) { + // length expects the deterministic length of the zero value, + // negative time and 100+ hours are automatically added if needed + if len(src) == 0 { + if justTime { + return zeroDateTime[11 : 11+length], nil + } + return zeroDateTime[:length], nil + } + var dst []byte // return value + var pt, p1, p2, p3 byte // current digit pair + var zOffs byte // offset of value in zeroDateTime + if justTime { + switch length { + case + 8, // time (can be up to 10 when negative and 100+ hours) + 10, 11, 12, 13, 14, 15: // time with fractional seconds + default: + return nil, fmt.Errorf("illegal TIME length %d", length) + } + switch len(src) { + case 8, 12: + default: + return nil, fmt.Errorf("invalid TIME packet length %d", len(src)) + } + // +2 to enable negative time and 100+ hours + dst = make([]byte, 0, length+2) + if src[0] == 1 { + dst = append(dst, '-') + } + if src[1] != 0 { + hour := uint16(src[1])*24 + uint16(src[5]) + pt = byte(hour / 100) + p1 = byte(hour - 100*uint16(pt)) + dst = append(dst, digits01[pt]) + } else { + p1 = src[5] + } + zOffs = 11 + src = src[6:] + } else { + switch length { + case 10, 19, 21, 22, 23, 24, 25, 26: + default: + t := "DATE" + if length > 10 { + t += "TIME" + } + return nil, fmt.Errorf("illegal %s length %d", t, length) + } + switch len(src) { + case 4, 7, 11: + default: + t := "DATE" + if length > 10 { + t += "TIME" + } + return nil, fmt.Errorf("illegal %s packet length %d", t, len(src)) + } + dst = make([]byte, 0, length) + // start with the date + year := binary.LittleEndian.Uint16(src[:2]) + pt = byte(year / 100) + p1 = byte(year - 100*uint16(pt)) + p2, p3 = src[2], src[3] + dst = append(dst, + digits10[pt], digits01[pt], + digits10[p1], digits01[p1], '-', + digits10[p2], digits01[p2], '-', + digits10[p3], digits01[p3], + ) + if length == 10 { + return dst, nil + } + if len(src) == 4 { + return append(dst, zeroDateTime[10:length]...), nil + } + dst = append(dst, ' ') + p1 = src[4] // hour + src = src[5:] + } + // p1 is 2-digit hour, src is after hour + p2, p3 = src[0], src[1] + dst = append(dst, + digits10[p1], digits01[p1], ':', + digits10[p2], digits01[p2], ':', + digits10[p3], digits01[p3], + ) + if length <= byte(len(dst)) { + return dst, nil + } + src = src[2:] + if len(src) == 0 { + return append(dst, zeroDateTime[19:zOffs+length]...), nil + } + microsecs := binary.LittleEndian.Uint32(src[:4]) + p1 = byte(microsecs / 10000) + microsecs -= 10000 * uint32(p1) + p2 = byte(microsecs / 100) + microsecs -= 100 * uint32(p2) + p3 = byte(microsecs) + switch decimals := zOffs + length - 20; decimals { + default: + return append(dst, '.', + digits10[p1], digits01[p1], + digits10[p2], digits01[p2], + digits10[p3], digits01[p3], + ), nil + case 1: + return append(dst, '.', + digits10[p1], + ), nil + case 2: + return append(dst, '.', + digits10[p1], digits01[p1], + ), nil + case 3: + return append(dst, '.', + digits10[p1], digits01[p1], + digits10[p2], + ), nil + case 4: + return append(dst, '.', + digits10[p1], digits01[p1], + digits10[p2], digits01[p2], + ), nil + case 5: + return append(dst, '.', + digits10[p1], digits01[p1], + digits10[p2], digits01[p2], + digits10[p3], + ), nil + } +} + +/****************************************************************************** +* Convert from and to bytes * +******************************************************************************/ + +func uint64ToBytes(n uint64) []byte { + return []byte{ + byte(n), + byte(n >> 8), + byte(n >> 16), + byte(n >> 24), + byte(n >> 32), + byte(n >> 40), + byte(n >> 48), + byte(n >> 56), + } +} + +func uint64ToString(n uint64) []byte { + var a [20]byte + i := 20 + + // U+0030 = 0 + // ... + // U+0039 = 9 + + var q uint64 + for n >= 10 { + i-- + q = n / 10 + a[i] = uint8(n-q*10) + 0x30 + n = q + } + + i-- + a[i] = uint8(n) + 0x30 + + return a[i:] +} + +// treats string value as unsigned integer representation +func stringToInt(b []byte) int { + val := 0 + for i := range b { + val *= 10 + val += int(b[i] - 0x30) + } + return val +} + +// returns the string read as a bytes slice, wheter the value is NULL, +// the number of bytes read and an error, in case the string is longer than +// the input slice +func readLengthEncodedString(b []byte) ([]byte, bool, int, error) { + // Get length + num, isNull, n := readLengthEncodedInteger(b) + if num < 1 { + return b[n:n], isNull, n, nil + } + + n += int(num) + + // Check data length + if len(b) >= n { + return b[n-int(num) : n : n], false, n, nil + } + return nil, false, n, io.EOF +} + +// returns the number of bytes skipped and an error, in case the string is +// longer than the input slice +func skipLengthEncodedString(b []byte) (int, error) { + // Get length + num, _, n := readLengthEncodedInteger(b) + if num < 1 { + return n, nil + } + + n += int(num) + + // Check data length + if len(b) >= n { + return n, nil + } + return n, io.EOF +} + +// returns the number read, whether the value is NULL and the number of bytes read +func readLengthEncodedInteger(b []byte) (uint64, bool, int) { + // See issue #349 + if len(b) == 0 { + return 0, true, 1 + } + + switch b[0] { + // 251: NULL + case 0xfb: + return 0, true, 1 + + // 252: value of following 2 + case 0xfc: + return uint64(b[1]) | uint64(b[2])<<8, false, 3 + + // 253: value of following 3 + case 0xfd: + return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16, false, 4 + + // 254: value of following 8 + case 0xfe: + return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 | + uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 | + uint64(b[7])<<48 | uint64(b[8])<<56, + false, 9 + } + + // 0-250: value of first byte + return uint64(b[0]), false, 1 +} + +// encodes a uint64 value and appends it to the given bytes slice +func appendLengthEncodedInteger(b []byte, n uint64) []byte { + switch { + case n <= 250: + return append(b, byte(n)) + + case n <= 0xffff: + return append(b, 0xfc, byte(n), byte(n>>8)) + + case n <= 0xffffff: + return append(b, 0xfd, byte(n), byte(n>>8), byte(n>>16)) + } + return append(b, 0xfe, byte(n), byte(n>>8), byte(n>>16), byte(n>>24), + byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56)) +} + +// reserveBuffer checks cap(buf) and expand buffer to len(buf) + appendSize. +// If cap(buf) is not enough, reallocate new buffer. +func reserveBuffer(buf []byte, appendSize int) []byte { + newSize := len(buf) + appendSize + if cap(buf) < newSize { + // Grow buffer exponentially + newBuf := make([]byte, len(buf)*2+appendSize) + copy(newBuf, buf) + buf = newBuf + } + return buf[:newSize] +} + +// escapeBytesBackslash escapes []byte with backslashes (\) +// This escapes the contents of a string (provided as []byte) by adding backslashes before special +// characters, and turning others into specific escape sequences, such as +// turning newlines into \n and null bytes into \0. +// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L823-L932 +func escapeBytesBackslash(buf, v []byte) []byte { + pos := len(buf) + buf = reserveBuffer(buf, len(v)*2) + + for _, c := range v { + switch c { + case '\x00': + buf[pos] = '\\' + buf[pos+1] = '0' + pos += 2 + case '\n': + buf[pos] = '\\' + buf[pos+1] = 'n' + pos += 2 + case '\r': + buf[pos] = '\\' + buf[pos+1] = 'r' + pos += 2 + case '\x1a': + buf[pos] = '\\' + buf[pos+1] = 'Z' + pos += 2 + case '\'': + buf[pos] = '\\' + buf[pos+1] = '\'' + pos += 2 + case '"': + buf[pos] = '\\' + buf[pos+1] = '"' + pos += 2 + case '\\': + buf[pos] = '\\' + buf[pos+1] = '\\' + pos += 2 + default: + buf[pos] = c + pos++ + } + } + + return buf[:pos] +} + +// escapeStringBackslash is similar to escapeBytesBackslash but for string. +func escapeStringBackslash(buf []byte, v string) []byte { + pos := len(buf) + buf = reserveBuffer(buf, len(v)*2) + + for i := 0; i < len(v); i++ { + c := v[i] + switch c { + case '\x00': + buf[pos] = '\\' + buf[pos+1] = '0' + pos += 2 + case '\n': + buf[pos] = '\\' + buf[pos+1] = 'n' + pos += 2 + case '\r': + buf[pos] = '\\' + buf[pos+1] = 'r' + pos += 2 + case '\x1a': + buf[pos] = '\\' + buf[pos+1] = 'Z' + pos += 2 + case '\'': + buf[pos] = '\\' + buf[pos+1] = '\'' + pos += 2 + case '"': + buf[pos] = '\\' + buf[pos+1] = '"' + pos += 2 + case '\\': + buf[pos] = '\\' + buf[pos+1] = '\\' + pos += 2 + default: + buf[pos] = c + pos++ + } + } + + return buf[:pos] +} + +// escapeBytesQuotes escapes apostrophes in []byte by doubling them up. +// This escapes the contents of a string by doubling up any apostrophes that +// it contains. This is used when the NO_BACKSLASH_ESCAPES SQL_MODE is in +// effect on the server. +// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L963-L1038 +func escapeBytesQuotes(buf, v []byte) []byte { + pos := len(buf) + buf = reserveBuffer(buf, len(v)*2) + + for _, c := range v { + if c == '\'' { + buf[pos] = '\'' + buf[pos+1] = '\'' + pos += 2 + } else { + buf[pos] = c + pos++ + } + } + + return buf[:pos] +} + +// escapeStringQuotes is similar to escapeBytesQuotes but for string. +func escapeStringQuotes(buf []byte, v string) []byte { + pos := len(buf) + buf = reserveBuffer(buf, len(v)*2) + + for i := 0; i < len(v); i++ { + c := v[i] + if c == '\'' { + buf[pos] = '\'' + buf[pos+1] = '\'' + pos += 2 + } else { + buf[pos] = c + pos++ + } + } + + return buf[:pos] +} + +/****************************************************************************** +* Sync utils * +******************************************************************************/ + +// noCopy may be embedded into structs which must not be copied +// after the first use. +// +// See https://github.com/golang/go/issues/8005#issuecomment-190753527 +// for details. +type noCopy struct{} + +// Lock is a no-op used by -copylocks checker from `go vet`. +func (*noCopy) Lock() {} + +// atomicBool is a wrapper around uint32 for usage as a boolean value with +// atomic access. +type atomicBool struct { + _noCopy noCopy + value uint32 +} + +// IsSet returns wether the current boolean value is true +func (ab *atomicBool) IsSet() bool { + return atomic.LoadUint32(&ab.value) > 0 +} + +// Set sets the value of the bool regardless of the previous value +func (ab *atomicBool) Set(value bool) { + if value { + atomic.StoreUint32(&ab.value, 1) + } else { + atomic.StoreUint32(&ab.value, 0) + } +} + +// TrySet sets the value of the bool and returns wether the value changed +func (ab *atomicBool) TrySet(value bool) bool { + if value { + return atomic.SwapUint32(&ab.value, 1) == 0 + } + return atomic.SwapUint32(&ab.value, 0) > 0 +} + +// atomicError is a wrapper for atomically accessed error values +type atomicError struct { + _noCopy noCopy + value atomic.Value +} + +// Set sets the error value regardless of the previous value. +// The value must not be nil +func (ae *atomicError) Set(value error) { + ae.value.Store(value) +} + +// Value returns the current error value +func (ae *atomicError) Value() error { + if v := ae.value.Load(); v != nil { + // this will panic if the value doesn't implement the error interface + return v.(error) + } + return nil +} diff --git a/_integration/go/vendor/github.com/go-sql-driver/mysql/utils_go17.go b/_integration/go/vendor/github.com/go-sql-driver/mysql/utils_go17.go new file mode 100644 index 000000000..f59563456 --- /dev/null +++ b/_integration/go/vendor/github.com/go-sql-driver/mysql/utils_go17.go @@ -0,0 +1,40 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build go1.7 +// +build !go1.8 + +package mysql + +import "crypto/tls" + +func cloneTLSConfig(c *tls.Config) *tls.Config { + return &tls.Config{ + Rand: c.Rand, + Time: c.Time, + Certificates: c.Certificates, + NameToCertificate: c.NameToCertificate, + GetCertificate: c.GetCertificate, + RootCAs: c.RootCAs, + NextProtos: c.NextProtos, + ServerName: c.ServerName, + ClientAuth: c.ClientAuth, + ClientCAs: c.ClientCAs, + InsecureSkipVerify: c.InsecureSkipVerify, + CipherSuites: c.CipherSuites, + PreferServerCipherSuites: c.PreferServerCipherSuites, + SessionTicketsDisabled: c.SessionTicketsDisabled, + SessionTicketKey: c.SessionTicketKey, + ClientSessionCache: c.ClientSessionCache, + MinVersion: c.MinVersion, + MaxVersion: c.MaxVersion, + CurvePreferences: c.CurvePreferences, + DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, + Renegotiation: c.Renegotiation, + } +} diff --git a/_integration/go/vendor/github.com/go-sql-driver/mysql/utils_go18.go b/_integration/go/vendor/github.com/go-sql-driver/mysql/utils_go18.go new file mode 100644 index 000000000..c35c2a6aa --- /dev/null +++ b/_integration/go/vendor/github.com/go-sql-driver/mysql/utils_go18.go @@ -0,0 +1,50 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build go1.8 + +package mysql + +import ( + "crypto/tls" + "database/sql" + "database/sql/driver" + "errors" + "fmt" +) + +func cloneTLSConfig(c *tls.Config) *tls.Config { + return c.Clone() +} + +func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { + dargs := make([]driver.Value, len(named)) + for n, param := range named { + if len(param.Name) > 0 { + // TODO: support the use of Named Parameters #561 + return nil, errors.New("mysql: driver does not support the use of Named Parameters") + } + dargs[n] = param.Value + } + return dargs, nil +} + +func mapIsolationLevel(level driver.IsolationLevel) (string, error) { + switch sql.IsolationLevel(level) { + case sql.LevelRepeatableRead: + return "REPEATABLE READ", nil + case sql.LevelReadCommitted: + return "READ COMMITTED", nil + case sql.LevelReadUncommitted: + return "READ UNCOMMITTED", nil + case sql.LevelSerializable: + return "SERIALIZABLE", nil + default: + return "", fmt.Errorf("mysql: unsupported isolation level: %v", level) + } +} diff --git a/_integration/go/vendor/google.golang.org/appengine/LICENSE b/_integration/go/vendor/google.golang.org/appengine/LICENSE new file mode 100644 index 000000000..d64569567 --- /dev/null +++ b/_integration/go/vendor/google.golang.org/appengine/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/_integration/go/vendor/google.golang.org/appengine/cloudsql/cloudsql.go b/_integration/go/vendor/google.golang.org/appengine/cloudsql/cloudsql.go new file mode 100644 index 000000000..7b27e6b12 --- /dev/null +++ b/_integration/go/vendor/google.golang.org/appengine/cloudsql/cloudsql.go @@ -0,0 +1,62 @@ +// Copyright 2013 Google Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +/* +Package cloudsql exposes access to Google Cloud SQL databases. + +This package does not work in App Engine "flexible environment". + +This package is intended for MySQL drivers to make App Engine-specific +connections. Applications should use this package through database/sql: +Select a pure Go MySQL driver that supports this package, and use sql.Open +with protocol "cloudsql" and an address of the Cloud SQL instance. + +A Go MySQL driver that has been tested to work well with Cloud SQL +is the go-sql-driver: + import "database/sql" + import _ "github.com/go-sql-driver/mysql" + + db, err := sql.Open("mysql", "user@cloudsql(project-id:instance-name)/dbname") + + +Another driver that works well with Cloud SQL is the mymysql driver: + import "database/sql" + import _ "github.com/ziutek/mymysql/godrv" + + db, err := sql.Open("mymysql", "cloudsql:instance-name*dbname/user/password") + + +Using either of these drivers, you can perform a standard SQL query. +This example assumes there is a table named 'users' with +columns 'first_name' and 'last_name': + + rows, err := db.Query("SELECT first_name, last_name FROM users") + if err != nil { + log.Errorf(ctx, "db.Query: %v", err) + } + defer rows.Close() + + for rows.Next() { + var firstName string + var lastName string + if err := rows.Scan(&firstName, &lastName); err != nil { + log.Errorf(ctx, "rows.Scan: %v", err) + continue + } + log.Infof(ctx, "First: %v - Last: %v", firstName, lastName) + } + if err := rows.Err(); err != nil { + log.Errorf(ctx, "Row error: %v", err) + } +*/ +package cloudsql + +import ( + "net" +) + +// Dial connects to the named Cloud SQL instance. +func Dial(instance string) (net.Conn, error) { + return connect(instance) +} diff --git a/_integration/go/vendor/google.golang.org/appengine/cloudsql/cloudsql_classic.go b/_integration/go/vendor/google.golang.org/appengine/cloudsql/cloudsql_classic.go new file mode 100644 index 000000000..af62dba14 --- /dev/null +++ b/_integration/go/vendor/google.golang.org/appengine/cloudsql/cloudsql_classic.go @@ -0,0 +1,17 @@ +// Copyright 2013 Google Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +// +build appengine + +package cloudsql + +import ( + "net" + + "appengine/cloudsql" +) + +func connect(instance string) (net.Conn, error) { + return cloudsql.Dial(instance) +} diff --git a/_integration/go/vendor/google.golang.org/appengine/cloudsql/cloudsql_vm.go b/_integration/go/vendor/google.golang.org/appengine/cloudsql/cloudsql_vm.go new file mode 100644 index 000000000..90fa7b31e --- /dev/null +++ b/_integration/go/vendor/google.golang.org/appengine/cloudsql/cloudsql_vm.go @@ -0,0 +1,16 @@ +// Copyright 2013 Google Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +// +build !appengine + +package cloudsql + +import ( + "errors" + "net" +) + +func connect(instance string) (net.Conn, error) { + return nil, errors.New(`cloudsql: not supported in App Engine "flexible environment"`) +} diff --git a/_integration/go/vendor/modules.txt b/_integration/go/vendor/modules.txt new file mode 100644 index 000000000..c61517b57 --- /dev/null +++ b/_integration/go/vendor/modules.txt @@ -0,0 +1,4 @@ +# github.com/go-sql-driver/mysql v1.4.0 +github.com/go-sql-driver/mysql +# google.golang.org/appengine v1.2.0 +google.golang.org/appengine/cloudsql diff --git a/_integration/javascript/.gitignore b/_integration/javascript/.gitignore new file mode 100644 index 000000000..b512c09d4 --- /dev/null +++ b/_integration/javascript/.gitignore @@ -0,0 +1 @@ +node_modules \ No newline at end of file diff --git a/_integration/javascript/Makefile b/_integration/javascript/Makefile new file mode 100644 index 000000000..fef87a074 --- /dev/null +++ b/_integration/javascript/Makefile @@ -0,0 +1,9 @@ +node_modules: + npm install + +dependencies: node_modules + +test: dependencies + npm test + +.PHONY: test \ No newline at end of file diff --git a/_integration/javascript/package-lock.json b/_integration/javascript/package-lock.json new file mode 100644 index 000000000..b37fc5abe --- /dev/null +++ b/_integration/javascript/package-lock.json @@ -0,0 +1,3562 @@ +{ + "name": "go-mysql-server-js", + "version": "1.0.0", + "lockfileVersion": 1, + "requires": true, + "dependencies": { + "@ava/babel-plugin-throws-helper": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/@ava/babel-plugin-throws-helper/-/babel-plugin-throws-helper-4.0.0.tgz", + "integrity": "sha512-3diBLIVBPPh3j4+hb5lo0I1D+S/O/VDJPI4Y502apBxmwEqjyXG4gTSPFUlm41sSZeZzMarT/Gzovw9kV7An0w==", + "dev": true + }, + "@ava/babel-preset-stage-4": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/@ava/babel-preset-stage-4/-/babel-preset-stage-4-4.0.0.tgz", + "integrity": "sha512-lZEV1ZANzfzSYBU6WHSErsy7jLPbD1iIgAboASPMcKo7woVni5/5IKWeT0RxC8rY802MFktur3OKEw2JY1Tv2w==", + "dev": true, + "requires": { + "@babel/plugin-proposal-async-generator-functions": "^7.2.0", + "@babel/plugin-proposal-dynamic-import": "^7.5.0", + "@babel/plugin-proposal-optional-catch-binding": "^7.2.0", + "@babel/plugin-transform-dotall-regex": "^7.4.4", + "@babel/plugin-transform-modules-commonjs": "^7.5.0" + } + }, + "@ava/babel-preset-transform-test-files": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/@ava/babel-preset-transform-test-files/-/babel-preset-transform-test-files-6.0.0.tgz", + "integrity": "sha512-8eKhFzZp7Qcq1VLfoC75ggGT8nQs9q8fIxltU47yCB7Wi7Y8Qf6oqY1Bm0z04fIec24vEgr0ENhDHEOUGVDqnA==", + "dev": true, + "requires": { + "@ava/babel-plugin-throws-helper": "^4.0.0", + "babel-plugin-espower": "^3.0.1" + } + }, + "@babel/code-frame": { + "version": "7.5.5", + "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.5.5.tgz", + "integrity": "sha512-27d4lZoomVyo51VegxI20xZPuSHusqbQag/ztrBC7wegWoQ1nLREPVSKSW8byhTlzTKyNE4ifaTA6lCp7JjpFw==", + "dev": true, + "requires": { + "@babel/highlight": "^7.0.0" + } + }, + "@babel/core": { + "version": "7.6.0", + "resolved": "https://registry.npmjs.org/@babel/core/-/core-7.6.0.tgz", + "integrity": "sha512-FuRhDRtsd6IptKpHXAa+4WPZYY2ZzgowkbLBecEDDSje1X/apG7jQM33or3NdOmjXBKWGOg4JmSiRfUfuTtHXw==", + "dev": true, + "requires": { + "@babel/code-frame": "^7.5.5", + "@babel/generator": "^7.6.0", + "@babel/helpers": "^7.6.0", + "@babel/parser": "^7.6.0", + "@babel/template": "^7.6.0", + "@babel/traverse": "^7.6.0", + "@babel/types": "^7.6.0", + "convert-source-map": "^1.1.0", + "debug": "^4.1.0", + "json5": "^2.1.0", + "lodash": "^4.17.13", + "resolve": "^1.3.2", + "semver": "^5.4.1", + "source-map": "^0.5.0" + } + }, + "@babel/generator": { + "version": "7.6.0", + "resolved": "https://registry.npmjs.org/@babel/generator/-/generator-7.6.0.tgz", + "integrity": "sha512-Ms8Mo7YBdMMn1BYuNtKuP/z0TgEIhbcyB8HVR6PPNYp4P61lMsABiS4A3VG1qznjXVCf3r+fVHhm4efTYVsySA==", + "dev": true, + "requires": { + "@babel/types": "^7.6.0", + "jsesc": "^2.5.1", + "lodash": "^4.17.13", + "source-map": "^0.5.0", + "trim-right": "^1.0.1" + } + }, + "@babel/helper-annotate-as-pure": { + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/@babel/helper-annotate-as-pure/-/helper-annotate-as-pure-7.0.0.tgz", + "integrity": "sha512-3UYcJUj9kvSLbLbUIfQTqzcy5VX7GRZ/CCDrnOaZorFFM01aXp1+GJwuFGV4NDDoAS+mOUyHcO6UD/RfqOks3Q==", + "dev": true, + "requires": { + "@babel/types": "^7.0.0" + } + }, + "@babel/helper-function-name": { + "version": "7.1.0", + "resolved": "https://registry.npmjs.org/@babel/helper-function-name/-/helper-function-name-7.1.0.tgz", + "integrity": "sha512-A95XEoCpb3TO+KZzJ4S/5uW5fNe26DjBGqf1o9ucyLyCmi1dXq/B3c8iaWTfBk3VvetUxl16e8tIrd5teOCfGw==", + "dev": true, + "requires": { + "@babel/helper-get-function-arity": "^7.0.0", + "@babel/template": "^7.1.0", + "@babel/types": "^7.0.0" + } + }, + "@babel/helper-get-function-arity": { + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/@babel/helper-get-function-arity/-/helper-get-function-arity-7.0.0.tgz", + "integrity": "sha512-r2DbJeg4svYvt3HOS74U4eWKsUAMRH01Z1ds1zx8KNTPtpTL5JAsdFv8BNyOpVqdFhHkkRDIg5B4AsxmkjAlmQ==", + "dev": true, + "requires": { + "@babel/types": "^7.0.0" + } + }, + "@babel/helper-module-imports": { + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/@babel/helper-module-imports/-/helper-module-imports-7.0.0.tgz", + "integrity": "sha512-aP/hlLq01DWNEiDg4Jn23i+CXxW/owM4WpDLFUbpjxe4NS3BhLVZQ5i7E0ZrxuQ/vwekIeciyamgB1UIYxxM6A==", + "dev": true, + "requires": { + "@babel/types": "^7.0.0" + } + }, + "@babel/helper-module-transforms": { + "version": "7.5.5", + "resolved": "https://registry.npmjs.org/@babel/helper-module-transforms/-/helper-module-transforms-7.5.5.tgz", + "integrity": "sha512-jBeCvETKuJqeiaCdyaheF40aXnnU1+wkSiUs/IQg3tB85up1LyL8x77ClY8qJpuRJUcXQo+ZtdNESmZl4j56Pw==", + "dev": true, + "requires": { + "@babel/helper-module-imports": "^7.0.0", + "@babel/helper-simple-access": "^7.1.0", + "@babel/helper-split-export-declaration": "^7.4.4", + "@babel/template": "^7.4.4", + "@babel/types": "^7.5.5", + "lodash": "^4.17.13" + } + }, + "@babel/helper-plugin-utils": { + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/@babel/helper-plugin-utils/-/helper-plugin-utils-7.0.0.tgz", + "integrity": "sha512-CYAOUCARwExnEixLdB6sDm2dIJ/YgEAKDM1MOeMeZu9Ld/bDgVo8aiWrXwcY7OBh+1Ea2uUcVRcxKk0GJvW7QA==", + "dev": true + }, + "@babel/helper-regex": { + "version": "7.5.5", + "resolved": "https://registry.npmjs.org/@babel/helper-regex/-/helper-regex-7.5.5.tgz", + "integrity": "sha512-CkCYQLkfkiugbRDO8eZn6lRuR8kzZoGXCg3149iTk5se7g6qykSpy3+hELSwquhu+TgHn8nkLiBwHvNX8Hofcw==", + "dev": true, + "requires": { + "lodash": "^4.17.13" + } + }, + "@babel/helper-remap-async-to-generator": { + "version": "7.1.0", + "resolved": "https://registry.npmjs.org/@babel/helper-remap-async-to-generator/-/helper-remap-async-to-generator-7.1.0.tgz", + "integrity": "sha512-3fOK0L+Fdlg8S5al8u/hWE6vhufGSn0bN09xm2LXMy//REAF8kDCrYoOBKYmA8m5Nom+sV9LyLCwrFynA8/slg==", + "dev": true, + "requires": { + "@babel/helper-annotate-as-pure": "^7.0.0", + "@babel/helper-wrap-function": "^7.1.0", + "@babel/template": "^7.1.0", + "@babel/traverse": "^7.1.0", + "@babel/types": "^7.0.0" + } + }, + "@babel/helper-simple-access": { + "version": "7.1.0", + "resolved": "https://registry.npmjs.org/@babel/helper-simple-access/-/helper-simple-access-7.1.0.tgz", + "integrity": "sha512-Vk+78hNjRbsiu49zAPALxTb+JUQCz1aolpd8osOF16BGnLtseD21nbHgLPGUwrXEurZgiCOUmvs3ExTu4F5x6w==", + "dev": true, + "requires": { + "@babel/template": "^7.1.0", + "@babel/types": "^7.0.0" + } + }, + "@babel/helper-split-export-declaration": { + "version": "7.4.4", + "resolved": "https://registry.npmjs.org/@babel/helper-split-export-declaration/-/helper-split-export-declaration-7.4.4.tgz", + "integrity": "sha512-Ro/XkzLf3JFITkW6b+hNxzZ1n5OQ80NvIUdmHspih1XAhtN3vPTuUFT4eQnela+2MaZ5ulH+iyP513KJrxbN7Q==", + "dev": true, + "requires": { + "@babel/types": "^7.4.4" + } + }, + "@babel/helper-wrap-function": { + "version": "7.2.0", + "resolved": "https://registry.npmjs.org/@babel/helper-wrap-function/-/helper-wrap-function-7.2.0.tgz", + "integrity": "sha512-o9fP1BZLLSrYlxYEYyl2aS+Flun5gtjTIG8iln+XuEzQTs0PLagAGSXUcqruJwD5fM48jzIEggCKpIfWTcR7pQ==", + "dev": true, + "requires": { + "@babel/helper-function-name": "^7.1.0", + "@babel/template": "^7.1.0", + "@babel/traverse": "^7.1.0", + "@babel/types": "^7.2.0" + } + }, + "@babel/helpers": { + "version": "7.6.0", + "resolved": "https://registry.npmjs.org/@babel/helpers/-/helpers-7.6.0.tgz", + "integrity": "sha512-W9kao7OBleOjfXtFGgArGRX6eCP0UEcA2ZWEWNkJdRZnHhW4eEbeswbG3EwaRsnQUAEGWYgMq1HsIXuNNNy2eQ==", + "dev": true, + "requires": { + "@babel/template": "^7.6.0", + "@babel/traverse": "^7.6.0", + "@babel/types": "^7.6.0" + } + }, + "@babel/highlight": { + "version": "7.5.0", + "resolved": "https://registry.npmjs.org/@babel/highlight/-/highlight-7.5.0.tgz", + "integrity": "sha512-7dV4eu9gBxoM0dAnj/BCFDW9LFU0zvTrkq0ugM7pnHEgguOEeOz1so2ZghEdzviYzQEED0r4EAgpsBChKy1TRQ==", + "dev": true, + "requires": { + "chalk": "^2.0.0", + "esutils": "^2.0.2", + "js-tokens": "^4.0.0" + } + }, + "@babel/parser": { + "version": "7.6.0", + "resolved": "https://registry.npmjs.org/@babel/parser/-/parser-7.6.0.tgz", + "integrity": "sha512-+o2q111WEx4srBs7L9eJmcwi655eD8sXniLqMB93TBK9GrNzGrxDWSjiqz2hLU0Ha8MTXFIP0yd9fNdP+m43ZQ==", + "dev": true + }, + "@babel/plugin-proposal-async-generator-functions": { + "version": "7.2.0", + "resolved": "https://registry.npmjs.org/@babel/plugin-proposal-async-generator-functions/-/plugin-proposal-async-generator-functions-7.2.0.tgz", + "integrity": "sha512-+Dfo/SCQqrwx48ptLVGLdE39YtWRuKc/Y9I5Fy0P1DDBB9lsAHpjcEJQt+4IifuSOSTLBKJObJqMvaO1pIE8LQ==", + "dev": true, + "requires": { + "@babel/helper-plugin-utils": "^7.0.0", + "@babel/helper-remap-async-to-generator": "^7.1.0", + "@babel/plugin-syntax-async-generators": "^7.2.0" + } + }, + "@babel/plugin-proposal-dynamic-import": { + "version": "7.5.0", + "resolved": "https://registry.npmjs.org/@babel/plugin-proposal-dynamic-import/-/plugin-proposal-dynamic-import-7.5.0.tgz", + "integrity": "sha512-x/iMjggsKTFHYC6g11PL7Qy58IK8H5zqfm9e6hu4z1iH2IRyAp9u9dL80zA6R76yFovETFLKz2VJIC2iIPBuFw==", + "dev": true, + "requires": { + "@babel/helper-plugin-utils": "^7.0.0", + "@babel/plugin-syntax-dynamic-import": "^7.2.0" + } + }, + "@babel/plugin-proposal-optional-catch-binding": { + "version": "7.2.0", + "resolved": "https://registry.npmjs.org/@babel/plugin-proposal-optional-catch-binding/-/plugin-proposal-optional-catch-binding-7.2.0.tgz", + "integrity": "sha512-mgYj3jCcxug6KUcX4OBoOJz3CMrwRfQELPQ5560F70YQUBZB7uac9fqaWamKR1iWUzGiK2t0ygzjTScZnVz75g==", + "dev": true, + "requires": { + "@babel/helper-plugin-utils": "^7.0.0", + "@babel/plugin-syntax-optional-catch-binding": "^7.2.0" + } + }, + "@babel/plugin-syntax-async-generators": { + "version": "7.2.0", + "resolved": "https://registry.npmjs.org/@babel/plugin-syntax-async-generators/-/plugin-syntax-async-generators-7.2.0.tgz", + "integrity": "sha512-1ZrIRBv2t0GSlcwVoQ6VgSLpLgiN/FVQUzt9znxo7v2Ov4jJrs8RY8tv0wvDmFN3qIdMKWrmMMW6yZ0G19MfGg==", + "dev": true, + "requires": { + "@babel/helper-plugin-utils": "^7.0.0" + } + }, + "@babel/plugin-syntax-dynamic-import": { + "version": "7.2.0", + "resolved": "https://registry.npmjs.org/@babel/plugin-syntax-dynamic-import/-/plugin-syntax-dynamic-import-7.2.0.tgz", + "integrity": "sha512-mVxuJ0YroI/h/tbFTPGZR8cv6ai+STMKNBq0f8hFxsxWjl94qqhsb+wXbpNMDPU3cfR1TIsVFzU3nXyZMqyK4w==", + "dev": true, + "requires": { + "@babel/helper-plugin-utils": "^7.0.0" + } + }, + "@babel/plugin-syntax-optional-catch-binding": { + "version": "7.2.0", + "resolved": "https://registry.npmjs.org/@babel/plugin-syntax-optional-catch-binding/-/plugin-syntax-optional-catch-binding-7.2.0.tgz", + "integrity": "sha512-bDe4xKNhb0LI7IvZHiA13kff0KEfaGX/Hv4lMA9+7TEc63hMNvfKo6ZFpXhKuEp+II/q35Gc4NoMeDZyaUbj9w==", + "dev": true, + "requires": { + "@babel/helper-plugin-utils": "^7.0.0" + } + }, + "@babel/plugin-transform-dotall-regex": { + "version": "7.4.4", + "resolved": "https://registry.npmjs.org/@babel/plugin-transform-dotall-regex/-/plugin-transform-dotall-regex-7.4.4.tgz", + "integrity": "sha512-P05YEhRc2h53lZDjRPk/OektxCVevFzZs2Gfjd545Wde3k+yFDbXORgl2e0xpbq8mLcKJ7Idss4fAg0zORN/zg==", + "dev": true, + "requires": { + "@babel/helper-plugin-utils": "^7.0.0", + "@babel/helper-regex": "^7.4.4", + "regexpu-core": "^4.5.4" + } + }, + "@babel/plugin-transform-modules-commonjs": { + "version": "7.6.0", + "resolved": "https://registry.npmjs.org/@babel/plugin-transform-modules-commonjs/-/plugin-transform-modules-commonjs-7.6.0.tgz", + "integrity": "sha512-Ma93Ix95PNSEngqomy5LSBMAQvYKVe3dy+JlVJSHEXZR5ASL9lQBedMiCyVtmTLraIDVRE3ZjTZvmXXD2Ozw3g==", + "dev": true, + "requires": { + "@babel/helper-module-transforms": "^7.4.4", + "@babel/helper-plugin-utils": "^7.0.0", + "@babel/helper-simple-access": "^7.1.0", + "babel-plugin-dynamic-import-node": "^2.3.0" + } + }, + "@babel/template": { + "version": "7.6.0", + "resolved": "https://registry.npmjs.org/@babel/template/-/template-7.6.0.tgz", + "integrity": "sha512-5AEH2EXD8euCk446b7edmgFdub/qfH1SN6Nii3+fyXP807QRx9Q73A2N5hNwRRslC2H9sNzaFhsPubkS4L8oNQ==", + "dev": true, + "requires": { + "@babel/code-frame": "^7.0.0", + "@babel/parser": "^7.6.0", + "@babel/types": "^7.6.0" + } + }, + "@babel/traverse": { + "version": "7.6.0", + "resolved": "https://registry.npmjs.org/@babel/traverse/-/traverse-7.6.0.tgz", + "integrity": "sha512-93t52SaOBgml/xY74lsmt7xOR4ufYvhb5c5qiM6lu4J/dWGMAfAh6eKw4PjLes6DI6nQgearoxnFJk60YchpvQ==", + "dev": true, + "requires": { + "@babel/code-frame": "^7.5.5", + "@babel/generator": "^7.6.0", + "@babel/helper-function-name": "^7.1.0", + "@babel/helper-split-export-declaration": "^7.4.4", + "@babel/parser": "^7.6.0", + "@babel/types": "^7.6.0", + "debug": "^4.1.0", + "globals": "^11.1.0", + "lodash": "^4.17.13" + } + }, + "@babel/types": { + "version": "7.6.1", + "resolved": "https://registry.npmjs.org/@babel/types/-/types-7.6.1.tgz", + "integrity": "sha512-X7gdiuaCmA0uRjCmRtYJNAVCc/q+5xSgsfKJHqMN4iNLILX39677fJE1O40arPMh0TTtS9ItH67yre6c7k6t0g==", + "dev": true, + "requires": { + "esutils": "^2.0.2", + "lodash": "^4.17.13", + "to-fast-properties": "^2.0.0" + } + }, + "@concordance/react": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/@concordance/react/-/react-2.0.0.tgz", + "integrity": "sha512-huLSkUuM2/P+U0uy2WwlKuixMsTODD8p4JVQBI4VKeopkiN0C7M3N9XYVawb4M+4spN5RrO/eLhk7KoQX6nsfA==", + "dev": true, + "requires": { + "arrify": "^1.0.1" + }, + "dependencies": { + "arrify": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/arrify/-/arrify-1.0.1.tgz", + "integrity": "sha1-iYUI2iIm84DfkEcoRWhJwVAaSw0=", + "dev": true + } + } + }, + "@nodelib/fs.scandir": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.2.tgz", + "integrity": "sha512-wrIBsjA5pl13f0RN4Zx4FNWmU71lv03meGKnqRUoCyan17s4V3WL92f3w3AIuWbNnpcrQyFBU5qMavJoB8d27w==", + "dev": true, + "requires": { + "@nodelib/fs.stat": "2.0.2", + "run-parallel": "^1.1.9" + } + }, + "@nodelib/fs.stat": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/@nodelib/fs.stat/-/fs.stat-2.0.2.tgz", + "integrity": "sha512-z8+wGWV2dgUhLqrtRYa03yDx4HWMvXKi1z8g3m2JyxAx8F7xk74asqPk5LAETjqDSGLFML/6CDl0+yFunSYicw==", + "dev": true + }, + "@nodelib/fs.walk": { + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/@nodelib/fs.walk/-/fs.walk-1.2.3.tgz", + "integrity": "sha512-l6t8xEhfK9Sa4YO5mIRdau7XSOADfmh3jCr0evNHdY+HNkW6xuQhgMH7D73VV6WpZOagrW0UludvMTiifiwTfA==", + "dev": true, + "requires": { + "@nodelib/fs.scandir": "2.1.2", + "fastq": "^1.6.0" + } + }, + "@sindresorhus/is": { + "version": "0.14.0", + "resolved": "https://registry.npmjs.org/@sindresorhus/is/-/is-0.14.0.tgz", + "integrity": "sha512-9NET910DNaIPngYnLLPeg+Ogzqsi9uM4mSboU5y6p8S5DzMTVEsJZrawi+BoDNUVBa2DhJqQYUFvMDfgU062LQ==", + "dev": true + }, + "@szmarczak/http-timer": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/@szmarczak/http-timer/-/http-timer-1.1.2.tgz", + "integrity": "sha512-XIB2XbzHTN6ieIjfIMV9hlVcfPU26s2vafYWQcZHWXHOxiaRZYEDKEwdl129Zyg50+foYV2jCgtrqSA6qNuNSA==", + "dev": true, + "requires": { + "defer-to-connect": "^1.0.1" + } + }, + "@types/events": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/@types/events/-/events-3.0.0.tgz", + "integrity": "sha512-EaObqwIvayI5a8dCzhFrjKzVwKLxjoG9T6Ppd5CEo07LRKfQ8Yokw54r5+Wq7FaBQ+yXRvQAYPrHwya1/UFt9g==", + "dev": true + }, + "@types/glob": { + "version": "7.1.1", + "resolved": "https://registry.npmjs.org/@types/glob/-/glob-7.1.1.tgz", + "integrity": "sha512-1Bh06cbWJUHMC97acuD6UMG29nMt0Aqz1vF3guLfG+kHHJhy3AyohZFFxYk2f7Q1SQIrNwvncxAE0N/9s70F2w==", + "dev": true, + "requires": { + "@types/events": "*", + "@types/minimatch": "*", + "@types/node": "*" + } + }, + "@types/minimatch": { + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/@types/minimatch/-/minimatch-3.0.3.tgz", + "integrity": "sha512-tHq6qdbT9U1IRSGf14CL0pUlULksvY9OZ+5eEgl1N7t+OA3tGvNpxJCzuKQlsNgCVwbAs670L1vcVQi8j9HjnA==", + "dev": true + }, + "@types/node": { + "version": "12.7.5", + "resolved": "https://registry.npmjs.org/@types/node/-/node-12.7.5.tgz", + "integrity": "sha512-9fq4jZVhPNW8r+UYKnxF1e2HkDWOWKM5bC2/7c9wPV835I0aOrVbS/Hw/pWPk2uKrNXQqg9Z959Kz+IYDd5p3w==", + "dev": true + }, + "ansi-align": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/ansi-align/-/ansi-align-3.0.0.tgz", + "integrity": "sha512-ZpClVKqXN3RGBmKibdfWzqCY4lnjEuoNzU5T0oEFpfd/z5qJHVarukridD4juLO2FXMiwUQxr9WqQtaYa8XRYw==", + "dev": true, + "requires": { + "string-width": "^3.0.0" + }, + "dependencies": { + "emoji-regex": { + "version": "7.0.3", + "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-7.0.3.tgz", + "integrity": "sha512-CwBLREIQ7LvYFB0WyRvwhq5N5qPhc6PMjD6bYggFlI5YyDgl+0vxq5VHbMOFqLg7hfWzmu8T5Z1QofhmTIhItA==", + "dev": true + }, + "string-width": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/string-width/-/string-width-3.1.0.tgz", + "integrity": "sha512-vafcv6KjVZKSgz06oM/H6GDBrAtz8vdhQakGjFIvNrHA6y3HCF1CInLy+QLq8dTJPQ1b+KDUqDFctkdRW44e1w==", + "dev": true, + "requires": { + "emoji-regex": "^7.0.1", + "is-fullwidth-code-point": "^2.0.0", + "strip-ansi": "^5.1.0" + } + } + } + }, + "ansi-escapes": { + "version": "4.2.1", + "resolved": "https://registry.npmjs.org/ansi-escapes/-/ansi-escapes-4.2.1.tgz", + "integrity": "sha512-Cg3ymMAdN10wOk/VYfLV7KCQyv7EDirJ64500sU7n9UlmioEtDuU5Gd+hj73hXSU/ex7tHJSssmyftDdkMLO8Q==", + "dev": true, + "requires": { + "type-fest": "^0.5.2" + } + }, + "ansi-regex": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-4.1.0.tgz", + "integrity": "sha512-1apePfXM1UOSqw0o9IiFAovVz9M5S1Dg+4TrDwfMewQ6p/rmMueb7tWZjQ1rx4Loy1ArBggoqGpfqqdI4rondg==", + "dev": true + }, + "ansi-styles": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.1.0.tgz", + "integrity": "sha512-Qts4KCLKG+waHc9C4m07weIY8qyeixoS0h6RnbsNVD6Fw+pEZGW3vTyObL3WXpE09Mq4Oi7/lBEyLmOiLtlYWQ==", + "dev": true, + "requires": { + "color-convert": "^2.0.1" + }, + "dependencies": { + "color-convert": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", + "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", + "dev": true, + "requires": { + "color-name": "~1.1.4" + } + }, + "color-name": { + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", + "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", + "dev": true + } + } + }, + "anymatch": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/anymatch/-/anymatch-3.1.0.tgz", + "integrity": "sha512-Ozz7l4ixzI7Oxj2+cw+p0tVUt27BpaJ+1+q1TCeANWxHpvyn2+Un+YamBdfKu0uh8xLodGhoa1v7595NhKDAuA==", + "dev": true, + "requires": { + "normalize-path": "^3.0.0", + "picomatch": "^2.0.4" + } + }, + "argparse": { + "version": "1.0.10", + "resolved": "https://registry.npmjs.org/argparse/-/argparse-1.0.10.tgz", + "integrity": "sha512-o5Roy6tNG4SL/FOkCAN6RzjiakZS25RLYFrcMttJqbdd8BWrnA+fGz57iN5Pb06pvBGvl5gQ0B48dJlslXvoTg==", + "dev": true, + "requires": { + "sprintf-js": "~1.0.2" + } + }, + "arr-flatten": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/arr-flatten/-/arr-flatten-1.1.0.tgz", + "integrity": "sha512-L3hKV5R/p5o81R7O02IGnwpDmkp6E982XhtbuwSe3O4qOtMMMtodicASA1Cny2U+aCXcNpml+m4dPsvsJ3jatg==", + "dev": true + }, + "array-find-index": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/array-find-index/-/array-find-index-1.0.2.tgz", + "integrity": "sha1-3wEKoSh+Fku9pvlyOwqWoexBh6E=", + "dev": true + }, + "array-union": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/array-union/-/array-union-2.1.0.tgz", + "integrity": "sha512-HGyxoOTYUyCM6stUe6EJgnd4EoewAI7zMdfqO+kGjnlZmBDz/cR5pf8r/cR4Wq60sL/p0IkcjUEEPwS3GFrIyw==", + "dev": true + }, + "array-uniq": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/array-uniq/-/array-uniq-2.1.0.tgz", + "integrity": "sha512-bdHxtev7FN6+MXI1YFW0Q8mQ8dTJc2S8AMfju+ZR77pbg2yAdVyDlwkaUI7Har0LyOMRFPHrJ9lYdyjZZswdlQ==", + "dev": true + }, + "arrify": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/arrify/-/arrify-2.0.1.tgz", + "integrity": "sha512-3duEwti880xqi4eAMN8AyR4a0ByT90zoYdLlevfrvU43vb0YZwZVfxOgxWrLXXXpyugL0hNZc9G6BiB5B3nUug==", + "dev": true + }, + "astral-regex": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/astral-regex/-/astral-regex-1.0.0.tgz", + "integrity": "sha512-+Ryf6g3BKoRc7jfp7ad8tM4TtMiaWvbF/1/sQcZPkkS7ag3D5nMBCe2UfOTONtAkaG0tO0ij3C5Lwmf1EiyjHg==", + "dev": true + }, + "ava": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/ava/-/ava-2.3.0.tgz", + "integrity": "sha512-4VaaSnl13vpTZmqW3aMqioSolT0/ozRkjQxTLi3p8wtyRONuX/uLKL3uF0j50w2BNRoLsJqztnkX2h8xeVp2lg==", + "dev": true, + "requires": { + "@ava/babel-preset-stage-4": "^4.0.0", + "@ava/babel-preset-transform-test-files": "^6.0.0", + "@babel/core": "^7.5.5", + "@babel/generator": "^7.5.5", + "@concordance/react": "^2.0.0", + "ansi-escapes": "^4.2.1", + "ansi-styles": "^4.0.0", + "arr-flatten": "^1.1.0", + "array-union": "^2.1.0", + "array-uniq": "^2.1.0", + "arrify": "^2.0.1", + "bluebird": "^3.5.5", + "chalk": "^2.4.2", + "chokidar": "^3.0.2", + "chunkd": "^1.0.0", + "ci-parallel-vars": "^1.0.0", + "clean-stack": "^2.2.0", + "clean-yaml-object": "^0.1.0", + "cli-cursor": "^3.1.0", + "cli-truncate": "^2.0.0", + "code-excerpt": "^2.1.1", + "common-path-prefix": "^1.0.0", + "concordance": "^4.0.0", + "convert-source-map": "^1.6.0", + "currently-unhandled": "^0.4.1", + "debug": "^4.1.1", + "del": "^4.1.1", + "dot-prop": "^5.1.0", + "emittery": "^0.4.1", + "empower-core": "^1.2.0", + "equal-length": "^1.0.0", + "escape-string-regexp": "^2.0.0", + "esm": "^3.2.25", + "figures": "^3.0.0", + "find-up": "^4.1.0", + "get-port": "^5.0.0", + "globby": "^10.0.1", + "ignore-by-default": "^1.0.0", + "import-local": "^3.0.2", + "indent-string": "^4.0.0", + "is-ci": "^2.0.0", + "is-error": "^2.2.2", + "is-observable": "^2.0.0", + "is-plain-object": "^3.0.0", + "is-promise": "^2.1.0", + "lodash": "^4.17.15", + "loud-rejection": "^2.1.0", + "make-dir": "^3.0.0", + "matcher": "^2.0.0", + "md5-hex": "^3.0.1", + "meow": "^5.0.0", + "micromatch": "^4.0.2", + "ms": "^2.1.2", + "observable-to-promise": "^1.0.0", + "ora": "^3.4.0", + "package-hash": "^4.0.0", + "pkg-conf": "^3.1.0", + "plur": "^3.1.1", + "pretty-ms": "^5.0.0", + "require-precompiled": "^0.1.0", + "resolve-cwd": "^3.0.0", + "slash": "^3.0.0", + "source-map-support": "^0.5.13", + "stack-utils": "^1.0.2", + "strip-ansi": "^5.2.0", + "strip-bom-buf": "^2.0.0", + "supertap": "^1.0.0", + "supports-color": "^7.0.0", + "trim-off-newlines": "^1.0.1", + "trim-right": "^1.0.1", + "unique-temp-dir": "^1.0.0", + "update-notifier": "^3.0.1", + "write-file-atomic": "^3.0.0" + } + }, + "babel-plugin-dynamic-import-node": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/babel-plugin-dynamic-import-node/-/babel-plugin-dynamic-import-node-2.3.0.tgz", + "integrity": "sha512-o6qFkpeQEBxcqt0XYlWzAVxNCSCZdUgcR8IRlhD/8DylxjjO4foPcvTW0GGKa/cVt3rvxZ7o5ippJ+/0nvLhlQ==", + "dev": true, + "requires": { + "object.assign": "^4.1.0" + } + }, + "babel-plugin-espower": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/babel-plugin-espower/-/babel-plugin-espower-3.0.1.tgz", + "integrity": "sha512-Ms49U7VIAtQ/TtcqRbD6UBmJBUCSxiC3+zPc+eGqxKUIFO1lTshyEDRUjhoAbd2rWfwYf3cZ62oXozrd8W6J0A==", + "dev": true, + "requires": { + "@babel/generator": "^7.0.0", + "@babel/parser": "^7.0.0", + "call-matcher": "^1.0.0", + "core-js": "^2.0.0", + "espower-location-detector": "^1.0.0", + "espurify": "^1.6.0", + "estraverse": "^4.1.1" + } + }, + "balanced-match": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.0.tgz", + "integrity": "sha1-ibTRmasr7kneFk6gK4nORi1xt2c=", + "dev": true + }, + "binary-extensions": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/binary-extensions/-/binary-extensions-2.0.0.tgz", + "integrity": "sha512-Phlt0plgpIIBOGTT/ehfFnbNlfsDEiqmzE2KRXoX1bLIlir4X/MR+zSyBEkL05ffWgnRSf/DXv+WrUAVr93/ow==", + "dev": true + }, + "bluebird": { + "version": "3.5.5", + "resolved": "https://registry.npmjs.org/bluebird/-/bluebird-3.5.5.tgz", + "integrity": "sha512-5am6HnnfN+urzt4yfg7IgTbotDjIT/u8AJpEt0sIU9FtXfVeezXAPKswrG+xKUCOYAINpSdgZVDU6QFh+cuH3w==", + "dev": true + }, + "blueimp-md5": { + "version": "2.12.0", + "resolved": "https://registry.npmjs.org/blueimp-md5/-/blueimp-md5-2.12.0.tgz", + "integrity": "sha512-zo+HIdIhzojv6F1siQPqPFROyVy7C50KzHv/k/Iz+BtvtVzSHXiMXOpq2wCfNkeBqdCv+V8XOV96tsEt2W/3rQ==", + "dev": true + }, + "boxen": { + "version": "3.2.0", + "resolved": "https://registry.npmjs.org/boxen/-/boxen-3.2.0.tgz", + "integrity": "sha512-cU4J/+NodM3IHdSL2yN8bqYqnmlBTidDR4RC7nJs61ZmtGz8VZzM3HLQX0zY5mrSmPtR3xWwsq2jOUQqFZN8+A==", + "dev": true, + "requires": { + "ansi-align": "^3.0.0", + "camelcase": "^5.3.1", + "chalk": "^2.4.2", + "cli-boxes": "^2.2.0", + "string-width": "^3.0.0", + "term-size": "^1.2.0", + "type-fest": "^0.3.0", + "widest-line": "^2.0.0" + }, + "dependencies": { + "camelcase": { + "version": "5.3.1", + "resolved": "https://registry.npmjs.org/camelcase/-/camelcase-5.3.1.tgz", + "integrity": "sha512-L28STB170nwWS63UjtlEOE3dldQApaJXZkOI1uMFfzf3rRuPegHaHesyee+YxQ+W6SvRDQV6UrdOdRiR153wJg==", + "dev": true + }, + "emoji-regex": { + "version": "7.0.3", + "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-7.0.3.tgz", + "integrity": "sha512-CwBLREIQ7LvYFB0WyRvwhq5N5qPhc6PMjD6bYggFlI5YyDgl+0vxq5VHbMOFqLg7hfWzmu8T5Z1QofhmTIhItA==", + "dev": true + }, + "string-width": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/string-width/-/string-width-3.1.0.tgz", + "integrity": "sha512-vafcv6KjVZKSgz06oM/H6GDBrAtz8vdhQakGjFIvNrHA6y3HCF1CInLy+QLq8dTJPQ1b+KDUqDFctkdRW44e1w==", + "dev": true, + "requires": { + "emoji-regex": "^7.0.1", + "is-fullwidth-code-point": "^2.0.0", + "strip-ansi": "^5.1.0" + } + }, + "type-fest": { + "version": "0.3.1", + "resolved": "https://registry.npmjs.org/type-fest/-/type-fest-0.3.1.tgz", + "integrity": "sha512-cUGJnCdr4STbePCgqNFbpVNCepa+kAVohJs1sLhxzdH+gnEoOd8VhbYa7pD3zZYGiURWM2xzEII3fQcRizDkYQ==", + "dev": true + } + } + }, + "brace-expansion": { + "version": "1.1.11", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.11.tgz", + "integrity": "sha512-iCuPHDFgrHX7H2vEI/5xpz07zSHB00TpugqhmYtVmMO6518mCuRMoOYFldEBl0g187ufozdaHgWKcYFb61qGiA==", + "dev": true, + "requires": { + "balanced-match": "^1.0.0", + "concat-map": "0.0.1" + } + }, + "braces": { + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.2.tgz", + "integrity": "sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A==", + "dev": true, + "requires": { + "fill-range": "^7.0.1" + } + }, + "buffer-from": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/buffer-from/-/buffer-from-1.1.1.tgz", + "integrity": "sha512-MQcXEUbCKtEo7bhqEs6560Hyd4XaovZlO/k9V3hjVUF/zwW7KBVdSK4gIt/bzwS9MbR5qob+F5jusZsb0YQK2A==", + "dev": true + }, + "cacheable-request": { + "version": "6.1.0", + "resolved": "https://registry.npmjs.org/cacheable-request/-/cacheable-request-6.1.0.tgz", + "integrity": "sha512-Oj3cAGPCqOZX7Rz64Uny2GYAZNliQSqfbePrgAQ1wKAihYmCUnraBtJtKcGR4xz7wF+LoJC+ssFZvv5BgF9Igg==", + "dev": true, + "requires": { + "clone-response": "^1.0.2", + "get-stream": "^5.1.0", + "http-cache-semantics": "^4.0.0", + "keyv": "^3.0.0", + "lowercase-keys": "^2.0.0", + "normalize-url": "^4.1.0", + "responselike": "^1.0.2" + }, + "dependencies": { + "get-stream": { + "version": "5.1.0", + "resolved": "https://registry.npmjs.org/get-stream/-/get-stream-5.1.0.tgz", + "integrity": "sha512-EXr1FOzrzTfGeL0gQdeFEvOMm2mzMOglyiOXSTpPC+iAjAKftbr3jpCMWynogwYnM+eSj9sHGc6wjIcDvYiygw==", + "dev": true, + "requires": { + "pump": "^3.0.0" + } + }, + "lowercase-keys": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/lowercase-keys/-/lowercase-keys-2.0.0.tgz", + "integrity": "sha512-tqNXrS78oMOE73NMxK4EMLQsQowWf8jKooH9g7xPavRT706R6bkQJ6DY2Te7QukaZsulxa30wQ7bk0pm4XiHmA==", + "dev": true + } + } + }, + "call-matcher": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/call-matcher/-/call-matcher-1.1.0.tgz", + "integrity": "sha512-IoQLeNwwf9KTNbtSA7aEBb1yfDbdnzwjCetjkC8io5oGeOmK2CBNdg0xr+tadRYKO0p7uQyZzvon0kXlZbvGrw==", + "dev": true, + "requires": { + "core-js": "^2.0.0", + "deep-equal": "^1.0.0", + "espurify": "^1.6.0", + "estraverse": "^4.0.0" + } + }, + "call-signature": { + "version": "0.0.2", + "resolved": "https://registry.npmjs.org/call-signature/-/call-signature-0.0.2.tgz", + "integrity": "sha1-qEq8glpV70yysCi9dOIFpluaSZY=", + "dev": true + }, + "camelcase": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/camelcase/-/camelcase-4.1.0.tgz", + "integrity": "sha1-1UVjW+HjPFQmScaRc+Xeas+uNN0=", + "dev": true + }, + "camelcase-keys": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/camelcase-keys/-/camelcase-keys-4.2.0.tgz", + "integrity": "sha1-oqpfsa9oh1glnDLBQUJteJI7m3c=", + "dev": true, + "requires": { + "camelcase": "^4.1.0", + "map-obj": "^2.0.0", + "quick-lru": "^1.0.0" + } + }, + "chalk": { + "version": "2.4.2", + "resolved": "https://registry.npmjs.org/chalk/-/chalk-2.4.2.tgz", + "integrity": "sha512-Mti+f9lpJNcwF4tWV8/OrTTtF1gZi+f8FqlyAdouralcFWFQWF2+NgCHShjkCb+IFBLq9buZwE1xckQU4peSuQ==", + "dev": true, + "requires": { + "ansi-styles": "^3.2.1", + "escape-string-regexp": "^1.0.5", + "supports-color": "^5.3.0" + }, + "dependencies": { + "ansi-styles": { + "version": "3.2.1", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-3.2.1.tgz", + "integrity": "sha512-VT0ZI6kZRdTh8YyJw3SMbYm/u+NqfsAxEpWO0Pf9sq8/e94WxxOpPKx9FR1FlyCtOVDNOQ+8ntlqFxiRc+r5qA==", + "dev": true, + "requires": { + "color-convert": "^1.9.0" + } + }, + "escape-string-regexp": { + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/escape-string-regexp/-/escape-string-regexp-1.0.5.tgz", + "integrity": "sha1-G2HAViGQqN/2rjuyzwIAyhMLhtQ=", + "dev": true + }, + "supports-color": { + "version": "5.5.0", + "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-5.5.0.tgz", + "integrity": "sha512-QjVjwdXIt408MIiAqCX4oUKsgU2EqAGzs2Ppkm4aQYbjm+ZEWEcW4SfFNTr4uMNZma0ey4f5lgLrkB0aX0QMow==", + "dev": true, + "requires": { + "has-flag": "^3.0.0" + } + } + } + }, + "chokidar": { + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/chokidar/-/chokidar-3.0.2.tgz", + "integrity": "sha512-c4PR2egjNjI1um6bamCQ6bUNPDiyofNQruHvKgHQ4gDUP/ITSVSzNsiI5OWtHOsX323i5ha/kk4YmOZ1Ktg7KA==", + "dev": true, + "requires": { + "anymatch": "^3.0.1", + "braces": "^3.0.2", + "fsevents": "^2.0.6", + "glob-parent": "^5.0.0", + "is-binary-path": "^2.1.0", + "is-glob": "^4.0.1", + "normalize-path": "^3.0.0", + "readdirp": "^3.1.1" + } + }, + "chunkd": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/chunkd/-/chunkd-1.0.0.tgz", + "integrity": "sha512-xx3Pb5VF9QaqCotolyZ1ywFBgyuJmu6+9dLiqBxgelEse9Xsr3yUlpoX3O4Oh11M00GT2kYMsRByTKIMJW2Lkg==", + "dev": true + }, + "ci-info": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/ci-info/-/ci-info-2.0.0.tgz", + "integrity": "sha512-5tK7EtrZ0N+OLFMthtqOj4fI2Jeb88C4CAZPu25LDVUgXJ0A3Js4PMGqrn0JU1W0Mh1/Z8wZzYPxqUrXeBboCQ==", + "dev": true + }, + "ci-parallel-vars": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/ci-parallel-vars/-/ci-parallel-vars-1.0.0.tgz", + "integrity": "sha512-u6dx20FBXm+apMi+5x7UVm6EH7BL1gc4XrcnQewjcB7HWRcor/V5qWc3RG2HwpgDJ26gIi2DSEu3B7sXynAw/g==", + "dev": true + }, + "clean-stack": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/clean-stack/-/clean-stack-2.2.0.tgz", + "integrity": "sha512-4diC9HaTE+KRAMWhDhrGOECgWZxoevMc5TlkObMqNSsVU62PYzXZ/SMTjzyGAFF1YusgxGcSWTEXBhp0CPwQ1A==", + "dev": true + }, + "clean-yaml-object": { + "version": "0.1.0", + "resolved": "https://registry.npmjs.org/clean-yaml-object/-/clean-yaml-object-0.1.0.tgz", + "integrity": "sha1-Y/sRDcLOGoTcIfbZM0h20BCui2g=", + "dev": true + }, + "cli-boxes": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/cli-boxes/-/cli-boxes-2.2.0.tgz", + "integrity": "sha512-gpaBrMAizVEANOpfZp/EEUixTXDyGt7DFzdK5hU+UbWt/J0lB0w20ncZj59Z9a93xHb9u12zF5BS6i9RKbtg4w==", + "dev": true + }, + "cli-cursor": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/cli-cursor/-/cli-cursor-3.1.0.tgz", + "integrity": "sha512-I/zHAwsKf9FqGoXM4WWRACob9+SNukZTd94DWF57E4toouRulbCxcUh6RKUEOQlYTHJnzkPMySvPNaaSLNfLZw==", + "dev": true, + "requires": { + "restore-cursor": "^3.1.0" + } + }, + "cli-spinners": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/cli-spinners/-/cli-spinners-2.2.0.tgz", + "integrity": "sha512-tgU3fKwzYjiLEQgPMD9Jt+JjHVL9kW93FiIMX/l7rivvOD4/LL0Mf7gda3+4U2KJBloybwgj5KEoQgGRioMiKQ==", + "dev": true + }, + "cli-truncate": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/cli-truncate/-/cli-truncate-2.0.0.tgz", + "integrity": "sha512-C4hp+8GCIFVsUUiXcw+ce+7wexVWImw8rQrgMBFsqerx9LvvcGlwm6sMjQYAEmV/Xb87xc1b5Ttx505MSpZVqg==", + "dev": true, + "requires": { + "slice-ansi": "^2.1.0", + "string-width": "^4.1.0" + } + }, + "clone": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/clone/-/clone-1.0.4.tgz", + "integrity": "sha1-2jCcwmPfFZlMaIypAheco8fNfH4=", + "dev": true + }, + "clone-response": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/clone-response/-/clone-response-1.0.2.tgz", + "integrity": "sha1-0dyXOSAxTfZ/vrlCI7TuNQI56Ws=", + "dev": true, + "requires": { + "mimic-response": "^1.0.0" + } + }, + "code-excerpt": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/code-excerpt/-/code-excerpt-2.1.1.tgz", + "integrity": "sha512-tJLhH3EpFm/1x7heIW0hemXJTUU5EWl2V0EIX558jp05Mt1U6DVryCgkp3l37cxqs+DNbNgxG43SkwJXpQ14Jw==", + "dev": true, + "requires": { + "convert-to-spaces": "^1.0.1" + } + }, + "color-convert": { + "version": "1.9.3", + "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-1.9.3.tgz", + "integrity": "sha512-QfAUtd+vFdAtFQcC8CCyYt1fYWxSqAiK2cSD6zDB8N3cpsEBAvRxp9zOGg6G/SHHJYAT88/az/IuDGALsNVbGg==", + "dev": true, + "requires": { + "color-name": "1.1.3" + } + }, + "color-name": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.3.tgz", + "integrity": "sha1-p9BVi9icQveV3UIyj3QIMcpTvCU=", + "dev": true + }, + "common-path-prefix": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/common-path-prefix/-/common-path-prefix-1.0.0.tgz", + "integrity": "sha1-zVL28HEuC6q5fW+XModPIvR3UsA=", + "dev": true + }, + "concat-map": { + "version": "0.0.1", + "resolved": "https://registry.npmjs.org/concat-map/-/concat-map-0.0.1.tgz", + "integrity": "sha1-2Klr13/Wjfd5OnMDajug1UBdR3s=", + "dev": true + }, + "concordance": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/concordance/-/concordance-4.0.0.tgz", + "integrity": "sha512-l0RFuB8RLfCS0Pt2Id39/oCPykE01pyxgAFypWTlaGRgvLkZrtczZ8atEHpTeEIW+zYWXTBuA9cCSeEOScxReQ==", + "dev": true, + "requires": { + "date-time": "^2.1.0", + "esutils": "^2.0.2", + "fast-diff": "^1.1.2", + "js-string-escape": "^1.0.1", + "lodash.clonedeep": "^4.5.0", + "lodash.flattendeep": "^4.4.0", + "lodash.islength": "^4.0.1", + "lodash.merge": "^4.6.1", + "md5-hex": "^2.0.0", + "semver": "^5.5.1", + "well-known-symbols": "^2.0.0" + }, + "dependencies": { + "md5-hex": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/md5-hex/-/md5-hex-2.0.0.tgz", + "integrity": "sha1-0FiOnxx0lUSS7NJKwKxs6ZfZLjM=", + "dev": true, + "requires": { + "md5-o-matic": "^0.1.1" + } + } + } + }, + "configstore": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/configstore/-/configstore-4.0.0.tgz", + "integrity": "sha512-CmquAXFBocrzaSM8mtGPMM/HiWmyIpr4CcJl/rgY2uCObZ/S7cKU0silxslqJejl+t/T9HS8E0PUNQD81JGUEQ==", + "dev": true, + "requires": { + "dot-prop": "^4.1.0", + "graceful-fs": "^4.1.2", + "make-dir": "^1.0.0", + "unique-string": "^1.0.0", + "write-file-atomic": "^2.0.0", + "xdg-basedir": "^3.0.0" + }, + "dependencies": { + "dot-prop": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/dot-prop/-/dot-prop-4.2.0.tgz", + "integrity": "sha512-tUMXrxlExSW6U2EXiiKGSBVdYgtV8qlHL+C10TsW4PURY/ic+eaysnSkwB4kA/mBlCyy/IKDJ+Lc3wbWeaXtuQ==", + "dev": true, + "requires": { + "is-obj": "^1.0.0" + } + }, + "is-obj": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/is-obj/-/is-obj-1.0.1.tgz", + "integrity": "sha1-PkcprB9f3gJc19g6iW2rn09n2w8=", + "dev": true + }, + "make-dir": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/make-dir/-/make-dir-1.3.0.tgz", + "integrity": "sha512-2w31R7SJtieJJnQtGc7RVL2StM2vGYVfqUOvUDxH6bC6aJTxPxTF0GnIgCyu7tjockiUWAYQRbxa7vKn34s5sQ==", + "dev": true, + "requires": { + "pify": "^3.0.0" + } + }, + "pify": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/pify/-/pify-3.0.0.tgz", + "integrity": "sha1-5aSs0sEB/fPZpNB/DbxNtJ3SgXY=", + "dev": true + }, + "write-file-atomic": { + "version": "2.4.3", + "resolved": "https://registry.npmjs.org/write-file-atomic/-/write-file-atomic-2.4.3.tgz", + "integrity": "sha512-GaETH5wwsX+GcnzhPgKcKjJ6M2Cq3/iZp1WyY/X1CSqrW+jVNM9Y7D8EC2sM4ZG/V8wZlSniJnCKWPmBYAucRQ==", + "dev": true, + "requires": { + "graceful-fs": "^4.1.11", + "imurmurhash": "^0.1.4", + "signal-exit": "^3.0.2" + } + } + } + }, + "convert-source-map": { + "version": "1.6.0", + "resolved": "https://registry.npmjs.org/convert-source-map/-/convert-source-map-1.6.0.tgz", + "integrity": "sha512-eFu7XigvxdZ1ETfbgPBohgyQ/Z++C0eEhTor0qRwBw9unw+L0/6V8wkSuGgzdThkiS5lSpdptOQPD8Ak40a+7A==", + "dev": true, + "requires": { + "safe-buffer": "~5.1.1" + } + }, + "convert-to-spaces": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/convert-to-spaces/-/convert-to-spaces-1.0.2.tgz", + "integrity": "sha1-fj5Iu+bZl7FBfdyihoIEtNPYVxU=", + "dev": true + }, + "core-js": { + "version": "2.6.9", + "resolved": "https://registry.npmjs.org/core-js/-/core-js-2.6.9.tgz", + "integrity": "sha512-HOpZf6eXmnl7la+cUdMnLvUxKNqLUzJvgIziQ0DiF3JwSImNphIqdGqzj6hIKyX04MmV0poclQ7+wjWvxQyR2A==", + "dev": true + }, + "core-util-is": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/core-util-is/-/core-util-is-1.0.2.tgz", + "integrity": "sha1-tf1UIgqivFq1eqtxQMlAdUUDwac=" + }, + "cross-spawn": { + "version": "5.1.0", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-5.1.0.tgz", + "integrity": "sha1-6L0O/uWPz/b4+UUQoKVUu/ojVEk=", + "dev": true, + "requires": { + "lru-cache": "^4.0.1", + "shebang-command": "^1.2.0", + "which": "^1.2.9" + } + }, + "crypto-random-string": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/crypto-random-string/-/crypto-random-string-1.0.0.tgz", + "integrity": "sha1-ojD2T1aDEOFJgAmUB5DsmVRbyn4=", + "dev": true + }, + "currently-unhandled": { + "version": "0.4.1", + "resolved": "https://registry.npmjs.org/currently-unhandled/-/currently-unhandled-0.4.1.tgz", + "integrity": "sha1-mI3zP+qxke95mmE2nddsF635V+o=", + "dev": true, + "requires": { + "array-find-index": "^1.0.1" + } + }, + "date-time": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/date-time/-/date-time-2.1.0.tgz", + "integrity": "sha512-/9+C44X7lot0IeiyfgJmETtRMhBidBYM2QFFIkGa0U1k+hSyY87Nw7PY3eDqpvCBm7I3WCSfPeZskW/YYq6m4g==", + "dev": true, + "requires": { + "time-zone": "^1.0.0" + } + }, + "debug": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/debug/-/debug-4.1.1.tgz", + "integrity": "sha512-pYAIzeRo8J6KPEaJ0VWOh5Pzkbw/RetuzehGM7QRRX5he4fPHx2rdKMB256ehJCkX+XRQm16eZLqLNS8RSZXZw==", + "dev": true, + "requires": { + "ms": "^2.1.1" + } + }, + "decamelize": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/decamelize/-/decamelize-1.2.0.tgz", + "integrity": "sha1-9lNNFRSCabIDUue+4m9QH5oZEpA=", + "dev": true + }, + "decamelize-keys": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/decamelize-keys/-/decamelize-keys-1.1.0.tgz", + "integrity": "sha1-0XGoeTMlKAfrPLYdwcFEXQeN8tk=", + "dev": true, + "requires": { + "decamelize": "^1.1.0", + "map-obj": "^1.0.0" + }, + "dependencies": { + "map-obj": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/map-obj/-/map-obj-1.0.1.tgz", + "integrity": "sha1-2TPOuSBdgr3PSIb2dCvcK03qFG0=", + "dev": true + } + } + }, + "decompress-response": { + "version": "3.3.0", + "resolved": "https://registry.npmjs.org/decompress-response/-/decompress-response-3.3.0.tgz", + "integrity": "sha1-gKTdMjdIOEv6JICDYirt7Jgq3/M=", + "dev": true, + "requires": { + "mimic-response": "^1.0.0" + } + }, + "deep-equal": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/deep-equal/-/deep-equal-1.1.0.tgz", + "integrity": "sha512-ZbfWJq/wN1Z273o7mUSjILYqehAktR2NVoSrOukDkU9kg2v/Uv89yU4Cvz8seJeAmtN5oqiefKq8FPuXOboqLw==", + "dev": true, + "requires": { + "is-arguments": "^1.0.4", + "is-date-object": "^1.0.1", + "is-regex": "^1.0.4", + "object-is": "^1.0.1", + "object-keys": "^1.1.1", + "regexp.prototype.flags": "^1.2.0" + } + }, + "deep-extend": { + "version": "0.6.0", + "resolved": "https://registry.npmjs.org/deep-extend/-/deep-extend-0.6.0.tgz", + "integrity": "sha512-LOHxIOaPYdHlJRtCQfDIVZtfw/ufM8+rVj649RIHzcm/vGwQRXFt6OPqIFWsm2XEMrNIEtWR64sY1LEKD2vAOA==", + "dev": true + }, + "defaults": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/defaults/-/defaults-1.0.3.tgz", + "integrity": "sha1-xlYFHpgX2f8I7YgUd/P+QBnz730=", + "dev": true, + "requires": { + "clone": "^1.0.2" + } + }, + "defer-to-connect": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/defer-to-connect/-/defer-to-connect-1.0.2.tgz", + "integrity": "sha512-k09hcQcTDY+cwgiwa6PYKLm3jlagNzQ+RSvhjzESOGOx+MNOuXkxTfEvPrO1IOQ81tArCFYQgi631clB70RpQw==", + "dev": true + }, + "define-properties": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/define-properties/-/define-properties-1.1.3.tgz", + "integrity": "sha512-3MqfYKj2lLzdMSf8ZIZE/V+Zuy+BgD6f164e8K2w7dgnpKArBDerGYpM46IYYcjnkdPNMjPk9A6VFB8+3SKlXQ==", + "dev": true, + "requires": { + "object-keys": "^1.0.12" + } + }, + "del": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/del/-/del-4.1.1.tgz", + "integrity": "sha512-QwGuEUouP2kVwQenAsOof5Fv8K9t3D8Ca8NxcXKrIpEHjTXK5J2nXLdP+ALI1cgv8wj7KuwBhTwBkOZSJKM5XQ==", + "dev": true, + "requires": { + "@types/glob": "^7.1.1", + "globby": "^6.1.0", + "is-path-cwd": "^2.0.0", + "is-path-in-cwd": "^2.0.0", + "p-map": "^2.0.0", + "pify": "^4.0.1", + "rimraf": "^2.6.3" + }, + "dependencies": { + "array-union": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/array-union/-/array-union-1.0.2.tgz", + "integrity": "sha1-mjRBDk9OPaI96jdb5b5w8kd47Dk=", + "dev": true, + "requires": { + "array-uniq": "^1.0.1" + } + }, + "array-uniq": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/array-uniq/-/array-uniq-1.0.3.tgz", + "integrity": "sha1-r2rId6Jcx/dOBYiUdThY39sk/bY=", + "dev": true + }, + "globby": { + "version": "6.1.0", + "resolved": "https://registry.npmjs.org/globby/-/globby-6.1.0.tgz", + "integrity": "sha1-9abXDoOV4hyFj7BInWTfAkJNUGw=", + "dev": true, + "requires": { + "array-union": "^1.0.1", + "glob": "^7.0.3", + "object-assign": "^4.0.1", + "pify": "^2.0.0", + "pinkie-promise": "^2.0.0" + }, + "dependencies": { + "pify": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/pify/-/pify-2.3.0.tgz", + "integrity": "sha1-7RQaasBDqEnqWISY59yosVMw6Qw=", + "dev": true + } + } + } + } + }, + "dir-glob": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/dir-glob/-/dir-glob-3.0.1.tgz", + "integrity": "sha512-WkrWp9GR4KXfKGYzOLmTuGVi1UWFfws377n9cc55/tb6DuqyF6pcQ5AbiHEshaDpY9v6oaSr2XCDidGmMwdzIA==", + "dev": true, + "requires": { + "path-type": "^4.0.0" + } + }, + "dot-prop": { + "version": "5.1.0", + "resolved": "https://registry.npmjs.org/dot-prop/-/dot-prop-5.1.0.tgz", + "integrity": "sha512-n1oC6NBF+KM9oVXtjmen4Yo7HyAVWV2UUl50dCYJdw2924K6dX9bf9TTTWaKtYlRn0FEtxG27KS80ayVLixxJA==", + "dev": true, + "requires": { + "is-obj": "^2.0.0" + } + }, + "duplexer3": { + "version": "0.1.4", + "resolved": "https://registry.npmjs.org/duplexer3/-/duplexer3-0.1.4.tgz", + "integrity": "sha1-7gHdHKwO08vH/b6jfcCo8c4ALOI=", + "dev": true + }, + "emittery": { + "version": "0.4.1", + "resolved": "https://registry.npmjs.org/emittery/-/emittery-0.4.1.tgz", + "integrity": "sha512-r4eRSeStEGf6M5SKdrQhhLK5bOwOBxQhIE3YSTnZE3GpKiLfnnhE+tPtrJE79+eDJgm39BM6LSoI8SCx4HbwlQ==", + "dev": true + }, + "emoji-regex": { + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz", + "integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==", + "dev": true + }, + "empower-core": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/empower-core/-/empower-core-1.2.0.tgz", + "integrity": "sha512-g6+K6Geyc1o6FdXs9HwrXleCFan7d66G5xSCfSF7x1mJDCes6t0om9lFQG3zOrzh3Bkb/45N0cZ5Gqsf7YrzGQ==", + "dev": true, + "requires": { + "call-signature": "0.0.2", + "core-js": "^2.0.0" + } + }, + "end-of-stream": { + "version": "1.4.1", + "resolved": "https://registry.npmjs.org/end-of-stream/-/end-of-stream-1.4.1.tgz", + "integrity": "sha512-1MkrZNvWTKCaigbn+W15elq2BB/L22nqrSY5DKlo3X6+vclJm8Bb5djXJBmEX6fS3+zCh/F4VBK5Z2KxJt4s2Q==", + "dev": true, + "requires": { + "once": "^1.4.0" + } + }, + "equal-length": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/equal-length/-/equal-length-1.0.1.tgz", + "integrity": "sha1-IcoRLUirJLTh5//A5TOdMf38J0w=", + "dev": true + }, + "error-ex": { + "version": "1.3.2", + "resolved": "https://registry.npmjs.org/error-ex/-/error-ex-1.3.2.tgz", + "integrity": "sha512-7dFHNmqeFSEt2ZBsCriorKnn3Z2pj+fd9kmI6QoWw4//DL+icEBfc0U7qJCisqrTsKTjw4fNFy2pW9OqStD84g==", + "dev": true, + "requires": { + "is-arrayish": "^0.2.1" + } + }, + "es6-error": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/es6-error/-/es6-error-4.1.1.tgz", + "integrity": "sha512-Um/+FxMr9CISWh0bi5Zv0iOD+4cFh5qLeks1qhAopKVAJw3drgKbKySikp7wGhDL0HPeaja0P5ULZrxLkniUVg==", + "dev": true + }, + "escape-string-regexp": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/escape-string-regexp/-/escape-string-regexp-2.0.0.tgz", + "integrity": "sha512-UpzcLCXolUWcNu5HtVMHYdXJjArjsF9C0aNnquZYY4uW/Vu0miy5YoWvbV345HauVvcAUnpRuhMMcqTcGOY2+w==", + "dev": true + }, + "esm": { + "version": "3.2.25", + "resolved": "https://registry.npmjs.org/esm/-/esm-3.2.25.tgz", + "integrity": "sha512-U1suiZ2oDVWv4zPO56S0NcR5QriEahGtdN2OR6FiOG4WJvcjBVFB0qI4+eKoWFH483PKGuLuu6V8Z4T5g63UVA==", + "dev": true + }, + "espower-location-detector": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/espower-location-detector/-/espower-location-detector-1.0.0.tgz", + "integrity": "sha1-oXt+zFnTDheeK+9z+0E3cEyzMbU=", + "dev": true, + "requires": { + "is-url": "^1.2.1", + "path-is-absolute": "^1.0.0", + "source-map": "^0.5.0", + "xtend": "^4.0.0" + } + }, + "esprima": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/esprima/-/esprima-4.0.1.tgz", + "integrity": "sha512-eGuFFw7Upda+g4p+QHvnW0RyTX/SVeJBDM/gCtMARO0cLuT2HcEKnTPvhjV6aGeqrCB/sbNop0Kszm0jsaWU4A==", + "dev": true + }, + "espurify": { + "version": "1.8.1", + "resolved": "https://registry.npmjs.org/espurify/-/espurify-1.8.1.tgz", + "integrity": "sha512-ZDko6eY/o+D/gHCWyHTU85mKDgYcS4FJj7S+YD6WIInm7GQ6AnOjmcL4+buFV/JOztVLELi/7MmuGU5NHta0Mg==", + "dev": true, + "requires": { + "core-js": "^2.0.0" + } + }, + "estraverse": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/estraverse/-/estraverse-4.3.0.tgz", + "integrity": "sha512-39nnKffWz8xN1BU/2c79n9nB9HDzo0niYUqx6xyqUnyoAnQyyWpOTdZEeiCch8BBu515t4wp9ZmgVfVhn9EBpw==", + "dev": true + }, + "esutils": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/esutils/-/esutils-2.0.3.tgz", + "integrity": "sha512-kVscqXk4OCp68SZ0dkgEKVi6/8ij300KBWTJq32P/dYeWTSwK41WyTxalN1eRmA5Z9UU/LX9D7FWSmV9SAYx6g==", + "dev": true + }, + "execa": { + "version": "0.7.0", + "resolved": "https://registry.npmjs.org/execa/-/execa-0.7.0.tgz", + "integrity": "sha1-lEvs00zEHuMqY6n68nrVpl/Fl3c=", + "dev": true, + "requires": { + "cross-spawn": "^5.0.1", + "get-stream": "^3.0.0", + "is-stream": "^1.1.0", + "npm-run-path": "^2.0.0", + "p-finally": "^1.0.0", + "signal-exit": "^3.0.0", + "strip-eof": "^1.0.0" + } + }, + "fast-diff": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/fast-diff/-/fast-diff-1.2.0.tgz", + "integrity": "sha512-xJuoT5+L99XlZ8twedaRf6Ax2TgQVxvgZOYoPKqZufmJib0tL2tegPBOZb1pVNgIhlqDlA0eO0c3wBvQcmzx4w==", + "dev": true + }, + "fast-glob": { + "version": "3.0.4", + "resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.0.4.tgz", + "integrity": "sha512-wkIbV6qg37xTJwqSsdnIphL1e+LaGz4AIQqr00mIubMaEhv1/HEmJ0uuCGZRNRUkZZmOB5mJKO0ZUTVq+SxMQg==", + "dev": true, + "requires": { + "@nodelib/fs.stat": "^2.0.1", + "@nodelib/fs.walk": "^1.2.1", + "glob-parent": "^5.0.0", + "is-glob": "^4.0.1", + "merge2": "^1.2.3", + "micromatch": "^4.0.2" + } + }, + "fastq": { + "version": "1.6.0", + "resolved": "https://registry.npmjs.org/fastq/-/fastq-1.6.0.tgz", + "integrity": "sha512-jmxqQ3Z/nXoeyDmWAzF9kH1aGZSis6e/SbfPmJpUnyZ0ogr6iscHQaml4wsEepEWSdtmpy+eVXmCRIMpxaXqOA==", + "dev": true, + "requires": { + "reusify": "^1.0.0" + } + }, + "figures": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/figures/-/figures-3.0.0.tgz", + "integrity": "sha512-HKri+WoWoUgr83pehn/SIgLOMZ9nAWC6dcGj26RY2R4F50u4+RTUz0RCrUlOV3nKRAICW1UGzyb+kcX2qK1S/g==", + "dev": true, + "requires": { + "escape-string-regexp": "^1.0.5" + }, + "dependencies": { + "escape-string-regexp": { + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/escape-string-regexp/-/escape-string-regexp-1.0.5.tgz", + "integrity": "sha1-G2HAViGQqN/2rjuyzwIAyhMLhtQ=", + "dev": true + } + } + }, + "fill-range": { + "version": "7.0.1", + "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.0.1.tgz", + "integrity": "sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ==", + "dev": true, + "requires": { + "to-regex-range": "^5.0.1" + } + }, + "find-up": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/find-up/-/find-up-4.1.0.tgz", + "integrity": "sha512-PpOwAdQ/YlXQ2vj8a3h8IipDuYRi3wceVQQGYWxNINccq40Anw7BlsEXCMbt1Zt+OLA6Fq9suIpIWD0OsnISlw==", + "dev": true, + "requires": { + "locate-path": "^5.0.0", + "path-exists": "^4.0.0" + } + }, + "fs.realpath": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/fs.realpath/-/fs.realpath-1.0.0.tgz", + "integrity": "sha1-FQStJSMVjKpA20onh8sBQRmU6k8=", + "dev": true + }, + "fsevents": { + "version": "2.0.7", + "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.0.7.tgz", + "integrity": "sha512-a7YT0SV3RB+DjYcppwVDLtn13UQnmg0SWZS7ezZD0UjnLwXmy8Zm21GMVGLaFGimIqcvyMQaOJBrop8MyOp1kQ==", + "dev": true, + "optional": true + }, + "function-bind": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/function-bind/-/function-bind-1.1.1.tgz", + "integrity": "sha512-yIovAzMX49sF8Yl58fSCWJ5svSLuaibPxXQJFLmBObTuCr0Mf1KiPopGM9NiFjiYBCbfaa2Fh6breQ6ANVTI0A==", + "dev": true + }, + "get-port": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/get-port/-/get-port-5.0.0.tgz", + "integrity": "sha512-imzMU0FjsZqNa6BqOjbbW6w5BivHIuQKopjpPqcnx0AVHJQKCxK1O+Ab3OrVXhrekqfVMjwA9ZYu062R+KcIsQ==", + "dev": true, + "requires": { + "type-fest": "^0.3.0" + }, + "dependencies": { + "type-fest": { + "version": "0.3.1", + "resolved": "https://registry.npmjs.org/type-fest/-/type-fest-0.3.1.tgz", + "integrity": "sha512-cUGJnCdr4STbePCgqNFbpVNCepa+kAVohJs1sLhxzdH+gnEoOd8VhbYa7pD3zZYGiURWM2xzEII3fQcRizDkYQ==", + "dev": true + } + } + }, + "get-stream": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/get-stream/-/get-stream-3.0.0.tgz", + "integrity": "sha1-jpQ9E1jcN1VQVOy+LtsFqhdO3hQ=", + "dev": true + }, + "glob": { + "version": "7.1.4", + "resolved": "https://registry.npmjs.org/glob/-/glob-7.1.4.tgz", + "integrity": "sha512-hkLPepehmnKk41pUGm3sYxoFs/umurYfYJCerbXEyFIWcAzvpipAgVkBqqT9RBKMGjnq6kMuyYwha6csxbiM1A==", + "dev": true, + "requires": { + "fs.realpath": "^1.0.0", + "inflight": "^1.0.4", + "inherits": "2", + "minimatch": "^3.0.4", + "once": "^1.3.0", + "path-is-absolute": "^1.0.0" + } + }, + "glob-parent": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-5.0.0.tgz", + "integrity": "sha512-Z2RwiujPRGluePM6j699ktJYxmPpJKCfpGA13jz2hmFZC7gKetzrWvg5KN3+OsIFmydGyZ1AVwERCq1w/ZZwRg==", + "dev": true, + "requires": { + "is-glob": "^4.0.1" + } + }, + "global-dirs": { + "version": "0.1.1", + "resolved": "https://registry.npmjs.org/global-dirs/-/global-dirs-0.1.1.tgz", + "integrity": "sha1-sxnA3UYH81PzvpzKTHL8FIxJ9EU=", + "dev": true, + "requires": { + "ini": "^1.3.4" + } + }, + "globals": { + "version": "11.12.0", + "resolved": "https://registry.npmjs.org/globals/-/globals-11.12.0.tgz", + "integrity": "sha512-WOBp/EEGUiIsJSp7wcv/y6MO+lV9UoncWqxuFfm8eBwzWNgyfBd6Gz+IeKQ9jCmyhoH99g15M3T+QaVHFjizVA==", + "dev": true + }, + "globby": { + "version": "10.0.1", + "resolved": "https://registry.npmjs.org/globby/-/globby-10.0.1.tgz", + "integrity": "sha512-sSs4inE1FB2YQiymcmTv6NWENryABjUNPeWhOvmn4SjtKybglsyPZxFB3U1/+L1bYi0rNZDqCLlHyLYDl1Pq5A==", + "dev": true, + "requires": { + "@types/glob": "^7.1.1", + "array-union": "^2.1.0", + "dir-glob": "^3.0.1", + "fast-glob": "^3.0.3", + "glob": "^7.1.3", + "ignore": "^5.1.1", + "merge2": "^1.2.3", + "slash": "^3.0.0" + } + }, + "got": { + "version": "9.6.0", + "resolved": "https://registry.npmjs.org/got/-/got-9.6.0.tgz", + "integrity": "sha512-R7eWptXuGYxwijs0eV+v3o6+XH1IqVK8dJOEecQfTmkncw9AV4dcw/Dhxi8MdlqPthxxpZyizMzyg8RTmEsG+Q==", + "dev": true, + "requires": { + "@sindresorhus/is": "^0.14.0", + "@szmarczak/http-timer": "^1.1.2", + "cacheable-request": "^6.0.0", + "decompress-response": "^3.3.0", + "duplexer3": "^0.1.4", + "get-stream": "^4.1.0", + "lowercase-keys": "^1.0.1", + "mimic-response": "^1.0.1", + "p-cancelable": "^1.0.0", + "to-readable-stream": "^1.0.0", + "url-parse-lax": "^3.0.0" + }, + "dependencies": { + "get-stream": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/get-stream/-/get-stream-4.1.0.tgz", + "integrity": "sha512-GMat4EJ5161kIy2HevLlr4luNjBgvmj413KaQA7jt4V8B4RDsfpHk7WQ9GVqfYyyx8OS/L66Kox+rJRNklLK7w==", + "dev": true, + "requires": { + "pump": "^3.0.0" + } + } + } + }, + "graceful-fs": { + "version": "4.2.2", + "resolved": "https://registry.npmjs.org/graceful-fs/-/graceful-fs-4.2.2.tgz", + "integrity": "sha512-IItsdsea19BoLC7ELy13q1iJFNmd7ofZH5+X/pJr90/nRoPEX0DJo1dHDbgtYWOhJhcCgMDTOw84RZ72q6lB+Q==", + "dev": true + }, + "has": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/has/-/has-1.0.3.tgz", + "integrity": "sha512-f2dvO0VU6Oej7RkWJGrehjbzMAjFp5/VKPp5tTpWIV4JHHZK1/BxbFRtf/siA2SWTe09caDmVtYYzWEIbBS4zw==", + "dev": true, + "requires": { + "function-bind": "^1.1.1" + } + }, + "has-flag": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-3.0.0.tgz", + "integrity": "sha1-tdRU3CGZriJWmfNGfloH87lVuv0=", + "dev": true + }, + "has-symbols": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/has-symbols/-/has-symbols-1.0.0.tgz", + "integrity": "sha1-uhqPGvKg/DllD1yFA2dwQSIGO0Q=", + "dev": true + }, + "has-yarn": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/has-yarn/-/has-yarn-2.1.0.tgz", + "integrity": "sha512-UqBRqi4ju7T+TqGNdqAO0PaSVGsDGJUBQvk9eUWNGRY1CFGDzYhLWoM7JQEemnlvVcv/YEmc2wNW8BC24EnUsw==", + "dev": true + }, + "hasha": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/hasha/-/hasha-5.0.0.tgz", + "integrity": "sha512-PqWdhnQhq6tqD32hZv+l1e5mJHNSudjnaAzgAHfkGiU0ABN6lmbZF8abJIulQHbZ7oiHhP8yL6O910ICMc+5pw==", + "dev": true, + "requires": { + "is-stream": "^1.1.0", + "type-fest": "^0.3.0" + }, + "dependencies": { + "type-fest": { + "version": "0.3.1", + "resolved": "https://registry.npmjs.org/type-fest/-/type-fest-0.3.1.tgz", + "integrity": "sha512-cUGJnCdr4STbePCgqNFbpVNCepa+kAVohJs1sLhxzdH+gnEoOd8VhbYa7pD3zZYGiURWM2xzEII3fQcRizDkYQ==", + "dev": true + } + } + }, + "hosted-git-info": { + "version": "2.8.4", + "resolved": "https://registry.npmjs.org/hosted-git-info/-/hosted-git-info-2.8.4.tgz", + "integrity": "sha512-pzXIvANXEFrc5oFFXRMkbLPQ2rXRoDERwDLyrcUxGhaZhgP54BBSl9Oheh7Vv0T090cszWBxPjkQQ5Sq1PbBRQ==", + "dev": true + }, + "http-cache-semantics": { + "version": "4.0.3", + "resolved": "https://registry.npmjs.org/http-cache-semantics/-/http-cache-semantics-4.0.3.tgz", + "integrity": "sha512-TcIMG3qeVLgDr1TEd2XvHaTnMPwYQUQMIBLy+5pLSDKYFc7UIqj39w8EGzZkaxoLv/l2K8HaI0t5AVA+YYgUew==", + "dev": true + }, + "ignore": { + "version": "5.1.4", + "resolved": "https://registry.npmjs.org/ignore/-/ignore-5.1.4.tgz", + "integrity": "sha512-MzbUSahkTW1u7JpKKjY7LCARd1fU5W2rLdxlM4kdkayuCwZImjkpluF9CM1aLewYJguPDqewLam18Y6AU69A8A==", + "dev": true + }, + "ignore-by-default": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/ignore-by-default/-/ignore-by-default-1.0.1.tgz", + "integrity": "sha1-SMptcvbGo68Aqa1K5odr44ieKwk=", + "dev": true + }, + "import-lazy": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/import-lazy/-/import-lazy-2.1.0.tgz", + "integrity": "sha1-BWmOPUXIjo1+nZLLBYTnfwlvPkM=", + "dev": true + }, + "import-local": { + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/import-local/-/import-local-3.0.2.tgz", + "integrity": "sha512-vjL3+w0oulAVZ0hBHnxa/Nm5TAurf9YLQJDhqRZyqb+VKGOB6LU8t9H1Nr5CIo16vh9XfJTOoHwU0B71S557gA==", + "dev": true, + "requires": { + "pkg-dir": "^4.2.0", + "resolve-cwd": "^3.0.0" + } + }, + "imurmurhash": { + "version": "0.1.4", + "resolved": "https://registry.npmjs.org/imurmurhash/-/imurmurhash-0.1.4.tgz", + "integrity": "sha1-khi5srkoojixPcT7a21XbyMUU+o=", + "dev": true + }, + "indent-string": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/indent-string/-/indent-string-4.0.0.tgz", + "integrity": "sha512-EdDDZu4A2OyIK7Lr/2zG+w5jmbuk1DVBnEwREQvBzspBJkCEbRa8GxU1lghYcaGJCnRWibjDXlq779X1/y5xwg==", + "dev": true + }, + "inflight": { + "version": "1.0.6", + "resolved": "https://registry.npmjs.org/inflight/-/inflight-1.0.6.tgz", + "integrity": "sha1-Sb1jMdfQLQwJvJEKEHW6gWW1bfk=", + "dev": true, + "requires": { + "once": "^1.3.0", + "wrappy": "1" + } + }, + "inherits": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.3.tgz", + "integrity": "sha1-Yzwsg+PaQqUC9SRmAiSA9CCCYd4=" + }, + "ini": { + "version": "1.3.5", + "resolved": "https://registry.npmjs.org/ini/-/ini-1.3.5.tgz", + "integrity": "sha512-RZY5huIKCMRWDUqZlEi72f/lmXKMvuszcMBduliQ3nnWbx9X/ZBQO7DijMEYS9EhHBb2qacRUMtC7svLwe0lcw==", + "dev": true + }, + "irregular-plurals": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/irregular-plurals/-/irregular-plurals-2.0.0.tgz", + "integrity": "sha512-Y75zBYLkh0lJ9qxeHlMjQ7bSbyiSqNW/UOPWDmzC7cXskL1hekSITh1Oc6JV0XCWWZ9DE8VYSB71xocLk3gmGw==", + "dev": true + }, + "is-arguments": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/is-arguments/-/is-arguments-1.0.4.tgz", + "integrity": "sha512-xPh0Rmt8NE65sNzvyUmWgI1tz3mKq74lGA0mL8LYZcoIzKOzDh6HmrYm3d18k60nHerC8A9Km8kYu87zfSFnLA==", + "dev": true + }, + "is-arrayish": { + "version": "0.2.1", + "resolved": "https://registry.npmjs.org/is-arrayish/-/is-arrayish-0.2.1.tgz", + "integrity": "sha1-d8mYQFJ6qOyxqLppe4BkWnqSap0=", + "dev": true + }, + "is-binary-path": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/is-binary-path/-/is-binary-path-2.1.0.tgz", + "integrity": "sha512-ZMERYes6pDydyuGidse7OsHxtbI7WVeUEozgR/g7rd0xUimYNlvZRE/K2MgZTjWy725IfelLeVcEM97mmtRGXw==", + "dev": true, + "requires": { + "binary-extensions": "^2.0.0" + } + }, + "is-ci": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/is-ci/-/is-ci-2.0.0.tgz", + "integrity": "sha512-YfJT7rkpQB0updsdHLGWrvhBJfcfzNNawYDNIyQXJz0IViGf75O8EBPKSdvw2rF+LGCsX4FZ8tcr3b19LcZq4w==", + "dev": true, + "requires": { + "ci-info": "^2.0.0" + } + }, + "is-date-object": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/is-date-object/-/is-date-object-1.0.1.tgz", + "integrity": "sha1-mqIOtq7rv/d/vTPnTKAbM1gdOhY=", + "dev": true + }, + "is-error": { + "version": "2.2.2", + "resolved": "https://registry.npmjs.org/is-error/-/is-error-2.2.2.tgz", + "integrity": "sha512-IOQqts/aHWbiisY5DuPJQ0gcbvaLFCa7fBa9xoLfxBZvQ+ZI/Zh9xoI7Gk+G64N0FdK4AbibytHht2tWgpJWLg==", + "dev": true + }, + "is-extglob": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/is-extglob/-/is-extglob-2.1.1.tgz", + "integrity": "sha1-qIwCU1eR8C7TfHahueqXc8gz+MI=", + "dev": true + }, + "is-fullwidth-code-point": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/is-fullwidth-code-point/-/is-fullwidth-code-point-2.0.0.tgz", + "integrity": "sha1-o7MKXE8ZkYMWeqq5O+764937ZU8=", + "dev": true + }, + "is-glob": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/is-glob/-/is-glob-4.0.1.tgz", + "integrity": "sha512-5G0tKtBTFImOqDnLB2hG6Bp2qcKEFduo4tZu9MT/H6NQv/ghhy30o55ufafxJ/LdH79LLs2Kfrn85TLKyA7BUg==", + "dev": true, + "requires": { + "is-extglob": "^2.1.1" + } + }, + "is-installed-globally": { + "version": "0.1.0", + "resolved": "https://registry.npmjs.org/is-installed-globally/-/is-installed-globally-0.1.0.tgz", + "integrity": "sha1-Df2Y9akRFxbdU13aZJL2e/PSWoA=", + "dev": true, + "requires": { + "global-dirs": "^0.1.0", + "is-path-inside": "^1.0.0" + }, + "dependencies": { + "is-path-inside": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/is-path-inside/-/is-path-inside-1.0.1.tgz", + "integrity": "sha1-jvW33lBDej/cprToZe96pVy0gDY=", + "dev": true, + "requires": { + "path-is-inside": "^1.0.1" + } + } + } + }, + "is-npm": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/is-npm/-/is-npm-3.0.0.tgz", + "integrity": "sha512-wsigDr1Kkschp2opC4G3yA6r9EgVA6NjRpWzIi9axXqeIaAATPRJc4uLujXe3Nd9uO8KoDyA4MD6aZSeXTADhA==", + "dev": true + }, + "is-number": { + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/is-number/-/is-number-7.0.0.tgz", + "integrity": "sha512-41Cifkg6e8TylSpdtTpeLVMqvSBEVzTttHvERD741+pnZ8ANv0004MRL43QKPDlK9cGvNp6NZWZUBlbGXYxxng==", + "dev": true + }, + "is-obj": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/is-obj/-/is-obj-2.0.0.tgz", + "integrity": "sha512-drqDG3cbczxxEJRoOXcOjtdp1J/lyp1mNn0xaznRs8+muBhgQcrnbspox5X5fOw0HnMnbfDzvnEMEtqDEJEo8w==", + "dev": true + }, + "is-observable": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/is-observable/-/is-observable-2.0.0.tgz", + "integrity": "sha512-fhBZv3eFKUbyHXZ1oHujdo2tZ+CNbdpdzzlENgCGZUC8keoGxUew2jYFLYcUB4qo7LDD03o4KK11m/QYD7kEjg==", + "dev": true + }, + "is-path-cwd": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/is-path-cwd/-/is-path-cwd-2.2.0.tgz", + "integrity": "sha512-w942bTcih8fdJPJmQHFzkS76NEP8Kzzvmw92cXsazb8intwLqPibPPdXf4ANdKV3rYMuuQYGIWtvz9JilB3NFQ==", + "dev": true + }, + "is-path-in-cwd": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/is-path-in-cwd/-/is-path-in-cwd-2.1.0.tgz", + "integrity": "sha512-rNocXHgipO+rvnP6dk3zI20RpOtrAM/kzbB258Uw5BWr3TpXi861yzjo16Dn4hUox07iw5AyeMLHWsujkjzvRQ==", + "dev": true, + "requires": { + "is-path-inside": "^2.1.0" + } + }, + "is-path-inside": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/is-path-inside/-/is-path-inside-2.1.0.tgz", + "integrity": "sha512-wiyhTzfDWsvwAW53OBWF5zuvaOGlZ6PwYxAbPVDhpm+gM09xKQGjBq/8uYN12aDvMxnAnq3dxTyoSoRNmg5YFg==", + "dev": true, + "requires": { + "path-is-inside": "^1.0.2" + } + }, + "is-plain-obj": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/is-plain-obj/-/is-plain-obj-1.1.0.tgz", + "integrity": "sha1-caUMhCnfync8kqOQpKA7OfzVHT4=", + "dev": true + }, + "is-plain-object": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/is-plain-object/-/is-plain-object-3.0.0.tgz", + "integrity": "sha512-tZIpofR+P05k8Aocp7UI/2UTa9lTJSebCXpFFoR9aibpokDj/uXBsJ8luUu0tTVYKkMU6URDUuOfJZ7koewXvg==", + "dev": true, + "requires": { + "isobject": "^4.0.0" + } + }, + "is-promise": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/is-promise/-/is-promise-2.1.0.tgz", + "integrity": "sha1-eaKp7OfwlugPNtKy87wWwf9L8/o=", + "dev": true + }, + "is-regex": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/is-regex/-/is-regex-1.0.4.tgz", + "integrity": "sha1-VRdIm1RwkbCTDglWVM7SXul+lJE=", + "dev": true, + "requires": { + "has": "^1.0.1" + } + }, + "is-stream": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/is-stream/-/is-stream-1.1.0.tgz", + "integrity": "sha1-EtSj3U5o4Lec6428hBc66A2RykQ=", + "dev": true + }, + "is-typedarray": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/is-typedarray/-/is-typedarray-1.0.0.tgz", + "integrity": "sha1-5HnICFjfDBsR3dppQPlgEfzaSpo=", + "dev": true + }, + "is-url": { + "version": "1.2.4", + "resolved": "https://registry.npmjs.org/is-url/-/is-url-1.2.4.tgz", + "integrity": "sha512-ITvGim8FhRiYe4IQ5uHSkj7pVaPDrCTkNd3yq3cV7iZAcJdHTUMPMEHcqSOy9xZ9qFenQCvi+2wjH9a1nXqHww==", + "dev": true + }, + "is-utf8": { + "version": "0.2.1", + "resolved": "https://registry.npmjs.org/is-utf8/-/is-utf8-0.2.1.tgz", + "integrity": "sha1-Sw2hRCEE0bM2NA6AeX6GXPOffXI=", + "dev": true + }, + "is-yarn-global": { + "version": "0.3.0", + "resolved": "https://registry.npmjs.org/is-yarn-global/-/is-yarn-global-0.3.0.tgz", + "integrity": "sha512-VjSeb/lHmkoyd8ryPVIKvOCn4D1koMqY+vqyjjUfc3xyKtP4dYOxM44sZrnqQSzSds3xyOrUTLTC9LVCVgLngw==", + "dev": true + }, + "isarray": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/isarray/-/isarray-1.0.0.tgz", + "integrity": "sha1-u5NdSFgsuhaMBoNJV6VKPgcSTxE=" + }, + "isexe": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/isexe/-/isexe-2.0.0.tgz", + "integrity": "sha1-6PvzdNxVb/iUehDcsFctYz8s+hA=", + "dev": true + }, + "isobject": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/isobject/-/isobject-4.0.0.tgz", + "integrity": "sha512-S/2fF5wH8SJA/kmwr6HYhK/RI/OkhD84k8ntalo0iJjZikgq1XFvR5M8NPT1x5F7fBwCG3qHfnzeP/Vh/ZxCUA==", + "dev": true + }, + "js-string-escape": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/js-string-escape/-/js-string-escape-1.0.1.tgz", + "integrity": "sha1-4mJbrbwNZ8dTPp7cEGjFh65BN+8=", + "dev": true + }, + "js-tokens": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/js-tokens/-/js-tokens-4.0.0.tgz", + "integrity": "sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ==", + "dev": true + }, + "js-yaml": { + "version": "3.13.1", + "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-3.13.1.tgz", + "integrity": "sha512-YfbcO7jXDdyj0DGxYVSlSeQNHbD7XPWvrVWeVUujrQEoZzWJIRrCPoyk6kL6IAjAG2IolMK4T0hNUe0HOUs5Jw==", + "dev": true, + "requires": { + "argparse": "^1.0.7", + "esprima": "^4.0.0" + } + }, + "jsesc": { + "version": "2.5.2", + "resolved": "https://registry.npmjs.org/jsesc/-/jsesc-2.5.2.tgz", + "integrity": "sha512-OYu7XEzjkCQ3C5Ps3QIZsQfNpqoJyZZA99wd9aWd05NCtC5pWOkShK2mkL6HXQR6/Cy2lbNdPlZBpuQHXE63gA==", + "dev": true + }, + "json-buffer": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/json-buffer/-/json-buffer-3.0.0.tgz", + "integrity": "sha1-Wx85evx11ne96Lz8Dkfh+aPZqJg=", + "dev": true + }, + "json-parse-better-errors": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/json-parse-better-errors/-/json-parse-better-errors-1.0.2.tgz", + "integrity": "sha512-mrqyZKfX5EhL7hvqcV6WG1yYjnjeuYDzDhhcAAUrq8Po85NBQBJP+ZDUT75qZQ98IkUoBqdkExkukOU7Ts2wrw==", + "dev": true + }, + "json5": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/json5/-/json5-2.1.0.tgz", + "integrity": "sha512-8Mh9h6xViijj36g7Dxi+Y4S6hNGV96vcJZr/SrlHh1LR/pEn/8j/+qIBbs44YKl69Lrfctp4QD+AdWLTMqEZAQ==", + "dev": true, + "requires": { + "minimist": "^1.2.0" + } + }, + "keyv": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/keyv/-/keyv-3.1.0.tgz", + "integrity": "sha512-9ykJ/46SN/9KPM/sichzQ7OvXyGDYKGTaDlKMGCAlg2UK8KRy4jb0d8sFc+0Tt0YYnThq8X2RZgCg74RPxgcVA==", + "dev": true, + "requires": { + "json-buffer": "3.0.0" + } + }, + "latest-version": { + "version": "5.1.0", + "resolved": "https://registry.npmjs.org/latest-version/-/latest-version-5.1.0.tgz", + "integrity": "sha512-weT+r0kTkRQdCdYCNtkMwWXQTMEswKrFBkm4ckQOMVhhqhIMI1UT2hMj+1iigIhgSZm5gTmrRXBNoGUgaTY1xA==", + "dev": true, + "requires": { + "package-json": "^6.3.0" + } + }, + "load-json-file": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/load-json-file/-/load-json-file-4.0.0.tgz", + "integrity": "sha1-L19Fq5HjMhYjT9U62rZo607AmTs=", + "dev": true, + "requires": { + "graceful-fs": "^4.1.2", + "parse-json": "^4.0.0", + "pify": "^3.0.0", + "strip-bom": "^3.0.0" + }, + "dependencies": { + "pify": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/pify/-/pify-3.0.0.tgz", + "integrity": "sha1-5aSs0sEB/fPZpNB/DbxNtJ3SgXY=", + "dev": true + } + } + }, + "locate-path": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/locate-path/-/locate-path-5.0.0.tgz", + "integrity": "sha512-t7hw9pI+WvuwNJXwk5zVHpyhIqzg2qTlklJOf0mVxGSbe3Fp2VieZcduNYjaLDoy6p9uGpQEGWG87WpMKlNq8g==", + "dev": true, + "requires": { + "p-locate": "^4.1.0" + } + }, + "lodash": { + "version": "4.17.15", + "resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.15.tgz", + "integrity": "sha512-8xOcRHvCjnocdS5cpwXQXVzmmh5e5+saE2QGoeQmbKmRS6J3VQppPOIt0MnmE+4xlZoumy0GPG0D0MVIQbNA1A==", + "dev": true + }, + "lodash.clonedeep": { + "version": "4.5.0", + "resolved": "https://registry.npmjs.org/lodash.clonedeep/-/lodash.clonedeep-4.5.0.tgz", + "integrity": "sha1-4j8/nE+Pvd6HJSnBBxhXoIblzO8=", + "dev": true + }, + "lodash.flattendeep": { + "version": "4.4.0", + "resolved": "https://registry.npmjs.org/lodash.flattendeep/-/lodash.flattendeep-4.4.0.tgz", + "integrity": "sha1-+wMJF/hqMTTlvJvsDWngAT3f7bI=", + "dev": true + }, + "lodash.islength": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/lodash.islength/-/lodash.islength-4.0.1.tgz", + "integrity": "sha1-Tpho1FJXXXUK/9NYyXlUPcIO1Xc=", + "dev": true + }, + "lodash.merge": { + "version": "4.6.2", + "resolved": "https://registry.npmjs.org/lodash.merge/-/lodash.merge-4.6.2.tgz", + "integrity": "sha512-0KpjqXRVvrYyCsX1swR/XTK0va6VQkQM6MNo7PqW77ByjAhoARA8EfrP1N4+KlKj8YS0ZUCtRT/YUuhyYDujIQ==", + "dev": true + }, + "log-symbols": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/log-symbols/-/log-symbols-2.2.0.tgz", + "integrity": "sha512-VeIAFslyIerEJLXHziedo2basKbMKtTw3vfn5IzG0XTjhAVEJyNHnL2p7vc+wBDSdQuUpNw3M2u6xb9QsAY5Eg==", + "dev": true, + "requires": { + "chalk": "^2.0.1" + } + }, + "loud-rejection": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/loud-rejection/-/loud-rejection-2.1.0.tgz", + "integrity": "sha512-g/6MQxUXYHeVqZ4PGpPL1fS1fOvlXoi7bay0pizmjAd/3JhyXwxzwrnr74yzdmhuerlslbRJ3x7IOXzFz0cE5w==", + "dev": true, + "requires": { + "currently-unhandled": "^0.4.1", + "signal-exit": "^3.0.2" + } + }, + "lowercase-keys": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/lowercase-keys/-/lowercase-keys-1.0.1.tgz", + "integrity": "sha512-G2Lj61tXDnVFFOi8VZds+SoQjtQC3dgokKdDG2mTm1tx4m50NUHBOZSBwQQHyy0V12A0JTG4icfZQH+xPyh8VA==", + "dev": true + }, + "lru-cache": { + "version": "4.1.5", + "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-4.1.5.tgz", + "integrity": "sha512-sWZlbEP2OsHNkXrMl5GYk/jKk70MBng6UU4YI/qGDYbgf6YbP4EvmqISbXCoJiRKs+1bSpFHVgQxvJ17F2li5g==", + "dev": true, + "requires": { + "pseudomap": "^1.0.2", + "yallist": "^2.1.2" + } + }, + "make-dir": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/make-dir/-/make-dir-3.0.0.tgz", + "integrity": "sha512-grNJDhb8b1Jm1qeqW5R/O63wUo4UXo2v2HMic6YT9i/HBlF93S8jkMgH7yugvY9ABDShH4VZMn8I+U8+fCNegw==", + "dev": true, + "requires": { + "semver": "^6.0.0" + }, + "dependencies": { + "semver": { + "version": "6.3.0", + "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.0.tgz", + "integrity": "sha512-b39TBaTSfV6yBrapU89p5fKekE2m/NwnDocOVruQFS1/veMgdzuPcnOM34M6CwxW8jH/lxEa5rBoDeUwu5HHTw==", + "dev": true + } + } + }, + "map-obj": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/map-obj/-/map-obj-2.0.0.tgz", + "integrity": "sha1-plzSkIepJZi4eRJXpSPgISIqwfk=", + "dev": true + }, + "matcher": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/matcher/-/matcher-2.0.0.tgz", + "integrity": "sha512-nlmfSlgHBFx36j/Pl/KQPbIaqE8Zf0TqmSMjsuddHDg6PMSVgmyW9HpkLs0o0M1n2GIZ/S2BZBLIww/xjhiGng==", + "dev": true, + "requires": { + "escape-string-regexp": "^2.0.0" + } + }, + "md5-hex": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/md5-hex/-/md5-hex-3.0.1.tgz", + "integrity": "sha512-BUiRtTtV39LIJwinWBjqVsU9xhdnz7/i889V859IBFpuqGAj6LuOvHv5XLbgZ2R7ptJoJaEcxkv88/h25T7Ciw==", + "dev": true, + "requires": { + "blueimp-md5": "^2.10.0" + } + }, + "md5-o-matic": { + "version": "0.1.1", + "resolved": "https://registry.npmjs.org/md5-o-matic/-/md5-o-matic-0.1.1.tgz", + "integrity": "sha1-givM1l4RfFFPqxdrJZRdVBAKA8M=", + "dev": true + }, + "meow": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/meow/-/meow-5.0.0.tgz", + "integrity": "sha512-CbTqYU17ABaLefO8vCU153ZZlprKYWDljcndKKDCFcYQITzWCXZAVk4QMFZPgvzrnUQ3uItnIE/LoUOwrT15Ig==", + "dev": true, + "requires": { + "camelcase-keys": "^4.0.0", + "decamelize-keys": "^1.0.0", + "loud-rejection": "^1.0.0", + "minimist-options": "^3.0.1", + "normalize-package-data": "^2.3.4", + "read-pkg-up": "^3.0.0", + "redent": "^2.0.0", + "trim-newlines": "^2.0.0", + "yargs-parser": "^10.0.0" + }, + "dependencies": { + "loud-rejection": { + "version": "1.6.0", + "resolved": "https://registry.npmjs.org/loud-rejection/-/loud-rejection-1.6.0.tgz", + "integrity": "sha1-W0b4AUft7leIcPCG0Eghz5mOVR8=", + "dev": true, + "requires": { + "currently-unhandled": "^0.4.1", + "signal-exit": "^3.0.0" + } + } + } + }, + "merge2": { + "version": "1.2.4", + "resolved": "https://registry.npmjs.org/merge2/-/merge2-1.2.4.tgz", + "integrity": "sha512-FYE8xI+6pjFOhokZu0We3S5NKCirLbCzSh2Usf3qEyr4X8U+0jNg9P8RZ4qz+V2UoECLVwSyzU3LxXBaLGtD3A==", + "dev": true + }, + "micromatch": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.2.tgz", + "integrity": "sha512-y7FpHSbMUMoyPbYUSzO6PaZ6FyRnQOpHuKwbo1G+Knck95XVU4QAiKdGEnj5wwoS7PlOgthX/09u5iFJ+aYf5Q==", + "dev": true, + "requires": { + "braces": "^3.0.1", + "picomatch": "^2.0.5" + } + }, + "mimic-fn": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/mimic-fn/-/mimic-fn-2.1.0.tgz", + "integrity": "sha512-OqbOk5oEQeAZ8WXWydlu9HJjz9WVdEIvamMCcXmuqUYjTknH/sqsWvhQ3vgwKFRR1HpjvNBKQ37nbJgYzGqGcg==", + "dev": true + }, + "mimic-response": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/mimic-response/-/mimic-response-1.0.1.tgz", + "integrity": "sha512-j5EctnkH7amfV/q5Hgmoal1g2QHFJRraOtmx0JpIqkxhBhI/lJSl1nMpQ45hVarwNETOoWEimndZ4QK0RHxuxQ==", + "dev": true + }, + "minimatch": { + "version": "3.0.4", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.0.4.tgz", + "integrity": "sha512-yJHVQEhyqPLUTgt9B83PXu6W3rx4MvvHvSUvToogpwoGDOUQ+yDrR0HRot+yOCdCO7u4hX3pWft6kWBBcqh0UA==", + "dev": true, + "requires": { + "brace-expansion": "^1.1.7" + } + }, + "minimist": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/minimist/-/minimist-1.2.0.tgz", + "integrity": "sha1-o1AIsg9BOD7sH7kU9M1d95omQoQ=", + "dev": true + }, + "minimist-options": { + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/minimist-options/-/minimist-options-3.0.2.tgz", + "integrity": "sha512-FyBrT/d0d4+uiZRbqznPXqw3IpZZG3gl3wKWiX784FycUKVwBt0uLBFkQrtE4tZOrgo78nZp2jnKz3L65T5LdQ==", + "dev": true, + "requires": { + "arrify": "^1.0.1", + "is-plain-obj": "^1.1.0" + }, + "dependencies": { + "arrify": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/arrify/-/arrify-1.0.1.tgz", + "integrity": "sha1-iYUI2iIm84DfkEcoRWhJwVAaSw0=", + "dev": true + } + } + }, + "mkdirp": { + "version": "0.5.1", + "resolved": "https://registry.npmjs.org/mkdirp/-/mkdirp-0.5.1.tgz", + "integrity": "sha1-MAV0OOrGz3+MR2fzhkjWaX11yQM=", + "dev": true, + "requires": { + "minimist": "0.0.8" + }, + "dependencies": { + "minimist": { + "version": "0.0.8", + "resolved": "https://registry.npmjs.org/minimist/-/minimist-0.0.8.tgz", + "integrity": "sha1-hX/Kv8M5fSYluCKCYuhqp6ARsF0=", + "dev": true + } + } + }, + "ms": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz", + "integrity": "sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==", + "dev": true + }, + "mysql": { + "version": "2.17.1", + "resolved": "https://registry.npmjs.org/mysql/-/mysql-2.17.1.tgz", + "integrity": "sha512-7vMqHQ673SAk5C8fOzTG2LpPcf3bNt0oL3sFpxPEEFp1mdlDcrLK0On7z8ZYKaaHrHwNcQ/MTUz7/oobZ2OyyA==", + "requires": { + "bignumber.js": "7.2.1", + "readable-stream": "2.3.6", + "safe-buffer": "5.1.2", + "sqlstring": "2.3.1" + }, + "dependencies": { + "bignumber.js": { + "version": "7.2.1", + "resolved": "https://registry.npmjs.org/bignumber.js/-/bignumber.js-7.2.1.tgz", + "integrity": "sha512-S4XzBk5sMB+Rcb/LNcpzXr57VRTxgAvaAEDAl1AwRx27j00hT84O6OkteE7u8UB3NuaaygCRrEpqox4uDOrbdQ==" + } + } + }, + "normalize-package-data": { + "version": "2.5.0", + "resolved": "https://registry.npmjs.org/normalize-package-data/-/normalize-package-data-2.5.0.tgz", + "integrity": "sha512-/5CMN3T0R4XTj4DcGaexo+roZSdSFW/0AOOTROrjxzCG1wrWXEsGbRKevjlIL+ZDE4sZlJr5ED4YW0yqmkK+eA==", + "dev": true, + "requires": { + "hosted-git-info": "^2.1.4", + "resolve": "^1.10.0", + "semver": "2 || 3 || 4 || 5", + "validate-npm-package-license": "^3.0.1" + } + }, + "normalize-path": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/normalize-path/-/normalize-path-3.0.0.tgz", + "integrity": "sha512-6eZs5Ls3WtCisHWp9S2GUy8dqkpGi4BVSz3GaqiE6ezub0512ESztXUwUB6C6IKbQkY2Pnb/mD4WYojCRwcwLA==", + "dev": true + }, + "normalize-url": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/normalize-url/-/normalize-url-4.3.0.tgz", + "integrity": "sha512-0NLtR71o4k6GLP+mr6Ty34c5GA6CMoEsncKJxvQd8NzPxaHRJNnb5gZE8R1XF4CPIS7QPHLJ74IFszwtNVAHVQ==", + "dev": true + }, + "npm-run-path": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/npm-run-path/-/npm-run-path-2.0.2.tgz", + "integrity": "sha1-NakjLfo11wZ7TLLd8jV7GHFTbF8=", + "dev": true, + "requires": { + "path-key": "^2.0.0" + } + }, + "object-assign": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/object-assign/-/object-assign-4.1.1.tgz", + "integrity": "sha1-IQmtx5ZYh8/AXLvUQsrIv7s2CGM=", + "dev": true + }, + "object-is": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/object-is/-/object-is-1.0.1.tgz", + "integrity": "sha1-CqYOyZiaCz7Xlc9NBvYs8a1lObY=", + "dev": true + }, + "object-keys": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/object-keys/-/object-keys-1.1.1.tgz", + "integrity": "sha512-NuAESUOUMrlIXOfHKzD6bpPu3tYt3xvjNdRIQ+FeT0lNb4K8WR70CaDxhuNguS2XG+GjkyMwOzsN5ZktImfhLA==", + "dev": true + }, + "object.assign": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/object.assign/-/object.assign-4.1.0.tgz", + "integrity": "sha512-exHJeq6kBKj58mqGyTQ9DFvrZC/eR6OwxzoM9YRoGBqrXYonaFyGiFMuc9VZrXf7DarreEwMpurG3dd+CNyW5w==", + "dev": true, + "requires": { + "define-properties": "^1.1.2", + "function-bind": "^1.1.1", + "has-symbols": "^1.0.0", + "object-keys": "^1.0.11" + } + }, + "observable-to-promise": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/observable-to-promise/-/observable-to-promise-1.0.0.tgz", + "integrity": "sha512-cqnGUrNsE6vdVDTPAX9/WeVzwy/z37vdxupdQXU8vgTXRFH72KCZiZga8aca2ulRPIeem8W3vW9rQHBwfIl2WA==", + "dev": true, + "requires": { + "is-observable": "^2.0.0", + "symbol-observable": "^1.0.4" + } + }, + "once": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/once/-/once-1.4.0.tgz", + "integrity": "sha1-WDsap3WWHUsROsF9nFC6753Xa9E=", + "dev": true, + "requires": { + "wrappy": "1" + } + }, + "onetime": { + "version": "5.1.0", + "resolved": "https://registry.npmjs.org/onetime/-/onetime-5.1.0.tgz", + "integrity": "sha512-5NcSkPHhwTVFIQN+TUqXoS5+dlElHXdpAWu9I0HP20YOtIi+aZ0Ct82jdlILDxjLEAWwvm+qj1m6aEtsDVmm6Q==", + "dev": true, + "requires": { + "mimic-fn": "^2.1.0" + } + }, + "ora": { + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/ora/-/ora-3.4.0.tgz", + "integrity": "sha512-eNwHudNbO1folBP3JsZ19v9azXWtQZjICdr3Q0TDPIaeBQ3mXLrh54wM+er0+hSp+dWKf+Z8KM58CYzEyIYxYg==", + "dev": true, + "requires": { + "chalk": "^2.4.2", + "cli-cursor": "^2.1.0", + "cli-spinners": "^2.0.0", + "log-symbols": "^2.2.0", + "strip-ansi": "^5.2.0", + "wcwidth": "^1.0.1" + }, + "dependencies": { + "cli-cursor": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/cli-cursor/-/cli-cursor-2.1.0.tgz", + "integrity": "sha1-s12sN2R5+sw+lHR9QdDQ9SOP/LU=", + "dev": true, + "requires": { + "restore-cursor": "^2.0.0" + } + }, + "mimic-fn": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/mimic-fn/-/mimic-fn-1.2.0.tgz", + "integrity": "sha512-jf84uxzwiuiIVKiOLpfYk7N46TSy8ubTonmneY9vrpHNAnp0QBt2BxWV9dO3/j+BoVAb+a5G6YDPW3M5HOdMWQ==", + "dev": true + }, + "onetime": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/onetime/-/onetime-2.0.1.tgz", + "integrity": "sha1-BnQoIw/WdEOyeUsiu6UotoZ5YtQ=", + "dev": true, + "requires": { + "mimic-fn": "^1.0.0" + } + }, + "restore-cursor": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/restore-cursor/-/restore-cursor-2.0.0.tgz", + "integrity": "sha1-n37ih/gv0ybU/RYpI9YhKe7g368=", + "dev": true, + "requires": { + "onetime": "^2.0.0", + "signal-exit": "^3.0.2" + } + } + } + }, + "os-tmpdir": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/os-tmpdir/-/os-tmpdir-1.0.2.tgz", + "integrity": "sha1-u+Z0BseaqFxc/sdm/lc0VV36EnQ=", + "dev": true + }, + "p-cancelable": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/p-cancelable/-/p-cancelable-1.1.0.tgz", + "integrity": "sha512-s73XxOZ4zpt1edZYZzvhqFa6uvQc1vwUa0K0BdtIZgQMAJj9IbebH+JkgKZc9h+B05PKHLOTl4ajG1BmNrVZlw==", + "dev": true + }, + "p-finally": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/p-finally/-/p-finally-1.0.0.tgz", + "integrity": "sha1-P7z7FbiZpEEjs0ttzBi3JDNqLK4=", + "dev": true + }, + "p-limit": { + "version": "2.2.1", + "resolved": "https://registry.npmjs.org/p-limit/-/p-limit-2.2.1.tgz", + "integrity": "sha512-85Tk+90UCVWvbDavCLKPOLC9vvY8OwEX/RtKF+/1OADJMVlFfEHOiMTPVyxg7mk/dKa+ipdHm0OUkTvCpMTuwg==", + "dev": true, + "requires": { + "p-try": "^2.0.0" + } + }, + "p-locate": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/p-locate/-/p-locate-4.1.0.tgz", + "integrity": "sha512-R79ZZ/0wAxKGu3oYMlz8jy/kbhsNrS7SKZ7PxEHBgJ5+F2mtFW2fK2cOtBh1cHYkQsbzFV7I+EoRKe6Yt0oK7A==", + "dev": true, + "requires": { + "p-limit": "^2.2.0" + } + }, + "p-map": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/p-map/-/p-map-2.1.0.tgz", + "integrity": "sha512-y3b8Kpd8OAN444hxfBbFfj1FY/RjtTd8tzYwhUqNYXx0fXx2iX4maP4Qr6qhIKbQXI02wTLAda4fYUbDagTUFw==", + "dev": true + }, + "p-try": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/p-try/-/p-try-2.2.0.tgz", + "integrity": "sha512-R4nPAVTAU0B9D35/Gk3uJf/7XYbQcyohSKdvAxIRSNghFl4e71hVoGnBNQz9cWaXxO2I10KTC+3jMdvvoKw6dQ==", + "dev": true + }, + "package-hash": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/package-hash/-/package-hash-4.0.0.tgz", + "integrity": "sha512-whdkPIooSu/bASggZ96BWVvZTRMOFxnyUG5PnTSGKoJE2gd5mbVNmR2Nj20QFzxYYgAXpoqC+AiXzl+UMRh7zQ==", + "dev": true, + "requires": { + "graceful-fs": "^4.1.15", + "hasha": "^5.0.0", + "lodash.flattendeep": "^4.4.0", + "release-zalgo": "^1.0.0" + } + }, + "package-json": { + "version": "6.5.0", + "resolved": "https://registry.npmjs.org/package-json/-/package-json-6.5.0.tgz", + "integrity": "sha512-k3bdm2n25tkyxcjSKzB5x8kfVxlMdgsbPr0GkZcwHsLpba6cBjqCt1KlcChKEvxHIcTB1FVMuwoijZ26xex5MQ==", + "dev": true, + "requires": { + "got": "^9.6.0", + "registry-auth-token": "^4.0.0", + "registry-url": "^5.0.0", + "semver": "^6.2.0" + }, + "dependencies": { + "semver": { + "version": "6.3.0", + "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.0.tgz", + "integrity": "sha512-b39TBaTSfV6yBrapU89p5fKekE2m/NwnDocOVruQFS1/veMgdzuPcnOM34M6CwxW8jH/lxEa5rBoDeUwu5HHTw==", + "dev": true + } + } + }, + "parse-json": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/parse-json/-/parse-json-4.0.0.tgz", + "integrity": "sha1-vjX1Qlvh9/bHRxhPmKeIy5lHfuA=", + "dev": true, + "requires": { + "error-ex": "^1.3.1", + "json-parse-better-errors": "^1.0.1" + } + }, + "parse-ms": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/parse-ms/-/parse-ms-2.1.0.tgz", + "integrity": "sha512-kHt7kzLoS9VBZfUsiKjv43mr91ea+U05EyKkEtqp7vNbHxmaVuEqN7XxeEVnGrMtYOAxGrDElSi96K7EgO1zCA==", + "dev": true + }, + "path-exists": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/path-exists/-/path-exists-4.0.0.tgz", + "integrity": "sha512-ak9Qy5Q7jYb2Wwcey5Fpvg2KoAc/ZIhLSLOSBmRmygPsGwkVVt0fZa0qrtMz+m6tJTAHfZQ8FnmB4MG4LWy7/w==", + "dev": true + }, + "path-is-absolute": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/path-is-absolute/-/path-is-absolute-1.0.1.tgz", + "integrity": "sha1-F0uSaHNVNP+8es5r9TpanhtcX18=", + "dev": true + }, + "path-is-inside": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/path-is-inside/-/path-is-inside-1.0.2.tgz", + "integrity": "sha1-NlQX3t5EQw0cEa9hAn+s8HS9/FM=", + "dev": true + }, + "path-key": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/path-key/-/path-key-2.0.1.tgz", + "integrity": "sha1-QRyttXTFoUDTpLGRDUDYDMn0C0A=", + "dev": true + }, + "path-parse": { + "version": "1.0.6", + "resolved": "https://registry.npmjs.org/path-parse/-/path-parse-1.0.6.tgz", + "integrity": "sha512-GSmOT2EbHrINBf9SR7CDELwlJ8AENk3Qn7OikK4nFYAu3Ote2+JYNVvkpAEQm3/TLNEJFD/xZJjzyxg3KBWOzw==", + "dev": true + }, + "path-type": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/path-type/-/path-type-4.0.0.tgz", + "integrity": "sha512-gDKb8aZMDeD/tZWs9P6+q0J9Mwkdl6xMV8TjnGP3qJVJ06bdMgkbBlLU8IdfOsIsFz2BW1rNVT3XuNEl8zPAvw==", + "dev": true + }, + "picomatch": { + "version": "2.0.7", + "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.0.7.tgz", + "integrity": "sha512-oLHIdio3tZ0qH76NybpeneBhYVj0QFTfXEFTc/B3zKQspYfYYkWYgFsmzo+4kvId/bQRcNkVeguI3y+CD22BtA==", + "dev": true + }, + "pify": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/pify/-/pify-4.0.1.tgz", + "integrity": "sha512-uB80kBFb/tfd68bVleG9T5GGsGPjJrLAUpR5PZIrhBnIaRTQRjqdJSsIKkOP6OAIFbj7GOrcudc5pNjZ+geV2g==", + "dev": true + }, + "pinkie": { + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/pinkie/-/pinkie-2.0.4.tgz", + "integrity": "sha1-clVrgM+g1IqXToDnckjoDtT3+HA=", + "dev": true + }, + "pinkie-promise": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/pinkie-promise/-/pinkie-promise-2.0.1.tgz", + "integrity": "sha1-ITXW36ejWMBprJsXh3YogihFD/o=", + "dev": true, + "requires": { + "pinkie": "^2.0.0" + } + }, + "pkg-conf": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/pkg-conf/-/pkg-conf-3.1.0.tgz", + "integrity": "sha512-m0OTbR/5VPNPqO1ph6Fqbj7Hv6QU7gR/tQW40ZqrL1rjgCU85W6C1bJn0BItuJqnR98PWzw7Z8hHeChD1WrgdQ==", + "dev": true, + "requires": { + "find-up": "^3.0.0", + "load-json-file": "^5.2.0" + }, + "dependencies": { + "find-up": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/find-up/-/find-up-3.0.0.tgz", + "integrity": "sha512-1yD6RmLI1XBfxugvORwlck6f75tYL+iR0jqwsOrOxMZyGYqUuDhJ0l4AXdO1iX/FTs9cBAMEk1gWSEx1kSbylg==", + "dev": true, + "requires": { + "locate-path": "^3.0.0" + } + }, + "load-json-file": { + "version": "5.3.0", + "resolved": "https://registry.npmjs.org/load-json-file/-/load-json-file-5.3.0.tgz", + "integrity": "sha512-cJGP40Jc/VXUsp8/OrnyKyTZ1y6v/dphm3bioS+RrKXjK2BB6wHUd6JptZEFDGgGahMT+InnZO5i1Ei9mpC8Bw==", + "dev": true, + "requires": { + "graceful-fs": "^4.1.15", + "parse-json": "^4.0.0", + "pify": "^4.0.1", + "strip-bom": "^3.0.0", + "type-fest": "^0.3.0" + } + }, + "locate-path": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/locate-path/-/locate-path-3.0.0.tgz", + "integrity": "sha512-7AO748wWnIhNqAuaty2ZWHkQHRSNfPVIsPIfwEOWO22AmaoVrWavlOcMR5nzTLNYvp36X220/maaRsrec1G65A==", + "dev": true, + "requires": { + "p-locate": "^3.0.0", + "path-exists": "^3.0.0" + } + }, + "p-locate": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/p-locate/-/p-locate-3.0.0.tgz", + "integrity": "sha512-x+12w/To+4GFfgJhBEpiDcLozRJGegY+Ei7/z0tSLkMmxGZNybVMSfWj9aJn8Z5Fc7dBUNJOOVgPv2H7IwulSQ==", + "dev": true, + "requires": { + "p-limit": "^2.0.0" + } + }, + "path-exists": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/path-exists/-/path-exists-3.0.0.tgz", + "integrity": "sha1-zg6+ql94yxiSXqfYENe1mwEP1RU=", + "dev": true + }, + "type-fest": { + "version": "0.3.1", + "resolved": "https://registry.npmjs.org/type-fest/-/type-fest-0.3.1.tgz", + "integrity": "sha512-cUGJnCdr4STbePCgqNFbpVNCepa+kAVohJs1sLhxzdH+gnEoOd8VhbYa7pD3zZYGiURWM2xzEII3fQcRizDkYQ==", + "dev": true + } + } + }, + "pkg-dir": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/pkg-dir/-/pkg-dir-4.2.0.tgz", + "integrity": "sha512-HRDzbaKjC+AOWVXxAU/x54COGeIv9eb+6CkDSQoNTt4XyWoIJvuPsXizxu/Fr23EiekbtZwmh1IcIG/l/a10GQ==", + "dev": true, + "requires": { + "find-up": "^4.0.0" + } + }, + "plur": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/plur/-/plur-3.1.1.tgz", + "integrity": "sha512-t1Ax8KUvV3FFII8ltczPn2tJdjqbd1sIzu6t4JL7nQ3EyeL/lTrj5PWKb06ic5/6XYDr65rQ4uzQEGN70/6X5w==", + "dev": true, + "requires": { + "irregular-plurals": "^2.0.0" + } + }, + "prepend-http": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/prepend-http/-/prepend-http-2.0.0.tgz", + "integrity": "sha1-6SQ0v6XqjBn0HN/UAddBo8gZ2Jc=", + "dev": true + }, + "pretty-ms": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/pretty-ms/-/pretty-ms-5.0.0.tgz", + "integrity": "sha512-94VRYjL9k33RzfKiGokPBPpsmloBYSf5Ri+Pq19zlsEcUKFob+admeXr5eFDRuPjFmEOcjJvPGdillYOJyvZ7Q==", + "dev": true, + "requires": { + "parse-ms": "^2.1.0" + } + }, + "process-nextick-args": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/process-nextick-args/-/process-nextick-args-2.0.1.tgz", + "integrity": "sha512-3ouUOpQhtgrbOa17J7+uxOTpITYWaGP7/AhoR3+A+/1e9skrzelGi/dXzEYyvbxubEF6Wn2ypscTKiKJFFn1ag==" + }, + "pseudomap": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/pseudomap/-/pseudomap-1.0.2.tgz", + "integrity": "sha1-8FKijacOYYkX7wqKw0wa5aaChrM=", + "dev": true + }, + "pump": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/pump/-/pump-3.0.0.tgz", + "integrity": "sha512-LwZy+p3SFs1Pytd/jYct4wpv49HiYCqd9Rlc5ZVdk0V+8Yzv6jR5Blk3TRmPL1ft69TxP0IMZGJ+WPFU2BFhww==", + "dev": true, + "requires": { + "end-of-stream": "^1.1.0", + "once": "^1.3.1" + } + }, + "quick-lru": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/quick-lru/-/quick-lru-1.1.0.tgz", + "integrity": "sha1-Q2CxfGETatOAeDl/8RQW4Ybc+7g=", + "dev": true + }, + "rc": { + "version": "1.2.8", + "resolved": "https://registry.npmjs.org/rc/-/rc-1.2.8.tgz", + "integrity": "sha512-y3bGgqKj3QBdxLbLkomlohkvsA8gdAiUQlSBJnBhfn+BPxg4bc62d8TcBW15wavDfgexCgccckhcZvywyQYPOw==", + "dev": true, + "requires": { + "deep-extend": "^0.6.0", + "ini": "~1.3.0", + "minimist": "^1.2.0", + "strip-json-comments": "~2.0.1" + } + }, + "read-pkg": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/read-pkg/-/read-pkg-3.0.0.tgz", + "integrity": "sha1-nLxoaXj+5l0WwA4rGcI3/Pbjg4k=", + "dev": true, + "requires": { + "load-json-file": "^4.0.0", + "normalize-package-data": "^2.3.2", + "path-type": "^3.0.0" + }, + "dependencies": { + "path-type": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/path-type/-/path-type-3.0.0.tgz", + "integrity": "sha512-T2ZUsdZFHgA3u4e5PfPbjd7HDDpxPnQb5jN0SrDsjNSuVXHJqtwTnWqG0B1jZrgmJ/7lj1EmVIByWt1gxGkWvg==", + "dev": true, + "requires": { + "pify": "^3.0.0" + } + }, + "pify": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/pify/-/pify-3.0.0.tgz", + "integrity": "sha1-5aSs0sEB/fPZpNB/DbxNtJ3SgXY=", + "dev": true + } + } + }, + "read-pkg-up": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/read-pkg-up/-/read-pkg-up-3.0.0.tgz", + "integrity": "sha1-PtSWaF26D4/hGNBpHcUfSh/5bwc=", + "dev": true, + "requires": { + "find-up": "^2.0.0", + "read-pkg": "^3.0.0" + }, + "dependencies": { + "find-up": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/find-up/-/find-up-2.1.0.tgz", + "integrity": "sha1-RdG35QbHF93UgndaK3eSCjwMV6c=", + "dev": true, + "requires": { + "locate-path": "^2.0.0" + } + }, + "locate-path": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/locate-path/-/locate-path-2.0.0.tgz", + "integrity": "sha1-K1aLJl7slExtnA3pw9u7ygNUzY4=", + "dev": true, + "requires": { + "p-locate": "^2.0.0", + "path-exists": "^3.0.0" + } + }, + "p-limit": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/p-limit/-/p-limit-1.3.0.tgz", + "integrity": "sha512-vvcXsLAJ9Dr5rQOPk7toZQZJApBl2K4J6dANSsEuh6QI41JYcsS/qhTGa9ErIUUgK3WNQoJYvylxvjqmiqEA9Q==", + "dev": true, + "requires": { + "p-try": "^1.0.0" + } + }, + "p-locate": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/p-locate/-/p-locate-2.0.0.tgz", + "integrity": "sha1-IKAQOyIqcMj9OcwuWAaA893l7EM=", + "dev": true, + "requires": { + "p-limit": "^1.1.0" + } + }, + "p-try": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/p-try/-/p-try-1.0.0.tgz", + "integrity": "sha1-y8ec26+P1CKOE/Yh8rGiN8GyB7M=", + "dev": true + }, + "path-exists": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/path-exists/-/path-exists-3.0.0.tgz", + "integrity": "sha1-zg6+ql94yxiSXqfYENe1mwEP1RU=", + "dev": true + } + } + }, + "readable-stream": { + "version": "2.3.6", + "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-2.3.6.tgz", + "integrity": "sha512-tQtKA9WIAhBF3+VLAseyMqZeBjW0AHJoxOtYqSUZNJxauErmLbVm2FW1y+J/YA9dUrAC39ITejlZWhVIwawkKw==", + "requires": { + "core-util-is": "~1.0.0", + "inherits": "~2.0.3", + "isarray": "~1.0.0", + "process-nextick-args": "~2.0.0", + "safe-buffer": "~5.1.1", + "string_decoder": "~1.1.1", + "util-deprecate": "~1.0.1" + } + }, + "readdirp": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/readdirp/-/readdirp-3.1.2.tgz", + "integrity": "sha512-8rhl0xs2cxfVsqzreYCvs8EwBfn/DhVdqtoLmw19uI3SC5avYX9teCurlErfpPXGmYtMHReGaP2RsLnFvz/lnw==", + "dev": true, + "requires": { + "picomatch": "^2.0.4" + } + }, + "redent": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/redent/-/redent-2.0.0.tgz", + "integrity": "sha1-wbIAe0LVfrE4kHmzyDM2OdXhzKo=", + "dev": true, + "requires": { + "indent-string": "^3.0.0", + "strip-indent": "^2.0.0" + }, + "dependencies": { + "indent-string": { + "version": "3.2.0", + "resolved": "https://registry.npmjs.org/indent-string/-/indent-string-3.2.0.tgz", + "integrity": "sha1-Sl/W0nzDMvN+VBmlBNu4NxBckok=", + "dev": true + } + } + }, + "regenerate": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/regenerate/-/regenerate-1.4.0.tgz", + "integrity": "sha512-1G6jJVDWrt0rK99kBjvEtziZNCICAuvIPkSiUFIQxVP06RCVpq3dmDo2oi6ABpYaDYaTRr67BEhL8r1wgEZZKg==", + "dev": true + }, + "regenerate-unicode-properties": { + "version": "8.1.0", + "resolved": "https://registry.npmjs.org/regenerate-unicode-properties/-/regenerate-unicode-properties-8.1.0.tgz", + "integrity": "sha512-LGZzkgtLY79GeXLm8Dp0BVLdQlWICzBnJz/ipWUgo59qBaZ+BHtq51P2q1uVZlppMuUAT37SDk39qUbjTWB7bA==", + "dev": true, + "requires": { + "regenerate": "^1.4.0" + } + }, + "regexp.prototype.flags": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/regexp.prototype.flags/-/regexp.prototype.flags-1.2.0.tgz", + "integrity": "sha512-ztaw4M1VqgMwl9HlPpOuiYgItcHlunW0He2fE6eNfT6E/CF2FtYi9ofOYe4mKntstYk0Fyh/rDRBdS3AnxjlrA==", + "dev": true, + "requires": { + "define-properties": "^1.1.2" + } + }, + "regexpu-core": { + "version": "4.5.5", + "resolved": "https://registry.npmjs.org/regexpu-core/-/regexpu-core-4.5.5.tgz", + "integrity": "sha512-FpI67+ky9J+cDizQUJlIlNZFKual/lUkFr1AG6zOCpwZ9cLrg8UUVakyUQJD7fCDIe9Z2nwTQJNPyonatNmDFQ==", + "dev": true, + "requires": { + "regenerate": "^1.4.0", + "regenerate-unicode-properties": "^8.1.0", + "regjsgen": "^0.5.0", + "regjsparser": "^0.6.0", + "unicode-match-property-ecmascript": "^1.0.4", + "unicode-match-property-value-ecmascript": "^1.1.0" + } + }, + "registry-auth-token": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/registry-auth-token/-/registry-auth-token-4.0.0.tgz", + "integrity": "sha512-lpQkHxd9UL6tb3k/aHAVfnVtn+Bcs9ob5InuFLLEDqSqeq+AljB8GZW9xY0x7F+xYwEcjKe07nyoxzEYz6yvkw==", + "dev": true, + "requires": { + "rc": "^1.2.8", + "safe-buffer": "^5.0.1" + } + }, + "registry-url": { + "version": "5.1.0", + "resolved": "https://registry.npmjs.org/registry-url/-/registry-url-5.1.0.tgz", + "integrity": "sha512-8acYXXTI0AkQv6RAOjE3vOaIXZkT9wo4LOFbBKYQEEnnMNBpKqdUrI6S4NT0KPIo/WVvJ5tE/X5LF/TQUf0ekw==", + "dev": true, + "requires": { + "rc": "^1.2.8" + } + }, + "regjsgen": { + "version": "0.5.0", + "resolved": "https://registry.npmjs.org/regjsgen/-/regjsgen-0.5.0.tgz", + "integrity": "sha512-RnIrLhrXCX5ow/E5/Mh2O4e/oa1/jW0eaBKTSy3LaCj+M3Bqvm97GWDp2yUtzIs4LEn65zR2yiYGFqb2ApnzDA==", + "dev": true + }, + "regjsparser": { + "version": "0.6.0", + "resolved": "https://registry.npmjs.org/regjsparser/-/regjsparser-0.6.0.tgz", + "integrity": "sha512-RQ7YyokLiQBomUJuUG8iGVvkgOLxwyZM8k6d3q5SAXpg4r5TZJZigKFvC6PpD+qQ98bCDC5YelPeA3EucDoNeQ==", + "dev": true, + "requires": { + "jsesc": "~0.5.0" + }, + "dependencies": { + "jsesc": { + "version": "0.5.0", + "resolved": "https://registry.npmjs.org/jsesc/-/jsesc-0.5.0.tgz", + "integrity": "sha1-597mbjXW/Bb3EP6R1c9p9w8IkR0=", + "dev": true + } + } + }, + "release-zalgo": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/release-zalgo/-/release-zalgo-1.0.0.tgz", + "integrity": "sha1-CXALflB0Mpc5Mw5TXFqQ+2eFFzA=", + "dev": true, + "requires": { + "es6-error": "^4.0.1" + } + }, + "require-precompiled": { + "version": "0.1.0", + "resolved": "https://registry.npmjs.org/require-precompiled/-/require-precompiled-0.1.0.tgz", + "integrity": "sha1-WhtS63Dr7UPrmC6XTIWrWVceVvo=", + "dev": true + }, + "resolve": { + "version": "1.12.0", + "resolved": "https://registry.npmjs.org/resolve/-/resolve-1.12.0.tgz", + "integrity": "sha512-B/dOmuoAik5bKcD6s6nXDCjzUKnaDvdkRyAk6rsmsKLipWj4797iothd7jmmUhWTfinVMU+wc56rYKsit2Qy4w==", + "dev": true, + "requires": { + "path-parse": "^1.0.6" + } + }, + "resolve-cwd": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/resolve-cwd/-/resolve-cwd-3.0.0.tgz", + "integrity": "sha512-OrZaX2Mb+rJCpH/6CpSqt9xFVpN++x01XnN2ie9g6P5/3xelLAkXWVADpdz1IHD/KFfEXyE6V0U01OQ3UO2rEg==", + "dev": true, + "requires": { + "resolve-from": "^5.0.0" + } + }, + "resolve-from": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/resolve-from/-/resolve-from-5.0.0.tgz", + "integrity": "sha512-qYg9KP24dD5qka9J47d0aVky0N+b4fTU89LN9iDnjB5waksiC49rvMB0PrUJQGoTmH50XPiqOvAjDfaijGxYZw==", + "dev": true + }, + "responselike": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/responselike/-/responselike-1.0.2.tgz", + "integrity": "sha1-kYcg7ztjHFZCvgaPFa3lpG9Loec=", + "dev": true, + "requires": { + "lowercase-keys": "^1.0.0" + } + }, + "restore-cursor": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/restore-cursor/-/restore-cursor-3.1.0.tgz", + "integrity": "sha512-l+sSefzHpj5qimhFSE5a8nufZYAM3sBSVMAPtYkmC+4EH2anSGaEMXSD0izRQbu9nfyQ9y5JrVmp7E8oZrUjvA==", + "dev": true, + "requires": { + "onetime": "^5.1.0", + "signal-exit": "^3.0.2" + } + }, + "reusify": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/reusify/-/reusify-1.0.4.tgz", + "integrity": "sha512-U9nH88a3fc/ekCF1l0/UP1IosiuIjyTh7hBvXVMHYgVcfGvt897Xguj2UOLDeI5BG2m7/uwyaLVT6fbtCwTyzw==", + "dev": true + }, + "rimraf": { + "version": "2.7.1", + "resolved": "https://registry.npmjs.org/rimraf/-/rimraf-2.7.1.tgz", + "integrity": "sha512-uWjbaKIK3T1OSVptzX7Nl6PvQ3qAGtKEtVRjRuazjfL3Bx5eI409VZSqgND+4UNnmzLVdPj9FqFJNPqBZFve4w==", + "dev": true, + "requires": { + "glob": "^7.1.3" + } + }, + "run-parallel": { + "version": "1.1.9", + "resolved": "https://registry.npmjs.org/run-parallel/-/run-parallel-1.1.9.tgz", + "integrity": "sha512-DEqnSRTDw/Tc3FXf49zedI638Z9onwUotBMiUFKmrO2sdFKIbXamXGQ3Axd4qgphxKB4kw/qP1w5kTxnfU1B9Q==", + "dev": true + }, + "safe-buffer": { + "version": "5.1.2", + "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.1.2.tgz", + "integrity": "sha512-Gd2UZBJDkXlY7GbJxfsE8/nvKkUEU1G38c1siN6QP6a9PT9MmHB8GnpscSmMJSoF8LOIrt8ud/wPtojys4G6+g==" + }, + "semver": { + "version": "5.7.1", + "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.1.tgz", + "integrity": "sha512-sauaDf/PZdVgrLTNYHRtpXa1iRiKcaebiKQ1BJdpQlWH2lCvexQdX55snPFyK7QzpudqbCI0qXFfOasHdyNDGQ==", + "dev": true + }, + "semver-diff": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/semver-diff/-/semver-diff-2.1.0.tgz", + "integrity": "sha1-S7uEN8jTfksM8aaP1ybsbWRdbTY=", + "dev": true, + "requires": { + "semver": "^5.0.3" + } + }, + "serialize-error": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/serialize-error/-/serialize-error-2.1.0.tgz", + "integrity": "sha1-ULZ51WNc34Rme9yOWa9OW4HV9go=", + "dev": true + }, + "shebang-command": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/shebang-command/-/shebang-command-1.2.0.tgz", + "integrity": "sha1-RKrGW2lbAzmJaMOfNj/uXer98eo=", + "dev": true, + "requires": { + "shebang-regex": "^1.0.0" + } + }, + "shebang-regex": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/shebang-regex/-/shebang-regex-1.0.0.tgz", + "integrity": "sha1-2kL0l0DAtC2yypcoVxyxkMmO/qM=", + "dev": true + }, + "signal-exit": { + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/signal-exit/-/signal-exit-3.0.2.tgz", + "integrity": "sha1-tf3AjxKH6hF4Yo5BXiUTK3NkbG0=", + "dev": true + }, + "slash": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/slash/-/slash-3.0.0.tgz", + "integrity": "sha512-g9Q1haeby36OSStwb4ntCGGGaKsaVSjQ68fBxoQcutl5fS1vuY18H3wSt3jFyFtrkx+Kz0V1G85A4MyAdDMi2Q==", + "dev": true + }, + "slice-ansi": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/slice-ansi/-/slice-ansi-2.1.0.tgz", + "integrity": "sha512-Qu+VC3EwYLldKa1fCxuuvULvSJOKEgk9pi8dZeCVK7TqBfUNTH4sFkk4joj8afVSfAYgJoSOetjx9QWOJ5mYoQ==", + "dev": true, + "requires": { + "ansi-styles": "^3.2.0", + "astral-regex": "^1.0.0", + "is-fullwidth-code-point": "^2.0.0" + }, + "dependencies": { + "ansi-styles": { + "version": "3.2.1", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-3.2.1.tgz", + "integrity": "sha512-VT0ZI6kZRdTh8YyJw3SMbYm/u+NqfsAxEpWO0Pf9sq8/e94WxxOpPKx9FR1FlyCtOVDNOQ+8ntlqFxiRc+r5qA==", + "dev": true, + "requires": { + "color-convert": "^1.9.0" + } + } + } + }, + "source-map": { + "version": "0.5.7", + "resolved": "https://registry.npmjs.org/source-map/-/source-map-0.5.7.tgz", + "integrity": "sha1-igOdLRAh0i0eoUyA2OpGi6LvP8w=", + "dev": true + }, + "source-map-support": { + "version": "0.5.13", + "resolved": "https://registry.npmjs.org/source-map-support/-/source-map-support-0.5.13.tgz", + "integrity": "sha512-SHSKFHadjVA5oR4PPqhtAVdcBWwRYVd6g6cAXnIbRiIwc2EhPrTuKUBdSLvlEKyIP3GCf89fltvcZiP9MMFA1w==", + "dev": true, + "requires": { + "buffer-from": "^1.0.0", + "source-map": "^0.6.0" + }, + "dependencies": { + "source-map": { + "version": "0.6.1", + "resolved": "https://registry.npmjs.org/source-map/-/source-map-0.6.1.tgz", + "integrity": "sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g==", + "dev": true + } + } + }, + "spdx-correct": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/spdx-correct/-/spdx-correct-3.1.0.tgz", + "integrity": "sha512-lr2EZCctC2BNR7j7WzJ2FpDznxky1sjfxvvYEyzxNyb6lZXHODmEoJeFu4JupYlkfha1KZpJyoqiJ7pgA1qq8Q==", + "dev": true, + "requires": { + "spdx-expression-parse": "^3.0.0", + "spdx-license-ids": "^3.0.0" + } + }, + "spdx-exceptions": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/spdx-exceptions/-/spdx-exceptions-2.2.0.tgz", + "integrity": "sha512-2XQACfElKi9SlVb1CYadKDXvoajPgBVPn/gOQLrTvHdElaVhr7ZEbqJaRnJLVNeaI4cMEAgVCeBMKF6MWRDCRA==", + "dev": true + }, + "spdx-expression-parse": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/spdx-expression-parse/-/spdx-expression-parse-3.0.0.tgz", + "integrity": "sha512-Yg6D3XpRD4kkOmTpdgbUiEJFKghJH03fiC1OPll5h/0sO6neh2jqRDVHOQ4o/LMea0tgCkbMgea5ip/e+MkWyg==", + "dev": true, + "requires": { + "spdx-exceptions": "^2.1.0", + "spdx-license-ids": "^3.0.0" + } + }, + "spdx-license-ids": { + "version": "3.0.5", + "resolved": "https://registry.npmjs.org/spdx-license-ids/-/spdx-license-ids-3.0.5.tgz", + "integrity": "sha512-J+FWzZoynJEXGphVIS+XEh3kFSjZX/1i9gFBaWQcB+/tmpe2qUsSBABpcxqxnAxFdiUFEgAX1bjYGQvIZmoz9Q==", + "dev": true + }, + "sprintf-js": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/sprintf-js/-/sprintf-js-1.0.3.tgz", + "integrity": "sha1-BOaSb2YolTVPPdAVIDYzuFcpfiw=", + "dev": true + }, + "sqlstring": { + "version": "2.3.1", + "resolved": "https://registry.npmjs.org/sqlstring/-/sqlstring-2.3.1.tgz", + "integrity": "sha1-R1OT/56RR5rqYtyvDKPRSYOn+0A=" + }, + "stack-utils": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/stack-utils/-/stack-utils-1.0.2.tgz", + "integrity": "sha512-MTX+MeG5U994cazkjd/9KNAapsHnibjMLnfXodlkXw76JEea0UiNzrqidzo1emMwk7w5Qhc9jd4Bn9TBb1MFwA==", + "dev": true + }, + "string-width": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/string-width/-/string-width-4.1.0.tgz", + "integrity": "sha512-NrX+1dVVh+6Y9dnQ19pR0pP4FiEIlUvdTGn8pw6CKTNq5sgib2nIhmUNT5TAmhWmvKr3WcxBcP3E8nWezuipuQ==", + "dev": true, + "requires": { + "emoji-regex": "^8.0.0", + "is-fullwidth-code-point": "^3.0.0", + "strip-ansi": "^5.2.0" + }, + "dependencies": { + "is-fullwidth-code-point": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/is-fullwidth-code-point/-/is-fullwidth-code-point-3.0.0.tgz", + "integrity": "sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg==", + "dev": true + } + } + }, + "string_decoder": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/string_decoder/-/string_decoder-1.1.1.tgz", + "integrity": "sha512-n/ShnvDi6FHbbVfviro+WojiFzv+s8MPMHBczVePfUpDJLwoLT0ht1l4YwBCbi8pJAveEEdnkHyPyTP/mzRfwg==", + "requires": { + "safe-buffer": "~5.1.0" + } + }, + "strip-ansi": { + "version": "5.2.0", + "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-5.2.0.tgz", + "integrity": "sha512-DuRs1gKbBqsMKIZlrffwlug8MHkcnpjs5VPmL1PAh+mA30U0DTotfDZ0d2UUsXpPmPmMMJ6W773MaA3J+lbiWA==", + "dev": true, + "requires": { + "ansi-regex": "^4.1.0" + } + }, + "strip-bom": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/strip-bom/-/strip-bom-3.0.0.tgz", + "integrity": "sha1-IzTBjpx1n3vdVv3vfprj1YjmjtM=", + "dev": true + }, + "strip-bom-buf": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/strip-bom-buf/-/strip-bom-buf-2.0.0.tgz", + "integrity": "sha512-gLFNHucd6gzb8jMsl5QmZ3QgnUJmp7qn4uUSHNwEXumAp7YizoGYw19ZUVfuq4aBOQUtyn2k8X/CwzWB73W2lQ==", + "dev": true, + "requires": { + "is-utf8": "^0.2.1" + } + }, + "strip-eof": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/strip-eof/-/strip-eof-1.0.0.tgz", + "integrity": "sha1-u0P/VZim6wXYm1n80SnJgzE2Br8=", + "dev": true + }, + "strip-indent": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/strip-indent/-/strip-indent-2.0.0.tgz", + "integrity": "sha1-XvjbKV0B5u1sv3qrlpmNeCJSe2g=", + "dev": true + }, + "strip-json-comments": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/strip-json-comments/-/strip-json-comments-2.0.1.tgz", + "integrity": "sha1-PFMZQukIwml8DsNEhYwobHygpgo=", + "dev": true + }, + "supertap": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/supertap/-/supertap-1.0.0.tgz", + "integrity": "sha512-HZJ3geIMPgVwKk2VsmO5YHqnnJYl6bV5A9JW2uzqV43WmpgliNEYbuvukfor7URpaqpxuw3CfZ3ONdVbZjCgIA==", + "dev": true, + "requires": { + "arrify": "^1.0.1", + "indent-string": "^3.2.0", + "js-yaml": "^3.10.0", + "serialize-error": "^2.1.0", + "strip-ansi": "^4.0.0" + }, + "dependencies": { + "ansi-regex": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-3.0.0.tgz", + "integrity": "sha1-7QMXwyIGT3lGbAKWa922Bas32Zg=", + "dev": true + }, + "arrify": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/arrify/-/arrify-1.0.1.tgz", + "integrity": "sha1-iYUI2iIm84DfkEcoRWhJwVAaSw0=", + "dev": true + }, + "indent-string": { + "version": "3.2.0", + "resolved": "https://registry.npmjs.org/indent-string/-/indent-string-3.2.0.tgz", + "integrity": "sha1-Sl/W0nzDMvN+VBmlBNu4NxBckok=", + "dev": true + }, + "strip-ansi": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-4.0.0.tgz", + "integrity": "sha1-qEeQIusaw2iocTibY1JixQXuNo8=", + "dev": true, + "requires": { + "ansi-regex": "^3.0.0" + } + } + } + }, + "supports-color": { + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.0.0.tgz", + "integrity": "sha512-WRt32iTpYEZWYOpcetGm0NPeSvaebccx7hhS/5M6sAiqnhedtFCHFxkjzZlJvFNCPowiKSFGiZk5USQDFy83vQ==", + "dev": true, + "requires": { + "has-flag": "^4.0.0" + }, + "dependencies": { + "has-flag": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", + "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", + "dev": true + } + } + }, + "symbol-observable": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/symbol-observable/-/symbol-observable-1.2.0.tgz", + "integrity": "sha512-e900nM8RRtGhlV36KGEU9k65K3mPb1WV70OdjfxlG2EAuM1noi/E/BaW/uMhL7bPEssK8QV57vN3esixjUvcXQ==", + "dev": true + }, + "term-size": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/term-size/-/term-size-1.2.0.tgz", + "integrity": "sha1-RYuDiH8oj8Vtb/+/rSYuJmOO+mk=", + "dev": true, + "requires": { + "execa": "^0.7.0" + } + }, + "time-zone": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/time-zone/-/time-zone-1.0.0.tgz", + "integrity": "sha1-mcW/VZWJZq9tBtg73zgA3IL67F0=", + "dev": true + }, + "to-fast-properties": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/to-fast-properties/-/to-fast-properties-2.0.0.tgz", + "integrity": "sha1-3F5pjL0HkmW8c+A3doGk5Og/YW4=", + "dev": true + }, + "to-readable-stream": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/to-readable-stream/-/to-readable-stream-1.0.0.tgz", + "integrity": "sha512-Iq25XBt6zD5npPhlLVXGFN3/gyR2/qODcKNNyTMd4vbm39HUaOiAM4PMq0eMVC/Tkxz+Zjdsc55g9yyz+Yq00Q==", + "dev": true + }, + "to-regex-range": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/to-regex-range/-/to-regex-range-5.0.1.tgz", + "integrity": "sha512-65P7iz6X5yEr1cwcgvQxbbIw7Uk3gOy5dIdtZ4rDveLqhrdJP+Li/Hx6tyK0NEb+2GCyneCMJiGqrADCSNk8sQ==", + "dev": true, + "requires": { + "is-number": "^7.0.0" + } + }, + "trim-newlines": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/trim-newlines/-/trim-newlines-2.0.0.tgz", + "integrity": "sha1-tAPQuRvlDDMd/EuC7s6yLD3hbSA=", + "dev": true + }, + "trim-off-newlines": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/trim-off-newlines/-/trim-off-newlines-1.0.1.tgz", + "integrity": "sha1-n5up2e+odkw4dpi8v+sshI8RrbM=", + "dev": true + }, + "trim-right": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/trim-right/-/trim-right-1.0.1.tgz", + "integrity": "sha1-yy4SAwZ+DI3h9hQJS5/kVwTqYAM=", + "dev": true + }, + "type-fest": { + "version": "0.5.2", + "resolved": "https://registry.npmjs.org/type-fest/-/type-fest-0.5.2.tgz", + "integrity": "sha512-DWkS49EQKVX//Tbupb9TFa19c7+MK1XmzkrZUR8TAktmE/DizXoaoJV6TZ/tSIPXipqNiRI6CyAe7x69Jb6RSw==", + "dev": true + }, + "typedarray-to-buffer": { + "version": "3.1.5", + "resolved": "https://registry.npmjs.org/typedarray-to-buffer/-/typedarray-to-buffer-3.1.5.tgz", + "integrity": "sha512-zdu8XMNEDepKKR+XYOXAVPtWui0ly0NtohUscw+UmaHiAWT8hrV1rr//H6V+0DvJ3OQ19S979M0laLfX8rm82Q==", + "dev": true, + "requires": { + "is-typedarray": "^1.0.0" + } + }, + "uid2": { + "version": "0.0.3", + "resolved": "https://registry.npmjs.org/uid2/-/uid2-0.0.3.tgz", + "integrity": "sha1-SDEm4Rd03y9xuLY53NeZw3YWK4I=", + "dev": true + }, + "unicode-canonical-property-names-ecmascript": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/unicode-canonical-property-names-ecmascript/-/unicode-canonical-property-names-ecmascript-1.0.4.tgz", + "integrity": "sha512-jDrNnXWHd4oHiTZnx/ZG7gtUTVp+gCcTTKr8L0HjlwphROEW3+Him+IpvC+xcJEFegapiMZyZe02CyuOnRmbnQ==", + "dev": true + }, + "unicode-match-property-ecmascript": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/unicode-match-property-ecmascript/-/unicode-match-property-ecmascript-1.0.4.tgz", + "integrity": "sha512-L4Qoh15vTfntsn4P1zqnHulG0LdXgjSO035fEpdtp6YxXhMT51Q6vgM5lYdG/5X3MjS+k/Y9Xw4SFCY9IkR0rg==", + "dev": true, + "requires": { + "unicode-canonical-property-names-ecmascript": "^1.0.4", + "unicode-property-aliases-ecmascript": "^1.0.4" + } + }, + "unicode-match-property-value-ecmascript": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/unicode-match-property-value-ecmascript/-/unicode-match-property-value-ecmascript-1.1.0.tgz", + "integrity": "sha512-hDTHvaBk3RmFzvSl0UVrUmC3PuW9wKVnpoUDYH0JDkSIovzw+J5viQmeYHxVSBptubnr7PbH2e0fnpDRQnQl5g==", + "dev": true + }, + "unicode-property-aliases-ecmascript": { + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/unicode-property-aliases-ecmascript/-/unicode-property-aliases-ecmascript-1.0.5.tgz", + "integrity": "sha512-L5RAqCfXqAwR3RriF8pM0lU0w4Ryf/GgzONwi6KnL1taJQa7x1TCxdJnILX59WIGOwR57IVxn7Nej0fz1Ny6fw==", + "dev": true + }, + "unique-string": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/unique-string/-/unique-string-1.0.0.tgz", + "integrity": "sha1-nhBXzKhRq7kzmPizOuGHuZyuwRo=", + "dev": true, + "requires": { + "crypto-random-string": "^1.0.0" + } + }, + "unique-temp-dir": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/unique-temp-dir/-/unique-temp-dir-1.0.0.tgz", + "integrity": "sha1-bc6VsmgcoAPuv7MEpBX5y6vMU4U=", + "dev": true, + "requires": { + "mkdirp": "^0.5.1", + "os-tmpdir": "^1.0.1", + "uid2": "0.0.3" + } + }, + "update-notifier": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/update-notifier/-/update-notifier-3.0.1.tgz", + "integrity": "sha512-grrmrB6Zb8DUiyDIaeRTBCkgISYUgETNe7NglEbVsrLWXeESnlCSP50WfRSj/GmzMPl6Uchj24S/p80nP/ZQrQ==", + "dev": true, + "requires": { + "boxen": "^3.0.0", + "chalk": "^2.0.1", + "configstore": "^4.0.0", + "has-yarn": "^2.1.0", + "import-lazy": "^2.1.0", + "is-ci": "^2.0.0", + "is-installed-globally": "^0.1.0", + "is-npm": "^3.0.0", + "is-yarn-global": "^0.3.0", + "latest-version": "^5.0.0", + "semver-diff": "^2.0.0", + "xdg-basedir": "^3.0.0" + } + }, + "url-parse-lax": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/url-parse-lax/-/url-parse-lax-3.0.0.tgz", + "integrity": "sha1-FrXK/Afb42dsGxmZF3gj1lA6yww=", + "dev": true, + "requires": { + "prepend-http": "^2.0.0" + } + }, + "util-deprecate": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/util-deprecate/-/util-deprecate-1.0.2.tgz", + "integrity": "sha1-RQ1Nyfpw3nMnYvvS1KKJgUGaDM8=" + }, + "validate-npm-package-license": { + "version": "3.0.4", + "resolved": "https://registry.npmjs.org/validate-npm-package-license/-/validate-npm-package-license-3.0.4.tgz", + "integrity": "sha512-DpKm2Ui/xN7/HQKCtpZxoRWBhZ9Z0kqtygG8XCgNQ8ZlDnxuQmWhj566j8fN4Cu3/JmbhsDo7fcAJq4s9h27Ew==", + "dev": true, + "requires": { + "spdx-correct": "^3.0.0", + "spdx-expression-parse": "^3.0.0" + } + }, + "wcwidth": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/wcwidth/-/wcwidth-1.0.1.tgz", + "integrity": "sha1-8LDc+RW8X/FSivrbLA4XtTLaL+g=", + "dev": true, + "requires": { + "defaults": "^1.0.3" + } + }, + "well-known-symbols": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/well-known-symbols/-/well-known-symbols-2.0.0.tgz", + "integrity": "sha512-ZMjC3ho+KXo0BfJb7JgtQ5IBuvnShdlACNkKkdsqBmYw3bPAaJfPeYUo6tLUaT5tG/Gkh7xkpBhKRQ9e7pyg9Q==", + "dev": true + }, + "which": { + "version": "1.3.1", + "resolved": "https://registry.npmjs.org/which/-/which-1.3.1.tgz", + "integrity": "sha512-HxJdYWq1MTIQbJ3nw0cqssHoTNU267KlrDuGZ1WYlxDStUtKUhOaJmh112/TZmHxxUfuJqPXSOm7tDyas0OSIQ==", + "dev": true, + "requires": { + "isexe": "^2.0.0" + } + }, + "widest-line": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/widest-line/-/widest-line-2.0.1.tgz", + "integrity": "sha512-Ba5m9/Fa4Xt9eb2ELXt77JxVDV8w7qQrH0zS/TWSJdLyAwQjWoOzpzj5lwVftDz6n/EOu3tNACS84v509qwnJA==", + "dev": true, + "requires": { + "string-width": "^2.1.1" + }, + "dependencies": { + "ansi-regex": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-3.0.0.tgz", + "integrity": "sha1-7QMXwyIGT3lGbAKWa922Bas32Zg=", + "dev": true + }, + "string-width": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/string-width/-/string-width-2.1.1.tgz", + "integrity": "sha512-nOqH59deCq9SRHlxq1Aw85Jnt4w6KvLKqWVik6oA9ZklXLNIOlqg4F2yrT1MVaTjAqvVwdfeZ7w7aCvJD7ugkw==", + "dev": true, + "requires": { + "is-fullwidth-code-point": "^2.0.0", + "strip-ansi": "^4.0.0" + } + }, + "strip-ansi": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-4.0.0.tgz", + "integrity": "sha1-qEeQIusaw2iocTibY1JixQXuNo8=", + "dev": true, + "requires": { + "ansi-regex": "^3.0.0" + } + } + } + }, + "wrappy": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/wrappy/-/wrappy-1.0.2.tgz", + "integrity": "sha1-tSQ9jz7BqjXxNkYFvA0QNuMKtp8=", + "dev": true + }, + "write-file-atomic": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/write-file-atomic/-/write-file-atomic-3.0.0.tgz", + "integrity": "sha512-EIgkf60l2oWsffja2Sf2AL384dx328c0B+cIYPTQq5q2rOYuDV00/iPFBOUiDKKwKMOhkymH8AidPaRvzfxY+Q==", + "dev": true, + "requires": { + "imurmurhash": "^0.1.4", + "is-typedarray": "^1.0.0", + "signal-exit": "^3.0.2", + "typedarray-to-buffer": "^3.1.5" + } + }, + "xdg-basedir": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/xdg-basedir/-/xdg-basedir-3.0.0.tgz", + "integrity": "sha1-SWsswQnsqNus/i3HK2A8F8WHCtQ=", + "dev": true + }, + "xtend": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/xtend/-/xtend-4.0.2.tgz", + "integrity": "sha512-LKYU1iAXJXUgAXn9URjiu+MWhyUXHsvfp7mcuYm9dSUKK0/CjtrUwFAxD82/mCWbtLsGjFIad0wIsod4zrTAEQ==", + "dev": true + }, + "yallist": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/yallist/-/yallist-2.1.2.tgz", + "integrity": "sha1-HBH5IY8HYImkfdUS+TxmmaaoHVI=", + "dev": true + }, + "yargs-parser": { + "version": "10.1.0", + "resolved": "https://registry.npmjs.org/yargs-parser/-/yargs-parser-10.1.0.tgz", + "integrity": "sha512-VCIyR1wJoEBZUqk5PA+oOBF6ypbwh5aNB3I50guxAL/quggdfs4TtNHQrSazFA3fYZ+tEqfs0zIGlv0c/rgjbQ==", + "dev": true, + "requires": { + "camelcase": "^4.1.0" + } + } + } +} diff --git a/_integration/javascript/package.json b/_integration/javascript/package.json new file mode 100644 index 000000000..da6670c88 --- /dev/null +++ b/_integration/javascript/package.json @@ -0,0 +1,15 @@ +{ + "name": "go-mysql-server-js", + "version": "1.0.0", + "description": "go-mysql-server JS compat test", + "main": "index.js", + "scripts": { + "test": "./node_modules/.bin/ava" + }, + "dependencies": { + "mysql": "2.17.1" + }, + "devDependencies": { + "ava": "2.3.0" + } +} diff --git a/_integration/javascript/test.js b/_integration/javascript/test.js new file mode 100644 index 000000000..b4b4f0364 --- /dev/null +++ b/_integration/javascript/test.js @@ -0,0 +1,32 @@ +import test from 'ava'; +import mysql from 'mysql'; + +test.cb('can connect to go-mysql-server', t => { + const connection = mysql.createConnection({ + host: '127.0.0.1', + port: 3306, + user: 'root', + password: '', + database: 'mydb' + }); + + connection.connect(); + + const query = 'SELECT name, email FROM mytable ORDER BY name, email'; + const expected = [ + { name: "Evil Bob", email: "evilbob@gmail.com" }, + { name: "Jane Doe", email: "jane@doe.com" }, + { name: "John Doe", email: "john@doe.com" }, + { name: "John Doe", email: "johnalt@doe.com" }, + ]; + + connection.query(query, function (error, results, _) { + if (error) throw error; + + const rows = results.map(r => ({ name: r.name, email: r.email })); + t.deepEqual(rows, expected); + t.end(); + }); + + connection.end(); +}); diff --git a/_integration/jdbc-mariadb/.gitignore b/_integration/jdbc-mariadb/.gitignore new file mode 100644 index 000000000..624041115 --- /dev/null +++ b/_integration/jdbc-mariadb/.gitignore @@ -0,0 +1,3 @@ +*.iml +.idea +target diff --git a/_integration/jdbc-mariadb/Makefile b/_integration/jdbc-mariadb/Makefile new file mode 100644 index 000000000..a5737c29c --- /dev/null +++ b/_integration/jdbc-mariadb/Makefile @@ -0,0 +1,4 @@ +test: + mvn clean test + +.PHONY: test diff --git a/_integration/jdbc-mariadb/pom.xml b/_integration/jdbc-mariadb/pom.xml new file mode 100644 index 000000000..541ed64d5 --- /dev/null +++ b/_integration/jdbc-mariadb/pom.xml @@ -0,0 +1,58 @@ + + + 4.0.0 + + tech.sourced + jdbc-mariadb + 1.0-SNAPSHOT + + + UTF-8 + 1.8 + ${maven.compiler.source} + + 5.3.1 + + + + + org.junit.jupiter + junit-jupiter-api + ${junit.jupiter.version} + test + + + + org.junit.jupiter + junit-jupiter-params + ${junit.jupiter.version} + test + + + + org.junit.jupiter + junit-jupiter-engine + ${junit.jupiter.version} + test + + + + org.mariadb.jdbc + mariadb-java-client + 2.3.0 + + + + + + + + maven-surefire-plugin + 2.22.0 + + + + + \ No newline at end of file diff --git a/_integration/jdbc-mariadb/src/main/java/.gitkeep b/_integration/jdbc-mariadb/src/main/java/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/_integration/jdbc-mariadb/src/test/java/tech/sourced/jdbcmariadb/MySQLTest.java b/_integration/jdbc-mariadb/src/test/java/tech/sourced/jdbcmariadb/MySQLTest.java new file mode 100644 index 000000000..e3f41eb5a --- /dev/null +++ b/_integration/jdbc-mariadb/src/test/java/tech/sourced/jdbcmariadb/MySQLTest.java @@ -0,0 +1,68 @@ +package tech.sourced.jdbcmariadb; + +import org.junit.jupiter.api.Test; + +import java.sql.*; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; + +class MySQLTest { + + @Test + void test() { + String dbUrl = "jdbc:mariadb://127.0.0.1:3306/mydb?user=root&password="; + String query = "SELECT name, email FROM mytable ORDER BY name, email"; + List expected = new ArrayList<>(); + expected.add(new Result("Evil Bob", "evilbob@gmail.com")); + expected.add(new Result("Jane Doe", "jane@doe.com")); + expected.add(new Result("John Doe", "john@doe.com")); + expected.add(new Result("John Doe", "johnalt@doe.com")); + + List result = new ArrayList<>(); + + try (Connection connection = DriverManager.getConnection(dbUrl)) { + try (PreparedStatement stmt = connection.prepareStatement(query)) { + try (ResultSet rs = stmt.executeQuery()) { + while (rs.next()) { + result.add(new Result(rs.getString(1), rs.getString(2))); + } + } + } + } catch (SQLException e) { + fail(e); + } + + assertEquals(expected, result); + } + + class Result { + String name; + String email; + + Result(String name, String email) { + this.name = name; + this.email = email; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Result result = (Result) o; + return Objects.equals(name, result.name) && + Objects.equals(email, result.email); + } + + @Override + public String toString() { + return "Result{" + + "name='" + name + '\'' + + ", email='" + email + '\'' + + '}'; + } + } +} diff --git a/_integration/php/.gitignore b/_integration/php/.gitignore new file mode 100644 index 000000000..22d0d82f8 --- /dev/null +++ b/_integration/php/.gitignore @@ -0,0 +1 @@ +vendor diff --git a/_integration/php/Makefile b/_integration/php/Makefile new file mode 100644 index 000000000..adc95f160 --- /dev/null +++ b/_integration/php/Makefile @@ -0,0 +1,9 @@ +vendor/autoload.php: + composer install + +dependencies: vendor/autoload.php + +test: dependencies + ./vendor/bin/phpunit --bootstrap=vendor/autoload.php tests/MySQLTest + +.PHONY: test \ No newline at end of file diff --git a/_integration/php/composer.json b/_integration/php/composer.json new file mode 100644 index 000000000..5f3db0a0e --- /dev/null +++ b/_integration/php/composer.json @@ -0,0 +1,12 @@ +{ + "name": "go-mysql-server php test", + "description": "go-mysql-server php test for compatibility", + "autoload": { + "classmap": [ + "src/" + ] + }, + "require-dev": { + "phpunit/phpunit": "^7" + } +} \ No newline at end of file diff --git a/_integration/php/composer.lock b/_integration/php/composer.lock new file mode 100644 index 000000000..9227f483d --- /dev/null +++ b/_integration/php/composer.lock @@ -0,0 +1,1426 @@ +{ + "_readme": [ + "This file locks the dependencies of your project to a known state", + "Read more about it at https://getcomposer.org/doc/01-basic-usage.md#installing-dependencies", + "This file is @generated automatically" + ], + "content-hash": "71edcb471c3fab8d1e0dfa5838aa9a72", + "packages": [], + "packages-dev": [ + { + "name": "doctrine/instantiator", + "version": "1.1.0", + "source": { + "type": "git", + "url": "https://github.com/doctrine/instantiator.git", + "reference": "185b8868aa9bf7159f5f953ed5afb2d7fcdc3bda" + }, + "dist": { + "type": "zip", + "url": "https://api.github.com/repos/doctrine/instantiator/zipball/185b8868aa9bf7159f5f953ed5afb2d7fcdc3bda", + "reference": "185b8868aa9bf7159f5f953ed5afb2d7fcdc3bda", + "shasum": "" + }, + "require": { + "php": "^7.1" + }, + "require-dev": { + "athletic/athletic": "~0.1.8", + "ext-pdo": "*", + "ext-phar": "*", + "phpunit/phpunit": "^6.2.3", + "squizlabs/php_codesniffer": "^3.0.2" + }, + "type": "library", + "extra": { + "branch-alias": { + "dev-master": "1.2.x-dev" + } + }, + "autoload": { + "psr-4": { + "Doctrine\\Instantiator\\": "src/Doctrine/Instantiator/" + } + }, + "notification-url": "https://packagist.org/downloads/", + "license": [ + "MIT" + ], + "authors": [ + { + "name": "Marco Pivetta", + "email": "ocramius@gmail.com", + "homepage": "http://ocramius.github.com/" + } + ], + "description": "A small, lightweight utility to instantiate objects in PHP without invoking their constructors", + "homepage": "https://github.com/doctrine/instantiator", + "keywords": [ + "constructor", + "instantiate" + ], + "time": "2017-07-22T11:58:36+00:00" + }, + { + "name": "myclabs/deep-copy", + "version": "1.8.1", + "source": { + "type": "git", + "url": "https://github.com/myclabs/DeepCopy.git", + "reference": "3e01bdad3e18354c3dce54466b7fbe33a9f9f7f8" + }, + "dist": { + "type": "zip", + "url": "https://api.github.com/repos/myclabs/DeepCopy/zipball/3e01bdad3e18354c3dce54466b7fbe33a9f9f7f8", + "reference": "3e01bdad3e18354c3dce54466b7fbe33a9f9f7f8", + "shasum": "" + }, + "require": { + "php": "^7.1" + }, + "replace": { + "myclabs/deep-copy": "self.version" + }, + "require-dev": { + "doctrine/collections": "^1.0", + "doctrine/common": "^2.6", + "phpunit/phpunit": "^7.1" + }, + "type": "library", + "autoload": { + "psr-4": { + "DeepCopy\\": "src/DeepCopy/" + }, + "files": [ + "src/DeepCopy/deep_copy.php" + ] + }, + "notification-url": "https://packagist.org/downloads/", + "license": [ + "MIT" + ], + "description": "Create deep copies (clones) of your objects", + "keywords": [ + "clone", + "copy", + "duplicate", + "object", + "object graph" + ], + "time": "2018-06-11T23:09:50+00:00" + }, + { + "name": "phar-io/manifest", + "version": "1.0.3", + "source": { + "type": "git", + "url": "https://github.com/phar-io/manifest.git", + "reference": "7761fcacf03b4d4f16e7ccb606d4879ca431fcf4" + }, + "dist": { + "type": "zip", + "url": "https://api.github.com/repos/phar-io/manifest/zipball/7761fcacf03b4d4f16e7ccb606d4879ca431fcf4", + "reference": "7761fcacf03b4d4f16e7ccb606d4879ca431fcf4", + "shasum": "" + }, + "require": { + "ext-dom": "*", + "ext-phar": "*", + "phar-io/version": "^2.0", + "php": "^5.6 || ^7.0" + }, + "type": "library", + "extra": { + "branch-alias": { + "dev-master": "1.0.x-dev" + } + }, + "autoload": { + "classmap": [ + "src/" + ] + }, + "notification-url": "https://packagist.org/downloads/", + "license": [ + "BSD-3-Clause" + ], + "authors": [ + { + "name": "Arne Blankerts", + "email": "arne@blankerts.de", + "role": "Developer" + }, + { + "name": "Sebastian Heuer", + "email": "sebastian@phpeople.de", + "role": "Developer" + }, + { + "name": "Sebastian Bergmann", + "email": "sebastian@phpunit.de", + "role": "Developer" + } + ], + "description": "Component for reading phar.io manifest information from a PHP Archive (PHAR)", + "time": "2018-07-08T19:23:20+00:00" + }, + { + "name": "phar-io/version", + "version": "2.0.1", + "source": { + "type": "git", + "url": "https://github.com/phar-io/version.git", + "reference": "45a2ec53a73c70ce41d55cedef9063630abaf1b6" + }, + "dist": { + "type": "zip", + "url": "https://api.github.com/repos/phar-io/version/zipball/45a2ec53a73c70ce41d55cedef9063630abaf1b6", + "reference": "45a2ec53a73c70ce41d55cedef9063630abaf1b6", + "shasum": "" + }, + "require": { + "php": "^5.6 || ^7.0" + }, + "type": "library", + "autoload": { + "classmap": [ + "src/" + ] + }, + "notification-url": "https://packagist.org/downloads/", + "license": [ + "BSD-3-Clause" + ], + "authors": [ + { + "name": "Arne Blankerts", + "email": "arne@blankerts.de", + "role": "Developer" + }, + { + "name": "Sebastian Heuer", + "email": "sebastian@phpeople.de", + "role": "Developer" + }, + { + "name": "Sebastian Bergmann", + "email": "sebastian@phpunit.de", + "role": "Developer" + } + ], + "description": "Library for handling version information and constraints", + "time": "2018-07-08T19:19:57+00:00" + }, + { + "name": "phpdocumentor/reflection-common", + "version": "1.0.1", + "source": { + "type": "git", + "url": "https://github.com/phpDocumentor/ReflectionCommon.git", + "reference": "21bdeb5f65d7ebf9f43b1b25d404f87deab5bfb6" + }, + "dist": { + "type": "zip", + "url": "https://api.github.com/repos/phpDocumentor/ReflectionCommon/zipball/21bdeb5f65d7ebf9f43b1b25d404f87deab5bfb6", + "reference": "21bdeb5f65d7ebf9f43b1b25d404f87deab5bfb6", + "shasum": "" + }, + "require": { + "php": ">=5.5" + }, + "require-dev": { + "phpunit/phpunit": "^4.6" + }, + "type": "library", + "extra": { + "branch-alias": { + "dev-master": "1.0.x-dev" + } + }, + "autoload": { + "psr-4": { + "phpDocumentor\\Reflection\\": [ + "src" + ] + } + }, + "notification-url": "https://packagist.org/downloads/", + "license": [ + "MIT" + ], + "authors": [ + { + "name": "Jaap van Otterdijk", + "email": "opensource@ijaap.nl" + } + ], + "description": "Common reflection classes used by phpdocumentor to reflect the code structure", + "homepage": "http://www.phpdoc.org", + "keywords": [ + "FQSEN", + "phpDocumentor", + "phpdoc", + "reflection", + "static analysis" + ], + "time": "2017-09-11T18:02:19+00:00" + }, + { + "name": "phpdocumentor/reflection-docblock", + "version": "4.3.0", + "source": { + "type": "git", + "url": "https://github.com/phpDocumentor/ReflectionDocBlock.git", + "reference": "94fd0001232e47129dd3504189fa1c7225010d08" + }, + "dist": { + "type": "zip", + "url": "https://api.github.com/repos/phpDocumentor/ReflectionDocBlock/zipball/94fd0001232e47129dd3504189fa1c7225010d08", + "reference": "94fd0001232e47129dd3504189fa1c7225010d08", + "shasum": "" + }, + "require": { + "php": "^7.0", + "phpdocumentor/reflection-common": "^1.0.0", + "phpdocumentor/type-resolver": "^0.4.0", + "webmozart/assert": "^1.0" + }, + "require-dev": { + "doctrine/instantiator": "~1.0.5", + "mockery/mockery": "^1.0", + "phpunit/phpunit": "^6.4" + }, + "type": "library", + "extra": { + "branch-alias": { + "dev-master": "4.x-dev" + } + }, + "autoload": { + "psr-4": { + "phpDocumentor\\Reflection\\": [ + "src/" + ] + } + }, + "notification-url": "https://packagist.org/downloads/", + "license": [ + "MIT" + ], + "authors": [ + { + "name": "Mike van Riel", + "email": "me@mikevanriel.com" + } + ], + "description": "With this component, a library can provide support for annotations via DocBlocks or otherwise retrieve information that is embedded in a DocBlock.", + "time": "2017-11-30T07:14:17+00:00" + }, + { + "name": "phpdocumentor/type-resolver", + "version": "0.4.0", + "source": { + "type": "git", + "url": "https://github.com/phpDocumentor/TypeResolver.git", + "reference": "9c977708995954784726e25d0cd1dddf4e65b0f7" + }, + "dist": { + "type": "zip", + "url": "https://api.github.com/repos/phpDocumentor/TypeResolver/zipball/9c977708995954784726e25d0cd1dddf4e65b0f7", + "reference": "9c977708995954784726e25d0cd1dddf4e65b0f7", + "shasum": "" + }, + "require": { + "php": "^5.5 || ^7.0", + "phpdocumentor/reflection-common": "^1.0" + }, + "require-dev": { + "mockery/mockery": "^0.9.4", + "phpunit/phpunit": "^5.2||^4.8.24" + }, + "type": "library", + "extra": { + "branch-alias": { + "dev-master": "1.0.x-dev" + } + }, + "autoload": { + "psr-4": { + "phpDocumentor\\Reflection\\": [ + "src/" + ] + } + }, + "notification-url": "https://packagist.org/downloads/", + "license": [ + "MIT" + ], + "authors": [ + { + "name": "Mike van Riel", + "email": "me@mikevanriel.com" + } + ], + "time": "2017-07-14T14:27:02+00:00" + }, + { + "name": "phpspec/prophecy", + "version": "1.8.0", + "source": { + "type": "git", + "url": "https://github.com/phpspec/prophecy.git", + "reference": "4ba436b55987b4bf311cb7c6ba82aa528aac0a06" + }, + "dist": { + "type": "zip", + "url": "https://api.github.com/repos/phpspec/prophecy/zipball/4ba436b55987b4bf311cb7c6ba82aa528aac0a06", + "reference": "4ba436b55987b4bf311cb7c6ba82aa528aac0a06", + "shasum": "" + }, + "require": { + "doctrine/instantiator": "^1.0.2", + "php": "^5.3|^7.0", + "phpdocumentor/reflection-docblock": "^2.0|^3.0.2|^4.0", + "sebastian/comparator": "^1.1|^2.0|^3.0", + "sebastian/recursion-context": "^1.0|^2.0|^3.0" + }, + "require-dev": { + "phpspec/phpspec": "^2.5|^3.2", + "phpunit/phpunit": "^4.8.35 || ^5.7 || ^6.5 || ^7.1" + }, + "type": "library", + "extra": { + "branch-alias": { + "dev-master": "1.8.x-dev" + } + }, + "autoload": { + "psr-0": { + "Prophecy\\": "src/" + } + }, + "notification-url": "https://packagist.org/downloads/", + "license": [ + "MIT" + ], + "authors": [ + { + "name": "Konstantin Kudryashov", + "email": "ever.zet@gmail.com", + "homepage": "http://everzet.com" + }, + { + "name": "Marcello Duarte", + "email": "marcello.duarte@gmail.com" + } + ], + "description": "Highly opinionated mocking framework for PHP 5.3+", + "homepage": "https://github.com/phpspec/prophecy", + "keywords": [ + "Double", + "Dummy", + "fake", + "mock", + "spy", + "stub" + ], + "time": "2018-08-05T17:53:17+00:00" + }, + { + "name": "phpunit/php-code-coverage", + "version": "6.0.7", + "source": { + "type": "git", + "url": "https://github.com/sebastianbergmann/php-code-coverage.git", + "reference": "865662550c384bc1db7e51d29aeda1c2c161d69a" + }, + "dist": { + "type": "zip", + "url": "https://api.github.com/repos/sebastianbergmann/php-code-coverage/zipball/865662550c384bc1db7e51d29aeda1c2c161d69a", + "reference": "865662550c384bc1db7e51d29aeda1c2c161d69a", + "shasum": "" + }, + "require": { + "ext-dom": "*", + "ext-xmlwriter": "*", + "php": "^7.1", + "phpunit/php-file-iterator": "^2.0", + "phpunit/php-text-template": "^1.2.1", + "phpunit/php-token-stream": "^3.0", + "sebastian/code-unit-reverse-lookup": "^1.0.1", + "sebastian/environment": "^3.1", + "sebastian/version": "^2.0.1", + "theseer/tokenizer": "^1.1" + }, + "require-dev": { + "phpunit/phpunit": "^7.0" + }, + "suggest": { + "ext-xdebug": "^2.6.0" + }, + "type": "library", + "extra": { + "branch-alias": { + "dev-master": "6.0-dev" + } + }, + "autoload": { + "classmap": [ + "src/" + ] + }, + "notification-url": "https://packagist.org/downloads/", + "license": [ + "BSD-3-Clause" + ], + "authors": [ + { + "name": "Sebastian Bergmann", + "email": "sebastian@phpunit.de", + "role": "lead" + } + ], + "description": "Library that provides collection, processing, and rendering functionality for PHP code coverage information.", + "homepage": "https://github.com/sebastianbergmann/php-code-coverage", + "keywords": [ + "coverage", + "testing", + "xunit" + ], + "time": "2018-06-01T07:51:50+00:00" + }, + { + "name": "phpunit/php-file-iterator", + "version": "2.0.2", + "source": { + "type": "git", + "url": "https://github.com/sebastianbergmann/php-file-iterator.git", + "reference": "050bedf145a257b1ff02746c31894800e5122946" + }, + "dist": { + "type": "zip", + "url": "https://api.github.com/repos/sebastianbergmann/php-file-iterator/zipball/050bedf145a257b1ff02746c31894800e5122946", + "reference": "050bedf145a257b1ff02746c31894800e5122946", + "shasum": "" + }, + "require": { + "php": "^7.1" + }, + "require-dev": { + "phpunit/phpunit": "^7.1" + }, + "type": "library", + "extra": { + "branch-alias": { + "dev-master": "2.0.x-dev" + } + }, + "autoload": { + "classmap": [ + "src/" + ] + }, + "notification-url": "https://packagist.org/downloads/", + "license": [ + "BSD-3-Clause" + ], + "authors": [ + { + "name": "Sebastian Bergmann", + "email": "sebastian@phpunit.de", + "role": "lead" + } + ], + "description": "FilterIterator implementation that filters files based on a list of suffixes.", + "homepage": "https://github.com/sebastianbergmann/php-file-iterator/", + "keywords": [ + "filesystem", + "iterator" + ], + "time": "2018-09-13T20:33:42+00:00" + }, + { + "name": "phpunit/php-text-template", + "version": "1.2.1", + "source": { + "type": "git", + "url": "https://github.com/sebastianbergmann/php-text-template.git", + "reference": "31f8b717e51d9a2afca6c9f046f5d69fc27c8686" + }, + "dist": { + "type": "zip", + "url": "https://api.github.com/repos/sebastianbergmann/php-text-template/zipball/31f8b717e51d9a2afca6c9f046f5d69fc27c8686", + "reference": "31f8b717e51d9a2afca6c9f046f5d69fc27c8686", + "shasum": "" + }, + "require": { + "php": ">=5.3.3" + }, + "type": "library", + "autoload": { + "classmap": [ + "src/" + ] + }, + "notification-url": "https://packagist.org/downloads/", + "license": [ + "BSD-3-Clause" + ], + "authors": [ + { + "name": "Sebastian Bergmann", + "email": "sebastian@phpunit.de", + "role": "lead" + } + ], + "description": "Simple template engine.", + "homepage": "https://github.com/sebastianbergmann/php-text-template/", + "keywords": [ + "template" + ], + "time": "2015-06-21T13:50:34+00:00" + }, + { + "name": "phpunit/php-timer", + "version": "2.0.0", + "source": { + "type": "git", + "url": "https://github.com/sebastianbergmann/php-timer.git", + "reference": "8b8454ea6958c3dee38453d3bd571e023108c91f" + }, + "dist": { + "type": "zip", + "url": "https://api.github.com/repos/sebastianbergmann/php-timer/zipball/8b8454ea6958c3dee38453d3bd571e023108c91f", + "reference": "8b8454ea6958c3dee38453d3bd571e023108c91f", + "shasum": "" + }, + "require": { + "php": "^7.1" + }, + "require-dev": { + "phpunit/phpunit": "^7.0" + }, + "type": "library", + "extra": { + "branch-alias": { + "dev-master": "2.0-dev" + } + }, + "autoload": { + "classmap": [ + "src/" + ] + }, + "notification-url": "https://packagist.org/downloads/", + "license": [ + "BSD-3-Clause" + ], + "authors": [ + { + "name": "Sebastian Bergmann", + "email": "sebastian@phpunit.de", + "role": "lead" + } + ], + "description": "Utility class for timing", + "homepage": "https://github.com/sebastianbergmann/php-timer/", + "keywords": [ + "timer" + ], + "time": "2018-02-01T13:07:23+00:00" + }, + { + "name": "phpunit/php-token-stream", + "version": "3.0.0", + "source": { + "type": "git", + "url": "https://github.com/sebastianbergmann/php-token-stream.git", + "reference": "21ad88bbba7c3d93530d93994e0a33cd45f02ace" + }, + "dist": { + "type": "zip", + "url": "https://api.github.com/repos/sebastianbergmann/php-token-stream/zipball/21ad88bbba7c3d93530d93994e0a33cd45f02ace", + "reference": "21ad88bbba7c3d93530d93994e0a33cd45f02ace", + "shasum": "" + }, + "require": { + "ext-tokenizer": "*", + "php": "^7.1" + }, + "require-dev": { + "phpunit/phpunit": "^7.0" + }, + "type": "library", + "extra": { + "branch-alias": { + "dev-master": "3.0-dev" + } + }, + "autoload": { + "classmap": [ + "src/" + ] + }, + "notification-url": "https://packagist.org/downloads/", + "license": [ + "BSD-3-Clause" + ], + "authors": [ + { + "name": "Sebastian Bergmann", + "email": "sebastian@phpunit.de" + } + ], + "description": "Wrapper around PHP's tokenizer extension.", + "homepage": "https://github.com/sebastianbergmann/php-token-stream/", + "keywords": [ + "tokenizer" + ], + "time": "2018-02-01T13:16:43+00:00" + }, + { + "name": "phpunit/phpunit", + "version": "7.3.5", + "source": { + "type": "git", + "url": "https://github.com/sebastianbergmann/phpunit.git", + "reference": "7b331efabbb628c518c408fdfcaf571156775de2" + }, + "dist": { + "type": "zip", + "url": "https://api.github.com/repos/sebastianbergmann/phpunit/zipball/7b331efabbb628c518c408fdfcaf571156775de2", + "reference": "7b331efabbb628c518c408fdfcaf571156775de2", + "shasum": "" + }, + "require": { + "doctrine/instantiator": "^1.1", + "ext-dom": "*", + "ext-json": "*", + "ext-libxml": "*", + "ext-mbstring": "*", + "ext-xml": "*", + "myclabs/deep-copy": "^1.7", + "phar-io/manifest": "^1.0.2", + "phar-io/version": "^2.0", + "php": "^7.1", + "phpspec/prophecy": "^1.7", + "phpunit/php-code-coverage": "^6.0.7", + "phpunit/php-file-iterator": "^2.0.1", + "phpunit/php-text-template": "^1.2.1", + "phpunit/php-timer": "^2.0", + "sebastian/comparator": "^3.0", + "sebastian/diff": "^3.0", + "sebastian/environment": "^3.1", + "sebastian/exporter": "^3.1", + "sebastian/global-state": "^2.0", + "sebastian/object-enumerator": "^3.0.3", + "sebastian/resource-operations": "^1.0", + "sebastian/version": "^2.0.1" + }, + "conflict": { + "phpunit/phpunit-mock-objects": "*" + }, + "require-dev": { + "ext-pdo": "*" + }, + "suggest": { + "ext-soap": "*", + "ext-xdebug": "*", + "phpunit/php-invoker": "^2.0" + }, + "bin": [ + "phpunit" + ], + "type": "library", + "extra": { + "branch-alias": { + "dev-master": "7.3-dev" + } + }, + "autoload": { + "classmap": [ + "src/" + ] + }, + "notification-url": "https://packagist.org/downloads/", + "license": [ + "BSD-3-Clause" + ], + "authors": [ + { + "name": "Sebastian Bergmann", + "email": "sebastian@phpunit.de", + "role": "lead" + } + ], + "description": "The PHP Unit Testing framework.", + "homepage": "https://phpunit.de/", + "keywords": [ + "phpunit", + "testing", + "xunit" + ], + "time": "2018-09-08T15:14:29+00:00" + }, + { + "name": "sebastian/code-unit-reverse-lookup", + "version": "1.0.1", + "source": { + "type": "git", + "url": "https://github.com/sebastianbergmann/code-unit-reverse-lookup.git", + "reference": "4419fcdb5eabb9caa61a27c7a1db532a6b55dd18" + }, + "dist": { + "type": "zip", + "url": "https://api.github.com/repos/sebastianbergmann/code-unit-reverse-lookup/zipball/4419fcdb5eabb9caa61a27c7a1db532a6b55dd18", + "reference": "4419fcdb5eabb9caa61a27c7a1db532a6b55dd18", + "shasum": "" + }, + "require": { + "php": "^5.6 || ^7.0" + }, + "require-dev": { + "phpunit/phpunit": "^5.7 || ^6.0" + }, + "type": "library", + "extra": { + "branch-alias": { + "dev-master": "1.0.x-dev" + } + }, + "autoload": { + "classmap": [ + "src/" + ] + }, + "notification-url": "https://packagist.org/downloads/", + "license": [ + "BSD-3-Clause" + ], + "authors": [ + { + "name": "Sebastian Bergmann", + "email": "sebastian@phpunit.de" + } + ], + "description": "Looks up which function or method a line of code belongs to", + "homepage": "https://github.com/sebastianbergmann/code-unit-reverse-lookup/", + "time": "2017-03-04T06:30:41+00:00" + }, + { + "name": "sebastian/comparator", + "version": "3.0.2", + "source": { + "type": "git", + "url": "https://github.com/sebastianbergmann/comparator.git", + "reference": "5de4fc177adf9bce8df98d8d141a7559d7ccf6da" + }, + "dist": { + "type": "zip", + "url": "https://api.github.com/repos/sebastianbergmann/comparator/zipball/5de4fc177adf9bce8df98d8d141a7559d7ccf6da", + "reference": "5de4fc177adf9bce8df98d8d141a7559d7ccf6da", + "shasum": "" + }, + "require": { + "php": "^7.1", + "sebastian/diff": "^3.0", + "sebastian/exporter": "^3.1" + }, + "require-dev": { + "phpunit/phpunit": "^7.1" + }, + "type": "library", + "extra": { + "branch-alias": { + "dev-master": "3.0-dev" + } + }, + "autoload": { + "classmap": [ + "src/" + ] + }, + "notification-url": "https://packagist.org/downloads/", + "license": [ + "BSD-3-Clause" + ], + "authors": [ + { + "name": "Jeff Welch", + "email": "whatthejeff@gmail.com" + }, + { + "name": "Volker Dusch", + "email": "github@wallbash.com" + }, + { + "name": "Bernhard Schussek", + "email": "bschussek@2bepublished.at" + }, + { + "name": "Sebastian Bergmann", + "email": "sebastian@phpunit.de" + } + ], + "description": "Provides the functionality to compare PHP values for equality", + "homepage": "https://github.com/sebastianbergmann/comparator", + "keywords": [ + "comparator", + "compare", + "equality" + ], + "time": "2018-07-12T15:12:46+00:00" + }, + { + "name": "sebastian/diff", + "version": "3.0.1", + "source": { + "type": "git", + "url": "https://github.com/sebastianbergmann/diff.git", + "reference": "366541b989927187c4ca70490a35615d3fef2dce" + }, + "dist": { + "type": "zip", + "url": "https://api.github.com/repos/sebastianbergmann/diff/zipball/366541b989927187c4ca70490a35615d3fef2dce", + "reference": "366541b989927187c4ca70490a35615d3fef2dce", + "shasum": "" + }, + "require": { + "php": "^7.1" + }, + "require-dev": { + "phpunit/phpunit": "^7.0", + "symfony/process": "^2 || ^3.3 || ^4" + }, + "type": "library", + "extra": { + "branch-alias": { + "dev-master": "3.0-dev" + } + }, + "autoload": { + "classmap": [ + "src/" + ] + }, + "notification-url": "https://packagist.org/downloads/", + "license": [ + "BSD-3-Clause" + ], + "authors": [ + { + "name": "Kore Nordmann", + "email": "mail@kore-nordmann.de" + }, + { + "name": "Sebastian Bergmann", + "email": "sebastian@phpunit.de" + } + ], + "description": "Diff implementation", + "homepage": "https://github.com/sebastianbergmann/diff", + "keywords": [ + "diff", + "udiff", + "unidiff", + "unified diff" + ], + "time": "2018-06-10T07:54:39+00:00" + }, + { + "name": "sebastian/environment", + "version": "3.1.0", + "source": { + "type": "git", + "url": "https://github.com/sebastianbergmann/environment.git", + "reference": "cd0871b3975fb7fc44d11314fd1ee20925fce4f5" + }, + "dist": { + "type": "zip", + "url": "https://api.github.com/repos/sebastianbergmann/environment/zipball/cd0871b3975fb7fc44d11314fd1ee20925fce4f5", + "reference": "cd0871b3975fb7fc44d11314fd1ee20925fce4f5", + "shasum": "" + }, + "require": { + "php": "^7.0" + }, + "require-dev": { + "phpunit/phpunit": "^6.1" + }, + "type": "library", + "extra": { + "branch-alias": { + "dev-master": "3.1.x-dev" + } + }, + "autoload": { + "classmap": [ + "src/" + ] + }, + "notification-url": "https://packagist.org/downloads/", + "license": [ + "BSD-3-Clause" + ], + "authors": [ + { + "name": "Sebastian Bergmann", + "email": "sebastian@phpunit.de" + } + ], + "description": "Provides functionality to handle HHVM/PHP environments", + "homepage": "http://www.github.com/sebastianbergmann/environment", + "keywords": [ + "Xdebug", + "environment", + "hhvm" + ], + "time": "2017-07-01T08:51:00+00:00" + }, + { + "name": "sebastian/exporter", + "version": "3.1.0", + "source": { + "type": "git", + "url": "https://github.com/sebastianbergmann/exporter.git", + "reference": "234199f4528de6d12aaa58b612e98f7d36adb937" + }, + "dist": { + "type": "zip", + "url": "https://api.github.com/repos/sebastianbergmann/exporter/zipball/234199f4528de6d12aaa58b612e98f7d36adb937", + "reference": "234199f4528de6d12aaa58b612e98f7d36adb937", + "shasum": "" + }, + "require": { + "php": "^7.0", + "sebastian/recursion-context": "^3.0" + }, + "require-dev": { + "ext-mbstring": "*", + "phpunit/phpunit": "^6.0" + }, + "type": "library", + "extra": { + "branch-alias": { + "dev-master": "3.1.x-dev" + } + }, + "autoload": { + "classmap": [ + "src/" + ] + }, + "notification-url": "https://packagist.org/downloads/", + "license": [ + "BSD-3-Clause" + ], + "authors": [ + { + "name": "Jeff Welch", + "email": "whatthejeff@gmail.com" + }, + { + "name": "Volker Dusch", + "email": "github@wallbash.com" + }, + { + "name": "Bernhard Schussek", + "email": "bschussek@2bepublished.at" + }, + { + "name": "Sebastian Bergmann", + "email": "sebastian@phpunit.de" + }, + { + "name": "Adam Harvey", + "email": "aharvey@php.net" + } + ], + "description": "Provides the functionality to export PHP variables for visualization", + "homepage": "http://www.github.com/sebastianbergmann/exporter", + "keywords": [ + "export", + "exporter" + ], + "time": "2017-04-03T13:19:02+00:00" + }, + { + "name": "sebastian/global-state", + "version": "2.0.0", + "source": { + "type": "git", + "url": "https://github.com/sebastianbergmann/global-state.git", + "reference": "e8ba02eed7bbbb9e59e43dedd3dddeff4a56b0c4" + }, + "dist": { + "type": "zip", + "url": "https://api.github.com/repos/sebastianbergmann/global-state/zipball/e8ba02eed7bbbb9e59e43dedd3dddeff4a56b0c4", + "reference": "e8ba02eed7bbbb9e59e43dedd3dddeff4a56b0c4", + "shasum": "" + }, + "require": { + "php": "^7.0" + }, + "require-dev": { + "phpunit/phpunit": "^6.0" + }, + "suggest": { + "ext-uopz": "*" + }, + "type": "library", + "extra": { + "branch-alias": { + "dev-master": "2.0-dev" + } + }, + "autoload": { + "classmap": [ + "src/" + ] + }, + "notification-url": "https://packagist.org/downloads/", + "license": [ + "BSD-3-Clause" + ], + "authors": [ + { + "name": "Sebastian Bergmann", + "email": "sebastian@phpunit.de" + } + ], + "description": "Snapshotting of global state", + "homepage": "http://www.github.com/sebastianbergmann/global-state", + "keywords": [ + "global state" + ], + "time": "2017-04-27T15:39:26+00:00" + }, + { + "name": "sebastian/object-enumerator", + "version": "3.0.3", + "source": { + "type": "git", + "url": "https://github.com/sebastianbergmann/object-enumerator.git", + "reference": "7cfd9e65d11ffb5af41198476395774d4c8a84c5" + }, + "dist": { + "type": "zip", + "url": "https://api.github.com/repos/sebastianbergmann/object-enumerator/zipball/7cfd9e65d11ffb5af41198476395774d4c8a84c5", + "reference": "7cfd9e65d11ffb5af41198476395774d4c8a84c5", + "shasum": "" + }, + "require": { + "php": "^7.0", + "sebastian/object-reflector": "^1.1.1", + "sebastian/recursion-context": "^3.0" + }, + "require-dev": { + "phpunit/phpunit": "^6.0" + }, + "type": "library", + "extra": { + "branch-alias": { + "dev-master": "3.0.x-dev" + } + }, + "autoload": { + "classmap": [ + "src/" + ] + }, + "notification-url": "https://packagist.org/downloads/", + "license": [ + "BSD-3-Clause" + ], + "authors": [ + { + "name": "Sebastian Bergmann", + "email": "sebastian@phpunit.de" + } + ], + "description": "Traverses array structures and object graphs to enumerate all referenced objects", + "homepage": "https://github.com/sebastianbergmann/object-enumerator/", + "time": "2017-08-03T12:35:26+00:00" + }, + { + "name": "sebastian/object-reflector", + "version": "1.1.1", + "source": { + "type": "git", + "url": "https://github.com/sebastianbergmann/object-reflector.git", + "reference": "773f97c67f28de00d397be301821b06708fca0be" + }, + "dist": { + "type": "zip", + "url": "https://api.github.com/repos/sebastianbergmann/object-reflector/zipball/773f97c67f28de00d397be301821b06708fca0be", + "reference": "773f97c67f28de00d397be301821b06708fca0be", + "shasum": "" + }, + "require": { + "php": "^7.0" + }, + "require-dev": { + "phpunit/phpunit": "^6.0" + }, + "type": "library", + "extra": { + "branch-alias": { + "dev-master": "1.1-dev" + } + }, + "autoload": { + "classmap": [ + "src/" + ] + }, + "notification-url": "https://packagist.org/downloads/", + "license": [ + "BSD-3-Clause" + ], + "authors": [ + { + "name": "Sebastian Bergmann", + "email": "sebastian@phpunit.de" + } + ], + "description": "Allows reflection of object attributes, including inherited and non-public ones", + "homepage": "https://github.com/sebastianbergmann/object-reflector/", + "time": "2017-03-29T09:07:27+00:00" + }, + { + "name": "sebastian/recursion-context", + "version": "3.0.0", + "source": { + "type": "git", + "url": "https://github.com/sebastianbergmann/recursion-context.git", + "reference": "5b0cd723502bac3b006cbf3dbf7a1e3fcefe4fa8" + }, + "dist": { + "type": "zip", + "url": "https://api.github.com/repos/sebastianbergmann/recursion-context/zipball/5b0cd723502bac3b006cbf3dbf7a1e3fcefe4fa8", + "reference": "5b0cd723502bac3b006cbf3dbf7a1e3fcefe4fa8", + "shasum": "" + }, + "require": { + "php": "^7.0" + }, + "require-dev": { + "phpunit/phpunit": "^6.0" + }, + "type": "library", + "extra": { + "branch-alias": { + "dev-master": "3.0.x-dev" + } + }, + "autoload": { + "classmap": [ + "src/" + ] + }, + "notification-url": "https://packagist.org/downloads/", + "license": [ + "BSD-3-Clause" + ], + "authors": [ + { + "name": "Jeff Welch", + "email": "whatthejeff@gmail.com" + }, + { + "name": "Sebastian Bergmann", + "email": "sebastian@phpunit.de" + }, + { + "name": "Adam Harvey", + "email": "aharvey@php.net" + } + ], + "description": "Provides functionality to recursively process PHP variables", + "homepage": "http://www.github.com/sebastianbergmann/recursion-context", + "time": "2017-03-03T06:23:57+00:00" + }, + { + "name": "sebastian/resource-operations", + "version": "1.0.0", + "source": { + "type": "git", + "url": "https://github.com/sebastianbergmann/resource-operations.git", + "reference": "ce990bb21759f94aeafd30209e8cfcdfa8bc3f52" + }, + "dist": { + "type": "zip", + "url": "https://api.github.com/repos/sebastianbergmann/resource-operations/zipball/ce990bb21759f94aeafd30209e8cfcdfa8bc3f52", + "reference": "ce990bb21759f94aeafd30209e8cfcdfa8bc3f52", + "shasum": "" + }, + "require": { + "php": ">=5.6.0" + }, + "type": "library", + "extra": { + "branch-alias": { + "dev-master": "1.0.x-dev" + } + }, + "autoload": { + "classmap": [ + "src/" + ] + }, + "notification-url": "https://packagist.org/downloads/", + "license": [ + "BSD-3-Clause" + ], + "authors": [ + { + "name": "Sebastian Bergmann", + "email": "sebastian@phpunit.de" + } + ], + "description": "Provides a list of PHP built-in functions that operate on resources", + "homepage": "https://www.github.com/sebastianbergmann/resource-operations", + "time": "2015-07-28T20:34:47+00:00" + }, + { + "name": "sebastian/version", + "version": "2.0.1", + "source": { + "type": "git", + "url": "https://github.com/sebastianbergmann/version.git", + "reference": "99732be0ddb3361e16ad77b68ba41efc8e979019" + }, + "dist": { + "type": "zip", + "url": "https://api.github.com/repos/sebastianbergmann/version/zipball/99732be0ddb3361e16ad77b68ba41efc8e979019", + "reference": "99732be0ddb3361e16ad77b68ba41efc8e979019", + "shasum": "" + }, + "require": { + "php": ">=5.6" + }, + "type": "library", + "extra": { + "branch-alias": { + "dev-master": "2.0.x-dev" + } + }, + "autoload": { + "classmap": [ + "src/" + ] + }, + "notification-url": "https://packagist.org/downloads/", + "license": [ + "BSD-3-Clause" + ], + "authors": [ + { + "name": "Sebastian Bergmann", + "email": "sebastian@phpunit.de", + "role": "lead" + } + ], + "description": "Library that helps with managing the version number of Git-hosted PHP projects", + "homepage": "https://github.com/sebastianbergmann/version", + "time": "2016-10-03T07:35:21+00:00" + }, + { + "name": "theseer/tokenizer", + "version": "1.1.0", + "source": { + "type": "git", + "url": "https://github.com/theseer/tokenizer.git", + "reference": "cb2f008f3f05af2893a87208fe6a6c4985483f8b" + }, + "dist": { + "type": "zip", + "url": "https://api.github.com/repos/theseer/tokenizer/zipball/cb2f008f3f05af2893a87208fe6a6c4985483f8b", + "reference": "cb2f008f3f05af2893a87208fe6a6c4985483f8b", + "shasum": "" + }, + "require": { + "ext-dom": "*", + "ext-tokenizer": "*", + "ext-xmlwriter": "*", + "php": "^7.0" + }, + "type": "library", + "autoload": { + "classmap": [ + "src/" + ] + }, + "notification-url": "https://packagist.org/downloads/", + "license": [ + "BSD-3-Clause" + ], + "authors": [ + { + "name": "Arne Blankerts", + "email": "arne@blankerts.de", + "role": "Developer" + } + ], + "description": "A small library for converting tokenized PHP source code into XML and potentially other formats", + "time": "2017-04-07T12:08:54+00:00" + }, + { + "name": "webmozart/assert", + "version": "1.3.0", + "source": { + "type": "git", + "url": "https://github.com/webmozart/assert.git", + "reference": "0df1908962e7a3071564e857d86874dad1ef204a" + }, + "dist": { + "type": "zip", + "url": "https://api.github.com/repos/webmozart/assert/zipball/0df1908962e7a3071564e857d86874dad1ef204a", + "reference": "0df1908962e7a3071564e857d86874dad1ef204a", + "shasum": "" + }, + "require": { + "php": "^5.3.3 || ^7.0" + }, + "require-dev": { + "phpunit/phpunit": "^4.6", + "sebastian/version": "^1.0.1" + }, + "type": "library", + "extra": { + "branch-alias": { + "dev-master": "1.3-dev" + } + }, + "autoload": { + "psr-4": { + "Webmozart\\Assert\\": "src/" + } + }, + "notification-url": "https://packagist.org/downloads/", + "license": [ + "MIT" + ], + "authors": [ + { + "name": "Bernhard Schussek", + "email": "bschussek@gmail.com" + } + ], + "description": "Assertions to validate method input/output with nice error messages.", + "keywords": [ + "assert", + "check", + "validate" + ], + "time": "2018-01-29T19:49:41+00:00" + } + ], + "aliases": [], + "minimum-stability": "stable", + "stability-flags": [], + "prefer-stable": false, + "prefer-lowest": false, + "platform": [], + "platform-dev": [] +} diff --git a/_integration/php/src/.gitkeep b/_integration/php/src/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/_integration/php/tests/MySQLTest.php b/_integration/php/tests/MySQLTest.php new file mode 100644 index 000000000..dd20b0922 --- /dev/null +++ b/_integration/php/tests/MySQLTest.php @@ -0,0 +1,28 @@ +setAttribute(PDO::ATTR_ERRMODE, PDO::ERRMODE_EXCEPTION); + + $stmt = $conn->query('SELECT name, email FROM mytable ORDER BY name, email'); + $result = $stmt->fetchAll(PDO::FETCH_ASSOC); + + $expected = [ + ["name" => "Evil Bob", "email" => "evilbob@gmail.com"], + ["name" => "Jane Doe", "email" => "jane@doe.com"], + ["name" => "John Doe", "email" => "john@doe.com"], + ["name" => "John Doe", "email" => "johnalt@doe.com"], + ]; + + $this->assertEquals($expected, $result); + } catch (\PDOException $e) { + $this->assertFalse(true, $e->getMessage()); + } + } +} diff --git a/_integration/python-mysql/.gitignore b/_integration/python-mysql/.gitignore new file mode 100644 index 000000000..7e99e367f --- /dev/null +++ b/_integration/python-mysql/.gitignore @@ -0,0 +1 @@ +*.pyc \ No newline at end of file diff --git a/_integration/python-mysql/Makefile b/_integration/python-mysql/Makefile new file mode 100644 index 000000000..756633e99 --- /dev/null +++ b/_integration/python-mysql/Makefile @@ -0,0 +1,7 @@ +dependencies: + python -m pip install -r requirements.txt + +test: dependencies + python -m unittest discover + +.PHONY: dependencies test \ No newline at end of file diff --git a/_integration/python-mysql/requirements.txt b/_integration/python-mysql/requirements.txt new file mode 100644 index 000000000..a6dc36278 --- /dev/null +++ b/_integration/python-mysql/requirements.txt @@ -0,0 +1 @@ +mysql-connector \ No newline at end of file diff --git a/_integration/python-mysql/test.py b/_integration/python-mysql/test.py new file mode 100644 index 000000000..56a09e136 --- /dev/null +++ b/_integration/python-mysql/test.py @@ -0,0 +1,30 @@ +import unittest +import mysql.connector + +class TestMySQL(unittest.TestCase): + + def test_connect(self): + connection = mysql.connector.connect(host='127.0.0.1', + user='root', + passwd='') + + try: + cursor = connection.cursor() + sql = "SELECT name, email FROM mytable ORDER BY name, email" + cursor.execute(sql) + rows = cursor.fetchall() + + expected = [ + ("Evil Bob", "evilbob@gmail.com"), + ("Jane Doe", "jane@doe.com"), + ("John Doe", "john@doe.com"), + ("John Doe", "johnalt@doe.com") + ] + + self.assertEqual(expected, rows) + finally: + connection.close() + + +if __name__ == '__main__': + unittest.main() diff --git a/_integration/python-pymysql/.gitignore b/_integration/python-pymysql/.gitignore new file mode 100644 index 000000000..7e99e367f --- /dev/null +++ b/_integration/python-pymysql/.gitignore @@ -0,0 +1 @@ +*.pyc \ No newline at end of file diff --git a/_integration/python-pymysql/Makefile b/_integration/python-pymysql/Makefile new file mode 100644 index 000000000..756633e99 --- /dev/null +++ b/_integration/python-pymysql/Makefile @@ -0,0 +1,7 @@ +dependencies: + python -m pip install -r requirements.txt + +test: dependencies + python -m unittest discover + +.PHONY: dependencies test \ No newline at end of file diff --git a/_integration/python-pymysql/requirements.txt b/_integration/python-pymysql/requirements.txt new file mode 100644 index 000000000..1b198419d --- /dev/null +++ b/_integration/python-pymysql/requirements.txt @@ -0,0 +1 @@ +PyMySQL \ No newline at end of file diff --git a/_integration/python-pymysql/test.py b/_integration/python-pymysql/test.py new file mode 100644 index 000000000..ee2090811 --- /dev/null +++ b/_integration/python-pymysql/test.py @@ -0,0 +1,32 @@ +import unittest +import pymysql.cursors + +class TestMySQL(unittest.TestCase): + + def test_connect(self): + connection = pymysql.connect(host='127.0.0.1', + user='root', + password='', + db='', + cursorclass=pymysql.cursors.DictCursor) + + try: + with connection.cursor() as cursor: + sql = "SELECT name, email FROM mytable ORDER BY name, email" + cursor.execute(sql) + rows = cursor.fetchall() + + expected = [ + {"name": "Evil Bob", "email": "evilbob@gmail.com"}, + {"name": "Jane Doe", "email": "jane@doe.com"}, + {"name": "John Doe", "email": "john@doe.com"}, + {"name": "John Doe", "email": "johnalt@doe.com"} + ] + + self.assertEqual(expected, rows) + finally: + connection.close() + + +if __name__ == '__main__': + unittest.main() diff --git a/_integration/python-sqlalchemy/Makefile b/_integration/python-sqlalchemy/Makefile new file mode 100644 index 000000000..756633e99 --- /dev/null +++ b/_integration/python-sqlalchemy/Makefile @@ -0,0 +1,7 @@ +dependencies: + python -m pip install -r requirements.txt + +test: dependencies + python -m unittest discover + +.PHONY: dependencies test \ No newline at end of file diff --git a/_integration/python-sqlalchemy/requirements.txt b/_integration/python-sqlalchemy/requirements.txt new file mode 100644 index 000000000..4a2392f06 --- /dev/null +++ b/_integration/python-sqlalchemy/requirements.txt @@ -0,0 +1,3 @@ +pandas +sqlalchemy +mysqlclient diff --git a/_integration/python-sqlalchemy/test.py b/_integration/python-sqlalchemy/test.py new file mode 100644 index 000000000..e1baa052d --- /dev/null +++ b/_integration/python-sqlalchemy/test.py @@ -0,0 +1,23 @@ +import unittest +import pandas as pd +import sqlalchemy + + +class TestMySQL(unittest.TestCase): + + def test_connect(self): + engine = sqlalchemy.create_engine('mysql+mysqldb://root:@127.0.0.1:3306/mydb') + with engine.connect() as conn: + expected = { + "name": {0: 'John Doe', 1: 'John Doe', 2: 'Jane Doe', 3: 'Evil Bob'}, + "email": {0: 'john@doe.com', 1: 'johnalt@doe.com', 2: 'jane@doe.com', 3: 'evilbob@gmail.com'}, + "phone_numbers": {0: ['555-555-555'], 1: [], 2: [], 3: ['555-666-555', '666-666-666']}, + } + repo_df = pd.read_sql_table("mytable", con=conn) + d = repo_df.to_dict() + del d["created_at"] + self.assertEqual(expected, d) + + +if __name__ == '__main__': + unittest.main() diff --git a/_integration/ruby/.gitignore b/_integration/ruby/.gitignore new file mode 100644 index 000000000..06de90aa1 --- /dev/null +++ b/_integration/ruby/.gitignore @@ -0,0 +1 @@ +.bundle \ No newline at end of file diff --git a/_integration/ruby/Gemfile b/_integration/ruby/Gemfile new file mode 100644 index 000000000..50f9a64eb --- /dev/null +++ b/_integration/ruby/Gemfile @@ -0,0 +1,3 @@ +source 'https://rubygems.org' +gem 'ruby-mysql', '~> 2.9', '>= 2.9.14' +gem 'minitest', '~> 5.11', '>= 5.11.3' \ No newline at end of file diff --git a/_integration/ruby/Gemfile.lock b/_integration/ruby/Gemfile.lock new file mode 100644 index 000000000..6939f04b3 --- /dev/null +++ b/_integration/ruby/Gemfile.lock @@ -0,0 +1,17 @@ +GEM + remote: https://rubygems.org/ + specs: + minitest (5.11.3) + rake (12.3.1) + ruby-mysql (2.9.14) + +PLATFORMS + ruby + +DEPENDENCIES + minitest (~> 5.11, >= 5.11.3) + rake + ruby-mysql (~> 2.9, >= 2.9.14) + +BUNDLED WITH + 1.16.5 diff --git a/_integration/ruby/Makefile b/_integration/ruby/Makefile new file mode 100644 index 000000000..5a52ad20b --- /dev/null +++ b/_integration/ruby/Makefile @@ -0,0 +1,10 @@ +vendor/bundle: + gem install bundler --version=1.16.5 + bundler install --path vendor/bundle + +dependencies: vendor/bundle + +test: dependencies + bundler exec ruby mysql_test.rb + +.PHONY: test diff --git a/_integration/ruby/mysql_test.rb b/_integration/ruby/mysql_test.rb new file mode 100644 index 000000000..191b7de5d --- /dev/null +++ b/_integration/ruby/mysql_test.rb @@ -0,0 +1,21 @@ +require "minitest/autorun" +require "mysql" + +class TestMySQL < Minitest::Test + def test_can_connect + conn = Mysql::new("127.0.0.1", "root", "") + res = conn.query "SELECT name, email FROM mytable ORDER BY name, email" + + expected = [ + ["Evil Bob", "evilbob@gmail.com"], + ["Jane Doe", "jane@doe.com"], + ["John Doe", "john@doe.com"], + ["John Doe", "johnalt@doe.com"] + ] + + rows = res.map do |row| [row[0], row[1]] end + assert_equal rows, expected + + conn.close() + end +end diff --git a/_integration/run b/_integration/run new file mode 100755 index 000000000..e6e419f05 --- /dev/null +++ b/_integration/run @@ -0,0 +1,10 @@ +#!/bin/sh + +go build -o test-server ./_example/main.go +./test-server > /dev/null 2>&1 & +SERVER_PID=$! +sleep 5 +cd "./_integration/$1" \ + && make test \ + && kill -9 $SERVER_PID \ + && rm -rf ../../test-server \ No newline at end of file diff --git a/_scripts/go-vitess/Makefile b/_scripts/go-vitess/Makefile deleted file mode 100644 index 23f9038aa..000000000 --- a/_scripts/go-vitess/Makefile +++ /dev/null @@ -1,56 +0,0 @@ -# Tooling to create the package `gopkg.in/src-d/go-vitess.v0`. - -# config -PACKAGE := gopkg.in/src-d/go-vitess.v0 -REMOTE := git@github.com:src-d/go-vitess.git -VITESS_GIT := https://github.com/youtube/vitess -VITESS_PKG := github.com/youtube/vitess/go/mysql -DEPENDENCIES := \ - github.com/novalagung/gorep \ - github.com/youtube/vitess - -VITESS_SRC := ${GOPATH}/src/${PACKAGE} -PACKAGES := $(VITESS_PKG) $(shell go list -f '{{ join .Deps "\n" }}' ${VITESS_PKG} | grep -i vitess) -FOLDERS := $(shell echo ${PACKAGES} | sed -e 's/github.com\/youtube\/vitess\///g') -GIT_COMMIT := $(shell cd ${VITESS_SRC} && git show-ref refs/original/refs/heads/master --hash) -ETC_PATH := $(PWD)/etc - -all: prepare-package -prepare-package: | filter-branch rename-packages replace-glog prepare-git commit - -$(VITESS_SRC): - git clone --single-branch --no-tags ${VITESS_GIT} $@ - -filter-branch: $(VITESS_SRC) - cd ${VITESS_SRC} && \ - git filter-branch --index-filter ' \ - git rm --cached -qr --ignore-unmatch -- . && \ - git reset -q $$GIT_COMMIT -- ${FOLDERS} \ - ' \ - --prune-empty \ - -- --all - -commit: - cd ${VITESS_SRC} && \ - cp -rf ${ETC_PATH}/* . && git add * && \ - git commit -m "update from upstream ${VITESS_GIT}/commit/${GIT_COMMIT}" -a -s - -rename-packages: - cd ${VITESS_SRC} && \ - git mv go/* . && \ - gorep -from=github.com/youtube/vitess/go -to=${PACKAGE} - -replace-glog: - cd ${VITESS_SRC} && \ - gorep -from=github.com/golang/glog -to=github.com/sirupsen/logrus && \ - grep -lr --exclude-dir=".git" -e "Exitf" . | xargs sed -i 's/log\.Exitf/log\.Panicf/g' - -prepare-git: - cd ${VITESS_SRC} && \ - git remote rm origin && \ - git remote add origin $(REMOTE) - -clean: - rm -rf ${VITESS_SRC} - -.PHONY: $(PACKAGES) test clean diff --git a/_scripts/go-vitess/etc/README.md b/_scripts/go-vitess/etc/README.md deleted file mode 100644 index 725a52960..000000000 --- a/_scripts/go-vitess/etc/README.md +++ /dev/null @@ -1,22 +0,0 @@ -# go-vitess [![GoDoc](https://godoc.org/gopkg.in/src-d/go-vitess.v0?status.svg)](https://godoc.org/gopkg.in/src-d/go-vitess.v0) - -`go-vitess` is an automatic filter-branch done by an [script](https://github.com/src-d/go-mysql-server/blob/master/_scripts/go-vitess/Makefile), of the great [Vitess](github.com/youtube/vitess) project. - -The goal is keeping the `github.com/youtube/vitess/go/mysql` package and all the dependent packages as a standalone versioned golang library, to be used by other projects. - -It holds all the packages to create your own MySQL server and a full SQL parser. - -## Installation - -```sh -go get -v -u gopkg.in/src-d/go-vitess.v0/... -``` - -## Contributions - -Since the code belongs to the upstream of [Vitess](github.com/youtube/vitess), -the issue neither pull requests aren't accepted to this repository. - -## License - -Apache License 2.0, see [LICENSE.md](LICENSE.md). diff --git a/_scripts/go-vitess/etc/doc.go b/_scripts/go-vitess/etc/doc.go deleted file mode 100644 index 82b11de32..000000000 --- a/_scripts/go-vitess/etc/doc.go +++ /dev/null @@ -1,9 +0,0 @@ -// Package vitess is an automatic filter-branch, of the great Vitess project. -// -// The goal is keeping the `github.com/youtube/vitess/go/mysql` package and all -// the dependent packages as a standalone versioned golang library, to be used -// by other projects. -// -// It holds all the packages to create your own MySQL server and a full SQL -// parser. -package vitess // import "gopkg.in/src-d/go-vitess.v0" diff --git a/auth/audit.go b/auth/audit.go new file mode 100644 index 000000000..1bcd9eec5 --- /dev/null +++ b/auth/audit.go @@ -0,0 +1,149 @@ +package auth + +import ( + "net" + "time" + + "github.com/src-d/go-mysql-server/sql" + "vitess.io/vitess/go/mysql" + + "github.com/sirupsen/logrus" +) + +// AuditMethod is called to log the audit trail of actions. +type AuditMethod interface { + // Authentication logs an authentication event. + Authentication(user, address string, err error) + // Authorization logs an authorization event. + Authorization(ctx *sql.Context, p Permission, err error) + // Query logs a query execution. + Query(ctx *sql.Context, d time.Duration, err error) +} + +// MysqlAudit wraps mysql.AuthServer to emit audit trails. +type MysqlAudit struct { + mysql.AuthServer + audit AuditMethod +} + +// ValidateHash sends authentication calls to an AuditMethod. +func (m *MysqlAudit) ValidateHash( + salt []byte, + user string, + resp []byte, + addr net.Addr, +) (mysql.Getter, error) { + getter, err := m.AuthServer.ValidateHash(salt, user, resp, addr) + m.audit.Authentication(user, addr.String(), err) + + return getter, err +} + +// NewAudit creates a wrapped Auth that sends audit trails to the specified +// method. +func NewAudit(auth Auth, method AuditMethod) Auth { + return &Audit{ + auth: auth, + method: method, + } +} + +// Audit is an Auth method proxy that sends audit trails to the specified +// AuditMethod. +type Audit struct { + auth Auth + method AuditMethod +} + +// Mysql implements Auth interface. +func (a *Audit) Mysql() mysql.AuthServer { + return &MysqlAudit{ + AuthServer: a.auth.Mysql(), + audit: a.method, + } +} + +// Allowed implements Auth interface. +func (a *Audit) Allowed(ctx *sql.Context, permission Permission) error { + err := a.auth.Allowed(ctx, permission) + a.method.Authorization(ctx, permission, err) + + return err +} + +// Query implements AuditQuery interface. +func (a *Audit) Query(ctx *sql.Context, d time.Duration, err error) { + if q, ok := a.auth.(*Audit); ok { + q.Query(ctx, d, err) + } + + a.method.Query(ctx, d, err) +} + +// NewAuditLog creates a new AuditMethod that logs to a logrus.Logger. +func NewAuditLog(l *logrus.Logger) AuditMethod { + la := l.WithField("system", "audit") + + return &AuditLog{ + log: la, + } +} + +const auditLogMessage = "audit trail" + +// AuditLog logs audit trails to a logrus.Logger. +type AuditLog struct { + log *logrus.Entry +} + +// Authentication implements AuditMethod interface. +func (a *AuditLog) Authentication(user string, address string, err error) { + fields := logrus.Fields{ + "action": "authentication", + "user": user, + "address": address, + "success": true, + } + + if err != nil { + fields["success"] = false + fields["err"] = err + } + + a.log.WithFields(fields).Info(auditLogMessage) +} + +func auditInfo(ctx *sql.Context, err error) logrus.Fields { + fields := logrus.Fields{ + "user": ctx.Client().User, + "query": ctx.Query(), + "address": ctx.Client().Address, + "connection_id": ctx.Session.ID(), + "pid": ctx.Pid(), + "success": true, + } + + if err != nil { + fields["success"] = false + fields["err"] = err + } + + return fields +} + +// Authorization implements AuditMethod interface. +func (a *AuditLog) Authorization(ctx *sql.Context, p Permission, err error) { + fields := auditInfo(ctx, err) + fields["action"] = "authorization" + fields["permission"] = p.String() + + a.log.WithFields(fields).Info(auditLogMessage) +} + +func (a *AuditLog) Query(ctx *sql.Context, d time.Duration, err error) { + fields := auditInfo(ctx, err) + fields["action"] = "query" + fields["duration"] = d + + a.log.WithFields(fields).Info(auditLogMessage) +} diff --git a/auth/audit_test.go b/auth/audit_test.go new file mode 100644 index 000000000..82e777be2 --- /dev/null +++ b/auth/audit_test.go @@ -0,0 +1,232 @@ +// +build !windows + +package auth_test + +import ( + "context" + "testing" + "time" + + "github.com/src-d/go-mysql-server/auth" + "github.com/src-d/go-mysql-server/sql" + + "github.com/sanity-io/litter" + "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus/hooks/test" + "github.com/stretchr/testify/require" +) + +type Authentication struct { + user string + address string + err error +} + +type Authorization struct { + ctx *sql.Context + p auth.Permission + err error +} + +type Query struct { + ctx *sql.Context + d time.Duration + err error +} + +type auditTest struct { + authentication Authentication + authorization Authorization + query Query +} + +func (a *auditTest) Authentication(user string, address string, err error) { + a.authentication = Authentication{ + user: user, + address: address, + err: err, + } +} + +func (a *auditTest) Authorization(ctx *sql.Context, p auth.Permission, err error) { + a.authorization = Authorization{ + ctx: ctx, + p: p, + err: err, + } +} + +func (a *auditTest) Query(ctx *sql.Context, d time.Duration, err error) { + println("query!") + a.query = Query{ + ctx: ctx, + d: d, + err: err, + } +} + +func (a *auditTest) Clean() { + a.authorization = Authorization{} + a.authentication = Authentication{} + a.query = Query{} +} + +func TestAuditAuthentication(t *testing.T) { + a := auth.NewNativeSingle("user", "password", auth.AllPermissions) + at := new(auditTest) + audit := auth.NewAudit(a, at) + + extra := func(t *testing.T, c authenticationTest) { + a := at.authentication + + require.Equal(t, c.user, a.user) + require.NotEmpty(t, a.address) + if c.success { + require.NoError(t, a.err) + } else { + require.Error(t, a.err) + require.Nil(t, at.authorization.ctx) + require.Nil(t, at.query.ctx) + } + + at.Clean() + } + + testAuthentication(t, audit, nativeSingleTests, extra) +} + +func TestAuditAuthorization(t *testing.T) { + a := auth.NewNativeSingle("user", "", auth.ReadPerm) + at := new(auditTest) + audit := auth.NewAudit(a, at) + + tests := []authorizationTest{ + {"user", "invalid query", false}, + {"user", queries["select"], true}, + + {"user", queries["create_index"], false}, + {"user", queries["drop_index"], false}, + {"user", queries["insert"], false}, + {"user", queries["lock"], false}, + {"user", queries["unlock"], false}, + } + + extra := func(t *testing.T, c authorizationTest) { + a := at.authorization + q := at.query + + litter.Dump(q) + require.NotNil(t, q.ctx) + require.Equal(t, c.user, q.ctx.Client().User) + require.NotEmpty(t, q.ctx.Client().Address) + require.NotZero(t, q.d) + require.Equal(t, c.user, at.authentication.user) + + if c.success { + require.Equal(t, c.user, a.ctx.Client().User) + require.NotEmpty(t, a.ctx.Client().Address) + require.NoError(t, a.err) + require.NoError(t, q.err) + } else { + require.Error(t, q.err) + + // if there's a syntax error authorization is not triggered + if auth.ErrNotAuthorized.Is(q.err) { + require.Equal(t, q.err, a.err) + require.NotNil(t, a.ctx) + require.Equal(t, c.user, a.ctx.Client().User) + require.NotEmpty(t, a.ctx.Client().Address) + } else { + require.NoError(t, a.err) + require.Nil(t, a.ctx) + } + } + + at.Clean() + } + + testAudit(t, audit, tests, extra) +} + +func TestAuditLog(t *testing.T) { + require := require.New(t) + + logger, hook := test.NewNullLogger() + l := auth.NewAuditLog(logger) + + pid := uint64(303) + id := uint32(42) + + l.Authentication("user", "client", nil) + e := hook.LastEntry() + require.NotNil(e) + require.Equal(logrus.InfoLevel, e.Level) + m := logrus.Fields{ + "system": "audit", + "action": "authentication", + "user": "user", + "address": "client", + "success": true, + } + require.Equal(m, e.Data) + + err := auth.ErrNoPermission.New(auth.ReadPerm) + l.Authentication("user", "client", err) + e = hook.LastEntry() + m["success"] = false + m["err"] = err + require.Equal(m, e.Data) + + s := sql.NewSession("server", "client", "user", id) + ctx := sql.NewContext(context.TODO(), + sql.WithSession(s), + sql.WithPid(pid), + sql.WithQuery("query"), + ) + + l.Authorization(ctx, auth.ReadPerm, nil) + e = hook.LastEntry() + require.NotNil(e) + require.Equal(logrus.InfoLevel, e.Level) + m = logrus.Fields{ + "system": "audit", + "action": "authorization", + "permission": auth.ReadPerm.String(), + "user": "user", + "query": "query", + "address": "client", + "connection_id": id, + "pid": pid, + "success": true, + } + require.Equal(m, e.Data) + + l.Authorization(ctx, auth.ReadPerm, err) + e = hook.LastEntry() + m["success"] = false + m["err"] = err + require.Equal(m, e.Data) + + l.Query(ctx, 808*time.Second, nil) + e = hook.LastEntry() + require.NotNil(e) + require.Equal(logrus.InfoLevel, e.Level) + m = logrus.Fields{ + "system": "audit", + "action": "query", + "duration": 808 * time.Second, + "user": "user", + "query": "query", + "address": "client", + "connection_id": id, + "pid": pid, + "success": true, + } + require.Equal(m, e.Data) + + l.Query(ctx, 808*time.Second, err) + e = hook.LastEntry() + m["success"] = false + m["err"] = err + require.Equal(m, e.Data) +} diff --git a/auth/auth.go b/auth/auth.go new file mode 100644 index 000000000..d2ea68d5e --- /dev/null +++ b/auth/auth.go @@ -0,0 +1,62 @@ +package auth + +import ( + "strings" + + "github.com/src-d/go-mysql-server/sql" + "gopkg.in/src-d/go-errors.v1" + "vitess.io/vitess/go/mysql" +) + +// Permission holds permissions required by a query or grated to a user. +type Permission int + +const ( + // ReadPerm means that it reads. + ReadPerm Permission = 1 << iota + // WritePerm means that it writes. + WritePerm +) + +var ( + // AllPermissions hold all defined permissions. + AllPermissions = ReadPerm | WritePerm + // DefaultPermissions are the permissions granted to a user if not defined. + DefaultPermissions = ReadPerm + + // PermissionNames is used to translate from human to machine + // representations. + PermissionNames = map[string]Permission{ + "read": ReadPerm, + "write": WritePerm, + } + + // ErrNotAuthorized is returned when the user is not allowed to use a + // permission. + ErrNotAuthorized = errors.NewKind("not authorized") + // ErrNoPermission is returned when the user lacks needed permissions. + ErrNoPermission = errors.NewKind("user does not have permission: %s") +) + +// String returns all the permissions set to on. +func (p Permission) String() string { + var str []string + for k, v := range PermissionNames { + if p&v != 0 { + str = append(str, k) + } + } + + return strings.Join(str, ", ") +} + +// Auth interface provides mysql authentication methods and permission checking +// for users. +type Auth interface { + // Mysql returns a configured authentication method used by server.Server. + Mysql() mysql.AuthServer + // Allowed checks user's permissions with needed permission. If the user + // does not have enough permissions it returns ErrNotAuthorized. + // Otherwise is an error using the authentication method. + Allowed(ctx *sql.Context, permission Permission) error +} diff --git a/auth/common_test.go b/auth/common_test.go new file mode 100644 index 000000000..492b6881b --- /dev/null +++ b/auth/common_test.go @@ -0,0 +1,226 @@ +// +build !windows + +package auth_test + +import ( + "context" + dsql "database/sql" + "fmt" + "io/ioutil" + "os" + "testing" + + sqle "github.com/src-d/go-mysql-server" + "github.com/src-d/go-mysql-server/auth" + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/server" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/analyzer" + "github.com/src-d/go-mysql-server/sql/index/pilosa" + "github.com/stretchr/testify/require" +) + +const port = 3336 + +func authEngine(au auth.Auth) (string, *sqle.Engine, error) { + db := memory.NewDatabase("test") + catalog := sql.NewCatalog() + catalog.AddDatabase(db) + + tblName := "test" + + table := memory.NewTable(tblName, sql.Schema{ + {Name: "id", Type: sql.Text, Nullable: false, Source: tblName}, + {Name: "name", Type: sql.Text, Nullable: false, Source: tblName}, + }) + + db.AddTable(tblName, table) + + tmpDir, err := ioutil.TempDir(os.TempDir(), "pilosa-test") + if err != nil { + return "", nil, err + } + + err = os.MkdirAll(tmpDir, 0644) + if err != nil { + return "", nil, err + } + + catalog.RegisterIndexDriver(pilosa.NewDriver(tmpDir)) + + a := analyzer.NewBuilder(catalog).Build() + config := &sqle.Config{Auth: au} + + return tmpDir, sqle.New(catalog, a, config), nil +} + +func authServer(a auth.Auth) (string, *server.Server, error) { + tmpDir, engine, err := authEngine(a) + if err != nil { + os.RemoveAll(tmpDir) + return "", nil, err + } + + config := server.Config{ + Protocol: "tcp", + Address: fmt.Sprintf("localhost:%d", port), + Auth: a, + } + + s, err := server.NewDefaultServer(config, engine) + if err != nil { + os.RemoveAll(tmpDir) + return "", nil, err + } + + go s.Start() + + return tmpDir, s, nil +} + +func connString(user, password string) string { + return fmt.Sprintf("%s:%s@tcp(127.0.0.1:%d)/test", user, password, port) +} + +type authenticationTest struct { + user string + password string + success bool +} + +func testAuthentication( + t *testing.T, + a auth.Auth, + tests []authenticationTest, + extra func(t *testing.T, c authenticationTest), +) { + t.Helper() + req := require.New(t) + + tmpDir, s, err := authServer(a) + req.NoError(err) + defer os.RemoveAll(tmpDir) + + for _, c := range tests { + t.Run(fmt.Sprintf("%s-%s", c.user, c.password), func(t *testing.T) { + r := require.New(t) + + var db *dsql.DB + db, err = dsql.Open("mysql", connString(c.user, c.password)) + r.NoError(err) + _, err = db.Query("SELECT 1") + + if c.success { + r.NoError(err) + } else { + r.Error(err) + r.Contains(err.Error(), "Access denied") + } + + err = db.Close() + r.NoError(err) + + if extra != nil { + extra(t, c) + } + }) + } + + err = s.Close() + req.NoError(err) +} + +var queries = map[string]string{ + "select": "select * from test", + "create_index": "create index t on test using pilosa (name) with (async = false)", + "drop_index": "drop index t on test", + "insert": "insert into test (id, name) values ('id', 'name')", + "lock": "lock tables test read", + "unlock": "unlock tables", +} + +type authorizationTest struct { + user string + query string + success bool +} + +func testAuthorization( + t *testing.T, + a auth.Auth, + tests []authorizationTest, + extra func(t *testing.T, c authorizationTest), +) { + t.Helper() + req := require.New(t) + + tmpDir, e, err := authEngine(a) + req.NoError(err) + defer os.RemoveAll(tmpDir) + + for i, c := range tests { + t.Run(fmt.Sprintf("%s-%s", c.user, c.query), func(t *testing.T) { + req := require.New(t) + + session := sql.NewSession("localhost", "client", c.user, uint32(i)) + ctx := sql.NewContext(context.TODO(), + sql.WithSession(session), + sql.WithPid(uint64(i))) + + _, _, err := e.Query(ctx, c.query) + + if c.success { + req.NoError(err) + return + } + + req.Error(err) + if extra != nil { + extra(t, c) + } else { + req.True(auth.ErrNotAuthorized.Is(err)) + } + }) + } +} + +func testAudit( + t *testing.T, + a auth.Auth, + tests []authorizationTest, + extra func(t *testing.T, c authorizationTest), +) { + t.Helper() + req := require.New(t) + + tmpDir, s, err := authServer(a) + req.NoError(err) + defer os.RemoveAll(tmpDir) + + for _, c := range tests { + t.Run(c.user, func(t *testing.T) { + r := require.New(t) + + var db *dsql.DB + db, err = dsql.Open("mysql", connString(c.user, "")) + r.NoError(err) + _, err = db.Query(c.query) + + if c.success { + r.NoError(err) + } else { + r.Error(err) + } + + err = db.Close() + r.NoError(err) + + if extra != nil { + extra(t, c) + } + }) + } + + err = s.Close() + req.NoError(err) +} diff --git a/auth/native.go b/auth/native.go new file mode 100644 index 000000000..6d0744e12 --- /dev/null +++ b/auth/native.go @@ -0,0 +1,155 @@ +package auth + +import ( + "crypto/sha1" + "encoding/hex" + "encoding/json" + "fmt" + "io/ioutil" + "regexp" + "strings" + + "github.com/src-d/go-mysql-server/sql" + + "gopkg.in/src-d/go-errors.v1" + "vitess.io/vitess/go/mysql" +) + +var ( + regNative = regexp.MustCompile(`^\*[0-9A-F]{40}$`) + + // ErrParseUserFile is given when user file is malformed. + ErrParseUserFile = errors.NewKind("error parsing user file") + // ErrUnknownPermission happens when a user permission is not defined. + ErrUnknownPermission = errors.NewKind("unknown permission, %s") + // ErrDuplicateUser happens when a user appears more than once. + ErrDuplicateUser = errors.NewKind("duplicate user, %s") +) + +// nativeUser holds information about credentials and permissions for a user. +type nativeUser struct { + Name string + Password string + JSONPermissions []string `json:"Permissions"` + Permissions Permission +} + +// Allowed checks if the user has certain permission. +func (u nativeUser) Allowed(p Permission) error { + if u.Permissions&p == p { + return nil + } + + // permissions needed but not granted to the user + p2 := (^u.Permissions) & p + + return ErrNotAuthorized.Wrap(ErrNoPermission.New(p2)) +} + +// NativePassword generates a mysql_native_password string. +func NativePassword(password string) string { + if len(password) == 0 { + return "" + } + + // native = sha1(sha1(password)) + + hash := sha1.New() + hash.Write([]byte(password)) + s1 := hash.Sum(nil) + + hash.Reset() + hash.Write(s1) + s2 := hash.Sum(nil) + + s := strings.ToUpper(hex.EncodeToString(s2)) + + return fmt.Sprintf("*%s", s) +} + +// Native holds mysql_native_password users. +type Native struct { + users map[string]nativeUser +} + +// NewNativeSingle creates a NativeAuth with a single user with given +// permissions. +func NewNativeSingle(name, password string, perm Permission) *Native { + users := make(map[string]nativeUser) + users[name] = nativeUser{ + Name: name, + Password: NativePassword(password), + Permissions: perm, + } + + return &Native{users} +} + +// NewNativeFile creates a NativeAuth and loads users from a JSON file. +func NewNativeFile(file string) (*Native, error) { + var data []nativeUser + + raw, err := ioutil.ReadFile(file) + if err != nil { + return nil, ErrParseUserFile.New(err) + } + + if err := json.Unmarshal(raw, &data); err != nil { + return nil, ErrParseUserFile.New(err) + } + + users := make(map[string]nativeUser) + for _, u := range data { + _, ok := users[u.Name] + if ok { + return nil, ErrParseUserFile.Wrap(ErrDuplicateUser.New(u.Name)) + } + + if !regNative.MatchString(u.Password) { + u.Password = NativePassword(u.Password) + } + + if len(u.JSONPermissions) == 0 { + u.Permissions = DefaultPermissions + } + + for _, p := range u.JSONPermissions { + perm, ok := PermissionNames[strings.ToLower(p)] + if !ok { + return nil, ErrParseUserFile.Wrap(ErrUnknownPermission.New(p)) + } + + u.Permissions |= perm + } + + users[u.Name] = u + } + + return &Native{users}, nil +} + +// Mysql implements Auth interface. +func (s *Native) Mysql() mysql.AuthServer { + auth := mysql.NewAuthServerStatic() + + for k, v := range s.users { + auth.Entries[k] = []*mysql.AuthServerStaticEntry{ + { + MysqlNativePassword: v.Password, + Password: v.Password}, + } + } + + return auth +} + +// Allowed implements Auth interface. +func (s *Native) Allowed(ctx *sql.Context, permission Permission) error { + name := ctx.Client().User + u, ok := s.users[name] + if !ok { + return ErrNotAuthorized.Wrap(ErrNoPermission.New(permission)) + } + + return u.Allowed(permission) +} diff --git a/auth/native_test.go b/auth/native_test.go new file mode 100644 index 000000000..10cb43255 --- /dev/null +++ b/auth/native_test.go @@ -0,0 +1,252 @@ +// +build !windows + +package auth_test + +import ( + "io/ioutil" + "os" + "testing" + + "github.com/src-d/go-mysql-server/auth" + + _ "github.com/go-sql-driver/mysql" + "github.com/stretchr/testify/require" + "gopkg.in/src-d/go-errors.v1" +) + +const ( + baseConfig = ` +[ + { + "name": "root", + "password": "*9E128DA0C64A6FCCCDCFBDD0FC0A2C967C6DB36F", + "permissions": ["read", "write"] + }, + { + "name": "user", + "password": "password", + "permissions": ["read"] + }, + { + "name": "no_password" + }, + { + "name": "empty_password", + "password": "" + }, + { + "name": "no_permissions", + "permissions": [] + } +]` + duplicateUser = ` +[ + { "name": "user" }, + { "name": "user" } +]` + badPermission = ` +[ + { "permissions": ["read", "write", "admin"] } +]` + badJSON = "I,am{not}JSON" +) + +func writeConfig(config string) (string, error) { + tmp, err := ioutil.TempFile("", "native-config") + if err != nil { + return "", err + } + + _, err = tmp.WriteString(config) + if err != nil { + os.Remove(tmp.Name()) + return "", err + } + + return tmp.Name(), nil +} + +var nativeSingleTests = []authenticationTest{ + {"root", "", false}, + {"root", "password", false}, + {"root", "mysql_password", false}, + {"user", "password", true}, + {"user", "other_password", false}, + {"user", "", false}, + {"", "", false}, + {"", "password", false}, +} + +func TestNativeAuthenticationSingle(t *testing.T) { + a := auth.NewNativeSingle("user", "password", auth.AllPermissions) + testAuthentication(t, a, nativeSingleTests, nil) +} + +func TestNativeAuthentication(t *testing.T) { + req := require.New(t) + + conf, err := writeConfig(baseConfig) + req.NoError(err) + defer os.Remove(conf) + + a, err := auth.NewNativeFile(conf) + req.NoError(err) + + tests := []authenticationTest{ + {"root", "", false}, + {"root", "password", false}, + {"root", "mysql_password", true}, + {"user", "password", true}, + {"user", "other_password", false}, + {"user", "", false}, + {"no_password", "", true}, + {"no_password", "password", false}, + {"empty_password", "", true}, + {"empty_password", "password", false}, + {"nonexistent", "", false}, + {"nonexistent", "password", false}, + } + + testAuthentication(t, a, tests, nil) +} + +func TestNativeAuthorizationSingleAll(t *testing.T) { + a := auth.NewNativeSingle("user", "password", auth.AllPermissions) + + tests := []authorizationTest{ + {"user", queries["select"], true}, + {"root", queries["select"], false}, + {"", queries["select"], false}, + + {"user", queries["create_index"], true}, + {"root", queries["create_index"], false}, + {"", queries["create_index"], false}, + + {"user", queries["drop_index"], true}, + {"root", queries["drop_index"], false}, + {"", queries["drop_index"], false}, + + {"user", queries["insert"], true}, + {"root", queries["insert"], false}, + {"", queries["insert"], false}, + + {"user", queries["lock"], true}, + {"root", queries["lock"], false}, + {"", queries["lock"], false}, + + {"user", queries["unlock"], true}, + {"root", queries["unlock"], false}, + {"", queries["unlock"], false}, + } + + testAuthorization(t, a, tests, nil) +} + +func TestNativeAuthorizationSingleRead(t *testing.T) { + a := auth.NewNativeSingle("user", "password", auth.ReadPerm) + + tests := []authorizationTest{ + {"user", queries["select"], true}, + {"root", queries["select"], false}, + {"", queries["select"], false}, + + {"user", queries["create_index"], false}, + {"root", queries["create_index"], false}, + {"", queries["create_index"], false}, + + {"user", queries["drop_index"], false}, + {"root", queries["drop_index"], false}, + {"", queries["drop_index"], false}, + + {"user", queries["insert"], false}, + {"root", queries["insert"], false}, + {"", queries["insert"], false}, + + {"user", queries["lock"], false}, + {"root", queries["lock"], false}, + {"", queries["lock"], false}, + + {"user", queries["unlock"], false}, + {"root", queries["unlock"], false}, + {"", queries["unlock"], false}, + } + + testAuthorization(t, a, tests, nil) +} + +func TestNativeAuthorization(t *testing.T) { + require := require.New(t) + + conf, err := writeConfig(baseConfig) + require.NoError(err) + defer os.Remove(conf) + + a, err := auth.NewNativeFile(conf) + require.NoError(err) + + tests := []authorizationTest{ + {"", queries["select"], false}, + {"user", queries["select"], true}, + {"no_password", queries["select"], true}, + {"no_permissions", queries["select"], true}, + {"root", queries["select"], true}, + + {"", queries["create_index"], false}, + {"user", queries["create_index"], false}, + {"no_password", queries["create_index"], false}, + {"no_permissions", queries["create_index"], false}, + {"root", queries["create_index"], true}, + + {"", queries["drop_index"], false}, + {"user", queries["drop_index"], false}, + {"no_password", queries["drop_index"], false}, + {"no_permissions", queries["drop_index"], false}, + {"root", queries["drop_index"], true}, + + {"", queries["insert"], false}, + {"user", queries["insert"], false}, + {"no_password", queries["insert"], false}, + {"no_permissions", queries["insert"], false}, + {"root", queries["insert"], true}, + + {"", queries["lock"], false}, + {"user", queries["lock"], false}, + {"no_password", queries["lock"], false}, + {"no_permissions", queries["lock"], false}, + {"root", queries["lock"], true}, + + {"", queries["unlock"], false}, + {"user", queries["unlock"], false}, + {"no_password", queries["unlock"], false}, + {"no_permissions", queries["unlock"], false}, + {"root", queries["unlock"], true}, + } + + testAuthorization(t, a, tests, nil) +} + +func TestNativeErrors(t *testing.T) { + tests := []struct { + name string + config string + err *errors.Kind + }{ + {"duplicate_user", duplicateUser, auth.ErrDuplicateUser}, + {"bad_permission", badPermission, auth.ErrUnknownPermission}, + {"malformed", badJSON, auth.ErrParseUserFile}, + } + + for _, c := range tests { + t.Run(c.name, func(t *testing.T) { + require := require.New(t) + + conf, err := writeConfig(c.config) + require.NoError(err) + defer os.Remove(conf) + + _, err = auth.NewNativeFile(conf) + require.Error(err) + require.True(c.err.Is(err)) + }) + } +} diff --git a/auth/none.go b/auth/none.go new file mode 100644 index 000000000..8da092936 --- /dev/null +++ b/auth/none.go @@ -0,0 +1,20 @@ +package auth + +import ( + "github.com/src-d/go-mysql-server/sql" + + "vitess.io/vitess/go/mysql" +) + +// None is an Auth method that always succeeds. +type None struct{} + +// Mysql implements Auth interface. +func (n *None) Mysql() mysql.AuthServer { + return new(mysql.AuthServerNone) +} + +// Mysql implements Auth interface. +func (n *None) Allowed(ctx *sql.Context, permission Permission) error { + return nil +} diff --git a/auth/none_test.go b/auth/none_test.go new file mode 100644 index 000000000..36d2495dc --- /dev/null +++ b/auth/none_test.go @@ -0,0 +1,48 @@ +// +build !windows + +package auth_test + +import ( + "testing" + + "github.com/src-d/go-mysql-server/auth" +) + +func TestNoneAuthentication(t *testing.T) { + a := new(auth.None) + + tests := []authenticationTest{ + {"root", "", true}, + {"root", "password", true}, + {"root", "mysql_password", true}, + {"user", "password", true}, + {"user", "other_password", true}, + {"user", "", true}, + {"", "", true}, + {"", "password", true}, + } + + testAuthentication(t, a, tests, nil) +} + +func TestNoneAuthorization(t *testing.T) { + a := new(auth.None) + + tests := []authorizationTest{ + {"user", queries["select"], true}, + {"root", queries["select"], true}, + {"", queries["select"], true}, + + {"user", queries["create_index"], true}, + + {"root", queries["drop_index"], true}, + + {"", queries["insert"], true}, + + {"user", queries["lock"], true}, + + {"root", queries["unlock"], true}, + } + + testAuthorization(t, a, tests, nil) +} diff --git a/benchmark/metadata.go b/benchmark/metadata.go index b8867cf8f..16545eff1 100644 --- a/benchmark/metadata.go +++ b/benchmark/metadata.go @@ -1,6 +1,6 @@ package benchmark -import "gopkg.in/src-d/go-mysql-server.v0/sql" +import "github.com/src-d/go-mysql-server/sql" type tableMetadata struct { schema sql.Schema diff --git a/benchmark/tpc_h_test.go b/benchmark/tpc_h_test.go index f18b322db..5524069fa 100644 --- a/benchmark/tpc_h_test.go +++ b/benchmark/tpc_h_test.go @@ -10,10 +10,10 @@ import ( "path/filepath" "testing" - "gopkg.in/src-d/go-mysql-server.v0" + sqle "github.com/src-d/go-mysql-server" - "gopkg.in/src-d/go-mysql-server.v0/mem" - "gopkg.in/src-d/go-mysql-server.v0/sql" + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" ) var scriptsPath = "../_scripts/tpc-h/" @@ -30,7 +30,7 @@ func BenchmarkTpch(b *testing.B) { b.Fatal(err) } - e := sqle.New() + e := sqle.NewDefault() e.AddDatabase(db) b.ResetTimer() @@ -83,12 +83,12 @@ func executeQueries(b *testing.B, e *sqle.Engine) error { } func genDB(b *testing.B) (sql.Database, error) { - db := mem.NewDatabase("tpch") + db := memory.NewDatabase("tpch") for _, m := range tpchTableMetadata { b.Log("generating table", m.name) - t := mem.NewTable(m.name, m.schema) - if err := insertDataToTable(t, len(m.schema)); err != nil { + t := memory.NewTable(m.name, m.schema) + if err := insertDataToTable(m.name, t, len(m.schema)); err != nil { return nil, err } @@ -98,8 +98,8 @@ func genDB(b *testing.B) (sql.Database, error) { return db, nil } -func insertDataToTable(t *mem.Table, columnCount int) error { - f, err := os.Open(t.Name() + ".tbl") +func insertDataToTable(name string, t *memory.Table, columnCount int) error { + f, err := os.Open(name + ".tbl") if err != nil { return err } @@ -122,7 +122,7 @@ func insertDataToTable(t *mem.Table, columnCount int) error { return err } - if err := t.Insert(row); err != nil { + if err := t.Insert(sql.NewEmptyContext(), row); err != nil { return err } } diff --git a/engine.go b/engine.go index 215e2ec7d..f09c167c5 100644 --- a/engine.go +++ b/engine.go @@ -1,47 +1,151 @@ package sqle import ( + "time" + + "github.com/go-kit/kit/metrics/discard" opentracing "github.com/opentracing/opentracing-go" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/analyzer" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression/function" - "gopkg.in/src-d/go-mysql-server.v0/sql/parse" + "github.com/sirupsen/logrus" + "github.com/src-d/go-mysql-server/auth" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/analyzer" + "github.com/src-d/go-mysql-server/sql/expression/function" + "github.com/src-d/go-mysql-server/sql/parse" + "github.com/src-d/go-mysql-server/sql/plan" ) +// Config for the Engine. +type Config struct { + // VersionPostfix to display with the `VERSION()` UDF. + VersionPostfix string + // Auth used for authentication and authorization. + Auth auth.Auth +} + // Engine is a SQL engine. type Engine struct { Catalog *sql.Catalog Analyzer *analyzer.Analyzer + Auth auth.Auth +} + +var ( + // QueryCounter describes a metric that accumulates number of queries monotonically. + QueryCounter = discard.NewCounter() + + // QueryErrorCounter describes a metric that accumulates number of failed queries monotonically. + QueryErrorCounter = discard.NewCounter() + + // QueryHistogram describes a queries latency. + QueryHistogram = discard.NewHistogram() +) + +func observeQuery(ctx *sql.Context, query string) func(err error) { + logrus.WithField("query", query).Debug("executing query") + span, _ := ctx.Span("query", opentracing.Tag{Key: "query", Value: query}) + + t := time.Now() + return func(err error) { + if err != nil { + QueryErrorCounter.With("query", query, "error", err.Error()).Add(1) + } else { + QueryCounter.With("query", query).Add(1) + QueryHistogram.With("query", query, "duration", "seconds").Observe(time.Since(t).Seconds()) + } + + span.Finish() + } +} + +// New creates a new Engine with custom configuration. To create an Engine with +// the default settings use `NewDefault`. +func New(c *sql.Catalog, a *analyzer.Analyzer, cfg *Config) *Engine { + var versionPostfix string + if cfg != nil { + versionPostfix = cfg.VersionPostfix + } + + c.MustRegister( + sql.FunctionN{ + Name: "version", + Fn: function.NewVersion(versionPostfix), + }, + sql.Function0{ + Name: "database", + Fn: function.NewDatabase(c), + }) + c.MustRegister(function.Defaults...) + + // use auth.None if auth is not specified + var au auth.Auth + if cfg == nil || cfg.Auth == nil { + au = new(auth.None) + } else { + au = cfg.Auth + } + + return &Engine{c, a, au} } -// New creates a new Engine. -func New() *Engine { +// NewDefault creates a new default Engine. +func NewDefault() *Engine { c := sql.NewCatalog() - c.RegisterFunctions(function.Defaults) + a := analyzer.NewDefault(c) - a := analyzer.New(c) - return &Engine{c, a} + return New(c, a, nil) } -// Query executes a query without attaching to any context. +// Query executes a query. func (e *Engine) Query( ctx *sql.Context, query string, ) (sql.Schema, sql.RowIter, error) { - span, ctx := ctx.Span("query", opentracing.Tag{Key: "query", Value: query}) - defer span.Finish() + var ( + parsed, analyzed sql.Node + iter sql.RowIter + err error + ) + + finish := observeQuery(ctx, query) + defer finish(err) + + parsed, err = parse.Parse(ctx, query) + if err != nil { + return nil, nil, err + } + + var perm = auth.ReadPerm + var typ = sql.QueryProcess + switch parsed.(type) { + case *plan.CreateIndex: + typ = sql.CreateIndexProcess + perm = auth.ReadPerm | auth.WritePerm + case *plan.InsertInto, *plan.DeleteFrom, *plan.Update, *plan.DropIndex, *plan.UnlockTables, *plan.LockTables: + perm = auth.ReadPerm | auth.WritePerm + } + + err = e.Auth.Allowed(ctx, perm) + if err != nil { + return nil, nil, err + } + + ctx, err = e.Catalog.AddProcess(ctx, typ, query) + defer func() { + if err != nil && ctx != nil { + e.Catalog.Done(ctx.Pid()) + } + }() - parsed, err := parse.Parse(ctx, query) if err != nil { return nil, nil, err } - analyzed, err := e.Analyzer.Analyze(ctx, parsed) + analyzed, err = e.Analyzer.Analyze(ctx, parsed) if err != nil { return nil, nil, err } - iter, err := analyzed.RowIter(ctx) + iter, err = analyzed.RowIter(ctx) if err != nil { return nil, nil, err } @@ -49,8 +153,24 @@ func (e *Engine) Query( return analyzed.Schema(), iter, nil } +// Async returns true if the query is async. If there are any errors with the +// query it returns false +func (e *Engine) Async(ctx *sql.Context, query string) bool { + parsed, err := parse.Parse(ctx, query) + if err != nil { + return false + } + + asyncNode, ok := parsed.(sql.AsyncNode) + return ok && asyncNode.IsAsync() +} + // AddDatabase adds the given database to the catalog. func (e *Engine) AddDatabase(db sql.Database) { - e.Catalog.Databases = append(e.Catalog.Databases, db) - e.Analyzer.CurrentDatabase = db.Name() + e.Catalog.AddDatabase(db) +} + +// Init performs all the initialization requirements for the engine to work. +func (e *Engine) Init() error { + return e.Catalog.LoadIndexes(e.Catalog.AllDatabases()) } diff --git a/engine_pilosa_test.go b/engine_pilosa_test.go new file mode 100644 index 000000000..a6da4c5a9 --- /dev/null +++ b/engine_pilosa_test.go @@ -0,0 +1,209 @@ +// +build !windows + +package sqle_test + +import ( + "context" + "io/ioutil" + "os" + "testing" + "time" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/index/pilosa" + "github.com/src-d/go-mysql-server/test" + + "github.com/stretchr/testify/require" +) + +func TestIndexes(t *testing.T) { + e := newEngine(t) + + tmpDir, err := ioutil.TempDir(os.TempDir(), "pilosa-test") + require.NoError(t, err) + + require.NoError(t, os.MkdirAll(tmpDir, 0644)) + e.Catalog.RegisterIndexDriver(pilosa.NewDriver(tmpDir)) + + _, _, err = e.Query( + newCtx(), + "CREATE INDEX idx_i ON mytable USING pilosa (i) WITH (async = false)", + ) + require.NoError(t, err) + + _, _, err = e.Query( + newCtx(), + "CREATE INDEX idx_s ON mytable USING pilosa (s) WITH (async = false)", + ) + require.NoError(t, err) + + _, _, err = e.Query( + newCtx(), + "CREATE INDEX idx_is ON mytable USING pilosa (i, s) WITH (async = false)", + ) + require.NoError(t, err) + + defer func() { + done, err := e.Catalog.DeleteIndex("mydb", "idx_i", true) + require.NoError(t, err) + <-done + + done, err = e.Catalog.DeleteIndex("mydb", "idx_s", true) + require.NoError(t, err) + <-done + + done, err = e.Catalog.DeleteIndex("foo", "idx_is", true) + require.NoError(t, err) + <-done + }() + + testCases := []struct { + query string + expected []sql.Row + }{ + { + "SELECT * FROM mytable WHERE i = 2", + []sql.Row{ + {int64(2), "second row"}, + }, + }, + { + "SELECT * FROM mytable WHERE i > 1", + []sql.Row{ + {int64(3), "third row"}, + {int64(2), "second row"}, + }, + }, + { + "SELECT * FROM mytable WHERE i < 3", + []sql.Row{ + {int64(1), "first row"}, + {int64(2), "second row"}, + }, + }, + { + "SELECT * FROM mytable WHERE i <= 2", + []sql.Row{ + {int64(2), "second row"}, + {int64(1), "first row"}, + }, + }, + { + "SELECT * FROM mytable WHERE i >= 2", + []sql.Row{ + {int64(2), "second row"}, + {int64(3), "third row"}, + }, + }, + { + "SELECT * FROM mytable WHERE i = 2 AND s = 'second row'", + []sql.Row{ + {int64(2), "second row"}, + }, + }, + { + "SELECT * FROM mytable WHERE i = 2 AND s = 'third row'", + ([]sql.Row)(nil), + }, + { + "SELECT * FROM mytable WHERE i BETWEEN 1 AND 2", + []sql.Row{ + {int64(1), "first row"}, + {int64(2), "second row"}, + }, + }, + { + "SELECT * FROM mytable WHERE i = 1 OR i = 2", + []sql.Row{ + {int64(1), "first row"}, + {int64(2), "second row"}, + }, + }, + { + "SELECT * FROM mytable WHERE i = 1 AND i = 2", + ([]sql.Row)(nil), + }, + { + "SELECT i as mytable_i FROM mytable WHERE mytable_i = 2", + []sql.Row{ + {int64(2)}, + }, + }, + { + "SELECT i as mytable_i FROM mytable WHERE mytable_i > 1", + []sql.Row{ + {int64(3)}, + {int64(2)}, + }, + }, + { + "SELECT i as mytable_i, s as mytable_s FROM mytable WHERE mytable_i = 2 AND mytable_s = 'second row'", + []sql.Row{ + {int64(2), "second row"}, + }, + }, + { + "SELECT s, SUBSTRING(s, 1, 1) AS sub_s FROM mytable WHERE sub_s = 's'", + []sql.Row{ + {"second row", "s"}, + }, + }, + { + "SELECT count(i) AS mytable_i, SUBSTR(s, -3) AS mytable_s FROM mytable WHERE i > 0 AND mytable_s='row' GROUP BY mytable_s", + []sql.Row{ + {int64(3), "row"}, + }, + }, + { + "SELECT mytable_i FROM (SELECT i AS mytable_i FROM mytable) as t WHERE mytable_i > 1", + []sql.Row{ + {int64(2)}, + {int64(3)}, + }, + }, + } + + for _, tt := range testCases { + t.Run(tt.query, func(t *testing.T) { + require := require.New(t) + + tracer := new(test.MemTracer) + ctx := sql.NewContext(context.TODO(), sql.WithTracer(tracer)) + + _, it, err := e.Query(ctx, tt.query) + require.NoError(err) + + rows, err := sql.RowIterToRows(it) + require.NoError(err) + + require.ElementsMatch(tt.expected, rows) + require.Equal("plan.ResolvedTable", tracer.Spans[len(tracer.Spans)-1]) + }) + } +} + +func TestCreateIndex(t *testing.T) { + require := require.New(t) + e := newEngine(t) + + tmpDir, err := ioutil.TempDir(os.TempDir(), "pilosa-test") + require.NoError(err) + + require.NoError(os.MkdirAll(tmpDir, 0644)) + e.Catalog.RegisterIndexDriver(pilosa.NewDriver(tmpDir)) + + _, iter, err := e.Query(newCtx(), "CREATE INDEX myidx ON mytable USING pilosa (i)") + require.NoError(err) + rows, err := sql.RowIterToRows(iter) + require.NoError(err) + require.Len(rows, 0) + + defer func() { + time.Sleep(1 * time.Second) + done, err := e.Catalog.DeleteIndex("foo", "myidx", true) + require.NoError(err) + <-done + + require.NoError(os.RemoveAll(tmpDir)) + }() +} diff --git a/engine_test.go b/engine_test.go index fa3c1f559..b022a3e6b 100644 --- a/engine_test.go +++ b/engine_test.go @@ -3,21 +3,26 @@ package sqle_test import ( "context" "io" + "math" "strings" + "sync/atomic" "testing" + "time" - "gopkg.in/src-d/go-mysql-server.v0" - "gopkg.in/src-d/go-mysql-server.v0/mem" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/parse" + "github.com/opentracing/opentracing-go" + + sqle "github.com/src-d/go-mysql-server" + "github.com/src-d/go-mysql-server/auth" + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/analyzer" + "github.com/src-d/go-mysql-server/sql/parse" + "github.com/src-d/go-mysql-server/sql/plan" + "github.com/src-d/go-mysql-server/test" - opentracing "github.com/opentracing/opentracing-go" - "github.com/opentracing/opentracing-go/log" "github.com/stretchr/testify/require" ) -const driverName = "engine_tests" - var queries = []struct { query string expected []sql.Row @@ -26,14 +31,90 @@ var queries = []struct { "SELECT i FROM mytable;", []sql.Row{{int64(1)}, {int64(2)}, {int64(3)}}, }, + { + "SELECT i + 1 FROM mytable;", + []sql.Row{{int64(2)}, {int64(3)}, {int64(4)}}, + }, + { + "SELECT -i FROM mytable;", + []sql.Row{{int64(-1)}, {int64(-2)}, {int64(-3)}}, + }, + { + "SELECT i FROM mytable where -i = -2;", + []sql.Row{{int64(2)}}, + }, { "SELECT i FROM mytable WHERE i = 2;", []sql.Row{{int64(2)}}, }, + { + "SELECT i FROM mytable WHERE i > 2;", + []sql.Row{{int64(3)}}, + }, + { + "SELECT i FROM mytable WHERE i < 2;", + []sql.Row{{int64(1)}}, + }, + { + "SELECT i FROM mytable WHERE i <> 2;", + []sql.Row{{int64(1)}, {int64(3)}}, + }, + { + "SELECT f32 FROM floattable WHERE f64 = 2.0;", + []sql.Row{{float32(2.0)}}, + }, + { + "SELECT f32 FROM floattable WHERE f64 < 2.0;", + []sql.Row{{float32(-1.0)}, {float32(-1.5)}, {float32(1.0)}, {float32(1.5)}}, + }, + { + "SELECT f32 FROM floattable WHERE f64 > 2.0;", + []sql.Row{{float32(2.5)}}, + }, + { + "SELECT f32 FROM floattable WHERE f64 <> 2.0;", + []sql.Row{{float32(-1.0)}, {float32(-1.5)}, {float32(1.0)}, {float32(1.5)}, {float32(2.5)}}, + }, + { + "SELECT f64 FROM floattable WHERE f32 = 2.0;", + []sql.Row{{float64(2.0)}}, + }, + { + "SELECT f64 FROM floattable WHERE f32 = -1.5;", + []sql.Row{{float64(-1.5)}}, + }, + { + "SELECT f64 FROM floattable WHERE -f32 = -2.0;", + []sql.Row{{float64(2.0)}}, + }, + { + "SELECT f64 FROM floattable WHERE f32 < 2.0;", + []sql.Row{{float64(-1.0)}, {float64(-1.5)}, {float64(1.0)}, {float64(1.5)}}, + }, + { + "SELECT f64 FROM floattable WHERE f32 > 2.0;", + []sql.Row{{float64(2.5)}}, + }, + { + "SELECT f64 FROM floattable WHERE f32 <> 2.0;", + []sql.Row{{float64(-1.0)}, {float64(-1.5)}, {float64(1.0)}, {float64(1.5)}, {float64(2.5)}}, + }, + { + "SELECT f32 FROM floattable ORDER BY f64;", + []sql.Row{{float32(-1.5)}, {float32(-1.0)}, {float32(1.0)}, {float32(1.5)}, {float32(2.0)}, {float32(2.5)}}, + }, { "SELECT i FROM mytable ORDER BY i DESC;", []sql.Row{{int64(3)}, {int64(2)}, {int64(1)}}, }, + { + "SELECT i FROM mytable WHERE 'hello';", + []sql.Row{}, + }, + { + "SELECT i FROM mytable WHERE not 'hello';", + []sql.Row{{int64(1)}, {int64(2)}, {int64(3)}}, + }, { "SELECT i FROM mytable WHERE s = 'first row' ORDER BY i DESC;", []sql.Row{{int64(1)}}, @@ -42,22 +123,96 @@ var queries = []struct { "SELECT i FROM mytable WHERE s = 'first row' ORDER BY i DESC LIMIT 1;", []sql.Row{{int64(1)}}, }, + { + "SELECT i FROM mytable ORDER BY i LIMIT 1 OFFSET 1;", + []sql.Row{{int64(2)}}, + }, + { + "SELECT i FROM mytable ORDER BY i LIMIT 1,1;", + []sql.Row{{int64(2)}}, + }, + { + "SELECT i FROM mytable ORDER BY i LIMIT 3,1;", + nil, + }, + { + "SELECT i FROM mytable ORDER BY i LIMIT 2,100;", + []sql.Row{{int64(3)}}, + }, + { + "SELECT i FROM niltable WHERE b IS NULL", + []sql.Row{{int64(2)}, {nil}}, + }, + { + "SELECT i FROM niltable WHERE b IS NOT NULL", + []sql.Row{{int64(1)}, {nil}, {int64(4)}}, + }, + { + "SELECT i FROM niltable WHERE b", + []sql.Row{{int64(1)}, {int64(4)}}, + }, + { + "SELECT i FROM niltable WHERE NOT b", + []sql.Row{{nil}}, + }, + { + "SELECT i FROM niltable WHERE b IS TRUE", + []sql.Row{{int64(1)}, {int64(4)}}, + }, + { + "SELECT i FROM niltable WHERE b IS NOT TRUE", + []sql.Row{{int64(2)}, {nil}, {nil}}, + }, + { + "SELECT f FROM niltable WHERE b IS FALSE", + []sql.Row{{3.0}}, + }, + { + "SELECT i FROM niltable WHERE b IS NOT FALSE", + []sql.Row{{int64(1)}, {int64(2)}, {int64(4)}, {nil}}, + }, { "SELECT COUNT(*) FROM mytable;", - []sql.Row{{int32(3)}}, + []sql.Row{{int64(3)}}, }, { "SELECT COUNT(*) FROM mytable LIMIT 1;", - []sql.Row{{int32(3)}}, + []sql.Row{{int64(3)}}, }, { "SELECT COUNT(*) AS c FROM mytable;", - []sql.Row{{int32(3)}}, + []sql.Row{{int64(3)}}, }, { "SELECT substring(s, 2, 3) FROM mytable", []sql.Row{{"irs"}, {"eco"}, {"hir"}}, }, + { + `SELECT substring("foo", 2, 2)`, + []sql.Row{{"oo"}}, + }, + { + `SELECT SUBSTRING_INDEX('a.b.c.d.e.f', '.', 2)`, + []sql.Row{ + {"a.b"}, + }, + }, + { + `SELECT SUBSTRING_INDEX('a.b.c.d.e.f', '.', -2)`, + []sql.Row{ + {"e.f"}, + }, + }, + { + `SELECT SUBSTRING_INDEX(SUBSTRING_INDEX('source{d}', '{d}', 1), 'r', -1)`, + []sql.Row{ + {"ce"}, + }, + }, + { + `SELECT SUBSTRING_INDEX(mytable.s, "d", 1) as s FROM mytable INNER JOIN othertable ON (SUBSTRING_INDEX(mytable.s, "d", 1) = SUBSTRING_INDEX(othertable.s2, "d", 1)) GROUP BY 1 HAVING s = 'secon'`, + []sql.Row{{"secon"}}, + }, { "SELECT YEAR('2007-12-11') FROM mytable", []sql.Row{{int32(2007)}, {int32(2007)}, {int32(2007)}}, @@ -86,6 +241,38 @@ var queries = []struct { "SELECT DAYOFYEAR('2007-12-11 20:21:22') FROM mytable", []sql.Row{{int32(345)}, {int32(345)}, {int32(345)}}, }, + { + "SELECT SECOND('2007-12-11T20:21:22Z') FROM mytable", + []sql.Row{{int32(22)}, {int32(22)}, {int32(22)}}, + }, + { + "SELECT DAYOFYEAR('2007-12-11') FROM mytable", + []sql.Row{{int32(345)}, {int32(345)}, {int32(345)}}, + }, + { + "SELECT DAYOFYEAR('20071211') FROM mytable", + []sql.Row{{int32(345)}, {int32(345)}, {int32(345)}}, + }, + { + "SELECT YEARWEEK('0000-01-01')", + []sql.Row{{int32(1)}}, + }, + { + "SELECT YEARWEEK('9999-12-31')", + []sql.Row{{int32(999952)}}, + }, + { + "SELECT YEARWEEK('2008-02-20', 1)", + []sql.Row{{int32(200808)}}, + }, + { + "SELECT YEARWEEK('1987-01-01')", + []sql.Row{{int32(198652)}}, + }, + { + "SELECT YEARWEEK('1987-01-01', 20), YEARWEEK('1987-01-01', 1), YEARWEEK('1987-01-01', 2), YEARWEEK('1987-01-01', 3), YEARWEEK('1987-01-01', 4), YEARWEEK('1987-01-01', 5), YEARWEEK('1987-01-01', 6), YEARWEEK('1987-01-01', 7)", + []sql.Row{{int32(198653), int32(198701), int32(198652), int32(198701), int32(198653), int32(198652), int32(198653), int32(198652)}}, + }, { "SELECT i FROM mytable WHERE i BETWEEN 1 AND 2", []sql.Row{{int64(1)}, {int64(2)}}, @@ -94,6 +281,14 @@ var queries = []struct { "SELECT i FROM mytable WHERE i NOT BETWEEN 1 AND 2", []sql.Row{{int64(3)}}, }, + { + "SELECT substring(mytable.s, 1, 5) as s FROM mytable INNER JOIN othertable ON (substring(mytable.s, 1, 5) = SUBSTRING(othertable.s2, 1, 5)) GROUP BY 1", + []sql.Row{ + {"third"}, + {"secon"}, + {"first"}, + }, + }, { "SELECT i, i2, s2 FROM mytable INNER JOIN othertable ON i = i2", []sql.Row{ @@ -102,6 +297,34 @@ var queries = []struct { {int64(3), int64(3), "first"}, }, }, + { + "SELECT substring(s2, 1), substring(s2, 2), substring(s2, 3) FROM othertable ORDER BY i2", + []sql.Row{ + {"third", "hird", "ird"}, + {"second", "econd", "cond"}, + {"first", "irst", "rst"}, + }, + }, + { + `SELECT substring("first", 1), substring("second", 2), substring("third", 3)`, + []sql.Row{ + {"first", "econd", "ird"}, + }, + }, + { + "SELECT substring(s2, -1), substring(s2, -2), substring(s2, -3) FROM othertable ORDER BY i2", + []sql.Row{ + {"d", "rd", "ird"}, + {"d", "nd", "ond"}, + {"t", "st", "rst"}, + }, + }, + { + `SELECT substring("first", -1), substring("second", -2), substring("third", -3)`, + []sql.Row{ + {"t", "nd", "ird"}, + }, + }, { "SELECT s FROM mytable INNER JOIN othertable " + "ON substring(s2, 1, 2) != '' AND i = i2", @@ -118,9 +341,63 @@ var queries = []struct { ) t GROUP BY fi`, []sql.Row{ - {int32(1), "first row"}, - {int32(1), "second row"}, - {int32(1), "third row"}, + {int64(1), "first row"}, + {int64(1), "second row"}, + {int64(1), "third row"}, + }, + }, + { + `SELECT fi, COUNT(*) FROM ( + SELECT tbl.s AS fi + FROM mytable tbl + ) t + GROUP BY fi + ORDER BY COUNT(*) ASC`, + []sql.Row{ + {"first row", int64(1)}, + {"second row", int64(1)}, + {"third row", int64(1)}, + }, + }, + { + `SELECT COUNT(*), fi FROM ( + SELECT tbl.s AS fi + FROM mytable tbl + ) t + GROUP BY fi + ORDER BY COUNT(*) ASC`, + []sql.Row{ + {int64(1), "first row"}, + {int64(1), "second row"}, + {int64(1), "third row"}, + }, + }, + { + `SELECT COUNT(*) as cnt, fi FROM ( + SELECT tbl.s AS fi + FROM mytable tbl + ) t + GROUP BY 2`, + []sql.Row{ + {int64(1), "first row"}, + {int64(1), "second row"}, + {int64(1), "third row"}, + }, + }, + { + `SELECT COUNT(*) as cnt, s as fi FROM mytable GROUP BY fi`, + []sql.Row{ + {int64(1), "first row"}, + {int64(1), "second row"}, + {int64(1), "third row"}, + }, + }, + { + `SELECT COUNT(*) as cnt, s as fi FROM mytable GROUP BY 2`, + []sql.Row{ + {int64(1), "first row"}, + {int64(1), "second row"}, + {int64(1), "third row"}, }, }, { @@ -148,7 +425,7 @@ var queries = []struct { }, }, { - "SELECT text > 2 FROM tabletest", + "SELECT s > 2 FROM tabletest", []sql.Row{ {false}, {false}, @@ -156,21 +433,21 @@ var queries = []struct { }, }, { - "SELECT * FROM tabletest WHERE text > 0", + "SELECT * FROM tabletest WHERE s > 0", nil, }, { - "SELECT * FROM tabletest WHERE text = 0", + "SELECT * FROM tabletest WHERE s = 0", []sql.Row{ - {"a", int32(1)}, - {"b", int32(2)}, - {"c", int32(3)}, + {int64(1), "first row"}, + {int64(2), "second row"}, + {int64(3), "third row"}, }, }, { - "SELECT * FROM tabletest WHERE text = 'a'", + "SELECT * FROM tabletest WHERE s = 'first row'", []sql.Row{ - {"a", int32(1)}, + {int64(1), "first row"}, }, }, { @@ -196,192 +473,2505 @@ var queries = []struct { `SELECT i AS foo FROM mytable WHERE foo NOT IN (1, 2, 5)`, []sql.Row{{int64(3)}}, }, -} - -func TestQueries(t *testing.T) { - e := newEngine(t) - - for _, tt := range queries { - testQuery(t, e, tt.query, tt.expected) - } -} - -func TestOrderByColumns(t *testing.T) { - require := require.New(t) - e := newEngine(t) - - _, iter, err := e.Query(sql.NewEmptyContext(), "SELECT s, i FROM mytable ORDER BY 2 DESC") - require.NoError(err) - - rows, err := sql.RowIterToRows(iter) - require.NoError(err) - - require.Equal( + { + `SELECT * FROM tabletest, mytable mt INNER JOIN othertable ot ON mt.i = ot.i2`, []sql.Row{ - {"third row", int64(3)}, - {"second row", int64(2)}, - {"first row", int64(1)}, + {int64(1), "first row", int64(1), "first row", "third", int64(1)}, + {int64(1), "first row", int64(2), "second row", "second", int64(2)}, + {int64(1), "first row", int64(3), "third row", "first", int64(3)}, + {int64(2), "second row", int64(1), "first row", "third", int64(1)}, + {int64(2), "second row", int64(2), "second row", "second", int64(2)}, + {int64(2), "second row", int64(3), "third row", "first", int64(3)}, + {int64(3), "third row", int64(1), "first row", "third", int64(1)}, + {int64(3), "third row", int64(2), "second row", "second", int64(2)}, + {int64(3), "third row", int64(3), "third row", "first", int64(3)}, }, - rows, - ) -} - -func TestInsertInto(t *testing.T) { - e := newEngine(t) - testQuery(t, e, - "INSERT INTO mytable (s, i) VALUES ('x', 999);", - []sql.Row{{int64(1)}}, - ) - - testQuery(t, e, - "SELECT i FROM mytable WHERE s = 'x';", - []sql.Row{{int64(999)}}, - ) -} - -func TestAmbiguousColumnResolution(t *testing.T) { - require := require.New(t) - - table := mem.NewTable("foo", sql.Schema{ - {Name: "a", Type: sql.Int64, Source: "foo"}, - {Name: "b", Type: sql.Text, Source: "foo"}, - }) - require.Nil(table.Insert(sql.NewRow(int64(1), "foo"))) - require.Nil(table.Insert(sql.NewRow(int64(2), "bar"))) - require.Nil(table.Insert(sql.NewRow(int64(3), "baz"))) - - table2 := mem.NewTable("bar", sql.Schema{ - {Name: "b", Type: sql.Text, Source: "bar"}, - {Name: "c", Type: sql.Int64, Source: "bar"}, - }) - require.Nil(table2.Insert(sql.NewRow("qux", int64(3)))) - require.Nil(table2.Insert(sql.NewRow("mux", int64(2)))) - require.Nil(table2.Insert(sql.NewRow("pux", int64(1)))) - - db := mem.NewDatabase("mydb") - db.AddTable(table.Name(), table) - db.AddTable(table2.Name(), table2) - - e := sqle.New() - e.AddDatabase(db) - - q := `SELECT f.a, bar.b, f.b FROM foo f INNER JOIN bar ON f.a = bar.c` - ctx := sql.NewEmptyContext() - - _, rows, err := e.Query(ctx, q) - require.NoError(err) - - var rs [][]interface{} - for { - row, err := rows.Next() - if err == io.EOF { - break - } - require.NoError(err) - - rs = append(rs, row) - } - - expected := [][]interface{}{ - {int64(1), "pux", "foo"}, - {int64(2), "mux", "bar"}, - {int64(3), "qux", "baz"}, - } - - require.Equal(expected, rs) -} - -func TestDDL(t *testing.T) { - require := require.New(t) - - e := newEngine(t) - testQuery(t, e, - "CREATE TABLE t1(a INTEGER, b TEXT, c DATE,"+ - "d TIMESTAMP, e VARCHAR(20), f BLOB NOT NULL)", - []sql.Row(nil), - ) - - db, err := e.Catalog.Database("mydb") - require.NoError(err) - - testTable, ok := db.Tables()["t1"] - require.True(ok) - - s := sql.Schema{ - {Name: "a", Type: sql.Int32, Nullable: true, Source: "t1"}, - {Name: "b", Type: sql.Text, Nullable: true, Source: "t1"}, + }, + { + `SELECT split(s," ") FROM mytable`, + []sql.Row{ + sql.NewRow([]interface{}{"first", "row"}), + sql.NewRow([]interface{}{"second", "row"}), + sql.NewRow([]interface{}{"third", "row"}), + }, + }, + { + `SELECT split(s,"s") FROM mytable`, + []sql.Row{ + sql.NewRow([]interface{}{"fir", "t row"}), + sql.NewRow([]interface{}{"", "econd row"}), + sql.NewRow([]interface{}{"third row"}), + }, + }, + { + `SELECT SUM(i) FROM mytable`, + []sql.Row{{float64(6)}}, + }, + { + `SELECT * FROM mytable mt INNER JOIN othertable ot ON mt.i = ot.i2 AND mt.i > 2`, + []sql.Row{ + {int64(3), "third row", "first", int64(3)}, + }, + }, + { + `SELECT i as foo FROM mytable ORDER BY i DESC`, + []sql.Row{ + {int64(3)}, + {int64(2)}, + {int64(1)}, + }, + }, + { + `SELECT COUNT(*) c, i as foo FROM mytable GROUP BY i ORDER BY i DESC`, + []sql.Row{ + {int64(1), int64(3)}, + {int64(1), int64(2)}, + {int64(1), int64(1)}, + }, + }, + { + `SELECT COUNT(*) c, i as foo FROM mytable GROUP BY 2 ORDER BY 2 DESC`, + []sql.Row{ + {int64(1), int64(3)}, + {int64(1), int64(2)}, + {int64(1), int64(1)}, + }, + }, + { + `SELECT COUNT(*) c, i as foo FROM mytable GROUP BY i ORDER BY foo DESC`, + []sql.Row{ + {int64(1), int64(3)}, + {int64(1), int64(2)}, + {int64(1), int64(1)}, + }, + }, + { + `SELECT COUNT(*) c, i as foo FROM mytable GROUP BY 2 ORDER BY foo DESC`, + []sql.Row{ + {int64(1), int64(3)}, + {int64(1), int64(2)}, + {int64(1), int64(1)}, + }, + }, + { + `SELECT COUNT(*) c, i as i FROM mytable GROUP BY 2`, + []sql.Row{ + {int64(1), int64(3)}, + {int64(1), int64(2)}, + {int64(1), int64(1)}, + }, + }, + { + `SELECT i as i FROM mytable GROUP BY 1`, + []sql.Row{ + {int64(3)}, + {int64(2)}, + {int64(1)}, + }, + }, + { + `SELECT CONCAT("a", "b", "c")`, + []sql.Row{ + {string("abc")}, + }, + }, + { + `SELECT COALESCE(NULL, NULL, NULL, 'example', NULL, 1234567890)`, + []sql.Row{ + {string("example")}, + }, + }, + { + `SELECT COALESCE(NULL, NULL, NULL, COALESCE(NULL, 1234567890))`, + []sql.Row{ + {int32(1234567890)}, + }, + }, + { + "SELECT concat(s, i) FROM mytable", + []sql.Row{ + {string("first row1")}, + {string("second row2")}, + {string("third row3")}, + }, + }, + { + "SELECT version()", + []sql.Row{ + {string("8.0.11")}, + }, + }, + { + "SELECT * FROM mytable WHERE 1 > 5", + []sql.Row{}, + }, + { + "SELECT SUM(i) + 1, i FROM mytable GROUP BY i ORDER BY i", + []sql.Row{ + {float64(2), int64(1)}, + {float64(3), int64(2)}, + {float64(4), int64(3)}, + }, + }, + { + "SELECT SUM(i), i FROM mytable GROUP BY i ORDER BY 1+SUM(i) ASC", + []sql.Row{ + {float64(1), int64(1)}, + {float64(2), int64(2)}, + {float64(3), int64(3)}, + }, + }, + { + "SELECT i, SUM(i) FROM mytable GROUP BY i ORDER BY SUM(i) DESC", + []sql.Row{ + {int64(3), float64(3)}, + {int64(2), float64(2)}, + {int64(1), float64(1)}, + }, + }, + { + `/*!40101 SET NAMES utf8 */`, + []sql.Row{}, + }, + { + `SHOW DATABASES`, + []sql.Row{{"mydb"}, {"foo"}}, + }, + { + `SHOW SCHEMAS`, + []sql.Row{{"mydb"}, {"foo"}}, + }, + { + `SELECT SCHEMA_NAME, DEFAULT_CHARACTER_SET_NAME, DEFAULT_COLLATION_NAME FROM information_schema.SCHEMATA`, + []sql.Row{ + {"mydb", "utf8mb4", "utf8_bin"}, + {"foo", "utf8mb4", "utf8_bin"}, + }, + }, + { + `SELECT s FROM mytable WHERE s LIKE '%d row'`, + []sql.Row{ + {"second row"}, + {"third row"}, + }, + }, + { + `SELECT SUBSTRING(s, -3, 3) as s FROM mytable WHERE s LIKE '%d row' GROUP BY 1`, + []sql.Row{ + {"row"}, + }, + }, + { + `SELECT s FROM mytable WHERE s NOT LIKE '%d row'`, + []sql.Row{ + {"first row"}, + }, + }, + { + `SHOW COLUMNS FROM mytable`, + []sql.Row{ + {"i", "INT64", "NO", "", "", ""}, + {"s", "TEXT", "NO", "", "", ""}, + }, + }, + { + `SHOW COLUMNS FROM mytable WHERE Field = 'i'`, + []sql.Row{ + {"i", "INT64", "NO", "", "", ""}, + }, + }, + { + `SHOW COLUMNS FROM mytable LIKE 'i'`, + []sql.Row{ + {"i", "INT64", "NO", "", "", ""}, + }, + }, + { + `SHOW FULL COLUMNS FROM mytable`, + []sql.Row{ + {"i", "INT64", nil, "NO", "", "", "", "", ""}, + {"s", "TEXT", "utf8_bin", "NO", "", "", "", "", ""}, + }, + }, + { + `SHOW TABLE STATUS FROM mydb`, + []sql.Row{ + {"mytable", "InnoDB", "10", "Fixed", int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), nil, nil, nil, "utf8_bin", nil, nil}, + {"othertable", "InnoDB", "10", "Fixed", int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), nil, nil, nil, "utf8_bin", nil, nil}, + {"tabletest", "InnoDB", "10", "Fixed", int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), nil, nil, nil, "utf8_bin", nil, nil}, + {"bigtable", "InnoDB", "10", "Fixed", int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), nil, nil, nil, "utf8_bin", nil, nil}, + {"floattable", "InnoDB", "10", "Fixed", int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), nil, nil, nil, "utf8_bin", nil, nil}, + {"niltable", "InnoDB", "10", "Fixed", int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), nil, nil, nil, "utf8_bin", nil, nil}, + {"newlinetable", "InnoDB", "10", "Fixed", int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), nil, nil, nil, "utf8_bin", nil, nil}, + {"typestable", "InnoDB", "10", "Fixed", int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), nil, nil, nil, "utf8_bin", nil, nil}, + }, + }, + { + `SHOW TABLE STATUS LIKE '%table'`, + []sql.Row{ + {"mytable", "InnoDB", "10", "Fixed", int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), nil, nil, nil, "utf8_bin", nil, nil}, + {"othertable", "InnoDB", "10", "Fixed", int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), nil, nil, nil, "utf8_bin", nil, nil}, + {"bigtable", "InnoDB", "10", "Fixed", int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), nil, nil, nil, "utf8_bin", nil, nil}, + {"floattable", "InnoDB", "10", "Fixed", int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), nil, nil, nil, "utf8_bin", nil, nil}, + {"niltable", "InnoDB", "10", "Fixed", int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), nil, nil, nil, "utf8_bin", nil, nil}, + {"newlinetable", "InnoDB", "10", "Fixed", int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), nil, nil, nil, "utf8_bin", nil, nil}, + {"typestable", "InnoDB", "10", "Fixed", int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), nil, nil, nil, "utf8_bin", nil, nil}, + }, + }, + { + `SHOW TABLE STATUS WHERE Name = 'mytable'`, + []sql.Row{ + {"mytable", "InnoDB", "10", "Fixed", int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), nil, nil, nil, "utf8_bin", nil, nil}, + }, + }, + { + `SELECT i FROM mytable NATURAL JOIN tabletest`, + []sql.Row{ + {int64(1)}, + {int64(2)}, + {int64(3)}, + }, + }, + { + `SELECT * FROM foo.other_table`, + []sql.Row{ + {"a", int32(4)}, + {"b", int32(2)}, + {"c", int32(0)}, + }, + }, + { + `SELECT AVG(23.222000)`, + []sql.Row{ + {float64(23.222)}, + }, + }, + { + `SELECT DATABASE()`, + []sql.Row{ + {"mydb"}, + }, + }, + { + `SHOW VARIABLES`, + []sql.Row{ + {"auto_increment_increment", int64(1)}, + {"time_zone", time.Local.String()}, + {"system_time_zone", time.Local.String()}, + {"max_allowed_packet", math.MaxInt32}, + {"sql_mode", ""}, + {"gtid_mode", int32(0)}, + {"collation_database", "utf8_bin"}, + {"ndbinfo_version", ""}, + {"sql_select_limit", math.MaxInt32}, + {"transaction_isolation", "READ UNCOMMITTED"}, + {"version", ""}, + {"version_comment", ""}, + }, + }, + { + `SHOW VARIABLES LIKE 'gtid_mode`, + []sql.Row{ + {"gtid_mode", int32(0)}, + }, + }, + { + `SHOW VARIABLES LIKE 'gtid%`, + []sql.Row{ + {"gtid_mode", int32(0)}, + }, + }, + { + `SHOW GLOBAL VARIABLES LIKE '%mode`, + []sql.Row{ + {"sql_mode", ""}, + {"gtid_mode", int32(0)}, + }, + }, + { + `SELECT JSON_EXTRACT("foo", "$")`, + []sql.Row{{"foo"}}, + }, + { + `SELECT JSON_UNQUOTE('"foo"')`, + []sql.Row{{"foo"}}, + }, + { + `SELECT JSON_UNQUOTE('[1, 2, 3]')`, + []sql.Row{{"[1, 2, 3]"}}, + }, + { + `SELECT JSON_UNQUOTE('"\\t\\u0032"')`, + []sql.Row{{"\t2"}}, + }, + { + `SELECT JSON_UNQUOTE('"\t\\u0032"')`, + []sql.Row{{"\t2"}}, + }, + { + `SELECT CONNECTION_ID()`, + []sql.Row{{uint32(1)}}, + }, + { + ` + SELECT + LOGFILE_GROUP_NAME, FILE_NAME, TOTAL_EXTENTS, INITIAL_SIZE, ENGINE, EXTRA + FROM INFORMATION_SCHEMA.FILES + WHERE FILE_TYPE = 'UNDO LOG' + AND FILE_NAME IS NOT NULL + AND LOGFILE_GROUP_NAME IS NOT NULL + GROUP BY LOGFILE_GROUP_NAME, FILE_NAME, ENGINE, TOTAL_EXTENTS, INITIAL_SIZE + ORDER BY LOGFILE_GROUP_NAME + `, + []sql.Row{}, + }, + { + ` + SELECT DISTINCT + TABLESPACE_NAME, FILE_NAME, LOGFILE_GROUP_NAME, EXTENT_SIZE, INITIAL_SIZE, ENGINE + FROM INFORMATION_SCHEMA.FILES + WHERE FILE_TYPE = 'DATAFILE' + ORDER BY TABLESPACE_NAME, LOGFILE_GROUP_NAME + `, + []sql.Row{}, + }, + { + ` + SELECT + COLUMN_NAME, + JSON_EXTRACT(HISTOGRAM, '$."number-of-buckets-specified"') + FROM information_schema.COLUMN_STATISTICS + WHERE SCHEMA_NAME = 'mydb' + AND TABLE_NAME = 'mytable' + `, + []sql.Row{}, + }, + { + ` + SELECT TABLE_NAME FROM information_schema.TABLES + WHERE TABLE_SCHEMA='mydb' AND (TABLE_TYPE='BASE TABLE' OR TABLE_TYPE='VIEW') + `, + []sql.Row{ + {"mytable"}, + {"othertable"}, + {"tabletest"}, + {"bigtable"}, + {"floattable"}, + {"niltable"}, + {"newlinetable"}, + {"typestable"}, + }, + }, + { + ` + SELECT COLUMN_NAME, DATA_TYPE FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA='mydb' AND TABLE_NAME='mytable' + `, + []sql.Row{ + {"s", "text"}, + {"i", "bigint"}, + }, + }, + { + ` + SELECT COLUMN_NAME FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA=DATABASE() AND TABLE_NAME LIKE '%table' + GROUP BY COLUMN_NAME + `, + []sql.Row{ + {"s"}, + {"i"}, + {"s2"}, + {"i2"}, + {"t"}, + {"n"}, + {"f32"}, + {"f64"}, + {"b"}, + {"f"}, + {"id"}, + {"i8"}, + {"i16"}, + {"i32"}, + {"i64"}, + {"u8"}, + {"u16"}, + {"u32"}, + {"u64"}, + {"ti"}, + {"da"}, + {"te"}, + {"bo"}, + {"js"}, + {"bl"}, + }, + }, + { + ` + SELECT COLUMN_NAME FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA=DATABASE() AND TABLE_NAME LIKE '%table' + GROUP BY 1 + `, + []sql.Row{ + {"s"}, + {"i"}, + {"s2"}, + {"i2"}, + {"t"}, + {"n"}, + {"f32"}, + {"f64"}, + {"b"}, + {"f"}, + {"id"}, + {"i8"}, + {"i16"}, + {"i32"}, + {"i64"}, + {"u8"}, + {"u16"}, + {"u32"}, + {"u64"}, + {"ti"}, + {"da"}, + {"te"}, + {"bo"}, + {"js"}, + {"bl"}, + }, + }, + { + ` + SELECT COLUMN_NAME as COLUMN_NAME FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA=DATABASE() AND TABLE_NAME LIKE '%table' + GROUP BY 1 + `, + []sql.Row{ + {"s"}, + {"i"}, + {"s2"}, + {"i2"}, + {"t"}, + {"n"}, + {"f32"}, + {"f64"}, + {"b"}, + {"f"}, + {"id"}, + {"i8"}, + {"i16"}, + {"i32"}, + {"i64"}, + {"u8"}, + {"u16"}, + {"u32"}, + {"u64"}, + {"ti"}, + {"da"}, + {"te"}, + {"bo"}, + {"js"}, + {"bl"}, + }, + }, + { + `SHOW CREATE DATABASE mydb`, + []sql.Row{{ + "mydb", + "CREATE DATABASE `mydb` /*!40100 DEFAULT CHARACTER SET utf8mb4 COLLATE utf8_bin */", + }}, + }, + { + `SELECT -1`, + []sql.Row{{int8(-1)}}, + }, + { + ` + SHOW WARNINGS + `, + []sql.Row{}, + }, + { + `SHOW WARNINGS LIMIT 0`, + []sql.Row{}, + }, + { + `SET SESSION NET_READ_TIMEOUT= 700, SESSION NET_WRITE_TIMEOUT= 700`, + []sql.Row{}, + }, + { + `SHOW TABLE STATUS`, + []sql.Row{ + {"mytable", "InnoDB", "10", "Fixed", int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), nil, nil, nil, "utf8_bin", nil, nil}, + {"othertable", "InnoDB", "10", "Fixed", int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), nil, nil, nil, "utf8_bin", nil, nil}, + {"tabletest", "InnoDB", "10", "Fixed", int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), nil, nil, nil, "utf8_bin", nil, nil}, + {"bigtable", "InnoDB", "10", "Fixed", int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), nil, nil, nil, "utf8_bin", nil, nil}, + {"floattable", "InnoDB", "10", "Fixed", int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), nil, nil, nil, "utf8_bin", nil, nil}, + {"niltable", "InnoDB", "10", "Fixed", int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), nil, nil, nil, "utf8_bin", nil, nil}, + {"newlinetable", "InnoDB", "10", "Fixed", int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), nil, nil, nil, "utf8_bin", nil, nil}, + {"typestable", "InnoDB", "10", "Fixed", int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), nil, nil, nil, "utf8_bin", nil, nil}, + }, + }, + { + `SELECT NULL`, + []sql.Row{ + {nil}, + }, + }, + { + `SELECT nullif('abc', NULL)`, + []sql.Row{ + {"abc"}, + }, + }, + { + `SELECT nullif(NULL, NULL)`, + []sql.Row{ + {sql.Null}, + }, + }, + { + `SELECT nullif(NULL, 123)`, + []sql.Row{ + {nil}, + }, + }, + { + `SELECT nullif(123, 123)`, + []sql.Row{ + {sql.Null}, + }, + }, + { + `SELECT nullif(123, 321)`, + []sql.Row{ + {int8(123)}, + }, + }, + { + `SELECT ifnull(123, NULL)`, + []sql.Row{ + {int8(123)}, + }, + }, + { + `SELECT ifnull(NULL, NULL)`, + []sql.Row{ + {nil}, + }, + }, + { + `SELECT ifnull(NULL, 123)`, + []sql.Row{ + {int8(123)}, + }, + }, + { + `SELECT ifnull(123, 123)`, + []sql.Row{ + {int8(123)}, + }, + }, + { + `SELECT ifnull(123, 321)`, + []sql.Row{ + {int8(123)}, + }, + }, + { + `SELECT round(15728640/1024/1024)`, + []sql.Row{ + {int64(15)}, + }, + }, + { + `SELECT round(15, 1)`, + []sql.Row{ + {int8(15)}, + }, + }, + { + `SELECT CASE i WHEN 1 THEN 'one' WHEN 2 THEN 'two' ELSE 'other' END FROM mytable`, + []sql.Row{ + {"one"}, + {"two"}, + {"other"}, + }, + }, + { + `SELECT CASE WHEN i > 2 THEN 'more than two' WHEN i < 2 THEN 'less than two' ELSE 'two' END FROM mytable`, + []sql.Row{ + {"less than two"}, + {"two"}, + {"more than two"}, + }, + }, + { + `SELECT CASE i WHEN 1 THEN 'one' WHEN 2 THEN 'two' END FROM mytable`, + []sql.Row{ + {"one"}, + {"two"}, + {nil}, + }, + }, + { + "SHOW TABLES", + []sql.Row{ + {"mytable"}, + {"othertable"}, + {"tabletest"}, + {"bigtable"}, + {"floattable"}, + {"niltable"}, + {"newlinetable"}, + {"typestable"}, + }, + }, + { + "SHOW FULL TABLES", + []sql.Row{ + {"mytable", "BASE TABLE"}, + {"othertable", "BASE TABLE"}, + {"tabletest", "BASE TABLE"}, + {"bigtable", "BASE TABLE"}, + {"floattable", "BASE TABLE"}, + {"niltable", "BASE TABLE"}, + {"newlinetable", "BASE TABLE"}, + {"typestable", "BASE TABLE"}, + }, + }, + { + "SHOW TABLES FROM foo", + []sql.Row{ + {"other_table"}, + }, + }, + { + "SHOW TABLES LIKE '%table'", + []sql.Row{ + {"mytable"}, + {"othertable"}, + {"bigtable"}, + {"floattable"}, + {"niltable"}, + {"newlinetable"}, + {"typestable"}, + }, + }, + { + "SHOW TABLES WHERE `Table` = 'mytable'", + []sql.Row{ + {"mytable"}, + }, + }, + { + `SHOW COLLATION`, + []sql.Row{{"utf8_bin", "utf8mb4", int64(1), "Yes", "Yes", int64(1)}}, + }, + { + `SHOW COLLATION LIKE 'foo'`, + []sql.Row{}, + }, + { + `SHOW COLLATION LIKE 'utf8%'`, + []sql.Row{{"utf8_bin", "utf8mb4", int64(1), "Yes", "Yes", int64(1)}}, + }, + { + `SHOW COLLATION WHERE charset = 'foo'`, + []sql.Row{}, + }, + { + "SHOW COLLATION WHERE `Default` = 'Yes'", + []sql.Row{{"utf8_bin", "utf8mb4", int64(1), "Yes", "Yes", int64(1)}}, + }, + { + "ROLLBACK", + []sql.Row{}, + }, + { + "SELECT substring(s, 1, 1) FROM mytable ORDER BY substring(s, 1, 1)", + []sql.Row{{"f"}, {"s"}, {"t"}}, + }, + { + "SELECT substring(s, 1, 1), count(*) FROM mytable GROUP BY substring(s, 1, 1)", + []sql.Row{{"f", int64(1)}, {"s", int64(1)}, {"t", int64(1)}}, + }, + { + "SELECT SLEEP(0.5)", + []sql.Row{{int(0)}}, + }, + { + "SELECT TO_BASE64('foo')", + []sql.Row{{string("Zm9v")}}, + }, + { + "SELECT FROM_BASE64('YmFy')", + []sql.Row{{string("bar")}}, + }, + { + "SELECT DATE_ADD('2018-05-02', INTERVAL 1 DAY)", + []sql.Row{{time.Date(2018, time.May, 3, 0, 0, 0, 0, time.UTC)}}, + }, + { + "SELECT DATE_SUB('2018-05-02', INTERVAL 1 DAY)", + []sql.Row{{time.Date(2018, time.May, 1, 0, 0, 0, 0, time.UTC)}}, + }, + { + "SELECT '2018-05-02' + INTERVAL 1 DAY", + []sql.Row{{time.Date(2018, time.May, 3, 0, 0, 0, 0, time.UTC)}}, + }, + { + "SELECT '2018-05-02' - INTERVAL 1 DAY", + []sql.Row{{time.Date(2018, time.May, 1, 0, 0, 0, 0, time.UTC)}}, + }, + { + `SELECT i AS i FROM mytable ORDER BY i`, + []sql.Row{{int64(1)}, {int64(2)}, {int64(3)}}, + }, + { + ` + SELECT + i, + foo + FROM ( + SELECT + i, + COUNT(s) AS foo + FROM mytable + GROUP BY i + ) AS q + ORDER BY foo DESC + `, + []sql.Row{ + {int64(1), int64(1)}, + {int64(2), int64(1)}, + {int64(3), int64(1)}, + }, + }, + { + "SELECT n, COUNT(n) FROM bigtable GROUP BY n HAVING COUNT(n) > 2", + []sql.Row{{int64(1), int64(3)}, {int64(2), int64(3)}}, + }, + { + "SELECT n, MAX(n) FROM bigtable GROUP BY n HAVING COUNT(n) > 2", + []sql.Row{{int64(1), int64(1)}, {int64(2), int64(2)}}, + }, + { + "SELECT substring(mytable.s, 1, 5) as s FROM mytable INNER JOIN othertable ON (substring(mytable.s, 1, 5) = SUBSTRING(othertable.s2, 1, 5)) GROUP BY 1 HAVING s = \"secon\"", + []sql.Row{{"secon"}}, + }, + { + ` + SELECT COLUMN_NAME as COLUMN_NAME FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA=DATABASE() AND TABLE_NAME LIKE '%table' + GROUP BY 1 HAVING SUBSTRING(COLUMN_NAME, 1, 1) = "s" + `, + []sql.Row{{"s"}, {"s2"}}, + }, + { + "SELECT s, i FROM mytable GROUP BY i ORDER BY SUBSTRING(s, 1, 1) DESC", + []sql.Row{ + {string("third row"), int64(3)}, + {string("second row"), int64(2)}, + {string("first row"), int64(1)}, + }, + }, + { + "SELECT s, i FROM mytable GROUP BY i HAVING count(*) > 0 ORDER BY SUBSTRING(s, 1, 1) DESC", + []sql.Row{ + {string("third row"), int64(3)}, + {string("second row"), int64(2)}, + {string("first row"), int64(1)}, + }, + }, + { + "SELECT CONVERT('9999-12-31 23:59:59', DATETIME)", + []sql.Row{{time.Date(9999, time.December, 31, 23, 59, 59, 0, time.UTC)}}, + }, + { + "SELECT CONVERT('10000-12-31 23:59:59', DATETIME)", + []sql.Row{{nil}}, + }, + { + "SELECT '9999-12-31 23:59:59' + INTERVAL 1 DAY", + []sql.Row{{nil}}, + }, + { + "SELECT DATE_ADD('9999-12-31 23:59:59', INTERVAL 1 DAY)", + []sql.Row{{nil}}, + }, + { + `SELECT t.date_col FROM (SELECT CONVERT('2019-06-06 00:00:00', DATETIME) as date_col) t WHERE t.date_col > '0000-01-01 00:00:00'`, + []sql.Row{{time.Date(2019, time.June, 6, 0, 0, 0, 0, time.UTC)}}, + }, + { + `SELECT t.date_col FROM (SELECT CONVERT('2019-06-06 00:00:00', DATETIME) as date_col) t GROUP BY t.date_col`, + []sql.Row{{time.Date(2019, time.June, 6, 0, 0, 0, 0, time.UTC)}}, + }, + { + `SELECT i AS foo FROM mytable ORDER BY mytable.i`, + []sql.Row{{int64(1)}, {int64(2)}, {int64(3)}}, + }, + { + `SELECT JSON_EXTRACT('[1, 2, 3]', '$.[0]')`, + []sql.Row{{float64(1)}}, + }, + { + `SELECT ARRAY_LENGTH(JSON_EXTRACT('[1, 2, 3]', '$'))`, + []sql.Row{{int32(3)}}, + }, + { + `SELECT ARRAY_LENGTH(JSON_EXTRACT('[{"i":0}, {"i":1, "y":"yyy"}, {"i":2, "x":"xxx"}]', '$.i'))`, + []sql.Row{{int32(3)}}, + }, + { + `SELECT GREATEST(1, 2, 3, 4)`, + []sql.Row{{int64(4)}}, + }, + { + `SELECT GREATEST(1, 2, "3", 4)`, + []sql.Row{{float64(4)}}, + }, + { + `SELECT GREATEST(1, 2, "9", "foo999")`, + []sql.Row{{float64(9)}}, + }, + { + `SELECT GREATEST("aaa", "bbb", "ccc")`, + []sql.Row{{"ccc"}}, + }, + { + `SELECT GREATEST(i, s) FROM mytable`, + []sql.Row{{float64(1)}, {float64(2)}, {float64(3)}}, + }, + { + `SELECT LEAST(1, 2, 3, 4)`, + []sql.Row{{int64(1)}}, + }, + { + `SELECT LEAST(1, 2, "3", 4)`, + []sql.Row{{float64(1)}}, + }, + { + `SELECT LEAST(1, 2, "9", "foo999")`, + []sql.Row{{float64(1)}}, + }, + { + `SELECT LEAST("aaa", "bbb", "ccc")`, + []sql.Row{{"aaa"}}, + }, + { + `SELECT LEAST(i, s) FROM mytable`, + []sql.Row{{float64(1)}, {float64(2)}, {float64(3)}}, + }, + { + "SELECT i, i2, s2 FROM mytable LEFT JOIN othertable ON i = i2 - 1", + []sql.Row{ + {int64(1), int64(2), "second"}, + {int64(2), int64(3), "first"}, + {int64(3), nil, nil}, + }, + }, + { + "SELECT i, i2, s2 FROM mytable RIGHT JOIN othertable ON i = i2 - 1", + []sql.Row{ + {nil, int64(1), "third"}, + {int64(1), int64(2), "second"}, + {int64(2), int64(3), "first"}, + }, + }, + { + "SELECT i, i2, s2 FROM mytable LEFT OUTER JOIN othertable ON i = i2 - 1", + []sql.Row{ + {int64(1), int64(2), "second"}, + {int64(2), int64(3), "first"}, + {int64(3), nil, nil}, + }, + }, + { + "SELECT i, i2, s2 FROM mytable RIGHT OUTER JOIN othertable ON i = i2 - 1", + []sql.Row{ + {nil, int64(1), "third"}, + {int64(1), int64(2), "second"}, + {int64(2), int64(3), "first"}, + }, + }, + { + `SELECT CHAR_LENGTH('áé'), LENGTH('àè')`, + []sql.Row{{int32(2), int32(4)}}, + }, + { + "SELECT i, COUNT(i) AS `COUNT(i)` FROM (SELECT i FROM mytable) t GROUP BY i ORDER BY i, `COUNT(i)` DESC", + []sql.Row{{int64(1), int64(1)}, {int64(2), int64(1)}, {int64(3), int64(1)}}, + }, + { + "SELECT i FROM mytable WHERE NOT s ORDER BY 1 DESC", + []sql.Row{ + {int64(3)}, + {int64(2)}, + {int64(1)}, + }, + }, + { + "SELECT i FROM mytable WHERE NOT(NOT i) ORDER BY 1 DESC", + []sql.Row{ + {int64(3)}, + {int64(2)}, + {int64(1)}, + }, + }, + { + `SELECT NOW() - NOW()`, + []sql.Row{{int64(0)}}, + }, + { + `SELECT NOW() - (NOW() - INTERVAL 1 SECOND)`, + []sql.Row{{int64(1)}}, + }, + { + `SELECT SUBSTR(SUBSTRING('0123456789ABCDEF', 1, 10), -4)`, + []sql.Row{{"6789"}}, + }, + { + `SELECT CASE i WHEN 1 THEN i ELSE NULL END FROM mytable`, + []sql.Row{{int64(1)}, {nil}, {nil}}, + }, + { + `SELECT (NULL+1)`, + []sql.Row{{nil}}, + }, + { + `SELECT ARRAY_LENGTH(null)`, + []sql.Row{{nil}}, + }, + { + `SELECT ARRAY_LENGTH("foo")`, + []sql.Row{{nil}}, + }, + { + `SELECT * FROM mytable WHERE NULL AND i = 3`, + []sql.Row{}, + }, + { + `SELECT 1 FROM mytable GROUP BY i HAVING i > 1`, + []sql.Row{{int8(1)}, {int8(1)}}, + }, + { + `SELECT avg(i) FROM mytable GROUP BY i HAVING avg(i) > 1`, + []sql.Row{{float64(2)}, {float64(3)}}, + }, + { + `SELECT s AS s, COUNT(*) AS count, AVG(i) AS ` + "`AVG(i)`" + ` + FROM ( + SELECT * FROM mytable + ) AS expr_qry + GROUP BY s + HAVING ((AVG(i) > 0)) + ORDER BY count DESC + LIMIT 10000`, + []sql.Row{ + {"first row", int64(1), float64(1)}, + {"second row", int64(1), float64(2)}, + {"third row", int64(1), float64(3)}, + }, + }, + { + `SELECT FIRST(i) FROM (SELECT i FROM mytable ORDER BY i) t`, + []sql.Row{{int64(1)}}, + }, + { + `SELECT LAST(i) FROM (SELECT i FROM mytable ORDER BY i) t`, + []sql.Row{{int64(3)}}, + }, + { + `SELECT COUNT(DISTINCT t.i) FROM tabletest t, mytable t2`, + []sql.Row{{int64(3)}}, + }, + { + `SELECT CASE WHEN NULL THEN "yes" ELSE "no" END AS test`, + []sql.Row{{"no"}}, + }, + { + `SELECT + table_schema, + table_name, + CASE + WHEN table_type = 'BASE TABLE' THEN + CASE + WHEN table_schema = 'mysql' + OR table_schema = 'performance_schema' THEN 'SYSTEM TABLE' + ELSE 'TABLE' + END + WHEN table_type = 'TEMPORARY' THEN 'LOCAL_TEMPORARY' + ELSE table_type + END AS TABLE_TYPE + FROM information_schema.tables + WHERE table_schema = 'mydb' + AND table_name = 'mytable' + HAVING table_type IN ('TABLE', 'VIEW') + ORDER BY table_type, table_schema, table_name`, + []sql.Row{{"mydb", "mytable", "TABLE"}}, + }, + { + `SELECT REGEXP_MATCHES("bopbeepbop", "bop")`, + []sql.Row{{[]interface{}{"bop", "bop"}}}, + }, + { + `SELECT EXPLODE(REGEXP_MATCHES("bopbeepbop", "bop"))`, + []sql.Row{{"bop"}, {"bop"}}, + }, + { + `SELECT EXPLODE(REGEXP_MATCHES("helloworld", "bop"))`, + []sql.Row{}, + }, + { + `SELECT EXPLODE(REGEXP_MATCHES("", ""))`, + []sql.Row{{""}}, + }, + { + `SELECT REGEXP_MATCHES(NULL, "")`, + []sql.Row{{nil}}, + }, + { + `SELECT REGEXP_MATCHES("", NULL)`, + []sql.Row{{nil}}, + }, + { + `SELECT REGEXP_MATCHES("", "", NULL)`, + []sql.Row{{nil}}, + }, + { + "SELECT * FROM newlinetable WHERE s LIKE '%text%'", + []sql.Row{ + {int64(1), "\nthere is some text in here"}, + {int64(2), "there is some\ntext in here"}, + {int64(3), "there is some text\nin here"}, + {int64(4), "there is some text in here\n"}, + {int64(5), "there is some text in here"}, + }, + }, + { + `SELECT i FROM mytable WHERE i = (SELECT 1)`, + []sql.Row{{int64(1)}}, + }, + { + `SELECT i FROM mytable WHERE i IN (SELECT i FROM mytable)`, + []sql.Row{ + {int64(1)}, + {int64(2)}, + {int64(3)}, + }, + }, + { + `SELECT i FROM mytable WHERE i NOT IN (SELECT i FROM mytable ORDER BY i ASC LIMIT 2)`, + []sql.Row{ + {int64(3)}, + }, + }, + { + `SELECT (SELECT i FROM mytable ORDER BY i ASC LIMIT 1) AS x`, + []sql.Row{{int64(1)}}, + }, + { + `SELECT DISTINCT n FROM bigtable ORDER BY t`, + []sql.Row{ + {int64(1)}, + {int64(9)}, + {int64(7)}, + {int64(3)}, + {int64(2)}, + {int64(8)}, + {int64(6)}, + {int64(5)}, + {int64(4)}, + }, + }, +} + +func TestQueries(t *testing.T) { + e := newEngine(t) + t.Run("sequential", func(t *testing.T) { + for _, tt := range queries { + testQuery(t, e, tt.query, tt.expected) + } + }) + + ep := newEngineWithParallelism(t, 2) + t.Run("parallel", func(t *testing.T) { + for _, tt := range queries { + testQuery(t, ep, tt.query, tt.expected) + } + }) +} + +func TestSessionSelectLimit(t *testing.T) { + ctx := newCtx() + ctx.Session.Set("sql_select_limit", sql.Int64, int64(1)) + + q := []struct { + query string + expected []sql.Row + }{ + { + "SELECT * FROM mytable ORDER BY i", + []sql.Row{{int64(1), "first row"}}, + }, + { + "SELECT * FROM mytable ORDER BY i LIMIT 2", + []sql.Row{ + {int64(1), "first row"}, + {int64(2), "second row"}, + }, + }, + { + "SELECT i FROM (SELECT i FROM mytable LIMIT 2) t ORDER BY i", + []sql.Row{{int64(1)}}, + }, + { + "SELECT i FROM (SELECT i FROM mytable) t ORDER BY i LIMIT 2", + []sql.Row{{int64(1)}}, + }, + } + e := newEngine(t) + t.Run("sql_select_limit", func(t *testing.T) { + for _, tt := range q { + testQueryWithContext(ctx, t, e, tt.query, tt.expected) + } + }) +} + +func TestSessionDefaults(t *testing.T) { + ctx := newCtx() + ctx.Session.Set("auto_increment_increment", sql.Int64, 0) + ctx.Session.Set("max_allowed_packet", sql.Int64, 0) + ctx.Session.Set("sql_select_limit", sql.Int64, 0) + ctx.Session.Set("ndbinfo_version", sql.Text, "non default value") + + q := `SET @@auto_increment_increment=DEFAULT, + @@max_allowed_packet=DEFAULT, + @@sql_select_limit=DEFAULT, + @@ndbinfo_version=DEFAULT` + + e := newEngine(t) + + defaults := sql.DefaultSessionConfig() + t.Run(q, func(t *testing.T) { + require := require.New(t) + _, _, err := e.Query(ctx, q) + require.NoError(err) + + typ, val := ctx.Get("auto_increment_increment") + require.Equal(defaults["auto_increment_increment"].Typ, typ) + require.Equal(defaults["auto_increment_increment"].Value, val) + + typ, val = ctx.Get("max_allowed_packet") + require.Equal(defaults["max_allowed_packet"].Typ, typ) + require.Equal(defaults["max_allowed_packet"].Value, val) + + typ, val = ctx.Get("sql_select_limit") + require.Equal(defaults["sql_select_limit"].Typ, typ) + require.Equal(defaults["sql_select_limit"].Value, val) + + typ, val = ctx.Get("ndbinfo_version") + require.Equal(defaults["ndbinfo_version"].Typ, typ) + require.Equal(defaults["ndbinfo_version"].Value, val) + }) +} + +func TestWarnings(t *testing.T) { + ctx := newCtx() + ctx.Session.Warn(&sql.Warning{Code: 1}) + ctx.Session.Warn(&sql.Warning{Code: 2}) + ctx.Session.Warn(&sql.Warning{Code: 3}) + + var queries = []struct { + query string + expected []sql.Row + }{ + { + ` + SHOW WARNINGS + `, + []sql.Row{ + {"", 3, ""}, + {"", 2, ""}, + {"", 1, ""}, + }, + }, + { + ` + SHOW WARNINGS LIMIT 1 + `, + []sql.Row{ + {"", 3, ""}, + }, + }, + { + ` + SHOW WARNINGS LIMIT 1,2 + `, + []sql.Row{ + {"", 2, ""}, + {"", 1, ""}, + }, + }, + { + ` + SHOW WARNINGS LIMIT 0 + `, + []sql.Row{ + {"", 3, ""}, + {"", 2, ""}, + {"", 1, ""}, + }, + }, + { + ` + SHOW WARNINGS LIMIT 2,0 + `, + []sql.Row{ + {"", 1, ""}, + }, + }, + { + ` + SHOW WARNINGS LIMIT 10 + `, + []sql.Row{ + {"", 3, ""}, + {"", 2, ""}, + {"", 1, ""}, + }, + }, + { + ` + SHOW WARNINGS LIMIT 10,1 + `, + []sql.Row{}, + }, + } + + e := newEngine(t) + ep := newEngineWithParallelism(t, 2) + + t.Run("sequential", func(t *testing.T) { + for _, tt := range queries { + testQueryWithContext(ctx, t, e, tt.query, tt.expected) + } + }) + + t.Run("parallel", func(t *testing.T) { + for _, tt := range queries { + testQueryWithContext(ctx, t, ep, tt.query, tt.expected) + } + }) +} + +func TestClearWarnings(t *testing.T) { + require := require.New(t) + e := newEngine(t) + ctx := newCtx() + + _, iter, err := e.Query(ctx, "-- some empty query as a comment") + require.NoError(err) + err = iter.Close() + require.NoError(err) + + _, iter, err = e.Query(ctx, "-- some empty query as a comment") + require.NoError(err) + err = iter.Close() + require.NoError(err) + + _, iter, err = e.Query(ctx, "-- some empty query as a comment") + require.NoError(err) + err = iter.Close() + require.NoError(err) + + _, iter, err = e.Query(ctx, "SHOW WARNINGS") + require.NoError(err) + rows, err := sql.RowIterToRows(iter) + require.NoError(err) + err = iter.Close() + require.NoError(err) + require.Equal(3, len(rows)) + + _, iter, err = e.Query(ctx, "SHOW WARNINGS LIMIT 1") + require.NoError(err) + rows, err = sql.RowIterToRows(iter) + require.NoError(err) + err = iter.Close() + require.NoError(err) + require.Equal(1, len(rows)) + + _, _, err = e.Query(ctx, "SELECT * FROM mytable LIMIT 1") + require.NoError(err) + _, err = sql.RowIterToRows(iter) + require.NoError(err) + err = iter.Close() + require.NoError(err) + + require.Equal(0, len(ctx.Session.Warnings())) +} + +func TestDescribe(t *testing.T) { + e := newEngine(t) + + ep := newEngineWithParallelism(t, 2) + + query := `DESCRIBE FORMAT=TREE SELECT * FROM mytable` + expectedSeq := []sql.Row{ + sql.NewRow("Table(mytable): Projected "), + sql.NewRow(" ├─ Column(i, INT64, nullable=false)"), + sql.NewRow(" └─ Column(s, TEXT, nullable=false)"), + } + + expectedParallel := []sql.Row{ + {"Exchange(parallelism=2)"}, + {" └─ Table(mytable): Projected "}, + {" ├─ Column(i, INT64, nullable=false)"}, + {" └─ Column(s, TEXT, nullable=false)"}, + } + + t.Run("sequential", func(t *testing.T) { + testQuery(t, e, query, expectedSeq) + }) + + t.Run("parallel", func(t *testing.T) { + testQuery(t, ep, query, expectedParallel) + }) +} + +func TestOrderByColumns(t *testing.T) { + require := require.New(t) + e := newEngine(t) + + _, iter, err := e.Query(newCtx(), "SELECT s, i FROM mytable ORDER BY 2 DESC") + require.NoError(err) + + rows, err := sql.RowIterToRows(iter) + require.NoError(err) + + require.Equal( + []sql.Row{ + {"third row", int64(3)}, + {"second row", int64(2)}, + {"first row", int64(1)}, + }, + rows, + ) +} + +func TestInsertInto(t *testing.T) { + var insertions = []struct { + insertQuery string + expectedInsert []sql.Row + selectQuery string + expectedSelect []sql.Row + }{ + { + "INSERT INTO mytable (s, i) VALUES ('x', 999);", + []sql.Row{{int64(1)}}, + "SELECT i FROM mytable WHERE s = 'x';", + []sql.Row{{int64(999)}}, + }, + { + "INSERT INTO mytable SET s = 'x', i = 999;", + []sql.Row{{int64(1)}}, + "SELECT i FROM mytable WHERE s = 'x';", + []sql.Row{{int64(999)}}, + }, + { + "INSERT INTO mytable VALUES (999, 'x');", + []sql.Row{{int64(1)}}, + "SELECT i FROM mytable WHERE s = 'x';", + []sql.Row{{int64(999)}}, + }, + { + "INSERT INTO mytable SET i = 999, s = 'x';", + []sql.Row{{int64(1)}}, + "SELECT i FROM mytable WHERE s = 'x';", + []sql.Row{{int64(999)}}, + }, + { + `INSERT INTO typestable VALUES ( + 999, 127, 32767, 2147483647, 9223372036854775807, + 255, 65535, 4294967295, 18446744073709551615, + 3.40282346638528859811704183484516925440e+38, 1.797693134862315708145274237317043567981e+308, + '2132-04-05 12:51:36', '2231-11-07', + 'random text', true, '{"key":"value"}', 'blobdata' + );`, + []sql.Row{{int64(1)}}, + "SELECT * FROM typestable WHERE id = 999;", + []sql.Row{{ + int64(999), int8(math.MaxInt8), int16(math.MaxInt16), int32(math.MaxInt32), int64(math.MaxInt64), + uint8(math.MaxUint8), uint16(math.MaxUint16), uint32(math.MaxUint32), uint64(math.MaxUint64), + float64(math.MaxFloat32), float64(math.MaxFloat64), + timeParse(sql.TimestampLayout, "2132-04-05 12:51:36"), timeParse(sql.DateLayout, "2231-11-07"), + "random text", true, `{"key":"value"}`, "blobdata", + }}, + }, + { + `INSERT INTO typestable SET + id = 999, i8 = 127, i16 = 32767, i32 = 2147483647, i64 = 9223372036854775807, + u8 = 255, u16 = 65535, u32 = 4294967295, u64 = 18446744073709551615, + f32 = 3.40282346638528859811704183484516925440e+38, f64 = 1.797693134862315708145274237317043567981e+308, + ti = '2132-04-05 12:51:36', da = '2231-11-07', + te = 'random text', bo = true, js = '{"key":"value"}', bl = 'blobdata' + ;`, + []sql.Row{{int64(1)}}, + "SELECT * FROM typestable WHERE id = 999;", + []sql.Row{{ + int64(999), int8(math.MaxInt8), int16(math.MaxInt16), int32(math.MaxInt32), int64(math.MaxInt64), + uint8(math.MaxUint8), uint16(math.MaxUint16), uint32(math.MaxUint32), uint64(math.MaxUint64), + float64(math.MaxFloat32), float64(math.MaxFloat64), + timeParse(sql.TimestampLayout, "2132-04-05 12:51:36"), timeParse(sql.DateLayout, "2231-11-07"), + "random text", true, `{"key":"value"}`, "blobdata", + }}, + }, + { + `INSERT INTO typestable VALUES ( + 999, -128, -32768, -2147483648, -9223372036854775808, + 0, 0, 0, 0, + 1.401298464324817070923729583289916131280e-45, 4.940656458412465441765687928682213723651e-324, + '0010-04-05 12:51:36', '0101-11-07', + '', false, '', '' + );`, + []sql.Row{{int64(1)}}, + "SELECT * FROM typestable WHERE id = 999;", + []sql.Row{{ + int64(999), int8(-math.MaxInt8 - 1), int16(-math.MaxInt16 - 1), int32(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1), + uint8(0), uint16(0), uint32(0), uint64(0), + float64(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64), + timeParse(sql.TimestampLayout, "0010-04-05 12:51:36"), timeParse(sql.DateLayout, "0101-11-07"), + "", false, ``, "", + }}, + }, + { + `INSERT INTO typestable SET + id = 999, i8 = -128, i16 = -32768, i32 = -2147483648, i64 = -9223372036854775808, + u8 = 0, u16 = 0, u32 = 0, u64 = 0, + f32 = 1.401298464324817070923729583289916131280e-45, f64 = 4.940656458412465441765687928682213723651e-324, + ti = '0010-04-05 12:51:36', da = '0101-11-07', + te = '', bo = false, js = '', bl = '' + ;`, + []sql.Row{{int64(1)}}, + "SELECT * FROM typestable WHERE id = 999;", + []sql.Row{{ + int64(999), int8(-math.MaxInt8 - 1), int16(-math.MaxInt16 - 1), int32(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1), + uint8(0), uint16(0), uint32(0), uint64(0), + float64(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64), + timeParse(sql.TimestampLayout, "0010-04-05 12:51:36"), timeParse(sql.DateLayout, "0101-11-07"), + "", false, ``, "", + }}, + }, + { + `INSERT INTO typestable VALUES (999, null, null, null, null, null, null, null, null, + null, null, null, null, null, null, null, null);`, + []sql.Row{{int64(1)}}, + "SELECT * FROM typestable WHERE id = 999;", + []sql.Row{{int64(999), nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil}}, + }, + { + `INSERT INTO typestable SET id=999, i8=null, i16=null, i32=null, i64=null, u8=null, u16=null, u32=null, u64=null, + f32=null, f64=null, ti=null, da=null, te=null, bo=null, js=null, bl=null;`, + []sql.Row{{int64(1)}}, + "SELECT * FROM typestable WHERE id = 999;", + []sql.Row{{int64(999), nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil}}, + }, + } + + for _, insertion := range insertions { + e := newEngine(t) + ctx := newCtx() + testQueryWithContext(ctx, t, e, insertion.insertQuery, insertion.expectedInsert) + testQueryWithContext(ctx, t, e, insertion.selectQuery, insertion.expectedSelect) + } +} + +func TestInsertIntoErrors(t *testing.T) { + var expectedFailures = []struct { + name string + query string + }{ + { + "too few values", + "INSERT INTO mytable (s, i) VALUES ('x');", + }, + { + "too many values one column", + "INSERT INTO mytable (s) VALUES ('x', 999);", + }, + { + "too many values two columns", + "INSERT INTO mytable (i, s) VALUES (999, 'x', 'y');", + }, + { + "too few values no columns specified", + "INSERT INTO mytable VALUES (999);", + }, + { + "too many values no columns specified", + "INSERT INTO mytable VALUES (999, 'x', 'y');", + }, + { + "non-existent column values", + "INSERT INTO mytable (i, s, z) VALUES (999, 'x', 999);", + }, + { + "non-existent column set", + "INSERT INTO mytable SET i = 999, s = 'x', z = 999;", + }, + { + "duplicate column", + "INSERT INTO mytable (i, s, s) VALUES (999, 'x', 'x');", + }, + { + "duplicate column set", + "INSERT INTO mytable SET i = 999, s = 'y', s = 'y';", + }, + { + "null given to non-nullable", + "INSERT INTO mytable (i, s) VALUES (null, 'y');", + }, + } + + for _, expectedFailure := range expectedFailures { + t.Run(expectedFailure.name, func(t *testing.T) { + _, _, err := newEngine(t).Query(newCtx(), expectedFailure.query) + require.Error(t, err) + }) + } +} + +func TestReplaceInto(t *testing.T) { + var insertions = []struct { + replaceQuery string + expectedReplace []sql.Row + selectQuery string + expectedSelect []sql.Row + }{ + { + "REPLACE INTO mytable VALUES (1, 'first row');", + []sql.Row{{int64(2)}}, + "SELECT s FROM mytable WHERE i = 1;", + []sql.Row{{"first row"}}, + }, + { + "REPLACE INTO mytable SET i = 1, s = 'first row';", + []sql.Row{{int64(2)}}, + "SELECT s FROM mytable WHERE i = 1;", + []sql.Row{{"first row"}}, + }, + { + "REPLACE INTO mytable VALUES (1, 'new row same i');", + []sql.Row{{int64(1)}}, + "SELECT s FROM mytable WHERE i = 1;", + []sql.Row{{"first row"}, {"new row same i"}}, + }, + { + "REPLACE INTO mytable (s, i) VALUES ('x', 999);", + []sql.Row{{int64(1)}}, + "SELECT i FROM mytable WHERE s = 'x';", + []sql.Row{{int64(999)}}, + }, + { + "REPLACE INTO mytable SET s = 'x', i = 999;", + []sql.Row{{int64(1)}}, + "SELECT i FROM mytable WHERE s = 'x';", + []sql.Row{{int64(999)}}, + }, + { + "REPLACE INTO mytable VALUES (999, 'x');", + []sql.Row{{int64(1)}}, + "SELECT i FROM mytable WHERE s = 'x';", + []sql.Row{{int64(999)}}, + }, + { + "REPLACE INTO mytable SET i = 999, s = 'x';", + []sql.Row{{int64(1)}}, + "SELECT i FROM mytable WHERE s = 'x';", + []sql.Row{{int64(999)}}, + }, + { + `REPLACE INTO typestable VALUES ( + 999, 127, 32767, 2147483647, 9223372036854775807, + 255, 65535, 4294967295, 18446744073709551615, + 3.40282346638528859811704183484516925440e+38, 1.797693134862315708145274237317043567981e+308, + '2132-04-05 12:51:36', '2231-11-07', + 'random text', true, '{"key":"value"}', 'blobdata' + );`, + []sql.Row{{int64(1)}}, + "SELECT * FROM typestable WHERE id = 999;", + []sql.Row{{ + int64(999), int8(math.MaxInt8), int16(math.MaxInt16), int32(math.MaxInt32), int64(math.MaxInt64), + uint8(math.MaxUint8), uint16(math.MaxUint16), uint32(math.MaxUint32), uint64(math.MaxUint64), + float64(math.MaxFloat32), float64(math.MaxFloat64), + timeParse(sql.TimestampLayout, "2132-04-05 12:51:36"), timeParse(sql.DateLayout, "2231-11-07"), + "random text", true, `{"key":"value"}`, "blobdata", + }}, + }, + { + `REPLACE INTO typestable SET + id = 999, i8 = 127, i16 = 32767, i32 = 2147483647, i64 = 9223372036854775807, + u8 = 255, u16 = 65535, u32 = 4294967295, u64 = 18446744073709551615, + f32 = 3.40282346638528859811704183484516925440e+38, f64 = 1.797693134862315708145274237317043567981e+308, + ti = '2132-04-05 12:51:36', da = '2231-11-07', + te = 'random text', bo = true, js = '{"key":"value"}', bl = 'blobdata' + ;`, + []sql.Row{{int64(1)}}, + "SELECT * FROM typestable WHERE id = 999;", + []sql.Row{{ + int64(999), int8(math.MaxInt8), int16(math.MaxInt16), int32(math.MaxInt32), int64(math.MaxInt64), + uint8(math.MaxUint8), uint16(math.MaxUint16), uint32(math.MaxUint32), uint64(math.MaxUint64), + float64(math.MaxFloat32), float64(math.MaxFloat64), + timeParse(sql.TimestampLayout, "2132-04-05 12:51:36"), timeParse(sql.DateLayout, "2231-11-07"), + "random text", true, `{"key":"value"}`, "blobdata", + }}, + }, + { + `REPLACE INTO typestable VALUES ( + 999, -128, -32768, -2147483648, -9223372036854775808, + 0, 0, 0, 0, + 1.401298464324817070923729583289916131280e-45, 4.940656458412465441765687928682213723651e-324, + '0010-04-05 12:51:36', '0101-11-07', + '', false, '', '' + );`, + []sql.Row{{int64(1)}}, + "SELECT * FROM typestable WHERE id = 999;", + []sql.Row{{ + int64(999), int8(-math.MaxInt8 - 1), int16(-math.MaxInt16 - 1), int32(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1), + uint8(0), uint16(0), uint32(0), uint64(0), + float64(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64), + timeParse(sql.TimestampLayout, "0010-04-05 12:51:36"), timeParse(sql.DateLayout, "0101-11-07"), + "", false, ``, "", + }}, + }, + { + `REPLACE INTO typestable SET + id = 999, i8 = -128, i16 = -32768, i32 = -2147483648, i64 = -9223372036854775808, + u8 = 0, u16 = 0, u32 = 0, u64 = 0, + f32 = 1.401298464324817070923729583289916131280e-45, f64 = 4.940656458412465441765687928682213723651e-324, + ti = '0010-04-05 12:51:36', da = '0101-11-07', + te = '', bo = false, js = '', bl = '' + ;`, + []sql.Row{{int64(1)}}, + "SELECT * FROM typestable WHERE id = 999;", + []sql.Row{{ + int64(999), int8(-math.MaxInt8 - 1), int16(-math.MaxInt16 - 1), int32(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1), + uint8(0), uint16(0), uint32(0), uint64(0), + float64(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64), + timeParse(sql.TimestampLayout, "0010-04-05 12:51:36"), timeParse(sql.DateLayout, "0101-11-07"), + "", false, ``, "", + }}, + }, + { + `REPLACE INTO typestable VALUES (999, null, null, null, null, null, null, null, null, + null, null, null, null, null, null, null, null);`, + []sql.Row{{int64(1)}}, + "SELECT * FROM typestable WHERE id = 999;", + []sql.Row{{int64(999), nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil}}, + }, + { + `REPLACE INTO typestable SET id=999, i8=null, i16=null, i32=null, i64=null, u8=null, u16=null, u32=null, u64=null, + f32=null, f64=null, ti=null, da=null, te=null, bo=null, js=null, bl=null;`, + []sql.Row{{int64(1)}}, + "SELECT * FROM typestable WHERE id = 999;", + []sql.Row{{int64(999), nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil}}, + }, + } + + for _, insertion := range insertions { + e := newEngine(t) + ctx := newCtx() + testQueryWithContext(ctx, t, e, insertion.replaceQuery, insertion.expectedReplace) + testQueryWithContext(ctx, t, e, insertion.selectQuery, insertion.expectedSelect) + } +} + +func TestReplaceIntoErrors(t *testing.T) { + var expectedFailures = []struct { + name string + query string + }{ + { + "too few values", + "REPLACE INTO mytable (s, i) VALUES ('x');", + }, + { + "too many values one column", + "REPLACE INTO mytable (s) VALUES ('x', 999);", + }, + { + "too many values two columns", + "REPLACE INTO mytable (i, s) VALUES (999, 'x', 'y');", + }, + { + "too few values no columns specified", + "REPLACE INTO mytable VALUES (999);", + }, + { + "too many values no columns specified", + "REPLACE INTO mytable VALUES (999, 'x', 'y');", + }, + { + "non-existent column values", + "REPLACE INTO mytable (i, s, z) VALUES (999, 'x', 999);", + }, + { + "non-existent column set", + "REPLACE INTO mytable SET i = 999, s = 'x', z = 999;", + }, + { + "duplicate column values", + "REPLACE INTO mytable (i, s, s) VALUES (999, 'x', 'x');", + }, + { + "duplicate column set", + "REPLACE INTO mytable SET i = 999, s = 'y', s = 'y';", + }, + { + "null given to non-nullable values", + "INSERT INTO mytable (i, s) VALUES (null, 'y');", + }, + { + "null given to non-nullable set", + "INSERT INTO mytable SET i = null, s = 'y';", + }, + } + + for _, expectedFailure := range expectedFailures { + t.Run(expectedFailure.name, func(t *testing.T) { + _, _, err := newEngine(t).Query(newCtx(), expectedFailure.query) + require.Error(t, err) + }) + } +} + +func TestUpdate(t *testing.T) { + var updates = []struct { + updateQuery string + expectedUpdate []sql.Row + selectQuery string + expectedSelect []sql.Row + }{ + { + "UPDATE mytable SET s = 'updated';", + []sql.Row{{int64(3), int64(3)}}, + "SELECT * FROM mytable;", + []sql.Row{{int64(1), "updated"}, {int64(2), "updated"}, {int64(3), "updated"}}, + }, + { + "UPDATE mytable SET s = 'updated' WHERE i > 9999;", + []sql.Row{{int64(0), int64(0)}}, + "SELECT * FROM mytable;", + []sql.Row{{int64(1), "first row"}, {int64(2), "second row"}, {int64(3), "third row"}}, + }, + { + "UPDATE mytable SET s = 'updated' WHERE i = 1;", + []sql.Row{{int64(1), int64(1)}}, + "SELECT * FROM mytable;", + []sql.Row{{int64(1), "updated"}, {int64(2), "second row"}, {int64(3), "third row"}}, + }, + { + "UPDATE mytable SET s = 'updated' WHERE i <> 9999;", + []sql.Row{{int64(3), int64(3)}}, + "SELECT * FROM mytable;", + []sql.Row{{int64(1), "updated"}, {int64(2), "updated"}, {int64(3), "updated"}}, + }, + { + "UPDATE floattable SET f32 = f32 + f32, f64 = f32 * f64 WHERE i = 2;", + []sql.Row{{int64(1), int64(1)}}, + "SELECT * FROM floattable WHERE i = 2;", + []sql.Row{{int64(2), float32(3.0), float64(4.5)}}, + }, + { + "UPDATE floattable SET f32 = 5, f32 = 4 WHERE i = 1;", + []sql.Row{{int64(1), int64(1)}}, + "SELECT f32 FROM floattable WHERE i = 1;", + []sql.Row{{float32(4.0)}}, + }, + { + "UPDATE mytable SET s = 'first row' WHERE i = 1;", + []sql.Row{{int64(1), int64(0)}}, + "SELECT * FROM mytable;", + []sql.Row{{int64(1), "first row"}, {int64(2), "second row"}, {int64(3), "third row"}}, + }, + { + "UPDATE niltable SET b = NULL WHERE f IS NULL;", + []sql.Row{{int64(2), int64(1)}}, + "SELECT * FROM niltable WHERE f IS NULL;", + []sql.Row{{int64(4), nil, nil}, {nil, nil, nil}}, + }, + { + "UPDATE mytable SET s = 'updated' ORDER BY i ASC LIMIT 2;", + []sql.Row{{int64(2), int64(2)}}, + "SELECT * FROM mytable;", + []sql.Row{{int64(1), "updated"}, {int64(2), "updated"}, {int64(3), "third row"}}, + }, + { + "UPDATE mytable SET s = 'updated' ORDER BY i DESC LIMIT 2;", + []sql.Row{{int64(2), int64(2)}}, + "SELECT * FROM mytable;", + []sql.Row{{int64(1), "first row"}, {int64(2), "updated"}, {int64(3), "updated"}}, + }, + { + "UPDATE mytable SET s = 'updated' ORDER BY i LIMIT 1 OFFSET 1;", + []sql.Row{{int64(1), int64(1)}}, + "SELECT * FROM mytable;", + []sql.Row{{int64(1), "first row"}, {int64(2), "updated"}, {int64(3), "third row"}}, + }, + { + "UPDATE mytable SET s = 'updated';", + []sql.Row{{int64(3), int64(3)}}, + "SELECT * FROM mytable;", + []sql.Row{{int64(1), "updated"}, {int64(2), "updated"}, {int64(3), "updated"}}, + }, + } + + for _, update := range updates { + e := newEngine(t) + ctx := newCtx() + testQueryWithContext(ctx, t, e, update.updateQuery, update.expectedUpdate) + testQueryWithContext(ctx, t, e, update.selectQuery, update.expectedSelect) + } +} + +func TestUpdateErrors(t *testing.T) { + var expectedFailures = []struct { + name string + query string + }{ + { + "invalid table", + "UPDATE doesnotexist SET i = 0;", + }, + { + "invalid column set", + "UPDATE mytable SET z = 0;", + }, + { + "invalid column set value", + "UPDATE mytable SET i = z;", + }, + { + "invalid column where", + "UPDATE mytable SET s = 'hi' WHERE z = 1;", + }, + { + "invalid column order by", + "UPDATE mytable SET s = 'hi' ORDER BY z;", + }, + { + "negative limit", + "UPDATE mytable SET s = 'hi' LIMIT -1;", + }, + { + "negative offset", + "UPDATE mytable SET s = 'hi' LIMIT 1 OFFSET -1;", + }, + { + "set null on non-nullable", + "UPDATE mytable SET s = NULL;", + }, + } + + for _, expectedFailure := range expectedFailures { + t.Run(expectedFailure.name, func(t *testing.T) { + _, _, err := newEngine(t).Query(newCtx(), expectedFailure.query) + require.Error(t, err) + }) + } +} + +const testNumPartitions = 5 + +func TestAmbiguousColumnResolution(t *testing.T) { + require := require.New(t) + + table := memory.NewPartitionedTable("foo", sql.Schema{ + {Name: "a", Type: sql.Int64, Source: "foo"}, + {Name: "b", Type: sql.Text, Source: "foo"}, + }, testNumPartitions) + + insertRows( + t, table, + sql.NewRow(int64(1), "foo"), + sql.NewRow(int64(2), "bar"), + sql.NewRow(int64(3), "baz"), + ) + + table2 := memory.NewPartitionedTable("bar", sql.Schema{ + {Name: "b", Type: sql.Text, Source: "bar"}, + {Name: "c", Type: sql.Int64, Source: "bar"}, + }, testNumPartitions) + insertRows( + t, table2, + sql.NewRow("qux", int64(3)), + sql.NewRow("mux", int64(2)), + sql.NewRow("pux", int64(1)), + ) + + db := memory.NewDatabase("mydb") + db.AddTable("foo", table) + db.AddTable("bar", table2) + + e := sqle.NewDefault() + e.AddDatabase(db) + + q := `SELECT f.a, bar.b, f.b FROM foo f INNER JOIN bar ON f.a = bar.c` + ctx := newCtx() + + _, rows, err := e.Query(ctx, q) + require.NoError(err) + + var rs [][]interface{} + for { + row, err := rows.Next() + if err == io.EOF { + break + } + require.NoError(err) + + rs = append(rs, row) + } + + expected := [][]interface{}{ + {int64(1), "pux", "foo"}, + {int64(2), "mux", "bar"}, + {int64(3), "qux", "baz"}, + } + + require.Equal(expected, rs) +} + +func TestCreateTable(t *testing.T) { + require := require.New(t) + + e := newEngine(t) + testQuery(t, e, + "CREATE TABLE t1(a INTEGER, b TEXT, c DATE, "+ + "d TIMESTAMP, e VARCHAR(20), f BLOB NOT NULL, "+ + "b1 BOOL, b2 BOOLEAN NOT NULL, g DATETIME, h CHAR(40))", + []sql.Row(nil), + ) + + db, err := e.Catalog.Database("mydb") + require.NoError(err) + + testTable, ok := db.Tables()["t1"] + require.True(ok) + + s := sql.Schema{ + {Name: "a", Type: sql.Int32, Nullable: true, Source: "t1"}, + {Name: "b", Type: sql.Text, Nullable: true, Source: "t1"}, {Name: "c", Type: sql.Date, Nullable: true, Source: "t1"}, {Name: "d", Type: sql.Timestamp, Nullable: true, Source: "t1"}, {Name: "e", Type: sql.Text, Nullable: true, Source: "t1"}, {Name: "f", Type: sql.Blob, Source: "t1"}, + {Name: "b1", Type: sql.Uint8, Nullable: true, Source: "t1"}, + {Name: "b2", Type: sql.Uint8, Source: "t1"}, + {Name: "g", Type: sql.Datetime, Nullable: true, Source: "t1"}, + {Name: "h", Type: sql.Text, Nullable: true, Source: "t1"}, + } + + require.Equal(s, testTable.Schema()) + + testQuery(t, e, + "CREATE TABLE t2 (a INTEGER NOT NULL PRIMARY KEY, "+ + "b VARCHAR(10) NOT NULL)", + []sql.Row(nil), + ) + + db, err = e.Catalog.Database("mydb") + require.NoError(err) + + testTable, ok = db.Tables()["t2"] + require.True(ok) + + s = sql.Schema{ + {Name: "a", Type: sql.Int32, Nullable: false, PrimaryKey: true, Source: "t2"}, + {Name: "b", Type: sql.Text, Nullable: false, Source: "t2"}, + } + + require.Equal(s, testTable.Schema()) + + testQuery(t, e, + "CREATE TABLE t3(a INTEGER NOT NULL,"+ + "b TEXT NOT NULL,"+ + "c bool, primary key (a,b))", + []sql.Row(nil), + ) + + db, err = e.Catalog.Database("mydb") + require.NoError(err) + + testTable, ok = db.Tables()["t3"] + require.True(ok) + + s = sql.Schema{ + {Name: "a", Type: sql.Int32, Nullable: false, PrimaryKey: true, Source: "t3"}, + {Name: "b", Type: sql.Text, Nullable: false, PrimaryKey: true, Source: "t3"}, + {Name: "c", Type: sql.Uint8, Nullable: true, Source: "t3"}, } - require.Equal(s, testTable.Schema()) + require.Equal(s, testTable.Schema()) +} + +func TestDropTable(t *testing.T) { + require := require.New(t) + + e := newEngine(t) + db, err := e.Catalog.Database("mydb") + require.NoError(err) + + _, ok := db.Tables()["mytable"] + require.True(ok) + + testQuery(t, e, + "DROP TABLE IF EXISTS mytable, not_exist", + []sql.Row(nil), + ) + + _, ok = db.Tables()["mytable"] + require.False(ok) + + _, ok = db.Tables()["othertable"] + require.True(ok) + _, ok = db.Tables()["tabletest"] + require.True(ok) + + testQuery(t, e, + "DROP TABLE IF EXISTS othertable, tabletest", + []sql.Row(nil), + ) + + _, ok = db.Tables()["othertable"] + require.False(ok) + _, ok = db.Tables()["tabletest"] + require.False(ok) + + _, _, err = e.Query(newCtx(), "DROP TABLE not_exist") + require.Error(err) +} + +func TestNaturalJoin(t *testing.T) { + require := require.New(t) + + t1 := memory.NewPartitionedTable("t1", sql.Schema{ + {Name: "a", Type: sql.Text, Source: "t1"}, + {Name: "b", Type: sql.Text, Source: "t1"}, + {Name: "c", Type: sql.Text, Source: "t1"}, + }, testNumPartitions) + + insertRows( + t, t1, + sql.NewRow("a_1", "b_1", "c_1"), + sql.NewRow("a_2", "b_2", "c_2"), + sql.NewRow("a_3", "b_3", "c_3"), + ) + + t2 := memory.NewPartitionedTable("t2", sql.Schema{ + {Name: "a", Type: sql.Text, Source: "t2"}, + {Name: "b", Type: sql.Text, Source: "t2"}, + {Name: "d", Type: sql.Text, Source: "t2"}, + }, testNumPartitions) + + insertRows( + t, t2, + sql.NewRow("a_1", "b_1", "d_1"), + sql.NewRow("a_2", "b_2", "d_2"), + sql.NewRow("a_3", "b_3", "d_3"), + ) + + db := memory.NewDatabase("mydb") + db.AddTable("t1", t1) + db.AddTable("t2", t2) + + e := sqle.NewDefault() + e.AddDatabase(db) + + _, iter, err := e.Query(newCtx(), `SELECT * FROM t1 NATURAL JOIN t2`) + require.NoError(err) + + rows, err := sql.RowIterToRows(iter) + require.NoError(err) + + require.Equal( + []sql.Row{ + {"a_1", "b_1", "c_1", "d_1"}, + {"a_2", "b_2", "c_2", "d_2"}, + {"a_3", "b_3", "c_3", "d_3"}, + }, + rows, + ) +} + +func TestNaturalJoinEqual(t *testing.T) { + require := require.New(t) + + t1 := memory.NewPartitionedTable("t1", sql.Schema{ + {Name: "a", Type: sql.Text, Source: "t1"}, + {Name: "b", Type: sql.Text, Source: "t1"}, + {Name: "c", Type: sql.Text, Source: "t1"}, + }, testNumPartitions) + + insertRows( + t, t1, + sql.NewRow("a_1", "b_1", "c_1"), + sql.NewRow("a_2", "b_2", "c_2"), + sql.NewRow("a_3", "b_3", "c_3"), + ) + + t2 := memory.NewPartitionedTable("t2", sql.Schema{ + {Name: "a", Type: sql.Text, Source: "t2"}, + {Name: "b", Type: sql.Text, Source: "t2"}, + {Name: "c", Type: sql.Text, Source: "t2"}, + }, testNumPartitions) + + insertRows( + t, t2, + sql.NewRow("a_1", "b_1", "c_1"), + sql.NewRow("a_2", "b_2", "c_2"), + sql.NewRow("a_3", "b_3", "c_3"), + ) + + db := memory.NewDatabase("mydb") + db.AddTable("t1", t1) + db.AddTable("t2", t2) + + e := sqle.NewDefault() + e.AddDatabase(db) + + _, iter, err := e.Query(newCtx(), `SELECT * FROM t1 NATURAL JOIN t2`) + require.NoError(err) + + rows, err := sql.RowIterToRows(iter) + require.NoError(err) + + require.Equal( + []sql.Row{ + {"a_1", "b_1", "c_1"}, + {"a_2", "b_2", "c_2"}, + {"a_3", "b_3", "c_3"}, + }, + rows, + ) +} + +func TestNaturalJoinDisjoint(t *testing.T) { + require := require.New(t) + + t1 := memory.NewPartitionedTable("t1", sql.Schema{ + {Name: "a", Type: sql.Text, Source: "t1"}, + }, testNumPartitions) + + insertRows( + t, t1, + sql.NewRow("a1"), + sql.NewRow("a2"), + sql.NewRow("a3"), + ) + + t2 := memory.NewPartitionedTable("t2", sql.Schema{ + {Name: "b", Type: sql.Text, Source: "t2"}, + }, testNumPartitions) + insertRows( + t, t2, + sql.NewRow("b1"), + sql.NewRow("b2"), + sql.NewRow("b3"), + ) + + db := memory.NewDatabase("mydb") + db.AddTable("t1", t1) + db.AddTable("t2", t2) + + e := sqle.NewDefault() + e.AddDatabase(db) + + _, iter, err := e.Query(newCtx(), `SELECT * FROM t1 NATURAL JOIN t2`) + require.NoError(err) + + rows, err := sql.RowIterToRows(iter) + require.NoError(err) + + require.Equal( + []sql.Row{ + {"a1", "b1"}, + {"a1", "b2"}, + {"a1", "b3"}, + {"a2", "b1"}, + {"a2", "b2"}, + {"a2", "b3"}, + {"a3", "b1"}, + {"a3", "b2"}, + {"a3", "b3"}, + }, + rows, + ) +} + +func TestInnerNestedInNaturalJoins(t *testing.T) { + require := require.New(t) + + table1 := memory.NewPartitionedTable("table1", sql.Schema{ + {Name: "i", Type: sql.Int32, Source: "table1"}, + {Name: "f", Type: sql.Float64, Source: "table1"}, + {Name: "t", Type: sql.Text, Source: "table1"}, + }, testNumPartitions) + + insertRows( + t, table1, + sql.NewRow(int32(1), float64(2.1), "table1"), + sql.NewRow(int32(1), float64(2.1), "table1"), + sql.NewRow(int32(10), float64(2.1), "table1"), + ) + + table2 := memory.NewPartitionedTable("table2", sql.Schema{ + {Name: "i2", Type: sql.Int32, Source: "table2"}, + {Name: "f2", Type: sql.Float64, Source: "table2"}, + {Name: "t2", Type: sql.Text, Source: "table2"}, + }, testNumPartitions) + + insertRows( + t, table2, + sql.NewRow(int32(1), float64(2.2), "table2"), + sql.NewRow(int32(1), float64(2.2), "table2"), + sql.NewRow(int32(20), float64(2.2), "table2"), + ) + + table3 := memory.NewPartitionedTable("table3", sql.Schema{ + {Name: "i", Type: sql.Int32, Source: "table3"}, + {Name: "f2", Type: sql.Float64, Source: "table3"}, + {Name: "t3", Type: sql.Text, Source: "table3"}, + }, testNumPartitions) + + insertRows( + t, table3, + sql.NewRow(int32(1), float64(2.2), "table3"), + sql.NewRow(int32(2), float64(2.2), "table3"), + sql.NewRow(int32(30), float64(2.2), "table3"), + ) + + db := memory.NewDatabase("mydb") + db.AddTable("table1", table1) + db.AddTable("table2", table2) + db.AddTable("table3", table3) + + e := sqle.NewDefault() + e.AddDatabase(db) + + _, iter, err := e.Query(newCtx(), `SELECT * FROM table1 INNER JOIN table2 ON table1.i = table2.i2 NATURAL JOIN table3`) + require.NoError(err) + + rows, err := sql.RowIterToRows(iter) + require.NoError(err) + + require.Equal( + []sql.Row{ + {int32(1), float64(2.2), float64(2.1), "table1", int32(1), "table2", "table3"}, + {int32(1), float64(2.2), float64(2.1), "table1", int32(1), "table2", "table3"}, + {int32(1), float64(2.2), float64(2.1), "table1", int32(1), "table2", "table3"}, + {int32(1), float64(2.2), float64(2.1), "table1", int32(1), "table2", "table3"}, + }, + rows, + ) +} + +func testQuery(t *testing.T, e *sqle.Engine, q string, expected []sql.Row) { + testQueryWithContext(newCtx(), t, e, q, expected) } -func testQuery(t *testing.T, e *sqle.Engine, q string, r []sql.Row) { +func testQueryWithContext(ctx *sql.Context, t *testing.T, e *sqle.Engine, q string, expected []sql.Row) { + orderBy := strings.Contains(strings.ToUpper(q), " ORDER BY ") + t.Run(q, func(t *testing.T) { require := require.New(t) - session := sql.NewEmptyContext() - - _, rows, err := e.Query(session, q) + _, iter, err := e.Query(ctx, q) require.NoError(err) - var rs []sql.Row - for { - row, err := rows.Next() - if err == io.EOF { - break - } - require.NoError(err) + rows, err := sql.RowIterToRows(iter) + require.NoError(err) - rs = append(rs, row) + if orderBy { + require.Equal(expected, rows) + } else { + require.ElementsMatch(expected, rows) } - - require.ElementsMatch(r, rs) }) } func newEngine(t *testing.T) *sqle.Engine { - require := require.New(t) + return newEngineWithParallelism(t, 1) +} - table := mem.NewTable("mytable", sql.Schema{ +func newEngineWithParallelism(t *testing.T, parallelism int) *sqle.Engine { + table := memory.NewPartitionedTable("mytable", sql.Schema{ {Name: "i", Type: sql.Int64, Source: "mytable"}, {Name: "s", Type: sql.Text, Source: "mytable"}, - }) - require.Nil(table.Insert(sql.NewRow(int64(1), "first row"))) - require.Nil(table.Insert(sql.NewRow(int64(2), "second row"))) - require.Nil(table.Insert(sql.NewRow(int64(3), "third row"))) + }, testNumPartitions) + + insertRows( + t, table, + sql.NewRow(int64(1), "first row"), + sql.NewRow(int64(2), "second row"), + sql.NewRow(int64(3), "third row"), + ) - table2 := mem.NewTable("othertable", sql.Schema{ + table2 := memory.NewPartitionedTable("othertable", sql.Schema{ {Name: "s2", Type: sql.Text, Source: "othertable"}, {Name: "i2", Type: sql.Int64, Source: "othertable"}, - }) - require.Nil(table2.Insert(sql.NewRow("first", int64(3)))) - require.Nil(table2.Insert(sql.NewRow("second", int64(2)))) - require.Nil(table2.Insert(sql.NewRow("third", int64(1)))) + }, testNumPartitions) + + insertRows( + t, table2, + sql.NewRow("first", int64(3)), + sql.NewRow("second", int64(2)), + sql.NewRow("third", int64(1)), + ) + + table3 := memory.NewPartitionedTable("tabletest", sql.Schema{ + {Name: "i", Type: sql.Int32, Source: "tabletest"}, + {Name: "s", Type: sql.Text, Source: "tabletest"}, + }, testNumPartitions) - table3 := mem.NewTable("tabletest", sql.Schema{ + insertRows( + t, table3, + sql.NewRow(int64(1), "first row"), + sql.NewRow(int64(2), "second row"), + sql.NewRow(int64(3), "third row"), + ) + + table4 := memory.NewPartitionedTable("other_table", sql.Schema{ {Name: "text", Type: sql.Text, Source: "tabletest"}, {Name: "number", Type: sql.Int32, Source: "tabletest"}, - }) - require.Nil(table3.Insert(sql.NewRow("a", int32(1)))) - require.Nil(table3.Insert(sql.NewRow("b", int32(2)))) - require.Nil(table3.Insert(sql.NewRow("c", int32(3)))) + }, testNumPartitions) + + insertRows( + t, table4, + sql.NewRow("a", int32(4)), + sql.NewRow("b", int32(2)), + sql.NewRow("c", int32(0)), + ) - db := mem.NewDatabase("mydb") - db.AddTable(table.Name(), table) - db.AddTable(table2.Name(), table2) - db.AddTable(table3.Name(), table3) + bigtable := memory.NewPartitionedTable("bigtable", sql.Schema{ + {Name: "t", Type: sql.Text, Source: "bigtable"}, + {Name: "n", Type: sql.Int64, Source: "bigtable"}, + }, testNumPartitions) - e := sqle.New() - e.AddDatabase(db) + insertRows( + t, bigtable, + sql.NewRow("a", int64(1)), + sql.NewRow("s", int64(2)), + sql.NewRow("f", int64(3)), + sql.NewRow("g", int64(1)), + sql.NewRow("h", int64(2)), + sql.NewRow("j", int64(3)), + sql.NewRow("k", int64(1)), + sql.NewRow("l", int64(2)), + sql.NewRow("ñ", int64(4)), + sql.NewRow("z", int64(5)), + sql.NewRow("x", int64(6)), + sql.NewRow("c", int64(7)), + sql.NewRow("v", int64(8)), + sql.NewRow("b", int64(9)), + ) + + floatTable := memory.NewPartitionedTable("floattable", sql.Schema{ + {Name: "i", Type: sql.Int64, Source: "floattable"}, + {Name: "f32", Type: sql.Float32, Source: "floattable"}, + {Name: "f64", Type: sql.Float64, Source: "floattable"}, + }, testNumPartitions) - return e + insertRows( + t, floatTable, + sql.NewRow(int64(1), float32(1.0), float64(1.0)), + sql.NewRow(int64(2), float32(1.5), float64(1.5)), + sql.NewRow(int64(3), float32(2.0), float64(2.0)), + sql.NewRow(int64(4), float32(2.5), float64(2.5)), + sql.NewRow(int64(-1), float32(-1.0), float64(-1.0)), + sql.NewRow(int64(-2), float32(-1.5), float64(-1.5)), + ) + + nilTable := memory.NewPartitionedTable("niltable", sql.Schema{ + {Name: "i", Type: sql.Int64, Source: "niltable", Nullable: true}, + {Name: "b", Type: sql.Boolean, Source: "niltable", Nullable: true}, + {Name: "f", Type: sql.Float64, Source: "niltable", Nullable: true}, + }, testNumPartitions) + + insertRows( + t, nilTable, + sql.NewRow(int64(1), true, float64(1.0)), + sql.NewRow(int64(2), nil, float64(2.0)), + sql.NewRow(nil, false, float64(3.0)), + sql.NewRow(int64(4), true, nil), + sql.NewRow(nil, nil, nil), + ) + + newlineTable := memory.NewPartitionedTable("newlinetable", sql.Schema{ + {Name: "i", Type: sql.Int64, Source: "newlinetable"}, + {Name: "s", Type: sql.Text, Source: "newlinetable"}, + }, testNumPartitions) + + insertRows( + t, newlineTable, + sql.NewRow(int64(1), "\nthere is some text in here"), + sql.NewRow(int64(2), "there is some\ntext in here"), + sql.NewRow(int64(3), "there is some text\nin here"), + sql.NewRow(int64(4), "there is some text in here\n"), + sql.NewRow(int64(5), "there is some text in here"), + ) + + typestable := memory.NewPartitionedTable("typestable", sql.Schema{ + {Name: "id", Type: sql.Int64, Source: "typestable"}, + {Name: "i8", Type: sql.Int8, Source: "typestable", Nullable: true}, + {Name: "i16", Type: sql.Int16, Source: "typestable", Nullable: true}, + {Name: "i32", Type: sql.Int32, Source: "typestable", Nullable: true}, + {Name: "i64", Type: sql.Int64, Source: "typestable", Nullable: true}, + {Name: "u8", Type: sql.Uint8, Source: "typestable", Nullable: true}, + {Name: "u16", Type: sql.Uint16, Source: "typestable", Nullable: true}, + {Name: "u32", Type: sql.Uint32, Source: "typestable", Nullable: true}, + {Name: "u64", Type: sql.Uint64, Source: "typestable", Nullable: true}, + {Name: "f32", Type: sql.Float32, Source: "typestable", Nullable: true}, + {Name: "f64", Type: sql.Float64, Source: "typestable", Nullable: true}, + {Name: "ti", Type: sql.Timestamp, Source: "typestable", Nullable: true}, + {Name: "da", Type: sql.Date, Source: "typestable", Nullable: true}, + {Name: "te", Type: sql.Text, Source: "typestable", Nullable: true}, + {Name: "bo", Type: sql.Boolean, Source: "typestable", Nullable: true}, + {Name: "js", Type: sql.JSON, Source: "typestable", Nullable: true}, + {Name: "bl", Type: sql.Blob, Source: "typestable", Nullable: true}, + }, testNumPartitions) + + db := memory.NewDatabase("mydb") + db.AddTable("mytable", table) + db.AddTable("othertable", table2) + db.AddTable("tabletest", table3) + db.AddTable("bigtable", bigtable) + db.AddTable("floattable", floatTable) + db.AddTable("niltable", nilTable) + db.AddTable("newlinetable", newlineTable) + db.AddTable("typestable", typestable) + + db2 := memory.NewDatabase("foo") + db2.AddTable("other_table", table4) + + catalog := sql.NewCatalog() + catalog.AddDatabase(db) + catalog.AddDatabase(db2) + catalog.AddDatabase(sql.NewInformationSchemaDatabase(catalog)) + + var a *analyzer.Analyzer + if parallelism > 1 { + a = analyzer.NewBuilder(catalog).WithParallelism(parallelism).Build() + } else { + a = analyzer.NewDefault(catalog) + } + + return sqle.New(catalog, a, new(sqle.Config)) } -const expectedTree = `Offset(2) - └─ Limit(5) +const expectedTree = `Limit(5) + └─ Offset(2) └─ Project(t.foo, bar.baz) └─ Filter(foo > qux) └─ InnerJoin(foo = baz) @@ -392,7 +2982,7 @@ const expectedTree = `Offset(2) func TestPrintTree(t *testing.T) { require := require.New(t) - node, err := parse.Parse(sql.NewEmptyContext(), ` + node, err := parse.Parse(newCtx(), ` SELECT t.foo, bar.baz FROM tbl t INNER JOIN bar @@ -404,11 +2994,97 @@ func TestPrintTree(t *testing.T) { require.Equal(expectedTree, node.String()) } +// see: https://github.com/src-d/go-mysql-server/issues/197 +func TestStarPanic197(t *testing.T) { + require := require.New(t) + e := newEngine(t) + + ctx := newCtx() + _, iter, err := e.Query(ctx, `SELECT * FROM mytable GROUP BY i, s`) + require.NoError(err) + + rows, err := sql.RowIterToRows(iter) + require.NoError(err) + + require.Len(rows, 3) +} + +func TestInvalidRegexp(t *testing.T) { + require := require.New(t) + e := newEngine(t) + + ctx := newCtx() + _, iter, err := e.Query(ctx, `SELECT * FROM mytable WHERE s REGEXP("*main.go")`) + require.NoError(err) + + _, err = sql.RowIterToRows(iter) + require.Error(err) +} + +func TestOrderByGroupBy(t *testing.T) { + require := require.New(t) + + table := memory.NewPartitionedTable("members", sql.Schema{ + {Name: "id", Type: sql.Int64, Source: "members"}, + {Name: "team", Type: sql.Text, Source: "members"}, + }, testNumPartitions) + + insertRows( + t, table, + sql.NewRow(int64(3), "red"), + sql.NewRow(int64(4), "red"), + sql.NewRow(int64(5), "orange"), + sql.NewRow(int64(6), "orange"), + sql.NewRow(int64(7), "orange"), + sql.NewRow(int64(8), "purple"), + ) + + db := memory.NewDatabase("db") + db.AddTable("members", table) + + e := sqle.NewDefault() + e.AddDatabase(db) + + _, iter, err := e.Query( + newCtx(), + "SELECT team, COUNT(*) FROM members GROUP BY team ORDER BY 2", + ) + require.NoError(err) + + rows, err := sql.RowIterToRows(iter) + require.NoError(err) + + expected := []sql.Row{ + {"purple", int64(1)}, + {"red", int64(2)}, + {"orange", int64(3)}, + } + + require.Equal(expected, rows) + + _, iter, err = e.Query( + newCtx(), + "SELECT team, COUNT(*) FROM members GROUP BY 1 ORDER BY 2", + ) + require.NoError(err) + + rows, err = sql.RowIterToRows(iter) + require.NoError(err) + + require.Equal(expected, rows) + + _, _, err = e.Query( + newCtx(), + "SELECT team, COUNT(*) FROM members GROUP BY team ORDER BY columndoesnotexist", + ) + require.Error(err) +} + func TestTracing(t *testing.T) { require := require.New(t) e := newEngine(t) - tracer := new(memTracer) + tracer := new(test.MemTracer) ctx := sql.NewContext(context.TODO(), sql.WithTracer(tracer)) @@ -423,17 +3099,13 @@ func TestTracing(t *testing.T) { require.Len(rows, 1) require.NoError(err) - spans := tracer.spans - + spans := tracer.Spans var expectedSpans = []string{ "plan.Limit", + "plan.Sort", "plan.Distinct", "plan.Project", - "plan.Sort", - "plan.Filter", - "expression.Equals", - "expression.Equals", - "expression.Equals", + "plan.ResolvedTable", } var spanOperations []string @@ -450,40 +3122,452 @@ func TestTracing(t *testing.T) { require.Equal(expectedSpans, spanOperations) } -type memTracer struct { - spans []string +func TestReadOnly(t *testing.T) { + require := require.New(t) + + table := memory.NewPartitionedTable("mytable", sql.Schema{ + {Name: "i", Type: sql.Int64, Source: "mytable"}, + {Name: "s", Type: sql.Text, Source: "mytable"}, + }, testNumPartitions) + + db := memory.NewDatabase("mydb") + db.AddTable("mytable", table) + + catalog := sql.NewCatalog() + catalog.AddDatabase(db) + + au := auth.NewNativeSingle("user", "pass", auth.ReadPerm) + cfg := &sqle.Config{Auth: au} + a := analyzer.NewBuilder(catalog).Build() + e := sqle.New(catalog, a, cfg) + + _, _, err := e.Query(newCtx(), `SELECT i FROM mytable`) + require.NoError(err) + + _, _, err = e.Query(newCtx(), `CREATE INDEX foo ON mytable USING pilosa (i, s)`) + require.Error(err) + require.True(auth.ErrNotAuthorized.Is(err)) + + _, _, err = e.Query(newCtx(), `DROP INDEX foo ON mytable`) + require.Error(err) + require.True(auth.ErrNotAuthorized.Is(err)) + + _, _, err = e.Query(newCtx(), `INSERT INTO mytable (i, s) VALUES(42, 'yolo')`) + require.Error(err) + require.True(auth.ErrNotAuthorized.Is(err)) +} + +func TestSessionVariables(t *testing.T) { + require := require.New(t) + + e := newEngine(t) + + session := sql.NewBaseSession() + ctx := sql.NewContext(context.Background(), sql.WithSession(session), sql.WithPid(1)) + + _, _, err := e.Query(ctx, `set autocommit=1, sql_mode = concat(@@sql_mode,',STRICT_TRANS_TABLES')`) + require.NoError(err) + + ctx = sql.NewContext(context.Background(), sql.WithSession(session), sql.WithPid(2)) + + _, iter, err := e.Query(ctx, `SELECT @@autocommit, @@session.sql_mode`) + require.NoError(err) + + rows, err := sql.RowIterToRows(iter) + require.NoError(err) + + require.Equal([]sql.Row{{int8(1), ",STRICT_TRANS_TABLES"}}, rows) +} + +func TestSessionVariablesONOFF(t *testing.T) { + require := require.New(t) + + e := newEngine(t) + + session := sql.NewBaseSession() + ctx := sql.NewContext(context.Background(), sql.WithSession(session), sql.WithPid(1)) + + _, _, err := e.Query(ctx, `set autocommit=ON, sql_mode = OFF, autoformat="true"`) + require.NoError(err) + + ctx = sql.NewContext(context.Background(), sql.WithSession(session), sql.WithPid(2)) + + _, iter, err := e.Query(ctx, `SELECT @@autocommit, @@session.sql_mode, @@autoformat`) + require.NoError(err) + + rows, err := sql.RowIterToRows(iter) + require.NoError(err) + + require.Equal([]sql.Row{{int64(1), int64(0), true}}, rows) +} + +func TestNestedAliases(t *testing.T) { + require := require.New(t) + + _, _, err := newEngine(t).Query(newCtx(), ` + SELECT SUBSTRING(s, 1, 10) AS sub_s, SUBSTRING(sub_s, 2, 3) as sub_sub_s + FROM mytable`) + require.Error(err) + require.True(analyzer.ErrMisusedAlias.Is(err)) +} + +func TestUse(t *testing.T) { + require := require.New(t) + e := newEngine(t) + + _, _, err := e.Query(newCtx(), "USE bar") + require.Error(err) + + _, iter, err := e.Query(newCtx(), "USE foo") + require.NoError(err) + + rows, err := sql.RowIterToRows(iter) + require.NoError(err) + require.Len(rows, 0) + + require.Equal("foo", e.Catalog.CurrentDatabase()) +} + +func TestLocks(t *testing.T) { + require := require.New(t) + + t1 := newLockableTable(memory.NewTable("t1", nil)) + t2 := newLockableTable(memory.NewTable("t2", nil)) + t3 := memory.NewTable("t3", nil) + catalog := sql.NewCatalog() + db := memory.NewDatabase("db") + db.AddTable("t1", t1) + db.AddTable("t2", t2) + db.AddTable("t3", t3) + catalog.AddDatabase(db) + + analyzer := analyzer.NewDefault(catalog) + engine := sqle.New(catalog, analyzer, new(sqle.Config)) + + _, iter, err := engine.Query(newCtx(), "LOCK TABLES t1 READ, t2 WRITE, t3 READ") + require.NoError(err) + + _, err = sql.RowIterToRows(iter) + require.NoError(err) + + _, iter, err = engine.Query(newCtx(), "UNLOCK TABLES") + require.NoError(err) + + _, err = sql.RowIterToRows(iter) + require.NoError(err) + + require.Equal(1, t1.readLocks) + require.Equal(0, t1.writeLocks) + require.Equal(1, t1.unlocks) + require.Equal(0, t2.readLocks) + require.Equal(1, t2.writeLocks) + require.Equal(1, t2.unlocks) +} + +func TestDescribeNoPruneColumns(t *testing.T) { + require := require.New(t) + ctx := newCtx() + e := newEngine(t) + query := `DESCRIBE FORMAT=TREE SELECT SUBSTRING(s, 1, 1) as foo, s, i FROM mytable WHERE foo = 'f'` + parsed, err := parse.Parse(ctx, query) + require.NoError(err) + result, err := e.Analyzer.Analyze(ctx, parsed) + require.NoError(err) + + qp, ok := result.(*plan.QueryProcess) + require.True(ok) + + d, ok := qp.Child.(*plan.DescribeQuery) + require.True(ok) + + p, ok := d.Child.(*plan.Project) + require.True(ok) + + require.Len(p.Schema(), 3) +} + +func TestDeleteFrom(t *testing.T) { + var deletions = []struct { + deleteQuery string + expectedDelete []sql.Row + selectQuery string + expectedSelect []sql.Row + }{ + { + "DELETE FROM mytable;", + []sql.Row{{int64(3)}}, + "SELECT * FROM mytable;", + []sql.Row{}, + }, + { + "DELETE FROM mytable WHERE i = 2;", + []sql.Row{{int64(1)}}, + "SELECT * FROM mytable;", + []sql.Row{{int64(1), "first row"}, {int64(3), "third row"}}, + }, + { + "DELETE FROM mytable WHERE i < 3;", + []sql.Row{{int64(2)}}, + "SELECT * FROM mytable;", + []sql.Row{{int64(3), "third row"}}, + }, + { + "DELETE FROM mytable WHERE i > 1;", + []sql.Row{{int64(2)}}, + "SELECT * FROM mytable;", + []sql.Row{{int64(1), "first row"}}, + }, + { + "DELETE FROM mytable WHERE i <= 2;", + []sql.Row{{int64(2)}}, + "SELECT * FROM mytable;", + []sql.Row{{int64(3), "third row"}}, + }, + { + "DELETE FROM mytable WHERE i >= 2;", + []sql.Row{{int64(2)}}, + "SELECT * FROM mytable;", + []sql.Row{{int64(1), "first row"}}, + }, + { + "DELETE FROM mytable WHERE s = 'first row';", + []sql.Row{{int64(1)}}, + "SELECT * FROM mytable;", + []sql.Row{{int64(2), "second row"}, {int64(3), "third row"}}, + }, + { + "DELETE FROM mytable WHERE s <> 'dne';", + []sql.Row{{int64(3)}}, + "SELECT * FROM mytable;", + []sql.Row{}, + }, + { + "DELETE FROM mytable WHERE s LIKE '%row';", + []sql.Row{{int64(3)}}, + "SELECT * FROM mytable;", + []sql.Row{}, + }, + { + "DELETE FROM mytable WHERE s = 'dne';", + []sql.Row{{int64(0)}}, + "SELECT * FROM mytable;", + []sql.Row{{int64(1), "first row"}, {int64(2), "second row"}, {int64(3), "third row"}}, + }, + { + "DELETE FROM mytable WHERE i = 'invalid';", + []sql.Row{{int64(0)}}, + "SELECT * FROM mytable;", + []sql.Row{{int64(1), "first row"}, {int64(2), "second row"}, {int64(3), "third row"}}, + }, + { + "DELETE FROM mytable ORDER BY i ASC LIMIT 2;", + []sql.Row{{int64(2)}}, + "SELECT * FROM mytable;", + []sql.Row{{int64(3), "third row"}}, + }, + { + "DELETE FROM mytable ORDER BY i DESC LIMIT 1;", + []sql.Row{{int64(1)}}, + "SELECT * FROM mytable;", + []sql.Row{{int64(1), "first row"}, {int64(2), "second row"}}, + }, + { + "DELETE FROM mytable ORDER BY i DESC LIMIT 1 OFFSET 1;", + []sql.Row{{int64(1)}}, + "SELECT * FROM mytable;", + []sql.Row{{int64(1), "first row"}, {int64(3), "third row"}}, + }, + } + + for _, deletion := range deletions { + e := newEngine(t) + ctx := newCtx() + testQueryWithContext(ctx, t, e, deletion.deleteQuery, deletion.expectedDelete) + testQueryWithContext(ctx, t, e, deletion.selectQuery, deletion.expectedSelect) + } +} + +func TestDeleteFromErrors(t *testing.T) { + var expectedFailures = []struct { + name string + query string + }{ + { + "invalid table", + "DELETE FROM invalidtable WHERE x < 1;", + }, + { + "invalid column", + "DELETE FROM mytable WHERE z = 'dne';", + }, + { + "negative limit", + "DELETE FROM mytable LIMIT -1;", + }, + { + "negative offset", + "DELETE FROM mytable LIMIT 1 OFFSET -1;", + }, + { + "missing keyword from", + "DELETE mytable WHERE id = 1;", + }, + } + + for _, expectedFailure := range expectedFailures { + t.Run(expectedFailure.name, func(t *testing.T) { + _, _, err := newEngine(t).Query(newCtx(), expectedFailure.query) + require.Error(t, err) + }) + } +} + +type mockSpan struct { + opentracing.Span + finished bool +} + +func (m *mockSpan) Finish() { + m.finished = true +} + +func TestRootSpanFinish(t *testing.T) { + e := newEngine(t) + fakeSpan := &mockSpan{Span: opentracing.NoopTracer{}.StartSpan("")} + ctx := sql.NewContext( + context.Background(), + sql.WithRootSpan(fakeSpan), + ) + + _, iter, err := e.Query(ctx, "SELECT 1") + require.NoError(t, err) + + _, err = sql.RowIterToRows(iter) + require.NoError(t, err) + + require.True(t, fakeSpan.finished) +} + +var generatorQueries = []struct { + query string + expected []sql.Row +}{ + { + `SELECT a, EXPLODE(b), c FROM t`, + []sql.Row{ + {int64(1), "a", "first"}, + {int64(1), "b", "first"}, + {int64(2), "c", "second"}, + {int64(2), "d", "second"}, + {int64(3), "e", "third"}, + {int64(3), "f", "third"}, + }, + }, + { + `SELECT a, EXPLODE(b) AS x, c FROM t`, + []sql.Row{ + {int64(1), "a", "first"}, + {int64(1), "b", "first"}, + {int64(2), "c", "second"}, + {int64(2), "d", "second"}, + {int64(3), "e", "third"}, + {int64(3), "f", "third"}, + }, + }, + { + `SELECT EXPLODE(SPLIT(c, "")) FROM t LIMIT 5`, + []sql.Row{ + {"f"}, + {"i"}, + {"r"}, + {"s"}, + {"t"}, + }, + }, + { + `SELECT a, EXPLODE(b) AS x, c FROM t WHERE x = 'e'`, + []sql.Row{ + {int64(3), "e", "third"}, + }, + }, +} + +func TestGenerators(t *testing.T) { + table := memory.NewPartitionedTable("t", sql.Schema{ + {Name: "a", Type: sql.Int64, Source: "t"}, + {Name: "b", Type: sql.Array(sql.Text), Source: "t"}, + {Name: "c", Type: sql.Text, Source: "t"}, + }, testNumPartitions) + + insertRows( + t, table, + sql.NewRow(int64(1), []interface{}{"a", "b"}, "first"), + sql.NewRow(int64(2), []interface{}{"c", "d"}, "second"), + sql.NewRow(int64(3), []interface{}{"e", "f"}, "third"), + ) + + db := memory.NewDatabase("db") + db.AddTable("t", table) + + catalog := sql.NewCatalog() + catalog.AddDatabase(db) + e := sqle.New(catalog, analyzer.NewDefault(catalog), new(sqle.Config)) + + for _, q := range generatorQueries { + testQuery(t, e, q.query, q.expected) + } +} + +func insertRows(t *testing.T, table sql.Inserter, rows ...sql.Row) { + t.Helper() + + for _, r := range rows { + require.NoError(t, table.Insert(newCtx(), r)) + } +} + +var pid uint64 + +func newCtx() *sql.Context { + session := sql.NewSession("address", "client", "user", 1) + return sql.NewContext( + context.Background(), + sql.WithPid(atomic.AddUint64(&pid, 1)), + sql.WithSession(session), + ) } -type memSpan struct { - opName string +type lockableTable struct { + sql.Table + readLocks int + writeLocks int + unlocks int } -func (t *memTracer) StartSpan(operationName string, opts ...opentracing.StartSpanOption) opentracing.Span { - t.spans = append(t.spans, operationName) - return &memSpan{operationName} +func newLockableTable(t sql.Table) *lockableTable { + return &lockableTable{Table: t} } -func (t *memTracer) Inject(sm opentracing.SpanContext, format interface{}, carrier interface{}) error { - panic("not implemented") +func timeParse(layout string, value string) time.Time { + t, err := time.Parse(layout, value) + if err != nil { + panic(err) + } + return t } -func (t *memTracer) Extract(format interface{}, carrier interface{}) (opentracing.SpanContext, error) { - panic("not implemented") +var _ sql.Lockable = (*lockableTable)(nil) + +func (l *lockableTable) Lock(ctx *sql.Context, write bool) error { + if write { + l.writeLocks++ + } else { + l.readLocks++ + } + return nil } -func (m memSpan) Context() opentracing.SpanContext { return m } -func (m memSpan) SetBaggageItem(key, val string) opentracing.Span { return m } -func (m memSpan) BaggageItem(key string) string { return "" } -func (m memSpan) SetTag(key string, value interface{}) opentracing.Span { return m } -func (m memSpan) LogFields(fields ...log.Field) {} -func (m memSpan) LogKV(keyVals ...interface{}) {} -func (m memSpan) Finish() {} -func (m memSpan) FinishWithOptions(opts opentracing.FinishOptions) {} -func (m memSpan) SetOperationName(operationName string) opentracing.Span { - return &memSpan{operationName} +func (l *lockableTable) Unlock(ctx *sql.Context, id uint32) error { + l.unlocks++ + return nil } -func (m memSpan) Tracer() opentracing.Tracer { return &memTracer{} } -func (m memSpan) LogEvent(event string) {} -func (m memSpan) LogEventWithPayload(event string, payload interface{}) {} -func (m memSpan) Log(data opentracing.LogData) {} -func (m memSpan) ForeachBaggageItem(handler func(k, v string) bool) {} diff --git a/example_test.go b/example_test.go index b272911ea..eddcae535 100644 --- a/example_test.go +++ b/example_test.go @@ -4,13 +4,13 @@ import ( "fmt" "io" - "gopkg.in/src-d/go-mysql-server.v0" - "gopkg.in/src-d/go-mysql-server.v0/mem" - "gopkg.in/src-d/go-mysql-server.v0/sql" + "github.com/src-d/go-mysql-server" + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" ) func Example() { - e := sqle.New() + e := sqle.NewDefault() ctx := sql.NewEmptyContext() // Create a test memory database and register it to the default engine. @@ -45,16 +45,24 @@ func checkIfError(err error) { } func createTestDatabase() sql.Database { - db := mem.NewDatabase("test") - table := mem.NewTable("mytable", sql.Schema{ + db := memory.NewDatabase("test") + table := memory.NewTable("mytable", sql.Schema{ {Name: "name", Type: sql.Text, Source: "mytable"}, {Name: "email", Type: sql.Text, Source: "mytable"}, }) db.AddTable("mytable", table) - table.Insert(sql.NewRow("John Doe", "john@doe.com")) - table.Insert(sql.NewRow("John Doe", "johnalt@doe.com")) - table.Insert(sql.NewRow("Jane Doe", "jane@doe.com")) - table.Insert(sql.NewRow("Evil Bob", "evilbob@gmail.com")) + ctx := sql.NewEmptyContext() + + rows := []sql.Row{ + sql.NewRow("John Doe", "john@doe.com"), + sql.NewRow("John Doe", "johnalt@doe.com"), + sql.NewRow("Jane Doe", "jane@doe.com"), + sql.NewRow("Evil Bob", "evilbob@gmail.com"), + } + + for _, row := range rows { + table.Insert(ctx, row) + } return db } diff --git a/go.mod b/go.mod new file mode 100644 index 000000000..e8aab26b7 --- /dev/null +++ b/go.mod @@ -0,0 +1,28 @@ +module github.com/src-d/go-mysql-server + +require ( + github.com/StackExchange/wmi v0.0.0-20190523213315-cbe66965904d // indirect + github.com/VividCortex/gohistogram v1.0.0 // indirect + github.com/go-kit/kit v0.8.0 + github.com/go-ole/go-ole v1.2.4 // indirect + github.com/go-sql-driver/mysql v1.4.1 + github.com/gogo/protobuf v1.2.1 // indirect + github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b + github.com/golang/protobuf v1.3.0 // indirect + github.com/hashicorp/golang-lru v0.5.3 + github.com/mitchellh/hashstructure v1.0.0 + github.com/oliveagle/jsonpath v0.0.0-20180606110733-2e52cf6e6852 + github.com/opentracing/opentracing-go v1.0.2 + github.com/pilosa/pilosa v1.3.0 + github.com/sanity-io/litter v1.1.0 + github.com/shirou/w32 v0.0.0-20160930032740-bb4de0191aa4 // indirect + github.com/sirupsen/logrus v1.3.0 + github.com/spf13/cast v1.3.0 + github.com/src-d/go-oniguruma v1.0.0 + github.com/stretchr/testify v1.3.0 + go.etcd.io/bbolt v1.3.2 + google.golang.org/grpc v1.19.0 // indirect + gopkg.in/src-d/go-errors.v1 v1.0.0 + gopkg.in/yaml.v2 v2.2.2 + vitess.io/vitess v3.0.0-rc.3.0.20190602171040-12bfde34629c+incompatible +) diff --git a/go.sum b/go.sum new file mode 100644 index 000000000..c2918275d --- /dev/null +++ b/go.sum @@ -0,0 +1,179 @@ +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/CAFxX/gcnotifier v0.0.0-20190112062741-224a280d589d h1:n0G4ckjMEj7bWuGYUX0i8YlBeBBJuZ+HEHvHfyBDZtI= +github.com/CAFxX/gcnotifier v0.0.0-20190112062741-224a280d589d/go.mod h1:Rn2zM2MnHze07LwkneP48TWt6UiZhzQTwCvw6djVGfE= +github.com/DataDog/datadog-go v0.0.0-20180822151419-281ae9f2d895 h1:dmc/C8bpE5VkQn65PNbbyACDC8xw8Hpp/NEurdPmQDQ= +github.com/DataDog/datadog-go v0.0.0-20180822151419-281ae9f2d895/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE= +github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= +github.com/StackExchange/wmi v0.0.0-20190523213315-cbe66965904d h1:G0m3OIz70MZUWq3EgK3CesDbo8upS2Vm9/P3FtgI+Jk= +github.com/StackExchange/wmi v0.0.0-20190523213315-cbe66965904d/go.mod h1:3eOhrUMpNV+6aFIbp5/iudMxNCF27Vw2OZgy4xEx0Fg= +github.com/VividCortex/gohistogram v1.0.0 h1:6+hBz+qvs0JOrrNhhmR7lFxo5sINxBCGXrdtl/UvroE= +github.com/VividCortex/gohistogram v1.0.0/go.mod h1:Pf5mBqqDxYaXu3hDrrU+w6nw50o/4+TcAqDqk/vUH7g= +github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= +github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da h1:8GUt8eRujhVEGZFFEjBj46YV4rDjvGrNxb0KMWYkL2I= +github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= +github.com/boltdb/bolt v1.3.1 h1:JQmyP4ZBrce+ZQu0dY660FMfatumYDLun9hBCUVIkF4= +github.com/boltdb/bolt v1.3.1/go.mod h1:clJnj/oiGkjum5o1McbSZDSLxVThjynRyGBgiAx27Ps= +github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko= +github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd h1:qMd81Ts1T2OTKmB4acZcyKaMtRnY5Y44NuXGX2GFJ1w= +github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd/go.mod h1:sE/e/2PUdi/liOCUjSTXgM1o87ZssimdTWN964YiIeI= +github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= +github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk= +github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/go-kit/kit v0.8.0 h1:Wz+5lgoB0kkuqLEc6NVmwRknTKP6dTGbSqvhZtBI/j0= +github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= +github.com/go-ole/go-ole v1.2.4 h1:nNBDSCOigTSiarFpYE9J/KtEA1IOW4CNeqT9TQDqCxI= +github.com/go-ole/go-ole v1.2.4/go.mod h1:XCwSNxSkXRo4vlyPy93sltvi/qJq0jqQhjqQNIwKuxM= +github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA= +github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= +github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/gogo/protobuf v1.2.1 h1:/s5zKNz0uPFCZ5hddgPdo2TK2TVrUNMn0OOX8/aZMTE= +github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.0 h1:kbxbvI4Un1LUWKxufD+BiE6AEExYYgkQLQmLFqA1LFk= +github.com/golang/protobuf v1.3.0/go.mod h1:Qd/q+1AKNOZr9uGQzbzCmRO6sUih6GTPZv6a1/R87v0= +github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c h1:964Od4U6p2jUkFxvCydnIczKteheJEzHRToSGK3Bnlw= +github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= +github.com/google/go-cmp v0.2.0 h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/gorilla/handlers v1.3.0 h1:tsg9qP3mjt1h4Roxp+M1paRjrVBfPSOpBuVclh6YluI= +github.com/gorilla/handlers v1.3.0/go.mod h1:Qkdc/uu4tH4g6mTK6auzZ766c4CA0Ng8+o/OAirnOIQ= +github.com/gorilla/mux v1.7.0 h1:tOSd0UKHQd6urX6ApfOn4XdBMY6Sh1MfxV3kmaazO+U= +github.com/gorilla/mux v1.7.0/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= +github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-immutable-radix v1.0.0 h1:AKDB1HM5PWEA7i4nhcpwOrO2byshxBjXVn/J/3+z5/0= +github.com/hashicorp/go-immutable-radix v1.0.0/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60= +github.com/hashicorp/go-msgpack v0.5.3 h1:zKjpN5BK/P5lMYrLmBHdBULWbJ0XpYR+7NGzqkZzoD4= +github.com/hashicorp/go-msgpack v0.5.3/go.mod h1:ahLV/dePpqEmjfWmKiqvPkv/twdG7iPBM1vqhUKIvfM= +github.com/hashicorp/go-multierror v1.0.0 h1:iVjPR7a6H0tWELX5NxNe7bYopibicUzc7uPribsnS6o= +github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk= +github.com/hashicorp/go-sockaddr v1.0.0 h1:GeH6tui99pF4NJgfnhp+L6+FfobzVW3Ah46sLo0ICXs= +github.com/hashicorp/go-sockaddr v1.0.0/go.mod h1:7Xibr9yA9JjQq1JpNB2Vw7kxv8xerXegt+ozgdvDeDU= +github.com/hashicorp/go-uuid v1.0.0 h1:RS8zrF7PhGwyNPOtxSClXXj9HA8feRnJzgnI1RJCSnM= +github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/hashicorp/golang-lru v0.5.0 h1:CL2msUPvZTLb5O648aiLNJw3hnBxN2+1Jq8rCOH9wdo= +github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/hashicorp/golang-lru v0.5.3 h1:YPkqC67at8FYaadspW/6uE0COsBxS2656RLEr8Bppgk= +github.com/hashicorp/golang-lru v0.5.3/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= +github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= +github.com/hashicorp/memberlist v0.1.3 h1:EmmoJme1matNzb+hMpDuR/0sbJSUisxyqBGG676r31M= +github.com/hashicorp/memberlist v0.1.3/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I= +github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/konsorten/go-windows-terminal-sequences v1.0.1 h1:mweAR1A6xJ3oS2pRaGiHgQ4OO8tzTaLawm8vnODuwDk= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= +github.com/miekg/dns v1.0.14 h1:9jZdLNd/P4+SfEJ0TNyxYpsK8N4GtfylBLqtbYN1sbA= +github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= +github.com/mitchellh/hashstructure v1.0.0 h1:ZkRJX1CyOoTkar7p/mLS5TZU4nJ1Rn/F8u9dGS02Q3Y= +github.com/mitchellh/hashstructure v1.0.0/go.mod h1:QjSHrPWS+BGUVBYkbTZWEnOh3G1DutKwClXU/ABz6AQ= +github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= +github.com/oliveagle/jsonpath v0.0.0-20180606110733-2e52cf6e6852 h1:Yl0tPBa8QPjGmesFh1D0rDy+q1Twx6FyU7VWHi8wZbI= +github.com/oliveagle/jsonpath v0.0.0-20180606110733-2e52cf6e6852/go.mod h1:eqOVx5Vwu4gd2mmMZvVZsgIqNSaW3xxRThUJ0k/TPk4= +github.com/opentracing/opentracing-go v1.0.2 h1:3jA2P6O1F9UOrWVpwrIo17pu01KWvNWg4X946/Y5Zwg= +github.com/opentracing/opentracing-go v1.0.2/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= +github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c h1:Lgl0gzECD8GnQ5QCWA8o6BtfL6mDH5rQgM4/fX3avOs= +github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= +github.com/pelletier/go-toml v1.2.0 h1:T5zMGML61Wp+FlcbWjRDT7yAxhJNAiPPLOFECq181zc= +github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= +github.com/pilosa/pilosa v1.3.0 h1:P27JB4tIqAN4Yc2Fw7wS5neD7JNkFKRUmwfyV87JMwQ= +github.com/pilosa/pilosa v1.3.0/go.mod h1:97yLL9mpUqOj9naKu5XA/b/U6JLe3JGGUlc2HOTDw+A= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/remyoudompheng/bigfft v0.0.0-20190321074620-2f0d2b0e0001 h1:YDeskXpkNDhPdWN3REluVa46HQOVuVkjkd2sWnrABNQ= +github.com/remyoudompheng/bigfft v0.0.0-20190321074620-2f0d2b0e0001/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/sanity-io/litter v1.1.0 h1:BllcKWa3VbZmOZbDCoszYLk7zCsKHz5Beossi8SUcTc= +github.com/sanity-io/litter v1.1.0/go.mod h1:CJ0VCw2q4qKU7LaQr3n7UOSHzgEMgcGco7N/SkZQPjw= +github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww= +github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= +github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529 h1:nn5Wsu0esKSJiIVhscUtVbo7ada43DJhG55ua/hjS5I= +github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= +github.com/shirou/gopsutil v2.18.12+incompatible h1:1eaJvGomDnH74/5cF4CTmTbLHAriGFsTZppLXDX93OM= +github.com/shirou/gopsutil v2.18.12+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= +github.com/shirou/w32 v0.0.0-20160930032740-bb4de0191aa4 h1:udFKJ0aHUL60LboW/A+DfgoHVedieIzIXE8uylPue0U= +github.com/shirou/w32 v0.0.0-20160930032740-bb4de0191aa4/go.mod h1:qsXQc7+bwAM3Q1u/4XEfrquwF8Lw7D7y5cD8CuHnfIc= +github.com/sirupsen/logrus v1.3.0 h1:hI/7Q+DtNZ2kINb6qt/lS+IyXnHQe9e90POfeewL/ME= +github.com/sirupsen/logrus v1.3.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72 h1:qLC7fQah7D6K1B0ujays3HV9gkFtllcxhzImRR7ArPQ= +github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= +github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= +github.com/spf13/cast v1.3.0 h1:oget//CVOEoFewqQxwr0Ej5yjygnqGkvggSE/gB35Q8= +github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= +github.com/spf13/cobra v0.0.3/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ= +github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo= +github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= +github.com/spf13/viper v1.3.1/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s= +github.com/src-d/go-oniguruma v1.0.0 h1:JDk5PUAjreGsGAKLsoDLNmrsaryjJ5RqT3h+Si6aw/E= +github.com/src-d/go-oniguruma v1.0.0/go.mod h1:chVbff8kcVtmrhxtZ3yBVLLquXbzCS6DrxQaAK/CeqM= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/uber/jaeger-client-go v2.15.0+incompatible h1:NP3qsSqNxh8VYr956ur1N/1C1PjvOJnJykCzcD5QHbk= +github.com/uber/jaeger-client-go v2.15.0+incompatible/go.mod h1:WVhlPFC8FDjOFMMWRy2pZqQJSXxYSwNYOkTr/Z6d3Kk= +github.com/uber/jaeger-lib v1.5.0 h1:OHbgr8l656Ub3Fw5k9SWnBfIEwvoHQ+W2y+Aa9D1Uyo= +github.com/uber/jaeger-lib v1.5.0/go.mod h1:ComeNDZlWwrWnDv8aPp0Ba6+uUTzImX/AauajbLI56U= +github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= +github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= +go.etcd.io/bbolt v1.3.2 h1:Z/90sZLPOeCy2PwprqkFa25PdkusRzaj9P8zm/KNyvk= +go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= +golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20181029021203-45a5f77698d3 h1:KYQXGkl6vs02hK7pK4eIbw0NpNPedieTSTEiJ//bwGs= +golang.org/x/crypto v0.0.0-20181029021203-45a5f77698d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9 h1:mKdxBk7AujPs8kU4m80U72y/zjbZ3UcXC7dClwKbUI0= +golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181023162649-9b4f9f5ad519 h1:x6rhz8Y9CjbgQkccRGmELH6K+LJj7tOoh3XWeC1yaQM= +golang.org/x/net v0.0.0-20181023162649-9b4f9f5ad519/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181026203630-95b1ffbd15a5 h1:x6r4Jo0KNzOOzYd8lbcRsqjuqEASK6ob3auvWYM4/8U= +golang.org/x/sys v0.0.0-20181026203630-95b1ffbd15a5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a h1:1n5lsVfiQW3yfsRGu98756EH1YthsFqr/5mxHduZW2A= +golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +google.golang.org/appengine v1.1.0 h1:igQkv0AAhEIvTEpD5LIpAfav2eeVO9HBTjvKHVJPRSs= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20180831171423-11092d34479b h1:lohp5blsw53GBXtLyLNaTXPXS9pJ1tiTw61ZHUoE9Qw= +google.golang.org/genproto v0.0.0-20180831171423-11092d34479b/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/grpc v1.19.0 h1:cfg4PD8YEdSFnm7qLV4++93WcmhH2nIUhMjhdCvl3j8= +google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/src-d/go-errors.v1 v1.0.0 h1:cooGdZnCjYbeS1zb1s6pVAAimTdKceRrpn7aKOnNIfc= +gopkg.in/src-d/go-errors.v1 v1.0.0/go.mod h1:q1cBlomlw2FnDBDNGlnh6X0jPihy+QxZfMMNxPCbdYg= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +modernc.org/mathutil v1.0.0 h1:93vKjrJopTPrtTNpZ8XIovER7iCIH1QU7wNbOQXC60I= +modernc.org/mathutil v1.0.0/go.mod h1:wU0vUrJsVWBZ4P6e7xtFJEhFSNsfRLJ8H458uRjg03k= +modernc.org/strutil v1.0.0 h1:XVFtQwFVwc02Wk+0L/Z/zDDXO81r5Lhe6iMKmGX3KhE= +modernc.org/strutil v1.0.0/go.mod h1:lstksw84oURvj9y3tn8lGvRxyRC1S2+g5uuIzNfIOBs= +vitess.io/vitess v3.0.0-rc.3.0.20190602171040-12bfde34629c+incompatible h1:GWnLrAdetgJM0Co5bwwczO49iFZBSInpyGAT77BP9Y0= +vitess.io/vitess v3.0.0-rc.3.0.20190602171040-12bfde34629c+incompatible/go.mod h1:h4qvkyNYTOC0xI+vcidSWoka0gQAZc9ZPHbkHo48gP0= diff --git a/internal/regex/regex.go b/internal/regex/regex.go new file mode 100644 index 000000000..33b92e1ff --- /dev/null +++ b/internal/regex/regex.go @@ -0,0 +1,100 @@ +package regex + +import ( + "github.com/go-kit/kit/metrics/discard" + errors "gopkg.in/src-d/go-errors.v1" +) + +var ( + // ErrRegexAlreadyRegistered is returned when there is a previously + // registered regex engine with the same name. + ErrRegexAlreadyRegistered = errors.NewKind("Regex engine already registered: %s") + // ErrRegexNameEmpty returned when the name is "". + ErrRegexNameEmpty = errors.NewKind("Regex engine name cannot be empty") + // ErrRegexNotFound returned when the regex engine is not registered. + ErrRegexNotFound = errors.NewKind("Regex engine not found: %s") + + registry map[string]Constructor + defaultEngine string +) + +// Matcher interface is used to compare regexes with strings. +type Matcher interface { + // Match returns true if the text matches the regular expression. + Match(text string) bool +} + +// Disposer interface is used to release resources. +// The interface should be implemented by all go binding for native C libraries +type Disposer interface { + Dispose() +} + +// Constructor creates a new Matcher. +type Constructor func(re string) (Matcher, Disposer, error) + +var ( + // CompileHistogram describes a regexp compile time. + CompileHistogram = discard.NewHistogram() + + // MatchHistogram describes a regexp match time. + MatchHistogram = discard.NewHistogram() +) + +// Register add a new regex engine to the registry. +func Register(name string, c Constructor) error { + if registry == nil { + registry = make(map[string]Constructor) + } + + if name == "" { + return ErrRegexNameEmpty.New() + } + + _, ok := registry[name] + if ok { + return ErrRegexAlreadyRegistered.New(name) + } + + registry[name] = c + + return nil +} + +// Engines returns the list of regex engines names. +func Engines() []string { + var names []string + + for n := range registry { + names = append(names, n) + } + + return names +} + +// New creates a new Matcher with the specified regex engine. +func New(name, re string) (Matcher, Disposer, error) { + n, ok := registry[name] + if !ok { + return nil, nil, ErrRegexNotFound.New(name) + } + + return n(re) +} + +// Default returns the default regex engine. +func Default() string { + if defaultEngine != "" { + return defaultEngine + } + if _, ok := registry["go"]; ok { + return "go" + } + + return "oniguruma" +} + +// SetDefault sets the regex engine returned by Default. +func SetDefault(name string) { + defaultEngine = name +} diff --git a/internal/regex/regex_go.go b/internal/regex/regex_go.go new file mode 100644 index 000000000..e4c1be7c5 --- /dev/null +++ b/internal/regex/regex_go.go @@ -0,0 +1,44 @@ +package regex + +import ( + "regexp" + "time" +) + +// Go holds go regex engine Matcher. +type Go struct { + reg *regexp.Regexp +} + +// Match implements Matcher interface. +func (r *Go) Match(s string) bool { + t := time.Now() + defer MatchHistogram.With("string", s, "duration", "seconds").Observe(time.Since(t).Seconds()) + + return r.reg.MatchString(s) +} + +// Dispose implements Disposer interface. +func (*Go) Dispose() {} + +// NewGo creates a new Matcher using go regex engine. +func NewGo(re string) (Matcher, Disposer, error) { + t := time.Now() + reg, err := regexp.Compile(re) + if err != nil { + return nil, nil, err + } + CompileHistogram.With("regex", re, "duration", "seconds").Observe(time.Since(t).Seconds()) + + r := Go{ + reg: reg, + } + return &r, &r, nil +} + +func init() { + err := Register("go", NewGo) + if err != nil { + panic(err.Error()) + } +} diff --git a/internal/regex/regex_oniguruma.go b/internal/regex/regex_oniguruma.go new file mode 100644 index 000000000..39330d417 --- /dev/null +++ b/internal/regex/regex_oniguruma.go @@ -0,0 +1,50 @@ +// +build oniguruma + +package regex + +import ( + "time" + + rubex "github.com/src-d/go-oniguruma" +) + +// Oniguruma holds a rubex regular expression Matcher. +type Oniguruma struct { + reg *rubex.Regexp +} + +// Match implements Matcher interface. +func (r *Oniguruma) Match(s string) bool { + t := time.Now() + defer MatchHistogram.With("string", s, "duration", "seconds").Observe(time.Since(t).Seconds()) + + return r.reg.MatchString(s) +} + +// Dispose implements Disposer interface. +// The function releases resources for oniguruma's precompiled regex +func (r *Oniguruma) Dispose() { + r.reg.Free() +} + +// NewOniguruma creates a new Matcher using oniguruma engine. +func NewOniguruma(re string) (Matcher, Disposer, error) { + t := time.Now() + reg, err := rubex.Compile(re) + if err != nil { + return nil, nil, err + } + CompileHistogram.With("regex", re, "duration", "seconds").Observe(time.Since(t).Seconds()) + + r := Oniguruma{ + reg: reg, + } + return &r, &r, nil +} + +func init() { + err := Register("oniguruma", NewOniguruma) + if err != nil { + panic(err.Error()) + } +} diff --git a/internal/regex/regex_test.go b/internal/regex/regex_test.go new file mode 100644 index 000000000..118a19c85 --- /dev/null +++ b/internal/regex/regex_test.go @@ -0,0 +1,108 @@ +package regex + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func dummy(s string) (Matcher, Disposer, error) { return nil, nil, nil } + +func getDefault() string { + for _, n := range Engines() { + if n == "oniguruma" { + return n + } + } + + return "go" +} + +func TestRegistration(t *testing.T) { + require := require.New(t) + + engines := Engines() + require.NotNil(engines) + number := len(engines) + + defaultEngine := getDefault() + require.Equal(defaultEngine, Default()) + + err := Register("", dummy) + require.Equal(true, ErrRegexNameEmpty.Is(err)) + engines = Engines() + require.Len(engines, number) + + err = Register("nil", dummy) + require.NoError(err) + require.Len(Engines(), number+1) + + matcher, disposer, err := New("nil", "") + require.NoError(err) + require.Nil(matcher) + require.Nil(disposer) +} + +func TestDefault(t *testing.T) { + require := require.New(t) + + def := getDefault() + require.Equal(def, Default()) + + SetDefault("default") + require.Equal("default", Default()) + + SetDefault("") + require.Equal(def, Default()) +} + +func TestMatcher(t *testing.T) { + for _, name := range Engines() { + if name == "nil" { + continue + } + + t.Run(name, func(t *testing.T) { + m, d, err := New(name, "a{3}") + require.NoError(t, err) + + require.Equal(t, true, m.Match("ooaaaoo")) + require.Equal(t, false, m.Match("ooaaoo")) + + d.Dispose() + }) + } +} + +func TestMatcherMultiPatterns(t *testing.T) { + const ( + email = `[\w\.+-]+@[\w\.-]+\.[\w\.-]+` + url = `[\w]+://[^/\s?#]+[^\s?#]+(?:\?[^\s#]*)?(?:#[^\s]*)?` + ip = `(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9])\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9])` + + data = `mysql://root@255.255.255.255:3306` + ) + + for _, name := range Engines() { + if name == "nil" { + continue + } + + t.Run(name, func(t *testing.T) { + m, d, err := New(name, email) + require.NoError(t, err) + require.Equal(t, true, m.Match(data)) + d.Dispose() + + m, d, err = New(name, url) + require.NoError(t, err) + require.Equal(t, true, m.Match(data)) + d.Dispose() + + m, d, err = New(name, ip) + require.NoError(t, err) + require.Equal(t, true, m.Match(data)) + d.Dispose() + }) + } +} diff --git a/internal/similartext/similartext.go b/internal/similartext/similartext.go new file mode 100644 index 000000000..a7beeb09d --- /dev/null +++ b/internal/similartext/similartext.go @@ -0,0 +1,107 @@ +package similartext + +import ( + "fmt" + "reflect" + "strings" +) + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +// DistanceForStrings returns the edit distance between source and target. +// It has a runtime proportional to len(source) * len(target) and memory use +// proportional to len(target). +// Taken (simplified, for strings and with default options) from: +// https://github.com/texttheater/golang-levenshtein +func distanceForStrings(source, target string) int { + height := len(source) + 1 + width := len(target) + 1 + matrix := make([][]int, 2) + + for i := 0; i < 2; i++ { + matrix[i] = make([]int, width) + matrix[i][0] = i + } + for j := 1; j < width; j++ { + matrix[0][j] = j + } + + for i := 1; i < height; i++ { + cur := matrix[i%2] + prev := matrix[(i-1)%2] + cur[0] = i + for j := 1; j < width; j++ { + delCost := prev[j] + 1 + matchSubCost := prev[j-1] + if source[i-1] != target[j-1] { + matchSubCost += 2 + } + insCost := cur[j-1] + 1 + cur[j] = min(delCost, min(matchSubCost, insCost)) + } + } + return matrix[(height-1)%2][width-1] +} + +// MaxDistanceIgnored is the maximum Levenshtein distance from which +// we won't consider a string similar at all and thus will be ignored. +var DistanceSkipped = 3 + +// Find returns a string with suggestions for name(s) in `names` +// similar to the string `src` until a max distance of `DistanceSkipped`. +func Find(names []string, src string) string { + if len(src) == 0 { + return "" + } + + minDistance := -1 + matchMap := make(map[int][]string) + + for _, name := range names { + dist := distanceForStrings(name, src) + if dist >= DistanceSkipped { + continue + } + + if minDistance == -1 || dist < minDistance { + minDistance = dist + } + + matchMap[dist] = append(matchMap[dist], name) + } + + if len(matchMap) == 0 { + return "" + } + + return fmt.Sprintf(", maybe you mean %s?", + strings.Join(matchMap[minDistance], " or ")) +} + +// FindFromMap does the same as Find but taking a map instead +// of a string array as first argument. +func FindFromMap(names interface{}, src string) string { + rnames := reflect.ValueOf(names) + if rnames.Kind() != reflect.Map { + panic("Implementation error: non map used as first argument " + + "to FindFromMap") + } + + t := rnames.Type() + if t.Key().Kind() != reflect.String { + panic("Implementation error: non string key for map used as " + + "first argument to FindFromMap") + } + + var namesList []string + for _, kv := range rnames.MapKeys() { + namesList = append(namesList, kv.String()) + } + + return Find(namesList, src) +} diff --git a/internal/similartext/similartext_test.go b/internal/similartext/similartext_test.go new file mode 100644 index 000000000..bfacc6802 --- /dev/null +++ b/internal/similartext/similartext_test.go @@ -0,0 +1,52 @@ +package similartext + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestFind(t *testing.T) { + require := require.New(t) + + var names []string + res := Find(names, "") + require.Empty(res) + + names = []string{"foo", "bar", "aka", "ake"} + res = Find(names, "baz") + require.Equal(", maybe you mean bar?", res) + + res = Find(names, "") + require.Empty(res) + + res = Find(names, "foo") + require.Equal(", maybe you mean foo?", res) + + res = Find(names, "willBeTooDifferent") + require.Empty(res) + + res = Find(names, "aki") + require.Equal(", maybe you mean aka or ake?", res) +} + +func TestFindFromMap(t *testing.T) { + require := require.New(t) + + var names map[string]int + res := FindFromMap(names, "") + require.Empty(res) + + names = map[string]int{ + "foo": 1, + "bar": 2, + } + res = FindFromMap(names, "baz") + require.Equal(", maybe you mean bar?", res) + + res = FindFromMap(names, "") + require.Empty(res) + + res = FindFromMap(names, "foo") + require.Equal(", maybe you mean foo?", res) +} diff --git a/internal/sockstate/netstat.go b/internal/sockstate/netstat.go new file mode 100644 index 000000000..500aa3879 --- /dev/null +++ b/internal/sockstate/netstat.go @@ -0,0 +1,86 @@ +package sockstate + +import ( + "fmt" + "net" + + "gopkg.in/src-d/go-errors.v1" +) + +// OS independent part of the netstat_[OS].go modules +// Taken (simplified, privatized and with utility functions added) from: +// https://github.com/cakturk/go-netstat + +// skState type represents socket connection state +type skState uint8 + +func (s skState) String() string { + return skStates[s] +} + +// ErrSocketCheckNotImplemented will be returned for OS where the socket checks is not implemented yet +var ErrSocketCheckNotImplemented = errors.NewKind("socket checking not implemented for this OS") + +// Socket states +const ( + Established skState = 0x01 + SynSent skState = 0x02 + SynRecv skState = 0x03 + FinWait1 skState = 0x04 + FinWait2 skState = 0x05 + TimeWait skState = 0x06 + Close skState = 0x07 + CloseWait skState = 0x08 + LastAck skState = 0x09 + Listen skState = 0x0a + Closing skState = 0x0b +) + +var skStates = [...]string{ + "UNKNOWN", + "ESTABLISHED", + "SYN_SENT", + "SYN_RECV", + "FIN_WAIT1", + "FIN_WAIT2", + "TIME_WAIT", + "", // CLOSE + "CLOSE_WAIT", + "LAST_ACK", + "LISTEN", + "CLOSING", +} + +// sockAddr represents an ip:port pair +type sockAddr struct { + IP net.IP + Port uint16 +} + +func (s *sockAddr) String() string { + return fmt.Sprintf("%v:%d", s.IP, s.Port) +} + +// sockTabEntry type represents each line of the /proc/net/tcp +type sockTabEntry struct { + Ino string + LocalAddr *sockAddr + RemoteAddr *sockAddr + State skState + UID uint32 + Process *process +} + +// process holds the PID and process name to which each socket belongs +type process struct { + pid int + name string +} + +func (p *process) String() string { + return fmt.Sprintf("%d/%s", p.pid, p.name) +} + +// AcceptFn is used to filter socket entries. The value returned indicates +// whether the element is to be appended to the socket list. +type AcceptFn func(*sockTabEntry) bool diff --git a/internal/sockstate/netstat_darwin.go b/internal/sockstate/netstat_darwin.go new file mode 100644 index 000000000..f68ffd989 --- /dev/null +++ b/internal/sockstate/netstat_darwin.go @@ -0,0 +1,21 @@ +// +build darwin + +package sockstate + +import ( + "net" + + "github.com/sirupsen/logrus" +) + +// tcpSocks returns a slice of active TCP sockets containing only those +// elements that satisfy the accept function +func tcpSocks(accept AcceptFn) ([]sockTabEntry, error) { + // (juanjux) TODO: not implemented + logrus.Warn("Connection checking not implemented for Darwin") + return nil, ErrSocketCheckNotImplemented.New() +} + +func GetConnInode(c *net.TCPConn) (n uint64, err error) { + return 0, ErrSocketCheckNotImplemented.New() +} diff --git a/internal/sockstate/netstat_linux.go b/internal/sockstate/netstat_linux.go new file mode 100644 index 000000000..a7eb6ff62 --- /dev/null +++ b/internal/sockstate/netstat_linux.go @@ -0,0 +1,260 @@ +// +build linux + +package sockstate + +// Taken (simplified and with utility functions added) from https://github.com/cakturk/go-netstat + +import ( + "bufio" + "bytes" + "encoding/binary" + "fmt" + "io" + "io/ioutil" + "log" + "net" + "os" + "path" + "strconv" + "strings" +) + +const ( + pathTCP4Tab = "/proc/net/tcp" + pathTCP6Tab = "/proc/net/tcp6" + ipv4StrLen = 8 + ipv6StrLen = 32 +) + +type procFd struct { + base string + pid int + sktab []sockTabEntry + p *process +} + +const sockPrefix = "socket:[" + +func getProcName(s []byte) string { + i := bytes.Index(s, []byte("(")) + if i < 0 { + return "" + } + j := bytes.LastIndex(s, []byte(")")) + if i < 0 { + return "" + } + if i > j { + return "" + } + return string(s[i+1 : j]) +} + +func (p *procFd) iterFdDir() { + // link name is of the form socket:[5860846] + fddir := path.Join(p.base, "/fd") + fi, err := ioutil.ReadDir(fddir) + if err != nil { + return + } + var buf [128]byte + + for _, file := range fi { + fd := path.Join(fddir, file.Name()) + lname, err := os.Readlink(fd) + if err != nil { + continue + } + + for i := range p.sktab { + sk := &p.sktab[i] + ss := sockPrefix + sk.Ino + "]" + if ss != lname { + continue + } + if p.p == nil { + stat, err := os.Open(path.Join(p.base, "stat")) + if err != nil { + return + } + n, err := stat.Read(buf[:]) + _ = stat.Close() + if err != nil { + return + } + z := bytes.SplitN(buf[:n], []byte(" "), 3) + name := getProcName(z[1]) + p.p = &process{p.pid, name} + } + sk.Process = p.p + } + } +} + +func extractProcInfo(sktab []sockTabEntry) { + const basedir = "/proc" + fi, err := ioutil.ReadDir(basedir) + if err != nil { + return + } + + for _, file := range fi { + if !file.IsDir() { + continue + } + pid, err := strconv.Atoi(file.Name()) + if err != nil { + continue + } + base := path.Join(basedir, file.Name()) + proc := procFd{base: base, pid: pid, sktab: sktab} + proc.iterFdDir() + } +} + +func parseIPv4(s string) (net.IP, error) { + v, err := strconv.ParseUint(s, 16, 32) + if err != nil { + return nil, err + } + ip := make(net.IP, net.IPv4len) + binary.LittleEndian.PutUint32(ip, uint32(v)) + return ip, nil +} + +func parseIPv6(s string) (net.IP, error) { + ip := make(net.IP, net.IPv6len) + const grpLen = 4 + i, j := 0, 4 + for len(s) != 0 { + grp := s[0:8] + u, err := strconv.ParseUint(grp, 16, 32) + binary.LittleEndian.PutUint32(ip[i:j], uint32(u)) + if err != nil { + return nil, err + } + i, j = i+grpLen, j+grpLen + s = s[8:] + } + return ip, nil +} + +func parseAddr(s string) (*sockAddr, error) { + fields := strings.Split(s, ":") + if len(fields) < 2 { + return nil, fmt.Errorf("sockstate: not enough fields: %v", s) + } + var ip net.IP + var err error + switch len(fields[0]) { + case ipv4StrLen: + ip, err = parseIPv4(fields[0]) + case ipv6StrLen: + ip, err = parseIPv6(fields[0]) + default: + log.Fatal("Badly formatted connection address:", s) + } + if err != nil { + return nil, err + } + v, err := strconv.ParseUint(fields[1], 16, 16) + if err != nil { + return nil, err + } + return &sockAddr{IP: ip, Port: uint16(v)}, nil +} + +func parseSocktab(r io.Reader, accept AcceptFn) ([]sockTabEntry, error) { + br := bufio.NewScanner(r) + tab := make([]sockTabEntry, 0, 4) + + // Discard title + br.Scan() + + for br.Scan() { + var e sockTabEntry + line := br.Text() + // Skip comments + if i := strings.Index(line, "#"); i >= 0 { + line = line[:i] + } + fields := strings.Fields(line) + if len(fields) < 12 { + return nil, fmt.Errorf("sockstate: not enough fields: %v, %v", len(fields), fields) + } + addr, err := parseAddr(fields[1]) + if err != nil { + return nil, err + } + e.LocalAddr = addr + addr, err = parseAddr(fields[2]) + if err != nil { + return nil, err + } + e.RemoteAddr = addr + u, err := strconv.ParseUint(fields[3], 16, 8) + if err != nil { + return nil, err + } + e.State = skState(u) + u, err = strconv.ParseUint(fields[7], 10, 32) + if err != nil { + return nil, err + } + e.UID = uint32(u) + e.Ino = fields[9] + if accept(&e) { + tab = append(tab, e) + } + } + return tab, br.Err() +} + +// tcpSocks returns a slice of active TCP sockets containing only those +// elements that satisfy the accept function +func tcpSocks(accept AcceptFn) ([]sockTabEntry, error) { + paths := [2]string{pathTCP4Tab, pathTCP6Tab} + var allTabs []sockTabEntry + for _, p := range paths { + f, err := os.Open(p) + defer func() { + _ = f.Close() + }() + if err != nil { + return nil, err + } + + t, err := parseSocktab(f, accept) + if err != nil { + return nil, err + } + allTabs = append(allTabs, t...) + + } + extractProcInfo(allTabs) + return allTabs, nil +} + +// GetConnInode returns the Linux inode number of a TCP connection +func GetConnInode(c *net.TCPConn) (n uint64, err error) { + f, err := c.File() + if err != nil { + return + } + + socketStr := fmt.Sprintf("/proc/%d/fd/%d", os.Getpid(), f.Fd()) + socketLnk, err := os.Readlink(socketStr) + if err != nil { + return + } + + if strings.HasPrefix(socketLnk, sockPrefix) { + _, err = fmt.Sscanf(socketLnk, sockPrefix+"%d]", &n) + if err != nil { + return + } + } else { + err = ErrNoSocketLink.New() + } + return +} diff --git a/internal/sockstate/netstat_windows.go b/internal/sockstate/netstat_windows.go new file mode 100644 index 000000000..1f8d98fc9 --- /dev/null +++ b/internal/sockstate/netstat_windows.go @@ -0,0 +1,21 @@ +// +build windows + +package sockstate + +import ( + "net" + + "github.com/sirupsen/logrus" +) + +// tcpSocks returns a slice of active TCP sockets containing only those +// elements that satisfy the accept function +func tcpSocks(accept AcceptFn) ([]sockTabEntry, error) { + // (juanjux) TODO: not implemented + logrus.Warn("Connection checking not implemented for Windows") + return nil, ErrSocketCheckNotImplemented.New() +} + +func GetConnInode(c *net.TCPConn) (n uint64, err error) { + return 0, ErrSocketCheckNotImplemented.New() +} diff --git a/internal/sockstate/sockstate.go b/internal/sockstate/sockstate.go new file mode 100644 index 000000000..f47b05a34 --- /dev/null +++ b/internal/sockstate/sockstate.go @@ -0,0 +1,60 @@ +package sockstate + +import ( + "gopkg.in/src-d/go-errors.v1" + "strconv" +) + +type SockState uint8 + +const ( + Broken = iota + Other + Error +) + +var ErrNoSocketLink = errors.NewKind("couldn't resolve file descriptor link to socket") + +// ErrMultipleSocketsForInode is returned when more than one socket is found for an inode +var ErrMultipleSocketsForInode = errors.NewKind("more than one socket found for inode") + +func GetInodeSockState(port int, inode uint64) (SockState, error) { + socks, err := tcpSocks(func(s *sockTabEntry) bool { + if s.LocalAddr.Port != uint16(port) { + return false + } + + si, err := strconv.ParseUint(s.Ino, 10, 64) + if err != nil { + return false + } + return inode == si + }) + if err != nil { + return Error, err + } + + switch len(socks) { + case 0: + return Broken, nil + case 1: + switch socks[0].State { + case CloseWait: + fallthrough + case TimeWait: + fallthrough + case FinWait1: + fallthrough + case FinWait2: + fallthrough + case Close: + fallthrough + case Closing: + return Broken, nil + default: + return Other, nil + } + default: // more than one sock for inode, impossible? + return Error, ErrMultipleSocketsForInode.New() + } +} diff --git a/log.go b/log.go new file mode 100644 index 000000000..346b545f7 --- /dev/null +++ b/log.go @@ -0,0 +1,74 @@ +package sqle + +import ( + "github.com/golang/glog" + "github.com/sirupsen/logrus" + vtlog "vitess.io/vitess/go/vt/log" +) + +func init() { + // V quickly checks if the logging verbosity meets a threshold. + vtlog.V = func(level glog.Level) glog.Verbose { + lvl := logrus.GetLevel() + switch int32(level) { + case 0: + return glog.Verbose(lvl == logrus.InfoLevel) + case 1: + return glog.Verbose(lvl == logrus.WarnLevel) + case 2: + return glog.Verbose(lvl == logrus.ErrorLevel) + case 3: + return glog.Verbose(lvl == logrus.FatalLevel) + default: + return glog.Verbose(false) + } + } + + // Flush ensures any pending I/O is written. + vtlog.Flush = func() {} + + // Info formats arguments like fmt.Print. + vtlog.Info = logrus.Info + // Infof formats arguments like fmt.Printf. + vtlog.Infof = logrus.Infof + // InfoDepth formats arguments like fmt.Print and uses depth to choose which call frame to log. + vtlog.InfoDepth = func(_ int, args ...interface{}) { + logrus.Info(args...) + } + + // Warning formats arguments like fmt.Print. + vtlog.Warning = logrus.Warning + // Warningf formats arguments like fmt.Printf. + vtlog.Warningf = logrus.Warningf + // WarningDepth formats arguments like fmt.Print and uses depth to choose which call frame to log. + vtlog.WarningDepth = func(depth int, args ...interface{}) { + logrus.Warning(args...) + } + + // Error formats arguments like fmt.Print. + vtlog.Error = logrus.Error + // Errorf formats arguments like fmt.Printf. + vtlog.Errorf = logrus.Errorf + // ErrorDepth formats arguments like fmt.Print and uses depth to choose which call frame to log. + vtlog.ErrorDepth = func(_ int, args ...interface{}) { + logrus.Error(args...) + } + + // Exit formats arguments like fmt.Print. + vtlog.Exit = logrus.Panic + // Exitf formats arguments like fmt.Printf. + vtlog.Exitf = logrus.Panicf + // ExitDepth formats arguments like fmt.Print and uses depth to choose which call frame to log. + vtlog.ExitDepth = func(_ int, args ...interface{}) { + logrus.Panic(args...) + } + + // Fatal formats arguments like fmt.Print. + vtlog.Fatal = logrus.Fatal + // Fatalf formats arguments like fmt.Printf + vtlog.Fatalf = logrus.Fatalf + // FatalDepth formats arguments like fmt.Print and uses depth to choose which call frame to log. + vtlog.FatalDepth = func(_ int, args ...interface{}) { + logrus.Fatal(args...) + } +} diff --git a/mem/table.go b/mem/table.go deleted file mode 100644 index 9cafbda21..000000000 --- a/mem/table.go +++ /dev/null @@ -1,90 +0,0 @@ -package mem - -import ( - "fmt" - - "gopkg.in/src-d/go-mysql-server.v0/sql" -) - -// Table represents an in-memory database table. -type Table struct { - name string - schema sql.Schema - data []sql.Row -} - -// NewTable creates a new Table with the given name and schema. -func NewTable(name string, schema sql.Schema) *Table { - return &Table{ - name: name, - schema: schema, - } -} - -// Resolved implements the Resolvable interface. -func (Table) Resolved() bool { - return true -} - -// Name returns the table name. -func (t *Table) Name() string { - return t.name -} - -// Schema implements the Node interface. -func (t *Table) Schema() sql.Schema { - return t.schema -} - -// Children implements the Node interface. -func (t *Table) Children() []sql.Node { - return nil -} - -// RowIter implements the Node interface. -func (t *Table) RowIter(ctx *sql.Context) (sql.RowIter, error) { - return sql.RowsToRowIter(t.data...), nil -} - -// TransformUp implements the Transformer interface. -func (t *Table) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(t) -} - -// TransformExpressionsUp implements the Transformer interface. -func (t *Table) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - return t, nil -} - -// Insert a new row into the table. -func (t *Table) Insert(row sql.Row) error { - if len(row) != len(t.schema) { - return sql.ErrUnexpectedRowLength.New(len(t.schema), len(row)) - } - - for idx, value := range row { - c := t.schema[idx] - if !c.Check(value) { - return sql.ErrInvalidType.New(value) - } - } - - t.data = append(t.data, row.Copy()) - return nil -} - -func (t Table) String() string { - p := sql.NewTreePrinter() - _ = p.WriteNode("Table(%s)", t.name) - var schema = make([]string, len(t.schema)) - for i, col := range t.schema { - schema[i] = fmt.Sprintf( - "Column(%s, %s, nullable=%v)", - col.Name, - col.Type.Type().String(), - col.Nullable, - ) - } - _ = p.WriteChildren(schema...) - return p.String() -} diff --git a/mem/table_test.go b/mem/table_test.go deleted file mode 100644 index a679df017..000000000 --- a/mem/table_test.go +++ /dev/null @@ -1,59 +0,0 @@ -package mem - -import ( - "testing" - - "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/sql" -) - -func TestTable_Name(t *testing.T) { - require := require.New(t) - s := sql.Schema{ - {"col1", sql.Text, nil, true, ""}, - } - table := NewTable("test", s) - require.Equal("test", table.Name()) -} - -const expectedString = `Table(foo) - ├─ Column(col1, TEXT, nullable=true) - └─ Column(col2, INT64, nullable=false) -` - -func TestTableString(t *testing.T) { - require := require.New(t) - table := NewTable("foo", sql.Schema{ - {"col1", sql.Text, nil, true, ""}, - {"col2", sql.Int64, nil, false, ""}, - }) - require.Equal(expectedString, table.String()) -} - -func TestTable_Insert_RowIter(t *testing.T) { - require := require.New(t) - ctx := sql.NewEmptyContext() - - s := sql.Schema{ - {"col1", sql.Text, nil, true, ""}, - } - - table := NewTable("test", s) - - rows, err := sql.NodeToRows(ctx, table) - require.Nil(err) - require.Len(rows, 0) - - err = table.Insert(sql.NewRow("foo")) - rows, err = sql.NodeToRows(ctx, table) - require.Nil(err) - require.Len(rows, 1) - require.Nil(s.CheckRow(rows[0])) - - err = table.Insert(sql.NewRow("bar")) - rows, err = sql.NodeToRows(ctx, table) - require.Nil(err) - require.Len(rows, 2) - require.Nil(s.CheckRow(rows[0])) - require.Nil(s.CheckRow(rows[1])) -} diff --git a/mem/database.go b/memory/database.go similarity index 63% rename from mem/database.go rename to memory/database.go index b65402daf..133c4f5de 100644 --- a/mem/database.go +++ b/memory/database.go @@ -1,7 +1,7 @@ -package mem +package memory import ( - "gopkg.in/src-d/go-mysql-server.v0/sql" + "github.com/src-d/go-mysql-server/sql" ) // Database is an in-memory database. @@ -33,14 +33,25 @@ func (d *Database) AddTable(name string, t sql.Table) { d.tables[name] = t } -// Create creates a table with the given name and schema -func (d *Database) Create(name string, schema sql.Schema) error { +// CreateTable creates a table with the given name and schema +func (d *Database) CreateTable(ctx *sql.Context, name string, schema sql.Schema) error { _, ok := d.tables[name] if ok { return sql.ErrTableAlreadyExists.New(name) } d.tables[name] = NewTable(name, schema) + return nil +} + +// DropTable drops the table with the given name +func (d *Database) DropTable(ctx *sql.Context, name string) error { + _, ok := d.tables[name] + if !ok { + return sql.ErrTableNotFound.New(name) + } + delete(d.tables, name) return nil } + diff --git a/mem/database_test.go b/memory/database_test.go similarity index 73% rename from mem/database_test.go rename to memory/database_test.go index a24a57c94..c6a2f9182 100644 --- a/mem/database_test.go +++ b/memory/database_test.go @@ -1,10 +1,10 @@ -package mem +package memory import ( "testing" + "github.com/src-d/go-mysql-server/sql" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/sql" ) func TestDatabase_Name(t *testing.T) { @@ -19,8 +19,7 @@ func TestDatabase_AddTable(t *testing.T) { tables := db.Tables() require.Equal(0, len(tables)) - var altDb sql.Alterable = db - err := altDb.Create("test_table", sql.Schema{}) + err := db.CreateTable(sql.NewEmptyContext(), "test_table", nil) require.NoError(err) tables = db.Tables() @@ -29,6 +28,6 @@ func TestDatabase_AddTable(t *testing.T) { require.True(ok) require.NotNil(tt) - err = altDb.Create("test_table", sql.Schema{}) + err = db.CreateTable(sql.NewEmptyContext(), "test_table", nil) require.Error(err) } diff --git a/memory/table.go b/memory/table.go new file mode 100644 index 000000000..088956316 --- /dev/null +++ b/memory/table.go @@ -0,0 +1,555 @@ +package memory + +import ( + "bytes" + "encoding/gob" + "fmt" + "io" + "strconv" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + errors "gopkg.in/src-d/go-errors.v1" +) + +// Table represents an in-memory database table. +type Table struct { + name string + schema sql.Schema + partitions map[string][]sql.Row + keys [][]byte + + insert int + + filters []sql.Expression + projection []string + columns []int + lookup sql.IndexLookup +} + +var _ sql.Table = (*Table)(nil) +var _ sql.Inserter = (*Table)(nil) +var _ sql.FilteredTable = (*Table)(nil) +var _ sql.ProjectedTable = (*Table)(nil) +var _ sql.IndexableTable = (*Table)(nil) + +// NewTable creates a new Table with the given name and schema. +func NewTable(name string, schema sql.Schema) *Table { + return NewPartitionedTable(name, schema, 0) +} + +// NewPartitionedTable creates a new Table with the given name, schema and number of partitions. +func NewPartitionedTable(name string, schema sql.Schema, numPartitions int) *Table { + var keys [][]byte + var partitions = map[string][]sql.Row{} + + if numPartitions < 1 { + numPartitions = 1 + } + + for i := 0; i < numPartitions; i++ { + key := strconv.Itoa(i) + keys = append(keys, []byte(key)) + partitions[key] = []sql.Row{} + } + + return &Table{ + name: name, + schema: schema, + partitions: partitions, + keys: keys, + } +} + +// Name implements the sql.Table interface. +func (t *Table) Name() string { + return t.name +} + +// Schema implements the sql.Table interface. +func (t *Table) Schema() sql.Schema { + return t.schema +} + +// Partitions implements the sql.Table interface. +func (t *Table) Partitions(ctx *sql.Context) (sql.PartitionIter, error) { + var keys [][]byte + for _, k := range t.keys { + if rows, ok := t.partitions[string(k)]; ok && len(rows) > 0 { + keys = append(keys, k) + } + } + return &partitionIter{keys: keys}, nil +} + +// PartitionCount implements the sql.PartitionCounter interface. +func (t *Table) PartitionCount(ctx *sql.Context) (int64, error) { + return int64(len(t.partitions)), nil +} + +// PartitionRows implements the sql.PartitionRows interface. +func (t *Table) PartitionRows(ctx *sql.Context, partition sql.Partition) (sql.RowIter, error) { + rows, ok := t.partitions[string(partition.Key())] + if !ok { + return nil, fmt.Errorf( + "partition not found: %q", partition.Key(), + ) + } + + var values sql.IndexValueIter + if t.lookup != nil { + var err error + values, err = t.lookup.Values(partition) + if err != nil { + return nil, err + } + } + + return &tableIter{ + rows: rows, + columns: t.columns, + filters: t.filters, + indexValues: values, + }, nil +} + +type partition struct { + key []byte +} + +func (p *partition) Key() []byte { return p.key } + +type partitionIter struct { + keys [][]byte + pos int +} + +func (p *partitionIter) Next() (sql.Partition, error) { + if p.pos >= len(p.keys) { + return nil, io.EOF + } + + key := p.keys[p.pos] + p.pos++ + return &partition{key}, nil +} + +func (p *partitionIter) Close() error { return nil } + +type tableIter struct { + columns []int + filters []sql.Expression + + rows []sql.Row + indexValues sql.IndexValueIter + pos int +} + +var _ sql.RowIter = (*tableIter)(nil) + +func (i *tableIter) Next() (sql.Row, error) { + row, err := i.getRow() + if err != nil { + return nil, err + } + + for _, f := range i.filters { + result, err := f.Eval(sql.NewEmptyContext(), row) + if err != nil { + return nil, err + } + + if result != true { + return i.Next() + } + } + + return projectOnRow(i.columns, row), nil +} + +func (i *tableIter) Close() error { + if i.indexValues == nil { + return nil + } + + return i.indexValues.Close() +} + +func (i *tableIter) getRow() (sql.Row, error) { + if i.indexValues != nil { + return i.getFromIndex() + } + + if i.pos >= len(i.rows) { + return nil, io.EOF + } + + row := i.rows[i.pos] + i.pos++ + return row, nil +} + +func projectOnRow(columns []int, row sql.Row) sql.Row { + if len(columns) < 1 { + return row + } + + projected := make([]interface{}, len(columns)) + for i, selected := range columns { + projected[i] = row[selected] + } + + return projected +} + +func (i *tableIter) getFromIndex() (sql.Row, error) { + data, err := i.indexValues.Next() + if err != nil { + return nil, err + } + + value, err := decodeIndexValue(data) + if err != nil { + return nil, err + } + + return i.rows[value.Pos], nil +} + +type indexValue struct { + Key string + Pos int +} + +func decodeIndexValue(data []byte) (*indexValue, error) { + dec := gob.NewDecoder(bytes.NewReader(data)) + var value indexValue + if err := dec.Decode(&value); err != nil { + return nil, err + } + + return &value, nil +} + +func encodeIndexValue(value *indexValue) ([]byte, error) { + var buf bytes.Buffer + enc := gob.NewEncoder(&buf) + if err := enc.Encode(value); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +// Insert a new row into the table. +func (t *Table) Insert(ctx *sql.Context, row sql.Row) error { + if err := checkRow(t.schema, row); err != nil { + return err + } + + key := string(t.keys[t.insert]) + t.insert++ + if t.insert == len(t.keys) { + t.insert = 0 + } + + t.partitions[key] = append(t.partitions[key], row) + return nil +} + +// Delete the given row from the table. +func (t *Table) Delete(ctx *sql.Context, row sql.Row) error { + if err := checkRow(t.schema, row); err != nil { + return err + } + + matches := false + for partitionIndex, partition := range t.partitions { + for partitionRowIndex, partitionRow := range partition { + matches = true + for rIndex, val := range row { + if val != partitionRow[rIndex] { + matches = false + break + } + } + if matches { + t.partitions[partitionIndex] = append(partition[:partitionRowIndex], partition[partitionRowIndex+1:]...) + break + } + } + if matches { + break + } + } + + if !matches { + return sql.ErrDeleteRowNotFound + } + + return nil +} + +func (t *Table) Update(ctx *sql.Context, oldRow sql.Row, newRow sql.Row) error { + if err := checkRow(t.schema, oldRow); err != nil { + return err + } + if err := checkRow(t.schema, newRow); err != nil { + return err + } + + matches := false + for partitionIndex, partition := range t.partitions { + for partitionRowIndex, partitionRow := range partition { + matches = true + for rIndex, val := range oldRow { + if val != partitionRow[rIndex] { + matches = false + break + } + } + if matches { + t.partitions[partitionIndex][partitionRowIndex] = newRow + break + } + } + if matches { + break + } + } + + return nil +} + +func checkRow(schema sql.Schema, row sql.Row) error { + if len(row) != len(schema) { + return sql.ErrUnexpectedRowLength.New(len(schema), len(row)) + } + + for i, value := range row { + c := schema[i] + if !c.Check(value) { + return sql.ErrInvalidType.New(value) + } + } + + return nil +} + +// String implements the sql.Table inteface. +func (t *Table) String() string { + p := sql.NewTreePrinter() + + kind := "" + if len(t.columns) > 0 { + kind += "Projected " + } + + if len(t.filters) > 0 { + kind += "Filtered " + } + + if t.lookup != nil { + kind += "Indexed" + } + + if kind != "" { + kind = ": " + kind + } + + _ = p.WriteNode("Table(%s)%s", t.name, kind) + var schema = make([]string, len(t.Schema())) + for i, col := range t.Schema() { + schema[i] = fmt.Sprintf( + "Column(%s, %s, nullable=%v)", + col.Name, + col.Type.Type().String(), + col.Nullable, + ) + } + _ = p.WriteChildren(schema...) + return p.String() +} + +// HandledFilters implements the sql.FilteredTable interface. +func (t *Table) HandledFilters(filters []sql.Expression) []sql.Expression { + var handled []sql.Expression + for _, f := range filters { + var hasOtherFields bool + expression.Inspect(f, func(e sql.Expression) bool { + if e, ok := e.(*expression.GetField); ok { + if e.Table() != t.name || !t.schema.Contains(e.Name(), t.name) { + hasOtherFields = true + return false + } + } + return true + }) + + if !hasOtherFields { + handled = append(handled, f) + } + } + + return handled +} + +// WithFilters implements the sql.FilteredTable interface. +func (t *Table) WithFilters(filters []sql.Expression) sql.Table { + if len(filters) == 0 { + return t + } + + nt := *t + nt.filters = filters + return &nt +} + +// WithProjection implements the sql.ProjectedTable interface. +func (t *Table) WithProjection(colNames []string) sql.Table { + if len(colNames) == 0 { + return t + } + + nt := *t + columns, schema, _ := nt.newColumnIndexesAndSchema(colNames) + nt.columns = columns + nt.projection = colNames + nt.schema = schema + + return &nt +} + +func (t *Table) newColumnIndexesAndSchema(colNames []string) ([]int, sql.Schema, error) { + var columns []int + var schema []*sql.Column + + for _, name := range colNames { + i := t.schema.IndexOf(name, t.name) + if i == -1 { + return nil, nil, errColumnNotFound.New(name) + } + + if len(t.columns) == 0 { + // if the table hasn't been projected before + // match against the origianl schema + columns = append(columns, i) + } else { + // get indexes for the new projections from + // the orginal indexes. + columns = append(columns, t.columns[i]) + } + + schema = append(schema, t.schema[i]) + } + + return columns, schema, nil +} + +// WithIndexLookup implements the sql.IndexableTable interface. +func (t *Table) WithIndexLookup(lookup sql.IndexLookup) sql.Table { + if lookup == nil { + return t + } + + nt := *t + nt.lookup = lookup + + return &nt +} + +// IndexKeyValues implements the sql.IndexableTable interface. +func (t *Table) IndexKeyValues( + ctx *sql.Context, + colNames []string, +) (sql.PartitionIndexKeyValueIter, error) { + iter, err := t.Partitions(ctx) + if err != nil { + return nil, err + } + + columns, _, err := t.newColumnIndexesAndSchema(colNames) + if err != nil { + return nil, err + } + + return &partitionIndexKeyValueIter{ + table: t, + iter: iter, + columns: columns, + ctx: ctx, + }, nil +} + +// Projection implements the sql.ProjectedTable interface. +func (t *Table) Projection() []string { + return t.projection +} + +// Filters implements the sql.FilteredTable interface. +func (t *Table) Filters() []sql.Expression { + return t.filters +} + +// IndexLookup implements the sql.IndexableTable interface. +func (t *Table) IndexLookup() sql.IndexLookup { + return t.lookup +} + +type partitionIndexKeyValueIter struct { + table *Table + iter sql.PartitionIter + columns []int + ctx *sql.Context +} + +func (i *partitionIndexKeyValueIter) Next() (sql.Partition, sql.IndexKeyValueIter, error) { + p, err := i.iter.Next() + if err != nil { + return nil, nil, err + } + + iter, err := i.table.PartitionRows(i.ctx, p) + if err != nil { + return nil, nil, err + } + + return p, &indexKeyValueIter{ + key: string(p.Key()), + iter: iter, + columns: i.columns, + }, nil +} + +func (i *partitionIndexKeyValueIter) Close() error { + return i.iter.Close() +} + +var errColumnNotFound = errors.NewKind("could not find column %s") + +type indexKeyValueIter struct { + key string + iter sql.RowIter + columns []int + pos int +} + +func (i *indexKeyValueIter) Next() ([]interface{}, []byte, error) { + row, err := i.iter.Next() + if err != nil { + return nil, nil, err + } + + value := &indexValue{Key: i.key, Pos: i.pos} + data, err := encodeIndexValue(value) + if err != nil { + return nil, nil, err + } + + i.pos++ + return projectOnRow(i.columns, row), data, nil +} + +func (i *indexKeyValueIter) Close() error { + return i.iter.Close() +} diff --git a/memory/table_test.go b/memory/table_test.go new file mode 100644 index 000000000..224ea5645 --- /dev/null +++ b/memory/table_test.go @@ -0,0 +1,403 @@ +package memory + +import ( + "fmt" + "io" + "testing" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" +) + +func TestTablePartitionsCount(t *testing.T) { + require := require.New(t) + table := NewPartitionedTable("foo", nil, 5) + count, err := table.PartitionCount(sql.NewEmptyContext()) + require.NoError(err) + require.Equal(int64(5), count) +} + +func TestTableName(t *testing.T) { + require := require.New(t) + s := sql.Schema{ + {Name: "col1", Type: sql.Text, Nullable: true}, + } + + table := NewTable("test", s) + require.Equal("test", table.name) +} + +const expectedString = `Table(foo) + ├─ Column(col1, TEXT, nullable=true) + └─ Column(col2, INT64, nullable=false) +` + +func TestTableString(t *testing.T) { + require := require.New(t) + table := NewTable("foo", sql.Schema{ + {Name: "col1", Type: sql.Text, Nullable: true}, + {Name: "col2", Type: sql.Int64, Nullable: false}, + }) + require.Equal(expectedString, table.String()) +} + +type indexKeyValue struct { + key sql.Row + value *indexValue +} + +type dummyLookup struct { + values map[string][]*indexValue +} + +func (dummyLookup) Indexes() []string { return nil } + +func (i *dummyLookup) Values(partition sql.Partition) (sql.IndexValueIter, error) { + key := string(partition.Key()) + values, ok := i.values[key] + if !ok { + return nil, fmt.Errorf("wrong partition key %q", key) + } + + return &dummyLookupIter{values: values}, nil +} + +type dummyLookupIter struct { + values []*indexValue + pos int +} + +var _ sql.IndexValueIter = (*dummyLookupIter)(nil) + +func (i *dummyLookupIter) Next() ([]byte, error) { + if i.pos >= len(i.values) { + return nil, io.EOF + } + + value := i.values[i.pos] + i.pos++ + return encodeIndexValue(value) +} + +func (i *dummyLookupIter) Close() error { return nil } + +var tests = []struct { + name string + schema sql.Schema + numPartitions int + rows []sql.Row + + filters []sql.Expression + expectedFiltered []sql.Row + + columns []string + expectedSchema sql.Schema + expectedProjected []sql.Row + + expectedFiltersAndProjections []sql.Row + + indexColumns []string + expectedKeyValues []*indexKeyValue + + lookup *dummyLookup + partition *partition + expectedIndexed []sql.Row +}{ + { + name: "test", + schema: sql.Schema{ + &sql.Column{Name: "col1", Source: "test", Type: sql.Text, Nullable: false, Default: ""}, + &sql.Column{Name: "col2", Source: "test", Type: sql.Int32, Nullable: false, Default: int32(0)}, + &sql.Column{Name: "col3", Source: "test", Type: sql.Int64, Nullable: false, Default: int64(0)}, + }, + numPartitions: 2, + rows: []sql.Row{ + sql.NewRow("a", int32(10), int64(100)), + sql.NewRow("b", int32(10), int64(100)), + sql.NewRow("c", int32(20), int64(100)), + sql.NewRow("d", int32(20), int64(200)), + sql.NewRow("e", int32(10), int64(200)), + sql.NewRow("f", int32(20), int64(200)), + }, + filters: []sql.Expression{ + expression.NewEquals( + expression.NewGetFieldWithTable(1, sql.Int32, "test", "col2", false), + expression.NewLiteral(int32(10), sql.Int32), + ), + }, + expectedFiltered: []sql.Row{ + sql.NewRow("a", int32(10), int64(100)), + sql.NewRow("b", int32(10), int64(100)), + sql.NewRow("e", int32(10), int64(200)), + }, + columns: []string{"col3", "col1"}, + expectedSchema: sql.Schema{ + &sql.Column{Name: "col3", Source: "test", Type: sql.Int64, Nullable: false, Default: int64(0)}, + &sql.Column{Name: "col1", Source: "test", Type: sql.Text, Nullable: false, Default: ""}, + }, + expectedProjected: []sql.Row{ + sql.NewRow(int64(100), "a"), + sql.NewRow(int64(100), "b"), + sql.NewRow(int64(100), "c"), + sql.NewRow(int64(200), "d"), + sql.NewRow(int64(200), "e"), + sql.NewRow(int64(200), "f"), + }, + expectedFiltersAndProjections: []sql.Row{ + sql.NewRow(int64(100), "a"), + sql.NewRow(int64(100), "b"), + sql.NewRow(int64(200), "e"), + }, + indexColumns: []string{"col1", "col3"}, + expectedKeyValues: []*indexKeyValue{ + {sql.NewRow("a", int64(100)), &indexValue{Key: "0", Pos: 0}}, + {sql.NewRow("c", int64(100)), &indexValue{Key: "0", Pos: 1}}, + {sql.NewRow("e", int64(200)), &indexValue{Key: "0", Pos: 2}}, + {sql.NewRow("b", int64(100)), &indexValue{Key: "1", Pos: 0}}, + {sql.NewRow("d", int64(200)), &indexValue{Key: "1", Pos: 1}}, + {sql.NewRow("f", int64(200)), &indexValue{Key: "1", Pos: 2}}, + }, + lookup: &dummyLookup{ + values: map[string][]*indexValue{ + "0": { + {Key: "0", Pos: 0}, + {Key: "0", Pos: 1}, + {Key: "0", Pos: 2}, + }, + "1": { + {Key: "1", Pos: 0}, + {Key: "1", Pos: 1}, + {Key: "1", Pos: 2}, + }, + }, + }, + partition: &partition{key: []byte("0")}, + expectedIndexed: []sql.Row{ + sql.NewRow(int64(100), "a"), + sql.NewRow(int64(200), "e"), + }, + }, +} + +func TestTable(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var require = require.New(t) + + table := NewPartitionedTable(test.name, test.schema, test.numPartitions) + for _, row := range test.rows { + require.NoError(table.Insert(sql.NewEmptyContext(), row)) + } + + pIter, err := table.Partitions(sql.NewEmptyContext()) + require.NoError(err) + + for i := 0; i < test.numPartitions; i++ { + var p sql.Partition + p, err = pIter.Next() + require.NoError(err) + + var iter sql.RowIter + iter, err = table.PartitionRows(sql.NewEmptyContext(), p) + require.NoError(err) + + var rows []sql.Row + rows, err = sql.RowIterToRows(iter) + require.NoError(err) + + expected := table.partitions[string(p.Key())] + require.Len(rows, len(expected)) + + for i, row := range rows { + require.ElementsMatch(expected[i], row) + } + } + + _, err = pIter.Next() + require.EqualError(err, io.EOF.Error()) + + }) + } +} + +func TestFiltered(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var require = require.New(t) + + table := NewPartitionedTable(test.name, test.schema, test.numPartitions) + for _, row := range test.rows { + require.NoError(table.Insert(sql.NewEmptyContext(), row)) + } + + filtered := table.WithFilters(test.filters) + + filteredRows := testFlatRows(t, filtered) + require.Len(filteredRows, len(test.expectedFiltered)) + for _, row := range filteredRows { + require.Contains(test.expectedFiltered, row) + } + + }) + } +} + +func TestProjected(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var require = require.New(t) + + table := NewPartitionedTable(test.name, test.schema, test.numPartitions) + for _, row := range test.rows { + require.NoError(table.Insert(sql.NewEmptyContext(), row)) + } + + projected := table.WithProjection(test.columns) + require.ElementsMatch(projected.Schema(), test.expectedSchema) + + projectedRows := testFlatRows(t, projected) + require.Len(projectedRows, len(test.expectedProjected)) + for _, row := range projectedRows { + require.Contains(test.expectedProjected, row) + } + }) + } +} + +func TestFilterAndProject(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var require = require.New(t) + + table := NewPartitionedTable(test.name, test.schema, test.numPartitions) + for _, row := range test.rows { + require.NoError(table.Insert(sql.NewEmptyContext(), row)) + } + + filtered := table.WithFilters(test.filters) + projected := filtered.(*Table).WithProjection(test.columns) + require.ElementsMatch(projected.Schema(), test.expectedSchema) + + rows := testFlatRows(t, projected) + require.Len(rows, len(test.expectedFiltersAndProjections)) + for _, row := range rows { + require.Contains(test.expectedFiltersAndProjections, row) + } + }) + } +} + +func TestIndexed(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var require = require.New(t) + + table := NewPartitionedTable(test.name, test.schema, test.numPartitions) + for _, row := range test.rows { + require.NoError(table.Insert(sql.NewEmptyContext(), row)) + } + + filtered := table.WithFilters(test.filters) + projected := filtered.(*Table).WithProjection(test.columns) + indexed := projected.(*Table).WithIndexLookup(test.lookup) + + require.ElementsMatch(indexed.Schema(), test.expectedSchema) + + iter, err := indexed.PartitionRows(sql.NewEmptyContext(), test.partition) + require.NoError(err) + + rows, err := sql.RowIterToRows(iter) + require.NoError(err) + + require.Len(rows, len(test.expectedIndexed)) + for _, row := range rows { + require.Contains(test.expectedIndexed, row) + } + }) + } +} + +func testFlatRows(t *testing.T, table sql.Table) []sql.Row { + var require = require.New(t) + + pIter, err := table.Partitions(sql.NewEmptyContext()) + require.NoError(err) + flatRows := []sql.Row{} + for { + p, err := pIter.Next() + if err != nil { + if err == io.EOF { + break + } + + require.NoError(err) + } + + iter, err := table.PartitionRows(sql.NewEmptyContext(), p) + require.NoError(err) + + rows, err := sql.RowIterToRows(iter) + require.NoError(err) + + flatRows = append(flatRows, rows...) + + } + + return flatRows +} + +func TestTableIndexKeyValueIter(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var require = require.New(t) + + table := NewPartitionedTable(test.name, test.schema, test.numPartitions) + for _, row := range test.rows { + require.NoError(table.Insert(sql.NewEmptyContext(), row)) + } + + pIter, err := table.IndexKeyValues( + sql.NewEmptyContext(), + []string{test.schema[0].Name, test.schema[2].Name}, + ) + require.NoError(err) + + var iter sql.IndexKeyValueIter + idxKVs := []*indexKeyValue{} + for { + if iter == nil { + _, iter, err = pIter.Next() + if err != nil { + if err == io.EOF { + iter = nil + break + } + + require.NoError(err) + } + } + + row, data, err := iter.Next() + if err != nil { + if err == io.EOF { + iter = nil + continue + } + + require.NoError(err) + } + + value, err := decodeIndexValue(data) + require.NoError(err) + + idxKVs = append(idxKVs, &indexKeyValue{key: row, value: value}) + } + + require.Len(idxKVs, len(test.expectedKeyValues)) + for i, e := range test.expectedKeyValues { + require.Equal(e, idxKVs[i]) + } + }) + } +} diff --git a/server/context.go b/server/context.go index 2f982d7aa..6ee2ee508 100644 --- a/server/context.go +++ b/server/context.go @@ -4,86 +4,104 @@ import ( "context" "sync" - opentracing "github.com/opentracing/opentracing-go" - uuid "github.com/satori/go.uuid" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-vitess.v0/mysql" + "github.com/opentracing/opentracing-go" + "github.com/src-d/go-mysql-server/sql" + "vitess.io/vitess/go/mysql" ) -// SessionBuilder creates sessions given a context and a MySQL connection. -type SessionBuilder func(*mysql.Conn) sql.Session +// SessionBuilder creates sessions given a MySQL connection and a server address. +type SessionBuilder func(conn *mysql.Conn, addr string) sql.Session // DoneFunc is a function that must be executed when the session is used and // it can be disposed. type DoneFunc func() // DefaultSessionBuilder is a SessionBuilder that returns a base session. -func DefaultSessionBuilder(_ *mysql.Conn) sql.Session { - return sql.NewBaseSession() +func DefaultSessionBuilder(c *mysql.Conn, addr string) sql.Session { + client := c.RemoteAddr().String() + return sql.NewSession(addr, client, c.User, c.ConnectionID) } // SessionManager is in charge of creating new sessions for the given // connections and keep track of which sessions are in each connection, so -// they can be cancelled is the connection is closed. +// they can be cancelled if the connection is closed. type SessionManager struct { - tracer opentracing.Tracer - mu *sync.Mutex - builder SessionBuilder - sessions map[uint32]sql.Session - sessionContexts map[uint32][]uuid.UUID - contexts map[uuid.UUID]context.CancelFunc + addr string + tracer opentracing.Tracer + memory *sql.MemoryManager + mu *sync.Mutex + builder SessionBuilder + sessions map[uint32]sql.Session + pid uint64 } -// NewSessionManager creates a SessionManager with the given ContextBuilder. -func NewSessionManager(builder SessionBuilder, tracer opentracing.Tracer) *SessionManager { +// NewSessionManager creates a SessionManager with the given SessionBuilder. +func NewSessionManager( + builder SessionBuilder, + tracer opentracing.Tracer, + memory *sql.MemoryManager, + addr string, +) *SessionManager { return &SessionManager{ - tracer: tracer, - mu: new(sync.Mutex), - builder: builder, - sessions: make(map[uint32]sql.Session), - sessionContexts: make(map[uint32][]uuid.UUID), - contexts: make(map[uuid.UUID]context.CancelFunc), + addr: addr, + tracer: tracer, + memory: memory, + mu: new(sync.Mutex), + builder: builder, + sessions: make(map[uint32]sql.Session), } } -// NewSession creates a Session for the given connection. +func (s *SessionManager) nextPid() uint64 { + s.mu.Lock() + defer s.mu.Unlock() + s.pid++ + return s.pid +} + +// NewSession creates a Session for the given connection and saves it to +// session pool. func (s *SessionManager) NewSession(conn *mysql.Conn) { s.mu.Lock() - s.sessions[conn.ConnectionID] = s.builder(conn) + s.sessions[conn.ConnectionID] = s.builder(conn, s.addr) s.mu.Unlock() } -// NewContext creates a new context for the session at the given conn. -func (s *SessionManager) NewContext(conn *mysql.Conn) (*sql.Context, DoneFunc, error) { - ctx, cancel := context.WithCancel(context.Background()) +func (s *SessionManager) session(conn *mysql.Conn) sql.Session { s.mu.Lock() - sess := s.sessions[conn.ConnectionID] - s.mu.Unlock() - context := sql.NewContext(ctx, sql.WithSession(sess), sql.WithTracer(s.tracer)) - id, err := uuid.NewV4() - if err != nil { - cancel() - return nil, nil, err - } + defer s.mu.Unlock() + return s.sessions[conn.ConnectionID] +} + +// NewContext creates a new context for the session at the given conn. +func (s *SessionManager) NewContext(conn *mysql.Conn) *sql.Context { + return s.NewContextWithQuery(conn, "") +} +// NewContextWithQuery creates a new context for the session at the given conn. +func (s *SessionManager) NewContextWithQuery( + conn *mysql.Conn, + query string, +) *sql.Context { s.mu.Lock() - s.sessionContexts[conn.ConnectionID] = append(s.sessionContexts[conn.ConnectionID], id) - s.contexts[id] = cancel + sess, ok := s.sessions[conn.ConnectionID] + if !ok { + sess = s.builder(conn, s.addr) + s.sessions[conn.ConnectionID] = sess + } s.mu.Unlock() - return context, func() { - s.mu.Lock() - defer s.mu.Unlock() + context := sql.NewContext( + context.Background(), + sql.WithSession(sess), + sql.WithTracer(s.tracer), + sql.WithPid(s.nextPid()), + sql.WithQuery(query), + sql.WithMemoryManager(s.memory), + sql.WithRootSpan(s.tracer.StartSpan("query")), + ) - delete(s.contexts, id) - ids := s.sessionContexts[conn.ConnectionID] - for i, sessID := range ids { - if sessID == id { - s.sessionContexts[conn.ConnectionID] = append(ids[:i], ids[i+1:]...) - break - } - } - }, nil + return context } // CloseConn closes the connection in the session manager and all its @@ -91,10 +109,5 @@ func (s *SessionManager) NewContext(conn *mysql.Conn) (*sql.Context, DoneFunc, e func (s *SessionManager) CloseConn(conn *mysql.Conn) { s.mu.Lock() defer s.mu.Unlock() - - for _, id := range s.sessionContexts[conn.ConnectionID] { - s.contexts[id]() - delete(s.contexts, id) - } - delete(s.sessionContexts, conn.ConnectionID) + delete(s.sessions, conn.ConnectionID) } diff --git a/server/handler.go b/server/handler.go index c6ef0c187..ed71a5f18 100644 --- a/server/handler.go +++ b/server/handler.go @@ -1,55 +1,90 @@ package server import ( + "context" "io" + "net" "regexp" "strconv" "strings" "sync" + "time" - errors "gopkg.in/src-d/go-errors.v1" - "gopkg.in/src-d/go-mysql-server.v0" - "gopkg.in/src-d/go-mysql-server.v0/sql" + sqle "github.com/src-d/go-mysql-server" + "github.com/src-d/go-mysql-server/auth" + "github.com/src-d/go-mysql-server/internal/sockstate" + "github.com/src-d/go-mysql-server/sql" + "gopkg.in/src-d/go-errors.v1" "github.com/sirupsen/logrus" - "gopkg.in/src-d/go-vitess.v0/mysql" - "gopkg.in/src-d/go-vitess.v0/sqltypes" - "gopkg.in/src-d/go-vitess.v0/vt/proto/query" + "vitess.io/vitess/go/mysql" + "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/proto/query" ) var regKillCmd = regexp.MustCompile(`^kill (?:(query|connection) )?(\d+)$`) -var errConnectionNotFound = errors.NewKind("Connection not found: %c") +var errConnectionNotFound = errors.NewKind("connection not found: %c") + +// ErrRowTimeout will be returned if the wait for the row is longer than the connection timeout +var ErrRowTimeout = errors.NewKind("row read wait bigger than connection timeout") + +// ErrConnectionWasClosed will be returned if we try to use a previously closed connection +var ErrConnectionWasClosed = errors.NewKind("connection was closed") // TODO parametrize const rowsBatch = 100 +const tcpCheckerSleepTime = 1 + +type conntainer struct { + MysqlConn *mysql.Conn + NetConn net.Conn +} // Handler is a connection handler for a SQLe engine. type Handler struct { - mu sync.Mutex - e *sqle.Engine - sm *SessionManager - c map[uint32]*mysql.Conn + mu sync.Mutex + e *sqle.Engine + sm *SessionManager + c map[uint32]conntainer + readTimeout time.Duration + lc []*net.Conn } // NewHandler creates a new Handler given a SQLe engine. -func NewHandler(e *sqle.Engine, sm *SessionManager) *Handler { +func NewHandler(e *sqle.Engine, sm *SessionManager, rt time.Duration) *Handler { return &Handler{ - e: e, - sm: sm, - c: make(map[uint32]*mysql.Conn), + e: e, + sm: sm, + c: make(map[uint32]conntainer), + readTimeout: rt, } } +// AddNetConnection is used to add the net.Conn to the Handler when available (usually on the +// Listener.Accept() method) +func (h *Handler) AddNetConnection(c *net.Conn) { + h.lc = append(h.lc, c) +} + // NewConnection reports that a new connection has been established. func (h *Handler) NewConnection(c *mysql.Conn) { h.mu.Lock() if _, ok := h.c[c.ConnectionID]; !ok { - h.c[c.ConnectionID] = c + // Retrieve the latest net.Conn stored by Listener.Accept(), if called, and remove it + var netConn net.Conn + if len(h.lc) > 0 { + netConn = *h.lc[len(h.lc)-1] + h.lc = h.lc[:len(h.lc)-1] + } else { + logrus.Debug("Could not find TCP socket connection after Accept(), " + + "connection checker won't run") + } + h.c[c.ConnectionID] = conntainer{c, netConn} } + h.mu.Unlock() - h.sm.NewSession(c) logrus.Infof("NewConnection: client %v", c.ConnectionID) } @@ -61,6 +96,13 @@ func (h *Handler) ConnectionClosed(c *mysql.Conn) { delete(h.c, c.ConnectionID) h.mu.Unlock() + // If connection was closed, kill only its associated queries. + h.e.Catalog.ProcessList.KillOnlyQueries(c.ConnectionID) + + if err := h.e.Catalog.UnlockTables(nil, c.ConnectionID); err != nil { + logrus.Errorf("unable to unlock tables on session close: %s", err) + } + logrus.Infof("ConnectionClosed: client %v", c.ConnectionID) } @@ -69,30 +111,131 @@ func (h *Handler) ComQuery( c *mysql.Conn, query string, callback func(*sqltypes.Result) error, -) error { - ctx, done, err := h.sm.NewContext(c) - if err != nil { - return err - } +) (err error) { + ctx := h.sm.NewContextWithQuery(c, query) + + if !h.e.Async(ctx, query) { + newCtx, cancel := context.WithCancel(ctx) + ctx = ctx.WithContext(newCtx) - defer done() + defer cancel() + } - handled, err := h.handleKill(query) + handled, err := h.handleKill(c, query) if err != nil { return err } if handled { - return nil + return callback(&sqltypes.Result{}) } + start := time.Now() schema, rows, err := h.e.Query(ctx, query) + defer func() { + if q, ok := h.e.Auth.(*auth.Audit); ok { + q.Query(ctx, time.Since(start), err) + } + }() if err != nil { return err } + nc, ok := h.c[c.ConnectionID] + if !ok { + return ErrConnectionWasClosed.New() + } + var r *sqltypes.Result var proccesedAtLeastOneBatch bool + + // Reads rows from the row reading goroutine + rowChan := make(chan sql.Row) + // To send errors from the two goroutines to the main one + errChan := make(chan error) + // To close the goroutines + quit := make(chan struct{}) + + // Default waitTime is one minute if there is not timeout configured, in which case + // it will loop to iterate again unless the socket died by the OS timeout or other problems. + // If there is a timeout, it will be enforced to ensure that Vitess has a chance to + // call Handler.CloseConnection() + waitTime := 1 * time.Minute + + if h.readTimeout > 0 { + waitTime = h.readTimeout + } + timer := time.NewTimer(waitTime) + defer timer.Stop() + + // This goroutine will be select{}ed giving a chance to Vitess to call the + // handler.CloseConnection callback and enforcing the timeout if configured + go func() { + for { + select { + case <-quit: + return + default: + row, err := rows.Next() + if err != nil { + errChan <- err + return + } + rowChan <- row + } + } + }() + + // This second goroutine will check the socket + // and try to determine if the socket is in CLOSE_WAIT state + // (because the remote client closed the connection). + go func() { + tcpConn, ok := nc.NetConn.(*net.TCPConn) + if !ok { + logrus.Debug("Connection checker exiting, connection isn't TCP") + return + } + + inode, err := sockstate.GetConnInode(tcpConn) + if err != nil || inode == 0 { + if sockstate.ErrSocketCheckNotImplemented.Is(err) { + logrus.Warn("Connection checker exiting, not supported in this OS") + } else { + errChan <- err + } + return + } + + t, ok := nc.NetConn.LocalAddr().(*net.TCPAddr) + if !ok { + logrus.Warn("Connection checker exiting, could not get local port") + return + } + + for { + select { + case <-quit: + return + default: + } + + st, err := sockstate.GetInodeSockState(t.Port, inode) + switch st { + case sockstate.Broken: + errChan <- ErrConnectionWasClosed.New() + return + case sockstate.Error: + errChan <- err + return + default: // Established + // (juanjux) this check is not free, each iteration takes about 9 milliseconds to run on my machine + // thus the small wait between checks + time.Sleep(tcpCheckerSleepTime * time.Second) + } + } + }() + +rowLoop: for { if r == nil { r = &sqltypes.Result{Fields: schemaToFields(schema)} @@ -100,26 +243,44 @@ func (h *Handler) ComQuery( if r.RowsAffected == rowsBatch { if err := callback(r); err != nil { + close(quit) return err } r = nil proccesedAtLeastOneBatch = true - continue } - row, err := rows.Next() - if err != nil { + select { + case err = <-errChan: if err == io.EOF { - break + break rowLoop } - + close(quit) return err + case row := <-rowChan: + outputRow, err := rowToSQL(schema, row) + if err != nil { + close(quit) + return err + } + + r.Rows = append(r.Rows, outputRow) + r.RowsAffected++ + case <-timer.C: + if h.readTimeout != 0 { + // Cancel and return so Vitess can call the CloseConnection callback + close(quit) + return ErrRowTimeout.New() + } } + timer.Reset(waitTime) + } + close(quit) - r.Rows = append(r.Rows, rowToSQL(schema, row)) - r.RowsAffected++ + if err := rows.Close(); err != nil { + return err } // Even if r.RowsAffected = 0, the callback must be @@ -133,60 +294,90 @@ func (h *Handler) ComQuery( return callback(r) } -func (h *Handler) handleKill(query string) (bool, error) { +// WarningCount is called at the end of each query to obtain +// the value to be returned to the client in the EOF packet. +// Note that this will be called either in the context of the +// ComQuery callback if the result does not contain any fields, +// or after the last ComQuery call completes. +func (h *Handler) WarningCount(c *mysql.Conn) uint16 { + if sess := h.sm.session(c); sess != nil { + return sess.WarningCount() + } + + return 0 +} + +func (h *Handler) handleKill(conn *mysql.Conn, query string) (bool, error) { q := strings.ToLower(query) s := regKillCmd.FindStringSubmatch(q) if s == nil { return false, nil } - id, err := strconv.Atoi(s[2]) + id, err := strconv.ParseUint(s[2], 10, 32) if err != nil { return false, err } - logrus.Infof("handleKill: id %v", id) - - h.mu.Lock() - c, ok := h.c[uint32(id)] - h.mu.Unlock() - if !ok { - return false, errConnectionNotFound.New(id) - } - - h.sm.CloseConn(c) - // KILL CONNECTION and KILL should close the connection. KILL QUERY only // cancels the query. // - // https://dev.mysql.com/doc/refman/5.7/en/kill.html - + // https://dev.mysql.com/doc/refman/8.0/en/kill.html + // + // KILL [CONNECTION | QUERY] processlist_id + // - KILL QUERY terminates the statement the connection is currently executing, + // but leaves the connection itself intact. + + // - KILL CONNECTION is the same as KILL with no modifier: + // It terminates the connection associated with the given processlist_id, + // after terminating any statement the connection is executing. + connID := uint32(id) + h.e.Catalog.Kill(connID) if s[1] != "query" { - c.Close() + logrus.Infof("kill connection: id %d", connID) h.mu.Lock() - delete(h.c, uint32(id)) + c, ok := h.c[connID] + if ok { + delete(h.c, connID) + } h.mu.Unlock() + if !ok { + return false, errConnectionNotFound.New(connID) + } + + h.sm.CloseConn(c.MysqlConn) + c.MysqlConn.Close() } return true, nil } -func rowToSQL(s sql.Schema, row sql.Row) []sqltypes.Value { +func rowToSQL(s sql.Schema, row sql.Row) ([]sqltypes.Value, error) { o := make([]sqltypes.Value, len(row)) + var err error for i, v := range row { - o[i] = s[i].Type.SQL(v) + o[i], err = s[i].Type.SQL(v) + if err != nil { + return nil, err + } } - return o + return o, nil } func schemaToFields(s sql.Schema) []*query.Field { fields := make([]*query.Field, len(s)) for i, c := range s { + var charset uint32 = mysql.CharacterSetUtf8 + if c.Type == sql.Blob { + charset = mysql.CharacterSetBinary + } + fields[i] = &query.Field{ - Name: c.Name, - Type: c.Type.Type(), + Name: c.Name, + Type: c.Type.Type(), + Charset: charset, } } diff --git a/server/handler_linux_test.go b/server/handler_linux_test.go new file mode 100644 index 000000000..9dfc6eb31 --- /dev/null +++ b/server/handler_linux_test.go @@ -0,0 +1,52 @@ +package server + +import ( + "fmt" + "net" + "testing" + + "github.com/opentracing/opentracing-go" + "github.com/src-d/go-mysql-server/sql" + "github.com/stretchr/testify/require" + "vitess.io/vitess/go/sqltypes" +) + +func TestBrokenConnection(t *testing.T) { + require := require.New(t) + e := setupMemDB(require) + + port, err := getFreePort() + require.NoError(err) + + ready := make(chan struct{}) + go brokenTestServer(t, ready, port) + <-ready + conn, err := net.Dial("tcp", "localhost:"+port) + require.NoError(err) + + h := NewHandler( + e, + NewSessionManager( + testSessionBuilder, + opentracing.NoopTracer{}, + sql.NewMemoryManager(nil), + "foo", + ), + 0, + ) + h.AddNetConnection(&conn) + c := newConn(1) + h.NewConnection(c) + + // (juanjux) Note that this is a little fuzzy because sometimes sockets take one or two seconds + // to go into TIME_WAIT but 4 seconds hopefully is enough + wait := tcpCheckerSleepTime * 2 + if wait < 4 { + wait = 4 + } + q := fmt.Sprintf("SELECT SLEEP(%d)", wait) + err = h.ComQuery(c, q, func(res *sqltypes.Result) error { + return nil + }) + require.EqualError(err, "connection was closed") +} diff --git a/server/handler_test.go b/server/handler_test.go index 517fe0cff..fb074cafa 100644 --- a/server/handler_test.go +++ b/server/handler_test.go @@ -1,42 +1,41 @@ package server import ( + "fmt" + "net" "testing" + "time" - "gopkg.in/src-d/go-mysql-server.v0" - "gopkg.in/src-d/go-mysql-server.v0/mem" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-vitess.v0/mysql" - "gopkg.in/src-d/go-vitess.v0/sqltypes" + sqle "github.com/src-d/go-mysql-server" + "github.com/src-d/go-mysql-server/sql" - opentracing "github.com/opentracing/opentracing-go" + "vitess.io/vitess/go/mysql" + "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/proto/query" + + "github.com/opentracing/opentracing-go" "github.com/stretchr/testify/require" ) -func setupMemDB(require *require.Assertions) *sqle.Engine { - e := sqle.New() - db := mem.NewDatabase("test") - e.AddDatabase(db) - - tableTest := mem.NewTable("test", sql.Schema{{Name: "c1", Type: sql.Int32, Source: "test"}}) - - for i := 0; i < 1010; i++ { - require.NoError(tableTest.Insert(sql.NewRow(int32(i)))) - } - - db.AddTable("test", tableTest) - - return e -} - func TestHandlerOutput(t *testing.T) { + e := setupMemDB(require.New(t)) dummyConn := &mysql.Conn{ConnectionID: 1} - handler := NewHandler(e, NewSessionManager(DefaultSessionBuilder, opentracing.NoopTracer{})) + handler := NewHandler( + e, + NewSessionManager( + testSessionBuilder, + opentracing.NoopTracer{}, + sql.NewMemoryManager(nil), + "foo", + ), + 0, + ) + handler.NewConnection(dummyConn) - type exptectedValues struct { + type expectedValues struct { callsToCallback int - lenLastBacth int + lenLastBatch int lastRowsAffected uint64 } @@ -45,16 +44,16 @@ func TestHandlerOutput(t *testing.T) { handler *Handler conn *mysql.Conn query string - expected exptectedValues + expected expectedValues }{ { name: "select all without limit", handler: handler, conn: dummyConn, query: "SELECT * FROM test", - expected: exptectedValues{ + expected: expectedValues{ callsToCallback: 11, - lenLastBacth: 10, + lenLastBatch: 10, lastRowsAffected: uint64(10), }, }, @@ -63,9 +62,9 @@ func TestHandlerOutput(t *testing.T) { handler: handler, conn: dummyConn, query: "SELECT * FROM test limit 100", - expected: exptectedValues{ + expected: expectedValues{ callsToCallback: 1, - lenLastBacth: 100, + lenLastBatch: 100, lastRowsAffected: uint64(100), }, }, @@ -74,9 +73,9 @@ func TestHandlerOutput(t *testing.T) { handler: handler, conn: dummyConn, query: "SELECT * FROM test limit 60", - expected: exptectedValues{ + expected: expectedValues{ callsToCallback: 1, - lenLastBacth: 60, + lenLastBatch: 60, lastRowsAffected: uint64(60), }, }, @@ -85,9 +84,9 @@ func TestHandlerOutput(t *testing.T) { handler: handler, conn: dummyConn, query: "SELECT * FROM test limit 200", - expected: exptectedValues{ + expected: expectedValues{ callsToCallback: 2, - lenLastBacth: 100, + lenLastBatch: 100, lastRowsAffected: uint64(100), }, }, @@ -96,9 +95,9 @@ func TestHandlerOutput(t *testing.T) { handler: handler, conn: dummyConn, query: "SELECT * FROM test limit 530", - expected: exptectedValues{ + expected: expectedValues{ callsToCallback: 6, - lenLastBacth: 30, + lenLastBatch: 30, lastRowsAffected: uint64(30), }, }, @@ -107,51 +106,54 @@ func TestHandlerOutput(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { var callsToCallback int - var lenLastBacth int + var lenLastBatch int var lastRowsAffected uint64 err := handler.ComQuery(test.conn, test.query, func(res *sqltypes.Result) error { callsToCallback++ - lenLastBacth = len(res.Rows) + lenLastBatch = len(res.Rows) lastRowsAffected = res.RowsAffected return nil }) require.NoError(t, err) require.Equal(t, test.expected.callsToCallback, callsToCallback) - require.Equal(t, test.expected.lenLastBacth, lenLastBacth) + require.Equal(t, test.expected.lenLastBatch, lenLastBatch) require.Equal(t, test.expected.lastRowsAffected, lastRowsAffected) }) } } -func newConn(id uint32) *mysql.Conn { - return &mysql.Conn{ - ConnectionID: id, - } -} - func TestHandlerKill(t *testing.T) { require := require.New(t) e := setupMemDB(require) - handler := NewHandler(e, - NewSessionManager(func(conn *mysql.Conn) sql.Session { - return sql.NewBaseSession() - }, opentracing.NoopTracer{})) + handler := NewHandler( + e, + NewSessionManager( + func(conn *mysql.Conn, addr string) sql.Session { + return sql.NewSession(addr, "", "", conn.ConnectionID) + }, + opentracing.NoopTracer{}, + sql.NewMemoryManager(nil), + "foo", + ), + 0, + ) require.Len(handler.c, 0) + var dummyNetConn net.Conn conn1 := newConn(1) - + conntainer1 := conntainer{conn1, dummyNetConn} handler.NewConnection(conn1) - require.Len(handler.c, 1) - c, ok := handler.c[1] - require.True(ok) - require.Equal(conn1, c) - conn2 := newConn(2) + conntainer2 := conntainer{conn2, dummyNetConn} + handler.NewConnection(conn2) + + require.Len(handler.sm.sessions, 0) + require.Len(handler.c, 2) err := handler.ComQuery(conn2, "KILL QUERY 1", func(res *sqltypes.Result) error { return nil @@ -159,11 +161,133 @@ func TestHandlerKill(t *testing.T) { require.NoError(err) + require.Len(handler.sm.sessions, 1) + require.Len(handler.c, 2) + require.Equal(conntainer1, handler.c[1]) + require.Equal(conntainer2, handler.c[2]) + assertNoConnProcesses(t, e, conn2.ConnectionID) + + ctx1 := handler.sm.NewContextWithQuery(conn1, "SELECT 1") + ctx1, err = handler.e.Catalog.AddProcess(ctx1, sql.QueryProcess, "SELECT 1") + require.NoError(err) + + err = handler.ComQuery(conn2, "KILL "+fmt.Sprint(ctx1.ID()), func(res *sqltypes.Result) error { + return nil + }) + require.NoError(err) + + require.Len(handler.sm.sessions, 1) require.Len(handler.c, 1) - c, ok = handler.c[1] - require.True(ok) - require.Equal(conn1, c) + _, ok := handler.c[1] + require.False(ok) + assertNoConnProcesses(t, e, conn1.ConnectionID) +} + +func assertNoConnProcesses(t *testing.T, e *sqle.Engine, conn uint32) { + t.Helper() - // Cannot test KILL CONNECTION as the connection can not be mocked. Calling - // mysql.Conn.Close panics. + for _, p := range e.Catalog.Processes() { + if p.Connection == conn { + t.Errorf("expecting no processes with connection id %d", conn) + } + } +} + +func TestSchemaToFields(t *testing.T) { + require := require.New(t) + + schema := sql.Schema{ + {Name: "foo", Type: sql.Blob}, + {Name: "bar", Type: sql.Text}, + {Name: "baz", Type: sql.Int64}, + } + + expected := []*query.Field{ + {Name: "foo", Type: query.Type_BLOB, Charset: mysql.CharacterSetBinary}, + {Name: "bar", Type: query.Type_TEXT, Charset: mysql.CharacterSetUtf8}, + {Name: "baz", Type: query.Type_INT64, Charset: mysql.CharacterSetUtf8}, + } + + fields := schemaToFields(schema) + require.Equal(expected, fields) +} + +func TestHandlerTimeout(t *testing.T) { + require := require.New(t) + + e := setupMemDB(require) + e2 := setupMemDB(require) + + timeOutHandler := NewHandler( + e, NewSessionManager(testSessionBuilder, + opentracing.NoopTracer{}, + sql.NewMemoryManager(nil), + "foo"), + 1*time.Second) + + noTimeOutHandler := NewHandler( + e2, NewSessionManager(testSessionBuilder, + opentracing.NoopTracer{}, + sql.NewMemoryManager(nil), + "foo"), + 0) + require.Equal(1*time.Second, timeOutHandler.readTimeout) + require.Equal(0*time.Second, noTimeOutHandler.readTimeout) + + connTimeout := newConn(1) + timeOutHandler.NewConnection(connTimeout) + + connNoTimeout := newConn(2) + noTimeOutHandler.NewConnection(connNoTimeout) + + err := timeOutHandler.ComQuery(connTimeout, "SELECT SLEEP(2)", func(res *sqltypes.Result) error { + return nil + }) + require.EqualError(err, "row read wait bigger than connection timeout") + + err = timeOutHandler.ComQuery(connTimeout, "SELECT SLEEP(0.5)", func(res *sqltypes.Result) error { + return nil + }) + require.NoError(err) + + err = noTimeOutHandler.ComQuery(connNoTimeout, "SELECT SLEEP(2)", func(res *sqltypes.Result) error { + return nil + }) + require.NoError(err) +} + +func TestOkClosedConnection(t *testing.T) { + require := require.New(t) + e := setupMemDB(require) + port, err := getFreePort() + require.NoError(err) + + ready := make(chan struct{}) + go okTestServer(t, ready, port) + <-ready + conn, err := net.Dial("tcp", "localhost:"+port) + require.NoError(err) + defer func() { + _ = conn.Close() + }() + + h := NewHandler( + e, + NewSessionManager( + testSessionBuilder, + opentracing.NoopTracer{}, + sql.NewMemoryManager(nil), + "foo", + ), + 0, + ) + h.AddNetConnection(&conn) + c := newConn(1) + h.NewConnection(c) + + q := fmt.Sprintf("SELECT SLEEP(%d)", tcpCheckerSleepTime*4) + err = h.ComQuery(c, q, func(res *sqltypes.Result) error { + return nil + }) + require.NoError(err) } diff --git a/server/handler_test_common.go b/server/handler_test_common.go new file mode 100644 index 000000000..2abecff30 --- /dev/null +++ b/server/handler_test_common.go @@ -0,0 +1,108 @@ +package server + +import ( + "io/ioutil" + "net" + "reflect" + "strconv" + "testing" + "unsafe" + + sqle "github.com/src-d/go-mysql-server" + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/stretchr/testify/require" + "vitess.io/vitess/go/mysql" +) + +func setupMemDB(require *require.Assertions) *sqle.Engine { + e := sqle.NewDefault() + db := memory.NewDatabase("test") + e.AddDatabase(db) + + tableTest := memory.NewTable("test", sql.Schema{{Name: "c1", Type: sql.Int32, Source: "test"}}) + + for i := 0; i < 1010; i++ { + require.NoError(tableTest.Insert( + sql.NewEmptyContext(), + sql.NewRow(int32(i)), + )) + } + + db.AddTable("test", tableTest) + + return e +} + +func getFreePort() (string, error) { + addr, err := net.ResolveTCPAddr("tcp", "localhost:0") + if err != nil { + return "", err + } + + l, err := net.ListenTCP("tcp", addr) + if err != nil { + return "", err + } + defer l.Close() + return strconv.Itoa(l.Addr().(*net.TCPAddr).Port), nil +} + +func testServer(t *testing.T, ready chan struct{}, port string, breakConn bool) { + l, err := net.Listen("tcp", ":"+port) + defer func() { + _ = l.Close() + }() + if err != nil { + t.Fatal(err) + } + close(ready) + conn, err := l.Accept() + if err != nil { + return + } + + if !breakConn { + defer func() { + _ = conn.Close() + }() + + _, err = ioutil.ReadAll(conn) + if err != nil { + t.Fatal(err) + } + } // else: dirty return without closing or reading to force the socket into TIME_WAIT +} +func okTestServer(t *testing.T, ready chan struct{}, port string) { + testServer(t, ready, port, false) +} +func brokenTestServer(t *testing.T, ready chan struct{}, port string) { + testServer(t, ready, port, true) +} + +// This session builder is used as dummy mysql Conn is not complete and +// causes panic when accessing remote address. +func testSessionBuilder(c *mysql.Conn, addr string) sql.Session { + const client = "127.0.0.1:34567" + return sql.NewSession(addr, client, c.User, c.ConnectionID) +} + +type mockConn struct { + net.Conn +} + +func (c *mockConn) Close() error { return nil } + +func newConn(id uint32) *mysql.Conn { + conn := &mysql.Conn{ + ConnectionID: id, + } + + // Set conn so it does not panic when we close it + val := reflect.ValueOf(conn).Elem() + field := val.FieldByName("conn") + field = reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem() + field.Set(reflect.ValueOf(new(mockConn))) + + return conn +} diff --git a/server/listener.go b/server/listener.go new file mode 100644 index 000000000..f1bc018f1 --- /dev/null +++ b/server/listener.go @@ -0,0 +1,29 @@ +package server + +import ( + "net" +) + +type Listener struct { + net.Listener + h *Handler +} + +// NewListener creates a new Listener. +func NewListener(protocol, address string, handler *Handler) (*Listener, error) { + l, err := net.Listen(protocol, address) + if err != nil { + return nil, err + } + return &Listener{l, handler}, nil +} + +func (l *Listener) Accept() (net.Conn, error) { + conn, err := l.Listener.Accept() + if err != nil { + return nil, err + } + + l.h.AddNetConnection(&conn) + return conn, err +} diff --git a/server/server.go b/server/server.go index 7a4ed5101..15fdb45c1 100644 --- a/server/server.go +++ b/server/server.go @@ -1,15 +1,19 @@ package server import ( - opentracing "github.com/opentracing/opentracing-go" - "gopkg.in/src-d/go-mysql-server.v0" + "time" - "gopkg.in/src-d/go-vitess.v0/mysql" + "github.com/opentracing/opentracing-go" + sqle "github.com/src-d/go-mysql-server" + "github.com/src-d/go-mysql-server/auth" + + "vitess.io/vitess/go/mysql" ) // Server is a MySQL server for SQLe engines. type Server struct { Listener *mysql.Listener + h *Handler } // Config for the mysql server. @@ -19,10 +23,13 @@ type Config struct { // Address of the server. Address string // Auth of the server. - Auth mysql.AuthServer + Auth auth.Auth // Tracer to use in the server. By default, a noop tracer will be used if // no tracer is provided. Tracer opentracing.Tracer + + ConnReadTimeout time.Duration + ConnWriteTimeout time.Duration } // NewDefaultServer creates a Server with the default session builder. @@ -40,13 +47,31 @@ func NewServer(cfg Config, e *sqle.Engine, sb SessionBuilder) (*Server, error) { tracer = opentracing.NoopTracer{} } - handler := NewHandler(e, NewSessionManager(sb, tracer)) - l, err := mysql.NewListener(cfg.Protocol, cfg.Address, cfg.Auth, handler) + if cfg.ConnReadTimeout < 0 { + cfg.ConnReadTimeout = 0 + } + + if cfg.ConnWriteTimeout < 0 { + cfg.ConnWriteTimeout = 0 + } + + handler := NewHandler(e, + NewSessionManager( + sb, tracer, + e.Catalog.MemoryManager, + cfg.Address), + cfg.ConnReadTimeout) + a := cfg.Auth.Mysql() + l, err := NewListener(cfg.Protocol, cfg.Address, handler) + if err != nil { + return nil, err + } + vtListnr, err := mysql.NewFromListener(l, a, handler, cfg.ConnReadTimeout, cfg.ConnWriteTimeout) if err != nil { return nil, err } - return &Server{Listener: l}, nil + return &Server{Listener: vtListnr, h: handler}, nil } // Start starts accepting connections on the server. diff --git a/sql/analyzer/aggregations.go b/sql/analyzer/aggregations.go new file mode 100644 index 000000000..9afa5b49f --- /dev/null +++ b/sql/analyzer/aggregations.go @@ -0,0 +1,119 @@ +package analyzer + +import ( + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/plan" +) + +func reorderAggregations(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + span, _ := ctx.Span("reorder_aggregations") + defer span.Finish() + + if !n.Resolved() { + return n, nil + } + + a.Log("reorder aggregations, node of type: %T", n) + + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { + switch n := n.(type) { + case *plan.GroupBy: + if !hasHiddenAggregations(n.Aggregate...) { + return n, nil + } + + a.Log("fixing aggregations of node of type: %T", n) + + return fixAggregations(n.Aggregate, n.Grouping, n.Child) + default: + return n, nil + } + }) +} + +func fixAggregations(projection, grouping []sql.Expression, child sql.Node) (sql.Node, error) { + var aggregate = make([]sql.Expression, 0, len(projection)) + var newProjection = make([]sql.Expression, len(projection)) + + for i, p := range projection { + var transformed bool + e, err := expression.TransformUp(p, func(e sql.Expression) (sql.Expression, error) { + agg, ok := e.(sql.Aggregation) + if !ok { + return e, nil + } + + transformed = true + aggregate = append(aggregate, agg) + return expression.NewGetField( + len(aggregate)-1, agg.Type(), agg.String(), agg.IsNullable(), + ), nil + }) + if err != nil { + return nil, err + } + + if !transformed { + aggregate = append(aggregate, e) + name, source := getNameAndSource(e) + newProjection[i] = expression.NewGetFieldWithTable( + len(aggregate)-1, e.Type(), source, name, e.IsNullable(), + ) + } else { + newProjection[i] = e + } + } + + return plan.NewProject( + newProjection, + plan.NewGroupBy(aggregate, grouping, child), + ), nil +} + +func getNameAndSource(e sql.Expression) (name, source string) { + if n, ok := e.(sql.Nameable); ok { + name = n.Name() + } else { + name = e.String() + } + + if t, ok := e.(sql.Tableable); ok { + source = t.Table() + } + + return +} + +// hasHiddenAggregations reports whether any of the given expressions has a +// hidden aggregation. That is, an aggregation that is not at the root of the +// expression. +func hasHiddenAggregations(exprs ...sql.Expression) bool { + for _, e := range exprs { + if containsHiddenAggregation(e) { + return true + } + } + return false +} + +func containsHiddenAggregation(e sql.Expression) bool { + _, ok := e.(sql.Aggregation) + if ok { + return false + } + + return containsAggregation(e) +} + +func containsAggregation(e sql.Expression) bool { + var hasAgg bool + expression.Inspect(e, func(e sql.Expression) bool { + if _, ok := e.(sql.Aggregation); ok { + hasAgg = true + return false + } + return true + }) + return hasAgg +} diff --git a/sql/analyzer/aggregations_test.go b/sql/analyzer/aggregations_test.go new file mode 100644 index 000000000..b4333becd --- /dev/null +++ b/sql/analyzer/aggregations_test.go @@ -0,0 +1,128 @@ +package analyzer + +import ( + "testing" + + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/expression/function/aggregation" + "github.com/src-d/go-mysql-server/sql/plan" + "github.com/stretchr/testify/require" +) + +func TestReorderAggregations(t *testing.T) { + require := require.New(t) + + table := memory.NewTable("foo", sql.Schema{ + {Name: "a", Type: sql.Int64, Source: "foo"}, + {Name: "b", Type: sql.Int64, Source: "foo"}, + {Name: "c", Type: sql.Int64, Source: "foo"}, + }) + rule := getRule("reorder_aggregations") + + node := plan.NewGroupBy( + []sql.Expression{ + expression.NewArithmetic( + aggregation.NewSum( + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), + ), + expression.NewLiteral(int64(1), sql.Int64), + "+", + ), + }, + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), + }, + plan.NewResolvedTable(table), + ) + + expected := plan.NewProject( + []sql.Expression{ + expression.NewArithmetic( + expression.NewGetField(0, sql.Float64, "SUM(foo.a)", false), + expression.NewLiteral(int64(1), sql.Int64), + "+", + ), + }, + plan.NewGroupBy( + []sql.Expression{ + aggregation.NewSum( + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), + ), + }, + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), + }, + plan.NewResolvedTable(table), + ), + ) + + result, err := rule.Apply(sql.NewEmptyContext(), NewDefault(nil), node) + require.NoError(err) + + require.Equal(expected, result) +} + +func TestReorderAggregationsMultiple(t *testing.T) { + require := require.New(t) + + table := memory.NewTable("foo", sql.Schema{ + {Name: "a", Type: sql.Int64, Source: "foo"}, + {Name: "b", Type: sql.Int64, Source: "foo"}, + {Name: "c", Type: sql.Int64, Source: "foo"}, + }) + rule := getRule("reorder_aggregations") + + node := plan.NewGroupBy( + []sql.Expression{ + expression.NewArithmetic( + aggregation.NewSum( + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), + ), + aggregation.NewCount( + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), + ), + "/", + ), + expression.NewGetFieldWithTable(1, sql.Int64, "foo", "b", false), + }, + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), + expression.NewGetFieldWithTable(1, sql.Int64, "foo", "b", false), + }, + plan.NewResolvedTable(table), + ) + + expected := plan.NewProject( + []sql.Expression{ + expression.NewArithmetic( + expression.NewGetField(0, sql.Float64, "SUM(foo.a)", false), + expression.NewGetField(1, sql.Int64, "COUNT(foo.a)", false), + "/", + ), + expression.NewGetFieldWithTable(2, sql.Int64, "foo", "b", false), + }, + plan.NewGroupBy( + []sql.Expression{ + aggregation.NewSum( + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), + ), + aggregation.NewCount( + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), + ), + expression.NewGetFieldWithTable(1, sql.Int64, "foo", "b", false), + }, + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), + expression.NewGetFieldWithTable(1, sql.Int64, "foo", "b", false), + }, + plan.NewResolvedTable(table), + ), + ) + + result, err := rule.Apply(sql.NewEmptyContext(), NewDefault(nil), node) + require.NoError(err) + + require.Equal(expected, result) +} diff --git a/sql/analyzer/analyzer.go b/sql/analyzer/analyzer.go index 7c6c10935..f7bc456d5 100644 --- a/sql/analyzer/analyzer.go +++ b/sql/analyzer/analyzer.go @@ -2,76 +2,152 @@ package analyzer import ( "os" - "reflect" - multierror "github.com/hashicorp/go-multierror" + opentracing "github.com/opentracing/opentracing-go" "github.com/sirupsen/logrus" + "github.com/src-d/go-mysql-server/sql" "gopkg.in/src-d/go-errors.v1" - "gopkg.in/src-d/go-mysql-server.v0/sql" ) +const debugAnalyzerKey = "DEBUG_ANALYZER" + const maxAnalysisIterations = 1000 // ErrMaxAnalysisIters is thrown when the analysis iterations are exceeded var ErrMaxAnalysisIters = errors.NewKind("exceeded max analysis iterations (%d)") -// Analyzer analyzes nodes of the execution plan and applies rules and validations -// to them. -type Analyzer struct { - Debug bool - // Rules to apply. - Rules []Rule - // ValidationRules to apply. - ValidationRules []ValidationRule - // Catalog of databases and registered functions. - Catalog *sql.Catalog - // CurrentDatabase in use. - CurrentDatabase string +// Builder provides an easy way to generate Analyzer with custom rules and options. +type Builder struct { + preAnalyzeRules []Rule + postAnalyzeRules []Rule + preValidationRules []Rule + postValidationRules []Rule + catalog *sql.Catalog + debug bool + parallelism int } -// RuleFunc is the function to be applied in a rule. -type RuleFunc func(*sql.Context, *Analyzer, sql.Node) (sql.Node, error) +// NewBuilder creates a new Builder from a specific catalog. +// This builder allow us add custom Rules and modify some internal properties. +func NewBuilder(c *sql.Catalog) *Builder { + return &Builder{catalog: c} +} -// ValidationRuleFunc is the function to be used in a validation rule. -type ValidationRuleFunc func(*sql.Context, sql.Node) error +// WithDebug activates debug on the Analyzer. +func (ab *Builder) WithDebug() *Builder { + ab.debug = true -// Rule to transform nodes. -type Rule struct { - // Name of the rule. - Name string - // Apply transforms a node. - Apply RuleFunc + return ab } -// ValidationRule validates the given nodes. -type ValidationRule struct { - // Name of the rule. - Name string - // Apply validates the given node. - Apply ValidationRuleFunc +// WithParallelism sets the parallelism level on the analyzer. +func (ab *Builder) WithParallelism(parallelism int) *Builder { + ab.parallelism = parallelism + return ab } -const debugAnalyzerKey = "DEBUG_ANALYZER" +// AddPreAnalyzeRule adds a new rule to the analyze before the standard analyzer rules. +func (ab *Builder) AddPreAnalyzeRule(name string, fn RuleFunc) *Builder { + ab.preAnalyzeRules = append(ab.preAnalyzeRules, Rule{name, fn}) + + return ab +} + +// AddPostAnalyzeRule adds a new rule to the analyzer after standard analyzer rules. +func (ab *Builder) AddPostAnalyzeRule(name string, fn RuleFunc) *Builder { + ab.postAnalyzeRules = append(ab.postAnalyzeRules, Rule{name, fn}) + + return ab +} + +// AddPreValidationRule adds a new rule to the analyzer before standard validation rules. +func (ab *Builder) AddPreValidationRule(name string, fn RuleFunc) *Builder { + ab.preValidationRules = append(ab.preValidationRules, Rule{name, fn}) -// New returns a new Analyzer given a catalog. -func New(catalog *sql.Catalog) *Analyzer { + return ab +} + +// AddPostValidationRule adds a new rule to the analyzer after standard validation rules. +func (ab *Builder) AddPostValidationRule(name string, fn RuleFunc) *Builder { + ab.postValidationRules = append(ab.postValidationRules, Rule{name, fn}) + + return ab +} + +// Build creates a new Analyzer using all previous data setted to the Builder +func (ab *Builder) Build() *Analyzer { _, debug := os.LookupEnv(debugAnalyzerKey) + var batches = []*Batch{ + &Batch{ + Desc: "pre-analyzer rules", + Iterations: maxAnalysisIterations, + Rules: ab.preAnalyzeRules, + }, + &Batch{ + Desc: "once execution rule before default", + Iterations: 1, + Rules: OnceBeforeDefault, + }, + &Batch{ + Desc: "analyzer rules", + Iterations: maxAnalysisIterations, + Rules: DefaultRules, + }, + &Batch{ + Desc: "once execution rules after default", + Iterations: 1, + Rules: OnceAfterDefault, + }, + &Batch{ + Desc: "post-analyzer rules", + Iterations: maxAnalysisIterations, + Rules: ab.postAnalyzeRules, + }, + &Batch{ + Desc: "pre-validation rules", + Iterations: 1, + Rules: ab.preValidationRules, + }, + &Batch{ + Desc: "validation rules", + Iterations: 1, + Rules: DefaultValidationRules, + }, + &Batch{ + Desc: "post-validation rules", + Iterations: 1, + Rules: ab.postValidationRules, + }, + &Batch{ + Desc: "after-all rules", + Iterations: 1, + Rules: OnceAfterAll, + }, + } + return &Analyzer{ - Debug: debug, - Rules: DefaultRules, - ValidationRules: DefaultValidationRules, - Catalog: catalog, + Debug: debug || ab.debug, + Batches: batches, + Catalog: ab.catalog, + Parallelism: ab.parallelism, } } -// AddRule adds a new rule to the analyzer. -func (a *Analyzer) AddRule(name string, fn RuleFunc) { - a.Rules = append(a.Rules, Rule{name, fn}) +// Analyzer analyzes nodes of the execution plan and applies rules and validations +// to them. +type Analyzer struct { + Debug bool + Parallelism int + // Batches of Rules to apply. + Batches []*Batch + // Catalog of databases and registered functions. + Catalog *sql.Catalog } -// AddValidationRule adds a new validation rule to the analyzer. -func (a *Analyzer) AddValidationRule(name string, fn ValidationRuleFunc) { - a.ValidationRules = append(a.ValidationRules, ValidationRule{name, fn}) +// NewDefault creates a default Analyzer instance with all default Rules and configuration. +// To add custom rules, the easiest way is use the Builder. +func NewDefault(c *sql.Catalog) *Analyzer { + return NewBuilder(c).Build() } // Log prints an INFO message to stdout with the given message and args @@ -84,85 +160,34 @@ func (a *Analyzer) Log(msg string, args ...interface{}) { // Analyze the node and all its children. func (a *Analyzer) Analyze(ctx *sql.Context, n sql.Node) (sql.Node, error) { - span, ctx := ctx.Span("analyze") - span.LogKV("plan", n.String()) + span, ctx := ctx.Span("analyze", opentracing.Tags{ + "plan": n.String(), + }) prev := n + var err error a.Log("starting analysis of node of type: %T", n) - cur, err := a.analyzeOnce(ctx, n) - defer func() { - if cur != nil { - span.SetTag("IsResolved", cur.Resolved()) + for _, batch := range a.Batches { + prev, err = batch.Eval(ctx, a, prev) + if ErrMaxAnalysisIters.Is(err) { + a.Log(err.Error()) + continue } - span.Finish() - }() - - if err != nil { - return nil, err - } - - for i := 0; !reflect.DeepEqual(prev, cur); { - a.Log("previous node does not match new node, analyzing again, iteration: %d", i) - prev = cur - cur, err = a.analyzeOnce(ctx, cur) if err != nil { return nil, err } - - i++ - if i >= maxAnalysisIterations { - return cur, ErrMaxAnalysisIters.New(maxAnalysisIterations) - } - } - - if errs := a.validate(ctx, cur); len(errs) != 0 { - for _, e := range errs { - err = multierror.Append(err, e) - } } - return cur, err -} - -func (a *Analyzer) analyzeOnce(ctx *sql.Context, n sql.Node) (sql.Node, error) { - span, ctx := ctx.Span("analyze_once") - span.LogKV("plan", n.String()) - defer span.Finish() - - result := n - for _, rule := range a.Rules { - var err error - result, err = rule.Apply(ctx, a, result) - if err != nil { - return nil, err + defer func() { + if prev != nil { + span.SetTag("IsResolved", prev.Resolved()) } - } - return result, nil -} - -func (a *Analyzer) validate(ctx *sql.Context, n sql.Node) (validationErrors []error) { - span, ctx := ctx.Span("validate") - defer span.Finish() - - validationErrors = append(validationErrors, a.validateOnce(ctx, n)...) - - for _, node := range n.Children() { - validationErrors = append(validationErrors, a.validate(ctx, node)...) - } + span.Finish() + }() - return validationErrors + return prev, err } -func (a *Analyzer) validateOnce(ctx *sql.Context, n sql.Node) (validationErrors []error) { - span, ctx := ctx.Span("validate_once") - defer span.Finish() - - for _, rule := range a.ValidationRules { - err := rule.Apply(ctx, n) - if err != nil { - validationErrors = append(validationErrors, err) - } - } - - return validationErrors +type equaler interface { + Equal(sql.Node) bool } diff --git a/sql/analyzer/analyzer_test.go b/sql/analyzer/analyzer_test.go index f41d9c315..53f85fcd5 100644 --- a/sql/analyzer/analyzer_test.go +++ b/sql/analyzer/analyzer_test.go @@ -4,89 +4,105 @@ import ( "fmt" "testing" - "gopkg.in/src-d/go-mysql-server.v0/mem" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" - "gopkg.in/src-d/go-mysql-server.v0/sql/plan" - + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/plan" "github.com/stretchr/testify/require" ) func TestAnalyzer_Analyze(t *testing.T) { require := require.New(t) - table := mem.NewTable("mytable", sql.Schema{ + table := memory.NewTable("mytable", sql.Schema{ {Name: "i", Type: sql.Int32, Source: "mytable"}, {Name: "t", Type: sql.Text, Source: "mytable"}, }) - table2 := mem.NewTable("mytable2", sql.Schema{{Name: "i2", Type: sql.Int32, Source: "mytable2"}}) - db := mem.NewDatabase("mydb") + + table2 := memory.NewTable("mytable2", sql.Schema{ + {Name: "i2", Type: sql.Int32, Source: "mytable2"}, + }) + + db := memory.NewDatabase("mydb") db.AddTable("mytable", table) db.AddTable("mytable2", table2) - catalog := &sql.Catalog{Databases: []sql.Database{db}} - a := New(catalog) - a.CurrentDatabase = "mydb" + catalog := sql.NewCatalog() + catalog.AddDatabase(db) + a := withoutProcessTracking(NewDefault(catalog)) - var notAnalyzed sql.Node = plan.NewUnresolvedTable("mytable") + var notAnalyzed sql.Node = plan.NewUnresolvedTable("mytable", "") analyzed, err := a.Analyze(sql.NewEmptyContext(), notAnalyzed) require.NoError(err) - require.Equal(table, analyzed) + require.Equal( + plan.NewResolvedTable(table), + analyzed, + ) - notAnalyzed = plan.NewUnresolvedTable("nonexistant") + notAnalyzed = plan.NewUnresolvedTable("nonexistant", "") analyzed, err = a.Analyze(sql.NewEmptyContext(), notAnalyzed) require.Error(err) require.Nil(analyzed) - analyzed, err = a.Analyze(sql.NewEmptyContext(), table) + analyzed, err = a.Analyze(sql.NewEmptyContext(), plan.NewResolvedTable(table)) require.NoError(err) - require.Equal(table, analyzed) + require.Equal( + plan.NewResolvedTable(table), + analyzed, + ) notAnalyzed = plan.NewProject( []sql.Expression{expression.NewUnresolvedColumn("o")}, - plan.NewUnresolvedTable("mytable"), + plan.NewUnresolvedTable("mytable", ""), ) _, err = a.Analyze(sql.NewEmptyContext(), notAnalyzed) require.Error(err) notAnalyzed = plan.NewProject( []sql.Expression{expression.NewUnresolvedColumn("i")}, - plan.NewUnresolvedTable("mytable"), + plan.NewUnresolvedTable("mytable", ""), ) analyzed, err = a.Analyze(sql.NewEmptyContext(), notAnalyzed) - var expected sql.Node = plan.NewProject( - []sql.Expression{expression.NewGetFieldWithTable(0, sql.Int32, "mytable", "i", false)}, - table, + var expected sql.Node = plan.NewResolvedTable( + table.WithProjection([]string{"i"}), ) require.NoError(err) require.Equal(expected, analyzed) notAnalyzed = plan.NewDescribe( - plan.NewUnresolvedTable("mytable"), + plan.NewUnresolvedTable("mytable", ""), ) analyzed, err = a.Analyze(sql.NewEmptyContext(), notAnalyzed) - expected = plan.NewDescribe(table) + expected = plan.NewDescribe( + plan.NewResolvedTable(table), + ) require.NoError(err) require.Equal(expected, analyzed) notAnalyzed = plan.NewProject( []sql.Expression{expression.NewStar()}, - plan.NewUnresolvedTable("mytable"), + plan.NewUnresolvedTable("mytable", ""), ) analyzed, err = a.Analyze(sql.NewEmptyContext(), notAnalyzed) require.NoError(err) - require.Equal(table, analyzed) + require.Equal( + plan.NewResolvedTable(table.WithProjection([]string{"i", "t"})), + analyzed, + ) notAnalyzed = plan.NewProject( []sql.Expression{expression.NewStar()}, plan.NewProject( []sql.Expression{expression.NewStar()}, - plan.NewUnresolvedTable("mytable"), + plan.NewUnresolvedTable("mytable", ""), ), ) analyzed, err = a.Analyze(sql.NewEmptyContext(), notAnalyzed) require.NoError(err) - require.Equal(table, analyzed) + require.Equal( + plan.NewResolvedTable(table.WithProjection([]string{"i", "t"})), + analyzed, + ) notAnalyzed = plan.NewProject( []sql.Expression{ @@ -95,7 +111,7 @@ func TestAnalyzer_Analyze(t *testing.T) { "foo", ), }, - plan.NewUnresolvedTable("mytable"), + plan.NewUnresolvedTable("mytable", ""), ) analyzed, err = a.Analyze(sql.NewEmptyContext(), notAnalyzed) expected = plan.NewProject( @@ -105,7 +121,7 @@ func TestAnalyzer_Analyze(t *testing.T) { "foo", ), }, - table, + plan.NewResolvedTable(table.WithProjection([]string{"i"})), ) require.NoError(err) require.Equal(expected, analyzed) @@ -117,42 +133,36 @@ func TestAnalyzer_Analyze(t *testing.T) { expression.NewUnresolvedColumn("i"), expression.NewLiteral(int32(1), sql.Int32), ), - plan.NewUnresolvedTable("mytable"), + plan.NewUnresolvedTable("mytable", ""), ), ) analyzed, err = a.Analyze(sql.NewEmptyContext(), notAnalyzed) - expected = plan.NewProject( - []sql.Expression{ - expression.NewGetFieldWithTable(0, sql.Int32, "mytable", "i", false), - }, - plan.NewFilter( + expected = plan.NewResolvedTable( + table.WithFilters([]sql.Expression{ expression.NewEquals( expression.NewGetFieldWithTable(0, sql.Int32, "mytable", "i", false), expression.NewLiteral(int32(1), sql.Int32), ), - table, - ), + }).(*memory.Table).WithProjection([]string{"i"}), ) require.NoError(err) require.Equal(expected, analyzed) + // notAnalyzed = plan.NewProject( []sql.Expression{ expression.NewUnresolvedColumn("i"), expression.NewUnresolvedColumn("i2"), }, plan.NewCrossJoin( - plan.NewUnresolvedTable("mytable"), - plan.NewUnresolvedTable("mytable2"), + plan.NewUnresolvedTable("mytable", ""), + plan.NewUnresolvedTable("mytable2", ""), ), ) analyzed, err = a.Analyze(sql.NewEmptyContext(), notAnalyzed) - expected = plan.NewProject( - []sql.Expression{ - expression.NewGetFieldWithTable(0, sql.Int32, "mytable", "i", false), - expression.NewGetFieldWithTable(2, sql.Int32, "mytable2", "i2", false), - }, - plan.NewCrossJoin(table, table2), + expected = plan.NewCrossJoin( + plan.NewResolvedTable(table.WithProjection([]string{"i"})), + plan.NewResolvedTable(table2.WithProjection([]string{"i2"})), ) require.NoError(err) require.Equal(expected, analyzed) @@ -162,58 +172,262 @@ func TestAnalyzer_Analyze(t *testing.T) { []sql.Expression{ expression.NewUnresolvedColumn("i"), }, - plan.NewUnresolvedTable("mytable"), + plan.NewUnresolvedTable("mytable", ""), ), ) analyzed, err = a.Analyze(sql.NewEmptyContext(), notAnalyzed) - expected = plan.NewLimit(int64(1), - plan.NewProject( - []sql.Expression{ - expression.NewGetFieldWithTable(0, sql.Int32, "mytable", "i", false), - }, - table, - ), + expected = plan.NewLimit( + int64(1), + plan.NewResolvedTable(table.WithProjection([]string{"i"})), ) - require.Nil(err) + require.NoError(err) require.Equal(expected, analyzed) } -func TestAnalyzer_Analyze_MaxIterations(t *testing.T) { +func TestMaxIterations(t *testing.T) { require := require.New(t) + tName := "my-table" + table := memory.NewTable(tName, sql.Schema{ + {Name: "i", Type: sql.Int32, Source: tName}, + {Name: "t", Type: sql.Text, Source: tName}, + }) + db := memory.NewDatabase("mydb") + db.AddTable(tName, table) - catalog := &sql.Catalog{} - a := New(catalog) - a.CurrentDatabase = "mydb" + catalog := sql.NewCatalog() + catalog.AddDatabase(db) - i := 0 - a.Rules = []Rule{{ - Name: "infinite", - Apply: func(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { - i++ - return plan.NewUnresolvedTable(fmt.Sprintf("table%d", i)), nil - }, - }} + count := 0 + a := withoutProcessTracking(NewBuilder(catalog).AddPostAnalyzeRule("loop", + func(c *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { - notAnalyzed := plan.NewUnresolvedTable("mytable") + switch n.(type) { + case *plan.ResolvedTable: + count++ + name := fmt.Sprintf("mytable-%v", count) + table := memory.NewTable(name, sql.Schema{ + {Name: "i", Type: sql.Int32, Source: name}, + {Name: "t", Type: sql.Text, Source: name}, + }) + n = plan.NewResolvedTable(table) + } + + return n, nil + }).Build()) + + notAnalyzed := plan.NewUnresolvedTable(tName, "") analyzed, err := a.Analyze(sql.NewEmptyContext(), notAnalyzed) - require.NotNil(err) - require.Equal(plan.NewUnresolvedTable("table1001"), analyzed) + require.NoError(err) + require.Equal( + plan.NewResolvedTable( + memory.NewTable("mytable-1000", sql.Schema{ + {Name: "i", Type: sql.Int32, Source: "mytable-1000"}, + {Name: "t", Type: sql.Text, Source: "mytable-1000"}, + }), + ), + analyzed, + ) + require.Equal(1000, count) } func TestAddRule(t *testing.T) { require := require.New(t) - a := New(nil) - require.Len(a.Rules, 13) - a.AddRule("foo", pushdown) - require.Len(a.Rules, 14) + defRulesCount := countRules(NewDefault(nil).Batches) + + a := NewBuilder(nil).AddPostAnalyzeRule("foo", pushdown).Build() + + require.Equal(countRules(a.Batches), defRulesCount+1) +} + +func TestAddPreValidationRule(t *testing.T) { + require := require.New(t) + + defRulesCount := countRules(NewDefault(nil).Batches) + + a := NewBuilder(nil).AddPreValidationRule("foo", pushdown).Build() + + require.Equal(countRules(a.Batches), defRulesCount+1) } -func TestAddValidationRule(t *testing.T) { +func TestAddPostValidationRule(t *testing.T) { require := require.New(t) - a := New(nil) - require.Len(a.ValidationRules, 6) - a.AddValidationRule("foo", validateGroupBy) - require.Len(a.ValidationRules, 7) + defRulesCount := countRules(NewDefault(nil).Batches) + + a := NewBuilder(nil).AddPostValidationRule("foo", pushdown).Build() + + require.Equal(countRules(a.Batches), defRulesCount+1) +} + +func countRules(batches []*Batch) int { + var count int + for _, b := range batches { + count = count + len(b.Rules) + } + return count + +} + +func TestMixInnerAndNaturalJoins(t *testing.T) { + var require = require.New(t) + + table := memory.NewTable("mytable", sql.Schema{ + {Name: "i", Type: sql.Int32, Source: "mytable"}, + {Name: "f", Type: sql.Float64, Source: "mytable"}, + {Name: "t", Type: sql.Text, Source: "mytable"}, + }) + + table2 := memory.NewTable("mytable2", sql.Schema{ + {Name: "i2", Type: sql.Int32, Source: "mytable2"}, + {Name: "f2", Type: sql.Float64, Source: "mytable2"}, + {Name: "t2", Type: sql.Text, Source: "mytable2"}, + }) + + table3 := memory.NewTable("mytable3", sql.Schema{ + {Name: "i", Type: sql.Int32, Source: "mytable3"}, + {Name: "f2", Type: sql.Float64, Source: "mytable3"}, + {Name: "t3", Type: sql.Text, Source: "mytable3"}, + }) + + db := memory.NewDatabase("mydb") + db.AddTable("mytable", table) + db.AddTable("mytable2", table2) + db.AddTable("mytable3", table3) + + catalog := sql.NewCatalog() + catalog.AddDatabase(db) + a := withoutProcessTracking(NewDefault(catalog)) + + node := plan.NewProject( + []sql.Expression{ + expression.NewStar(), + }, + plan.NewNaturalJoin( + plan.NewInnerJoin( + plan.NewUnresolvedTable("mytable", ""), + plan.NewUnresolvedTable("mytable2", ""), + expression.NewEquals( + expression.NewUnresolvedQualifiedColumn("mytable", "i"), + expression.NewUnresolvedQualifiedColumn("mytable2", "i2"), + ), + ), + plan.NewUnresolvedTable("mytable3", ""), + ), + ) + + expected := plan.NewProject( + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int32, "mytable", "i", false), + expression.NewGetFieldWithTable(3, sql.Float64, "mytable2", "f2", false), + expression.NewGetFieldWithTable(1, sql.Float64, "mytable", "f", false), + expression.NewGetFieldWithTable(2, sql.Text, "mytable", "t", false), + expression.NewGetFieldWithTable(4, sql.Int32, "mytable2", "i2", false), + expression.NewGetFieldWithTable(5, sql.Text, "mytable2", "t2", false), + expression.NewGetFieldWithTable(6, sql.Text, "mytable3", "t3", false), + }, + plan.NewInnerJoin( + plan.NewInnerJoin( + plan.NewResolvedTable(table.WithProjection([]string{"i", "f", "t"})), + plan.NewResolvedTable(table2.WithProjection([]string{"f2", "i2", "t2"})), + expression.NewEquals( + expression.NewGetFieldWithTable(0, sql.Int32, "mytable", "i", false), + expression.NewGetFieldWithTable(4, sql.Int32, "mytable2", "i2", false), + ), + ), + plan.NewResolvedTable(table3.WithProjection([]string{"t3", "i", "f2"})), + expression.NewAnd( + expression.NewEquals( + expression.NewGetFieldWithTable(0, sql.Int32, "mytable", "i", false), + expression.NewGetFieldWithTable(7, sql.Int32, "mytable3", "i", false), + ), + expression.NewEquals( + expression.NewGetFieldWithTable(3, sql.Float64, "mytable2", "f2", false), + expression.NewGetFieldWithTable(8, sql.Float64, "mytable3", "f2", false), + ), + ), + ), + ) + + result, err := a.Analyze(sql.NewEmptyContext(), node) + require.NoError(err) + require.Equal(expected, result) +} + +func TestReorderProjectionUnresolvedChild(t *testing.T) { + require := require.New(t) + node := plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedQualifiedColumn("rc", "commit_hash"), + expression.NewUnresolvedColumn("commit_author_when"), + }, + plan.NewFilter( + expression.JoinAnd( + expression.NewEquals( + expression.NewUnresolvedQualifiedColumn("rc", "repository_id"), + expression.NewLiteral("foo", sql.Text), + ), + expression.NewEquals( + expression.NewUnresolvedQualifiedColumn("rc", "ref_name"), + expression.NewLiteral("HEAD", sql.Text), + ), + expression.NewEquals( + expression.NewUnresolvedQualifiedColumn("rc", "history_index"), + expression.NewLiteral(int64(0), sql.Int64), + ), + ), + plan.NewNaturalJoin( + plan.NewInnerJoin( + plan.NewUnresolvedTable("refs", ""), + plan.NewTableAlias("rc", + plan.NewUnresolvedTable("ref_commits", ""), + ), + expression.NewAnd( + expression.NewEquals( + expression.NewUnresolvedQualifiedColumn("refs", "ref_name"), + expression.NewUnresolvedQualifiedColumn("rc", "ref_name"), + ), + expression.NewEquals( + expression.NewUnresolvedQualifiedColumn("refs", "repository_id"), + expression.NewUnresolvedQualifiedColumn("rc", "repository_id"), + ), + ), + ), + plan.NewTableAlias("c", + plan.NewUnresolvedTable("commits", ""), + ), + ), + ), + ) + + commits := memory.NewTable("commits", sql.Schema{ + {Name: "repository_id", Source: "commits", Type: sql.Text}, + {Name: "commit_hash", Source: "commits", Type: sql.Text}, + {Name: "commit_author_when", Source: "commits", Type: sql.Text}, + }) + + refs := memory.NewTable("refs", sql.Schema{ + {Name: "repository_id", Source: "refs", Type: sql.Text}, + {Name: "ref_name", Source: "refs", Type: sql.Text}, + }) + + refCommits := memory.NewTable("ref_commits", sql.Schema{ + {Name: "repository_id", Source: "ref_commits", Type: sql.Text}, + {Name: "ref_name", Source: "ref_commits", Type: sql.Text}, + {Name: "commit_hash", Source: "ref_commits", Type: sql.Text}, + {Name: "history_index", Source: "ref_commits", Type: sql.Int64}, + }) + + db := memory.NewDatabase("") + db.AddTable("refs", refs) + db.AddTable("ref_commits", refCommits) + db.AddTable("commits", commits) + + catalog := sql.NewCatalog() + catalog.AddDatabase(db) + a := withoutProcessTracking(NewDefault(catalog)) + + result, err := a.Analyze(sql.NewEmptyContext(), node) + require.NoError(err) + require.True(result.Resolved()) } diff --git a/sql/analyzer/assign_catalog.go b/sql/analyzer/assign_catalog.go new file mode 100644 index 000000000..0a8f76ee8 --- /dev/null +++ b/sql/analyzer/assign_catalog.go @@ -0,0 +1,67 @@ +package analyzer + +import ( + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/plan" +) + +// assignCatalog sets the catalog in the required nodes. +func assignCatalog(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + span, _ := ctx.Span("assign_catalog") + defer span.Finish() + + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { + if !n.Resolved() { + return n, nil + } + + switch node := n.(type) { + case *plan.CreateIndex: + nc := *node + nc.Catalog = a.Catalog + nc.CurrentDatabase = a.Catalog.CurrentDatabase() + return &nc, nil + case *plan.DropIndex: + nc := *node + nc.Catalog = a.Catalog + nc.CurrentDatabase = a.Catalog.CurrentDatabase() + return &nc, nil + case *plan.ShowIndexes: + nc := *node + nc.Registry = a.Catalog.IndexRegistry + return &nc, nil + case *plan.ShowDatabases: + nc := *node + nc.Catalog = a.Catalog + return &nc, nil + case *plan.ShowCreateTable: + nc := *node + nc.Catalog = a.Catalog + nc.CurrentDatabase = a.Catalog.CurrentDatabase() + return &nc, nil + case *plan.ShowProcessList: + nc := *node + nc.Database = a.Catalog.CurrentDatabase() + nc.ProcessList = a.Catalog.ProcessList + return &nc, nil + case *plan.ShowTableStatus: + nc := *node + nc.Catalog = a.Catalog + return &nc, nil + case *plan.Use: + nc := *node + nc.Catalog = a.Catalog + return &nc, nil + case *plan.LockTables: + nc := *node + nc.Catalog = a.Catalog + return &nc, nil + case *plan.UnlockTables: + nc := *node + nc.Catalog = a.Catalog + return &nc, nil + default: + return n, nil + } + }) +} diff --git a/sql/analyzer/assign_catalog_test.go b/sql/analyzer/assign_catalog_test.go new file mode 100644 index 000000000..39aed1a4f --- /dev/null +++ b/sql/analyzer/assign_catalog_test.go @@ -0,0 +1,76 @@ +package analyzer + +import ( + "testing" + + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/plan" + "github.com/stretchr/testify/require" +) + +func TestAssignCatalog(t *testing.T) { + require := require.New(t) + f := getRule("assign_catalog") + + db := memory.NewDatabase("foo") + c := sql.NewCatalog() + c.AddDatabase(db) + + a := NewDefault(c) + a.Catalog.IndexRegistry = sql.NewIndexRegistry() + + tbl := memory.NewTable("foo", nil) + + node, err := f.Apply(sql.NewEmptyContext(), a, + plan.NewCreateIndex("", plan.NewResolvedTable(tbl), nil, "", make(map[string]string))) + require.NoError(err) + + ci, ok := node.(*plan.CreateIndex) + require.True(ok) + require.Equal(c, ci.Catalog) + require.Equal("foo", ci.CurrentDatabase) + + node, err = f.Apply(sql.NewEmptyContext(), a, + plan.NewDropIndex("foo", plan.NewResolvedTable(tbl))) + require.NoError(err) + + di, ok := node.(*plan.DropIndex) + require.True(ok) + require.Equal(c, di.Catalog) + require.Equal("foo", di.CurrentDatabase) + + node, err = f.Apply(sql.NewEmptyContext(), a, plan.NewShowIndexes(db, "table-test", nil)) + require.NoError(err) + + si, ok := node.(*plan.ShowIndexes) + require.True(ok) + require.Equal(db, si.Database()) + require.Equal(c.IndexRegistry, si.Registry) + + node, err = f.Apply(sql.NewEmptyContext(), a, plan.NewShowProcessList()) + require.NoError(err) + + pl, ok := node.(*plan.ShowProcessList) + require.True(ok) + require.Equal(db.Name(), pl.Database) + require.Equal(c.ProcessList, pl.ProcessList) + + node, err = f.Apply(sql.NewEmptyContext(), a, plan.NewShowDatabases()) + require.NoError(err) + sd, ok := node.(*plan.ShowDatabases) + require.True(ok) + require.Equal(c, sd.Catalog) + + node, err = f.Apply(sql.NewEmptyContext(), a, plan.NewLockTables(nil)) + require.NoError(err) + lt, ok := node.(*plan.LockTables) + require.True(ok) + require.Equal(c, lt.Catalog) + + node, err = f.Apply(sql.NewEmptyContext(), a, plan.NewUnlockTables()) + require.NoError(err) + ut, ok := node.(*plan.UnlockTables) + require.True(ok) + require.Equal(c, ut.Catalog) +} diff --git a/sql/analyzer/assign_indexes.go b/sql/analyzer/assign_indexes.go new file mode 100644 index 000000000..b6bcb1327 --- /dev/null +++ b/sql/analyzer/assign_indexes.go @@ -0,0 +1,790 @@ +package analyzer + +import ( + "reflect" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/plan" + errors "gopkg.in/src-d/go-errors.v1" +) + +var errInvalidInRightEvaluation = errors.NewKind("expecting evaluation of IN expression right hand side to be a tuple, but it is %T") + +// indexLookup contains an sql.IndexLookup and all sql.Index that are involved +// in it. +type indexLookup struct { + lookup sql.IndexLookup + indexes []sql.Index +} + +func assignIndexes(a *Analyzer, node sql.Node) (map[string]*indexLookup, error) { + a.Log("assigning indexes, node of type: %T", node) + + var indexes map[string]*indexLookup + // release all unused indexes + defer func() { + if indexes == nil { + return + } + + for _, i := range indexes { + for _, index := range i.indexes { + a.Catalog.ReleaseIndex(index) + } + } + }() + + aliases := make(map[string]sql.Expression) + var ( + err error + fn func(node sql.Node) bool + ) + fn = func(n sql.Node) bool { + if n == nil { + return true + } + + if prj, ok := n.(*plan.Project); ok { + for _, ex := range prj.Expressions() { + if alias, ok := ex.(*expression.Alias); ok { + if _, ok := aliases[alias.Name()]; !ok { + aliases[alias.Name()] = alias.Child + } + } + } + } else { + for _, ch := range n.Children() { + plan.Inspect(ch, fn) + } + } + + return true + } + + plan.Inspect(node, func(node sql.Node) bool { + filter, ok := node.(*plan.Filter) + if !ok { + return true + } + fn(filter.Child) + + var result map[string]*indexLookup + result, err = getIndexes(filter.Expression, aliases, a) + if err != nil { + return false + } + + if indexes != nil { + indexes = indexesIntersection(a, indexes, result) + } else { + indexes = result + } + + return true + }) + + return indexes, err +} + +func getIndexes(e sql.Expression, aliases map[string]sql.Expression, a *Analyzer) (map[string]*indexLookup, error) { + var result = make(map[string]*indexLookup) + switch e := e.(type) { + case *expression.Or: + leftIndexes, err := getIndexes(e.Left, aliases, a) + if err != nil { + return nil, err + } + + rightIndexes, err := getIndexes(e.Right, aliases, a) + if err != nil { + return nil, err + } + + for table, idx := range leftIndexes { + if idx2, ok := rightIndexes[table]; ok && canMergeIndexes(idx.lookup, idx2.lookup) { + idx.lookup = idx.lookup.(sql.SetOperations).Union(idx2.lookup) + idx.indexes = append(idx.indexes, idx2.indexes...) + } + result[table] = idx + } + + // Put in the result map the indexes for tables we don't have indexes yet. + // The others were already handled by the previous loop. + for table, lookup := range rightIndexes { + if _, ok := result[table]; !ok { + result[table] = lookup + } + } + case *expression.In, *expression.NotIn: + c, ok := e.(expression.Comparer) + if !ok { + return nil, nil + } + + _, negate := e.(*expression.NotIn) + + // Take the index of a SOMETHING IN SOMETHING expression only if: + // the right branch is evaluable and the indexlookup supports set + // operations. + if !isEvaluable(c.Left()) && isEvaluable(c.Right()) { + idx := a.Catalog.IndexByExpression(a.Catalog.CurrentDatabase(), unifyExpressions(aliases, c.Left())...) + if idx != nil { + var nidx sql.NegateIndex + if negate { + nidx, ok = idx.(sql.NegateIndex) + if !ok { + return nil, nil + } + } + + // release the index if it was not used + defer func() { + if _, ok := result[idx.Table()]; !ok { + a.Catalog.ReleaseIndex(idx) + } + }() + + value, err := c.Right().Eval(sql.NewEmptyContext(), nil) + if err != nil { + return nil, err + } + + values, ok := value.([]interface{}) + if !ok { + return nil, errInvalidInRightEvaluation.New(value) + } + + var lookup sql.IndexLookup + var errLookup error + if negate { + lookup, errLookup = nidx.Not(values[0]) + } else { + lookup, errLookup = idx.Get(values[0]) + + } + + if errLookup != nil { + return nil, err + } + + for _, v := range values[1:] { + var lookup2 sql.IndexLookup + var errLookup error + if negate { + lookup2, errLookup = nidx.Not(v) + } else { + lookup2, errLookup = idx.Get(v) + + } + + if errLookup != nil { + return nil, err + } + + // if one of the indexes cannot be merged, return already + if !canMergeIndexes(lookup, lookup2) { + return result, nil + } + + if negate { + lookup = lookup.(sql.SetOperations).Intersection(lookup2) + } else { + lookup = lookup.(sql.SetOperations).Union(lookup2) + } + } + + result[idx.Table()] = &indexLookup{ + indexes: []sql.Index{idx}, + lookup: lookup, + } + } + } + case *expression.Equals, + *expression.LessThan, + *expression.GreaterThan, + *expression.LessThanOrEqual, + *expression.GreaterThanOrEqual: + idx, lookup, err := getComparisonIndex(a, e.(expression.Comparer), aliases) + if err != nil || lookup == nil { + return result, err + } + + result[idx.Table()] = &indexLookup{ + indexes: []sql.Index{idx}, + lookup: lookup, + } + case *expression.Not: + r, err := getNegatedIndexes(a, e, aliases) + if err != nil { + return nil, err + } + + for table, indexLookup := range r { + result[table] = indexLookup + } + case *expression.Between: + if !isEvaluable(e.Val) && isEvaluable(e.Upper) && isEvaluable(e.Lower) { + idx := a.Catalog.IndexByExpression(a.Catalog.CurrentDatabase(), unifyExpressions(aliases, e.Val)...) + if idx != nil { + // release the index if it was not used + defer func() { + if _, ok := result[idx.Table()]; !ok { + a.Catalog.ReleaseIndex(idx) + } + }() + + upper, err := e.Upper.Eval(sql.NewEmptyContext(), nil) + if err != nil { + return nil, err + } + + lower, err := e.Lower.Eval(sql.NewEmptyContext(), nil) + if err != nil { + return nil, err + } + + lookup, err := betweenIndexLookup( + idx, + []interface{}{upper}, + []interface{}{lower}, + ) + if err != nil { + return nil, err + } + + if lookup != nil { + result[idx.Table()] = &indexLookup{ + indexes: []sql.Index{idx}, + lookup: lookup, + } + } + } + } + case *expression.And: + exprs := splitExpression(e) + used := make(map[sql.Expression]struct{}) + + result, err := getMultiColumnIndexes(exprs, a, used, aliases) + if err != nil { + return nil, err + } + + for _, e := range exprs { + if _, ok := used[e]; ok { + continue + } + + indexes, err := getIndexes(e, aliases, a) + if err != nil { + return nil, err + } + + result = indexesIntersection(a, result, indexes) + } + + return result, nil + } + + return result, nil +} + +func unifyExpressions(aliases map[string]sql.Expression, expr ...sql.Expression) []sql.Expression { + expressions := make([]sql.Expression, len(expr)) + + for i, e := range expr { + uex := e + name := e.String() + if n, ok := e.(sql.Nameable); ok { + name = n.Name() + } + + if aliases != nil && len(aliases) > 0 { + if alias, ok := aliases[name]; ok { + uex = alias + } + } + + expressions[i] = uex + } + + return expressions +} + +func betweenIndexLookup(index sql.Index, upper, lower []interface{}) (sql.IndexLookup, error) { + ai, isAscend := index.(sql.AscendIndex) + di, isDescend := index.(sql.DescendIndex) + if isAscend && isDescend { + ascendLookup, err := ai.AscendRange(lower, upper) + if err != nil { + return nil, err + } + + descendLookup, err := di.DescendRange(upper, lower) + if err != nil { + return nil, err + } + + m, ok := ascendLookup.(sql.Mergeable) + if ok && m.IsMergeable(descendLookup) { + return ascendLookup.(sql.SetOperations).Union(descendLookup), nil + } + } + + return nil, nil +} + +// getComparisonIndex returns the index and index lookup for the given +// comparison if any index can be found. +// It works for the following comparisons: eq, lt, gt, gte and lte. +// TODO(erizocosmico): add support for BETWEEN once the appropiate interfaces +// can handle inclusiveness on both sides. +func getComparisonIndex( + a *Analyzer, + e expression.Comparer, + aliases map[string]sql.Expression, +) (sql.Index, sql.IndexLookup, error) { + left, right := e.Left(), e.Right() + // if the form is SOMETHING OP {INDEXABLE EXPR}, swap it, so it's {INDEXABLE EXPR} OP SOMETHING + if !isEvaluable(right) { + left, right = right, left + } + + if !isEvaluable(left) && isEvaluable(right) { + idx := a.Catalog.IndexByExpression(a.Catalog.CurrentDatabase(), unifyExpressions(aliases, left)...) + if idx != nil { + value, err := right.Eval(sql.NewEmptyContext(), nil) + if err != nil { + a.Catalog.ReleaseIndex(idx) + return nil, nil, err + } + + lookup, err := comparisonIndexLookup(e, idx, value) + if err != nil || lookup == nil { + a.Catalog.ReleaseIndex(idx) + return nil, nil, err + } + + return idx, lookup, nil + } + } + + return nil, nil, nil +} + +func comparisonIndexLookup( + c expression.Comparer, + idx sql.Index, + values ...interface{}, +) (sql.IndexLookup, error) { + switch c.(type) { + case *expression.Equals: + return idx.Get(values...) + case *expression.GreaterThan: + index, ok := idx.(sql.DescendIndex) + if !ok { + return nil, nil + } + + return index.DescendGreater(values...) + case *expression.GreaterThanOrEqual: + index, ok := idx.(sql.AscendIndex) + if !ok { + return nil, nil + } + + return index.AscendGreaterOrEqual(values...) + case *expression.LessThan: + index, ok := idx.(sql.AscendIndex) + if !ok { + return nil, nil + } + + return index.AscendLessThan(values...) + case *expression.LessThanOrEqual: + index, ok := idx.(sql.DescendIndex) + if !ok { + return nil, nil + } + + return index.DescendLessOrEqual(values...) + } + + return nil, nil +} + +func getNegatedIndexes(a *Analyzer, not *expression.Not, aliases map[string]sql.Expression) (map[string]*indexLookup, error) { + switch e := not.Child.(type) { + case *expression.Not: + return getIndexes(e.Child, aliases, a) + case *expression.Equals: + left, right := e.Left(), e.Right() + // if the form is SOMETHING OP {INDEXABLE EXPR}, swap it, so it's {INDEXABLE EXPR} OP SOMETHING + if !isEvaluable(right) { + left, right = right, left + } + + if isEvaluable(left) || !isEvaluable(right) { + return nil, nil + } + + idx := a.Catalog.IndexByExpression(a.Catalog.CurrentDatabase(), unifyExpressions(aliases, left)...) + if idx == nil { + return nil, nil + } + + index, ok := idx.(sql.NegateIndex) + if !ok { + return nil, nil + } + + value, err := right.Eval(sql.NewEmptyContext(), nil) + if err != nil { + a.Catalog.ReleaseIndex(idx) + return nil, err + } + + lookup, err := index.Not(value) + if err != nil || lookup == nil { + a.Catalog.ReleaseIndex(idx) + return nil, err + } + + result := map[string]*indexLookup{ + idx.Table(): &indexLookup{ + indexes: []sql.Index{idx}, + lookup: lookup, + }, + } + + return result, nil + case *expression.GreaterThan: + lte := expression.NewLessThanOrEqual(e.Left(), e.Right()) + return getIndexes(lte, aliases, a) + case *expression.GreaterThanOrEqual: + lt := expression.NewLessThan(e.Left(), e.Right()) + return getIndexes(lt, aliases, a) + case *expression.LessThan: + gte := expression.NewGreaterThanOrEqual(e.Left(), e.Right()) + return getIndexes(gte, aliases, a) + case *expression.LessThanOrEqual: + gt := expression.NewGreaterThan(e.Left(), e.Right()) + return getIndexes(gt, aliases, a) + case *expression.Between: + or := expression.NewOr( + expression.NewLessThan(e.Val, e.Lower), + expression.NewGreaterThan(e.Val, e.Upper), + ) + + return getIndexes(or, aliases, a) + case *expression.Or: + and := expression.NewAnd( + expression.NewNot(e.Left), + expression.NewNot(e.Right), + ) + + return getIndexes(and, aliases, a) + case *expression.And: + or := expression.NewOr( + expression.NewNot(e.Left), + expression.NewNot(e.Right), + ) + + return getIndexes(or, aliases, a) + default: + return nil, nil + + } +} + +func indexesIntersection( + a *Analyzer, + left, right map[string]*indexLookup, +) map[string]*indexLookup { + var result = make(map[string]*indexLookup) + + for table, idx := range left { + if idx2, ok := right[table]; ok && canMergeIndexes(idx.lookup, idx2.lookup) { + idx.lookup = idx.lookup.(sql.SetOperations).Intersection(idx2.lookup) + idx.indexes = append(idx.indexes, idx2.indexes...) + } else if ok { + for _, idx := range idx2.indexes { + a.Catalog.ReleaseIndex(idx) + } + } + + result[table] = idx + } + + // Put in the result map the indexes for tables we don't have indexes yet. + // The others were already handled by the previous loop. + for table, lookup := range right { + if _, ok := result[table]; !ok { + result[table] = lookup + } + } + + return result +} + +func getMultiColumnIndexes( + exprs []sql.Expression, + a *Analyzer, + used map[sql.Expression]struct{}, + aliases map[string]sql.Expression, +) (map[string]*indexLookup, error) { + result := make(map[string]*indexLookup) + columnExprs := columnExprsByTable(exprs) + for table, exps := range columnExprs { + exprsByOp := groupExpressionsByOperator(exps) + for _, exps := range exprsByOp { + cols := make([]sql.Expression, len(exps)) + for i, e := range exps { + cols[i] = e.col + } + + exprList := a.Catalog.ExpressionsWithIndexes(a.Catalog.CurrentDatabase(), cols...) + + var selected []sql.Expression + for _, l := range exprList { + if len(l) > len(selected) { + selected = l + } + } + + if len(selected) > 0 { + index, lookup, err := getMultiColumnIndexForExpressions(a, selected, exps, used, aliases) + if err != nil || lookup == nil { + if index != nil { + a.Catalog.ReleaseIndex(index) + } + + if err != nil { + return nil, err + } + } + + if lookup != nil { + if _, ok := result[table]; ok { + result = indexesIntersection(a, result, map[string]*indexLookup{ + table: &indexLookup{lookup, []sql.Index{index}}, + }) + } else { + result[table] = &indexLookup{lookup, []sql.Index{index}} + } + } + } + } + } + + return result, nil +} + +func getMultiColumnIndexForExpressions( + a *Analyzer, + selected []sql.Expression, + exprs []columnExpr, + used map[sql.Expression]struct{}, + aliases map[string]sql.Expression, +) (index sql.Index, lookup sql.IndexLookup, err error) { + index = a.Catalog.IndexByExpression(a.Catalog.CurrentDatabase(), unifyExpressions(aliases, selected...)...) + if index != nil { + var first sql.Expression + for _, e := range exprs { + if e.col == selected[0] { + first = e.expr + break + } + } + + if first == nil { + return + } + + switch e := first.(type) { + case *expression.Equals, + *expression.LessThan, + *expression.GreaterThan, + *expression.LessThanOrEqual, + *expression.GreaterThanOrEqual: + var values = make([]interface{}, len(index.Expressions())) + for i, e := range index.Expressions() { + col := findColumn(exprs, e) + used[col.expr] = struct{}{} + var val interface{} + val, err = col.val.Eval(sql.NewEmptyContext(), nil) + if err != nil { + return + } + values[i] = val + } + + lookup, err = comparisonIndexLookup(e.(expression.Comparer), index, values...) + case *expression.Between: + var lowers = make([]interface{}, len(index.Expressions())) + var uppers = make([]interface{}, len(index.Expressions())) + for i, e := range index.Expressions() { + col := findColumn(exprs, e) + used[col.expr] = struct{}{} + between := col.expr.(*expression.Between) + lowers[i], err = between.Lower.Eval(sql.NewEmptyContext(), nil) + if err != nil { + return + } + + uppers[i], err = between.Upper.Eval(sql.NewEmptyContext(), nil) + if err != nil { + return + } + } + + lookup, err = betweenIndexLookup(index, uppers, lowers) + } + } + + return +} + +func groupExpressionsByOperator(exprs []columnExpr) [][]columnExpr { + var result [][]columnExpr + + for _, e := range exprs { + var found bool + for i, group := range result { + t1 := reflect.TypeOf(group[0].expr) + t2 := reflect.TypeOf(e.expr) + if t1 == t2 { + result[i] = append(result[i], e) + found = true + break + } + } + + if !found { + result = append(result, []columnExpr{e}) + } + } + + return result +} + +type columnExpr struct { + col *expression.GetField + val sql.Expression + expr sql.Expression +} + +func findColumn(cols []columnExpr, column string) *columnExpr { + for _, col := range cols { + if col.col.String() == column { + return &col + } + } + return nil +} + +func columnExprsByTable(exprs []sql.Expression) map[string][]columnExpr { + var result = make(map[string][]columnExpr) + + for _, expr := range exprs { + table, colExpr := extractColumnExpr(expr) + if colExpr == nil { + continue + } + + result[table] = append(result[table], *colExpr) + } + + return result +} + +func extractColumnExpr(e sql.Expression) (string, *columnExpr) { + switch e := e.(type) { + case *expression.Not: + table, colExpr := extractColumnExpr(e.Child) + if colExpr != nil { + colExpr = &columnExpr{ + col: colExpr.col, + val: colExpr.val, + expr: expression.NewNot(colExpr.expr), + } + } + + return table, colExpr + case *expression.Equals, + *expression.GreaterThan, + *expression.LessThan, + *expression.GreaterThanOrEqual, + *expression.LessThanOrEqual: + cmp := e.(expression.Comparer) + left, right := cmp.Left(), cmp.Right() + if !isEvaluable(right) { + left, right = right, left + } + + if !isEvaluable(right) { + return "", nil + } + + col, ok := left.(*expression.GetField) + if !ok { + return "", nil + } + + return col.Table(), &columnExpr{col, right, e} + case *expression.Between: + if !isEvaluable(e.Upper) || !isEvaluable(e.Lower) || isEvaluable(e.Val) { + return "", nil + } + + col, ok := e.Val.(*expression.GetField) + if !ok { + return "", nil + } + + return col.Table(), &columnExpr{col, nil, e} + default: + return "", nil + } +} + +func containsColumns(e sql.Expression) bool { + var result bool + expression.Inspect(e, func(e sql.Expression) bool { + if _, ok := e.(*expression.GetField); ok { + result = true + } + return true + }) + return result +} + +func containsSubquery(e sql.Expression) bool { + var result bool + expression.Inspect(e, func(e sql.Expression) bool { + if _, ok := e.(*expression.Subquery); ok { + result = true + return false + } + return true + }) + return result +} + +func isEvaluable(e sql.Expression) bool { + return !containsColumns(e) && !containsSubquery(e) +} + +func canMergeIndexes(a, b sql.IndexLookup) bool { + m, ok := a.(sql.Mergeable) + if !ok { + return false + } + + if !m.IsMergeable(b) { + return false + } + + _, ok = a.(sql.SetOperations) + return ok +} diff --git a/sql/analyzer/assign_indexes_test.go b/sql/analyzer/assign_indexes_test.go new file mode 100644 index 000000000..d8e61f05a --- /dev/null +++ b/sql/analyzer/assign_indexes_test.go @@ -0,0 +1,1226 @@ +package analyzer + +import ( + "fmt" + "strings" + "testing" + + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/plan" + "github.com/stretchr/testify/require" +) + +func TestNegateIndex(t *testing.T) { + require := require.New(t) + + catalog := sql.NewCatalog() + idx1 := &dummyIndex{ + "t1", + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "t1", "foo", false), + }, + } + done, ready, err := catalog.AddIndex(idx1) + require.NoError(err) + close(done) + <-ready + + a := NewDefault(catalog) + + t1 := memory.NewTable("t1", sql.Schema{ + {Name: "foo", Type: sql.Int64, Source: "t1"}, + }) + + node := plan.NewProject( + []sql.Expression{}, + plan.NewFilter( + expression.NewNot( + expression.NewEquals( + expression.NewGetFieldWithTable(0, sql.Int64, "t1", "foo", false), + expression.NewLiteral(int64(1), sql.Int64), + ), + ), + plan.NewResolvedTable(t1), + ), + ) + + result, err := assignIndexes(a, node) + require.NoError(err) + + lookupIdxs, ok := result["t1"] + require.True(ok) + + negate, ok := lookupIdxs.lookup.(*negateIndexLookup) + require.True(ok) + require.True(negate.value == "1") +} + +func TestAssignIndexes(t *testing.T) { + require := require.New(t) + + catalog := sql.NewCatalog() + idx1 := &dummyIndex{ + "t2", + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "t2", "bar", false), + }, + } + done, ready, err := catalog.AddIndex(idx1) + require.NoError(err) + close(done) + <-ready + + idx2 := &dummyIndex{ + "t1", + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "t1", "foo", false), + }, + } + done, ready, err = catalog.AddIndex(idx2) + + require.NoError(err) + close(done) + <-ready + + a := NewDefault(catalog) + + t1 := memory.NewTable("t1", sql.Schema{ + {Name: "foo", Type: sql.Int64, Source: "t1"}, + }) + + t2 := memory.NewTable("t2", sql.Schema{ + {Name: "bar", Type: sql.Int64, Source: "t2"}, + {Name: "baz", Type: sql.Int64, Source: "t2"}, + }) + + node := plan.NewProject( + []sql.Expression{}, + plan.NewFilter( + expression.NewOr( + expression.NewEquals( + expression.NewGetFieldWithTable(0, sql.Int64, "t2", "bar", false), + expression.NewLiteral(int64(1), sql.Int64), + ), + expression.NewEquals( + expression.NewGetFieldWithTable(0, sql.Int64, "t1", "foo", false), + expression.NewLiteral(int64(2), sql.Int64), + ), + ), + plan.NewInnerJoin( + plan.NewResolvedTable(t1), + plan.NewResolvedTable(t2), + expression.NewEquals( + expression.NewGetFieldWithTable(0, sql.Int64, "t1", "foo", false), + expression.NewGetFieldWithTable(0, sql.Int64, "t2", "baz", false), + ), + ), + ), + ) + + result, err := assignIndexes(a, node) + require.NoError(err) + + lookupIdxs, ok := result["t1"] + require.True(ok) + + mergeable, ok := lookupIdxs.lookup.(*mergeableIndexLookup) + require.True(ok) + require.True(mergeable.id == "2") + + lookupIdxs, ok = result["t2"] + require.True(ok) + + mergeable, ok = lookupIdxs.lookup.(*mergeableIndexLookup) + require.True(ok) + require.True(mergeable.id == "1") +} + +func TestGetIndexes(t *testing.T) { + indexes := []*dummyIndex{ + { + "t1", + []sql.Expression{ + col(0, "t1", "bar"), + }, + }, + { + "t2", + []sql.Expression{ + col(0, "t2", "foo"), + col(0, "t2", "bar"), + }, + }, + { + "t2", + []sql.Expression{ + col(0, "t2", "bar"), + }, + }, + } + + testCases := []struct { + expr sql.Expression + expected map[string]*indexLookup + ok bool + }{ + { + eq( + col(0, "t1", "bar"), + col(1, "t1", "baz"), + ), + map[string]*indexLookup{}, + true, + }, + { + eq( + col(0, "t1", "bar"), + lit(1), + ), + map[string]*indexLookup{ + "t1": &indexLookup{ + &mergeableIndexLookup{id: "1"}, + []sql.Index{indexes[0]}, + }, + }, + true, + }, + { + or( + eq( + col(0, "t1", "bar"), + lit(1), + ), + eq( + col(0, "t1", "bar"), + lit(2), + ), + ), + map[string]*indexLookup{ + "t1": &indexLookup{ + &mergeableIndexLookup{id: "1", unions: []string{"2"}}, + []sql.Index{ + indexes[0], + indexes[0], + }, + }, + }, + true, + }, + { + and( + eq( + col(0, "t1", "bar"), + lit(1), + ), + eq( + col(0, "t1", "bar"), + lit(2), + ), + ), + map[string]*indexLookup{ + "t1": &indexLookup{ + &mergeableIndexLookup{id: "1", intersections: []string{"2"}}, + []sql.Index{ + indexes[0], + indexes[0], + }, + }, + }, + true, + }, + { + and( + or( + eq( + col(0, "t1", "bar"), + lit(1), + ), + eq( + col(0, "t1", "bar"), + lit(2), + ), + ), + or( + eq( + col(0, "t1", "bar"), + lit(3), + ), + eq( + col(0, "t1", "bar"), + lit(4), + ), + ), + ), + map[string]*indexLookup{ + "t1": &indexLookup{ + &mergeableIndexLookup{id: "1", unions: []string{"2", "4"}, intersections: []string{"3"}}, + []sql.Index{ + indexes[0], + indexes[0], + indexes[0], + indexes[0], + }, + }, + }, + true, + }, + { + or( + or( + eq( + col(0, "t1", "bar"), + lit(1), + ), + eq( + col(0, "t1", "bar"), + lit(2), + ), + ), + or( + eq( + col(0, "t1", "bar"), + lit(3), + ), + eq( + col(0, "t1", "bar"), + lit(4), + ), + ), + ), + map[string]*indexLookup{ + "t1": &indexLookup{ + &mergeableIndexLookup{id: "1", unions: []string{"2", "3", "4"}}, + []sql.Index{ + indexes[0], + indexes[0], + indexes[0], + indexes[0], + }, + }, + }, + true, + }, + { + expression.NewIn( + col(0, "t1", "bar"), + expression.NewTuple( + lit(1), + lit(2), + lit(3), + lit(4), + ), + ), + map[string]*indexLookup{ + "t1": &indexLookup{ + &mergeableIndexLookup{id: "1", unions: []string{"2", "3", "4"}}, + []sql.Index{indexes[0]}, + }, + }, + true, + }, + { + and( + eq( + col(0, "t2", "foo"), + lit(1), + ), + and( + eq( + col(0, "t2", "baz"), + lit(4), + ), + and( + eq( + col(0, "t2", "bar"), + lit(2), + ), + eq( + col(0, "t1", "bar"), + lit(3), + ), + ), + ), + ), + map[string]*indexLookup{ + "t1": &indexLookup{ + &mergeableIndexLookup{id: "3"}, + []sql.Index{indexes[0]}, + }, + "t2": &indexLookup{ + &mergeableIndexLookup{id: "1, 2"}, + []sql.Index{indexes[1]}, + }, + }, + true, + }, + { + or( + eq( + col(0, "t2", "bar"), + lit(5), + ), + and( + eq( + col(0, "t2", "foo"), + lit(1), + ), + and( + eq( + col(0, "t2", "baz"), + lit(4), + ), + and( + eq( + col(0, "t2", "bar"), + lit(2), + ), + eq( + col(0, "t1", "bar"), + lit(3), + ), + ), + ), + ), + ), + map[string]*indexLookup{ + "t1": &indexLookup{ + &mergeableIndexLookup{id: "3"}, + []sql.Index{indexes[0]}, + }, + "t2": &indexLookup{ + &mergeableIndexLookup{id: "5", unions: []string{"1, 2"}}, + []sql.Index{ + indexes[2], + indexes[1], + }, + }, + }, + true, + }, + { + gt( + col(0, "t1", "bar"), + lit(1), + ), + map[string]*indexLookup{ + "t1": &indexLookup{ + &descendIndexLookup{gt: []interface{}{int64(1)}}, + []sql.Index{indexes[0]}, + }, + }, + true, + }, + { + lt( + col(0, "t1", "bar"), + lit(1), + ), + map[string]*indexLookup{ + "t1": &indexLookup{ + &ascendIndexLookup{lt: []interface{}{int64(1)}}, + []sql.Index{indexes[0]}, + }, + }, + true, + }, + { + gte( + col(0, "t1", "bar"), + lit(1), + ), + map[string]*indexLookup{ + "t1": &indexLookup{ + &ascendIndexLookup{gte: []interface{}{int64(1)}}, + []sql.Index{indexes[0]}, + }, + }, + true, + }, + { + lte( + col(0, "t1", "bar"), + lit(1), + ), + map[string]*indexLookup{ + "t1": &indexLookup{ + &descendIndexLookup{lte: []interface{}{int64(1)}}, + []sql.Index{indexes[0]}, + }, + }, + true, + }, + { + expression.NewBetween( + col(0, "t1", "bar"), + lit(1), + lit(5), + ), + map[string]*indexLookup{ + "t1": &indexLookup{ + &mergedIndexLookup{ + []sql.IndexLookup{ + &ascendIndexLookup{ + gte: []interface{}{int64(1)}, + lt: []interface{}{int64(5)}, + }, + &descendIndexLookup{ + gt: []interface{}{int64(1)}, + lte: []interface{}{int64(5)}, + }, + }, + }, + []sql.Index{indexes[0]}, + }, + }, + true, + }, + { + not( + eq( + col(0, "t1", "bar"), + lit(1), + ), + ), + map[string]*indexLookup{ + "t1": &indexLookup{ + &negateIndexLookup{ + value: "1", + }, + []sql.Index{indexes[0]}, + }, + }, + true, + }, + { + + not( + gt( + col(0, "t1", "bar"), + lit(10), + ), + ), + map[string]*indexLookup{ + "t1": &indexLookup{ + &descendIndexLookup{lte: []interface{}{int64(10)}}, + []sql.Index{indexes[0]}, + }, + }, + true, + }, + { + + not( + gte( + col(0, "t1", "bar"), + lit(10), + ), + ), + map[string]*indexLookup{ + "t1": &indexLookup{ + &ascendIndexLookup{lt: []interface{}{int64(10)}}, + []sql.Index{indexes[0]}, + }, + }, + true, + }, + { + + not( + lte( + col(0, "t1", "bar"), + lit(10), + ), + ), + map[string]*indexLookup{ + "t1": &indexLookup{ + &descendIndexLookup{gt: []interface{}{int64(10)}}, + []sql.Index{indexes[0]}, + }, + }, + true, + }, + { + + not( + lt( + col(0, "t1", "bar"), + lit(10), + ), + ), + map[string]*indexLookup{ + "t1": &indexLookup{ + &ascendIndexLookup{gte: []interface{}{int64(10)}}, + []sql.Index{indexes[0]}, + }, + }, + true, + }, + { + not( + and( + eq( + col(0, "t1", "bar"), + lit(10), + ), + eq( + col(0, "t1", "bar"), + lit(11), + ), + ), + ), + map[string]*indexLookup{ + "t1": &indexLookup{ + &mergedIndexLookup{ + children: []sql.IndexLookup{ + &negateIndexLookup{ + value: "10", + }, + &negateIndexLookup{ + value: "11", + }, + }, + }, + []sql.Index{ + indexes[0], + indexes[0], + }, + }, + }, + true, + }, + { + not( + or( + eq( + col(0, "t1", "bar"), + lit(10), + ), + eq( + col(0, "t1", "bar"), + lit(11), + ), + ), + ), + map[string]*indexLookup{ + "t1": &indexLookup{ + &mergeableIndexLookup{ + id: "not 10", + intersections: []string{"not 11"}, + }, + []sql.Index{ + indexes[0], + indexes[0], + }, + }, + }, + true, + }, + { + // `NOT` doesn't work for multicolumn indexes, so the expression + // will use indexes if there are indexes for the single columns + // involved. In this case there is a index for the column `t2.bar`. + not( + or( + eq( + col(0, "t2", "foo"), + lit(100), + ), + eq( + col(0, "t2", "bar"), + lit(110), + ), + ), + ), + map[string]*indexLookup{ + "t2": &indexLookup{ + &negateIndexLookup{ + value: "110", + }, + []sql.Index{ + indexes[2], + }, + }, + }, + true, + }, + { + expression.NewNotIn( + col(0, "t1", "bar"), + expression.NewTuple( + lit(1), + lit(2), + lit(3), + lit(4), + ), + ), + map[string]*indexLookup{ + "t1": &indexLookup{ + &mergeableIndexLookup{ + id: "not 1", + intersections: []string{"not 2", "not 3", "not 4"}, + }, + []sql.Index{indexes[0]}, + }, + }, + true, + }, + } + + catalog := sql.NewCatalog() + for _, idx := range indexes { + done, ready, err := catalog.AddIndex(idx) + require.NoError(t, err) + close(done) + <-ready + } + + a := NewDefault(catalog) + + for _, tt := range testCases { + t.Run(tt.expr.String(), func(t *testing.T) { + require := require.New(t) + + result, err := getIndexes(tt.expr, nil, a) + if tt.ok { + require.NoError(err) + require.Equal(tt.expected, result) + } else { + require.Error(err) + } + }) + } +} + +func TestGetMultiColumnIndexes(t *testing.T) { + require := require.New(t) + + catalog := sql.NewCatalog() + indexes := []*dummyIndex{ + { + "t1", + []sql.Expression{ + col(1, "t1", "foo"), + col(2, "t1", "bar"), + }, + }, + { + "t2", + []sql.Expression{ + col(0, "t2", "foo"), + col(1, "t2", "bar"), + col(2, "t2", "baz"), + }, + }, + { + "t2", + []sql.Expression{ + col(0, "t2", "foo"), + col(0, "t2", "bar"), + }, + }, + { + "t3", + []sql.Expression{col(0, "t3", "foo")}, + }, + { + "t4", + []sql.Expression{ + col(1, "t4", "foo"), + col(2, "t4", "bar"), + }, + }, + } + + for _, idx := range indexes { + done, ready, err := catalog.AddIndex(idx) + require.NoError(err) + close(done) + <-ready + } + + a := NewDefault(catalog) + + used := make(map[sql.Expression]struct{}) + exprs := []sql.Expression{ + eq( + col(2, "t2", "bar"), + lit(2), + ), + eq( + col(2, "t2", "foo"), + lit(1), + ), + eq( + lit(3), + col(2, "t2", "baz"), + ), + eq( + col(2, "t3", "foo"), + lit(4), + ), + eq( + col(2, "t1", "foo"), + lit(5), + ), + eq( + col(2, "t1", "bar"), + lit(6), + ), + expression.NewBetween( + col(2, "t4", "bar"), + lit(2), + lit(5), + ), + expression.NewBetween( + col(2, "t4", "foo"), + lit(1), + lit(6), + ), + } + result, err := getMultiColumnIndexes(exprs, a, used, nil) + require.NoError(err) + + expected := map[string]*indexLookup{ + "t1": &indexLookup{ + &mergeableIndexLookup{id: "5, 6"}, + []sql.Index{indexes[0]}, + }, + "t2": &indexLookup{ + &mergeableIndexLookup{id: "1, 2, 3"}, + []sql.Index{indexes[1]}, + }, + "t4": &indexLookup{ + &mergedIndexLookup{[]sql.IndexLookup{ + &ascendIndexLookup{ + gte: []interface{}{int64(1), int64(2)}, + lt: []interface{}{int64(6), int64(5)}, + }, + &descendIndexLookup{ + gt: []interface{}{int64(1), int64(2)}, + lte: []interface{}{int64(6), int64(5)}, + }, + }}, + []sql.Index{indexes[4]}, + }, + } + + require.Equal(expected, result) + + expectedUsed := map[sql.Expression]struct{}{ + exprs[0]: struct{}{}, + exprs[1]: struct{}{}, + exprs[2]: struct{}{}, + exprs[4]: struct{}{}, + exprs[5]: struct{}{}, + exprs[6]: struct{}{}, + exprs[7]: struct{}{}, + } + require.Equal(expectedUsed, used) +} + +func TestContainsSources(t *testing.T) { + testCases := []struct { + name string + haystack []string + needle []string + expected bool + }{ + { + "needle is in haystack", + []string{"a", "b", "c"}, + []string{"c", "b"}, + true, + }, + { + "needle is not in haystack", + []string{"a", "b", "c"}, + []string{"d", "b"}, + false, + }, + { + "no elements in needle", + []string{"a", "b", "c"}, + nil, + true, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require.Equal( + t, + containsSources(tt.haystack, tt.needle), + tt.expected, + ) + }) + } +} + +func TestNodeSources(t *testing.T) { + sources := nodeSources( + plan.NewResolvedTable( + memory.NewTable("foo", sql.Schema{ + {Source: "foo"}, + {Source: "foo"}, + {Source: "bar"}, + {Source: "baz"}, + }), + ), + ) + + expected := []string{"foo", "bar", "baz"} + require.Equal(t, expected, sources) +} + +func TestExpressionSources(t *testing.T) { + sources := expressionSources(expression.JoinAnd( + col(0, "foo", "bar"), + col(0, "foo", "qux"), + and( + eq( + col(0, "bar", "baz"), + lit(1), + ), + eq( + col(0, "baz", "baz"), + lit(2), + ), + ), + )) + + expected := []string{"foo", "bar", "baz"} + require.Equal(t, expected, sources) +} + +type dummyIndexLookup struct{} + +func (dummyIndexLookup) Indexes() []string { return nil } + +func (dummyIndexLookup) Values(sql.Partition) (sql.IndexValueIter, error) { + return nil, nil +} + +type dummyIndex struct { + table string + expr []sql.Expression +} + +var _ sql.Index = (*dummyIndex)(nil) +var _ sql.AscendIndex = (*dummyIndex)(nil) +var _ sql.DescendIndex = (*dummyIndex)(nil) +var _ sql.NegateIndex = (*dummyIndex)(nil) + +func (dummyIndex) Database() string { return "" } +func (dummyIndex) Driver() string { return "" } +func (i dummyIndex) Expressions() []string { + var exprs []string + for _, e := range i.expr { + exprs = append(exprs, e.String()) + } + return exprs +} + +func (i dummyIndex) AscendGreaterOrEqual(keys ...interface{}) (sql.IndexLookup, error) { + return &ascendIndexLookup{gte: keys}, nil +} + +func (i dummyIndex) AscendLessThan(keys ...interface{}) (sql.IndexLookup, error) { + return &ascendIndexLookup{lt: keys}, nil +} + +func (i dummyIndex) AscendRange(greaterOrEqual, lessThan []interface{}) (sql.IndexLookup, error) { + return &ascendIndexLookup{gte: greaterOrEqual, lt: lessThan}, nil +} + +func (i dummyIndex) DescendGreater(keys ...interface{}) (sql.IndexLookup, error) { + return &descendIndexLookup{gt: keys}, nil +} + +func (i dummyIndex) DescendLessOrEqual(keys ...interface{}) (sql.IndexLookup, error) { + return &descendIndexLookup{lte: keys}, nil +} + +func (i dummyIndex) DescendRange(lessOrEqual, greaterThan []interface{}) (sql.IndexLookup, error) { + return &descendIndexLookup{gt: greaterThan, lte: lessOrEqual}, nil +} + +func (i dummyIndex) Not(keys ...interface{}) (sql.IndexLookup, error) { + lookup, err := i.Get(keys...) + if err != nil { + return nil, err + } + + mergeable, _ := lookup.(*mergeableIndexLookup) + return &negateIndexLookup{value: mergeable.id}, nil +} + +func (i dummyIndex) Get(key ...interface{}) (sql.IndexLookup, error) { + if len(key) != 1 { + var parts = make([]string, len(key)) + for i, p := range key { + parts[i] = fmt.Sprint(p) + } + + return &mergeableIndexLookup{id: strings.Join(parts, ", ")}, nil + } + + return &mergeableIndexLookup{id: fmt.Sprint(key[0])}, nil +} +func (i dummyIndex) Has(sql.Partition, ...interface{}) (bool, error) { + panic("not implemented") +} +func (i dummyIndex) ID() string { + if len(i.expr) == 1 { + return i.expr[0].String() + } + var parts = make([]string, len(i.expr)) + for i, e := range i.expr { + parts[i] = e.String() + } + + return "(" + strings.Join(parts, ", ") + ")" +} +func (i dummyIndex) Table() string { return i.table } + +type mergedIndexLookup struct { + children []sql.IndexLookup +} + +func (mergedIndexLookup) Values(sql.Partition) (sql.IndexValueIter, error) { + panic("mergedIndexLookup.Values is a placeholder") +} + +func (i *mergedIndexLookup) Indexes() []string { + var indexes []string + for _, c := range i.children { + indexes = append(indexes, c.Indexes()...) + } + return indexes +} + +func (i *mergedIndexLookup) IsMergeable(sql.IndexLookup) bool { + return true +} + +func (i *mergedIndexLookup) Union(lookups ...sql.IndexLookup) sql.IndexLookup { + return &mergedIndexLookup{append(i.children, lookups...)} +} + +func (mergedIndexLookup) Difference(...sql.IndexLookup) sql.IndexLookup { + panic("mergedIndexLookup.Difference is not implemented") +} + +func (mergedIndexLookup) Intersection(...sql.IndexLookup) sql.IndexLookup { + panic("mergedIndexLookup.Intersection is not implemented") +} + +type negateIndexLookup struct { + value string + intersections []string + unions []string +} + +func (l *negateIndexLookup) ID() string { return "not " + l.value } +func (l *negateIndexLookup) Unions() []string { return l.unions } +func (l *negateIndexLookup) Intersections() []string { return l.intersections } + +func (*negateIndexLookup) Values(sql.Partition) (sql.IndexValueIter, error) { + panic("negateIndexLookup.Values is a placeholder") +} + +func (l *negateIndexLookup) Indexes() []string { + return []string{l.ID()} +} + +func (*negateIndexLookup) IsMergeable(sql.IndexLookup) bool { + return true +} + +func (l *negateIndexLookup) Union(lookups ...sql.IndexLookup) sql.IndexLookup { + return &mergedIndexLookup{append([]sql.IndexLookup{l}, lookups...)} +} + +func (*negateIndexLookup) Difference(...sql.IndexLookup) sql.IndexLookup { + panic("negateIndexLookup.Difference is not implemented") +} + +func (l *negateIndexLookup) Intersection(indexes ...sql.IndexLookup) sql.IndexLookup { + var intersections, unions []string + for _, idx := range indexes { + intersections = append(intersections, idx.(mergeableLookup).ID()) + intersections = append(intersections, idx.(mergeableLookup).Intersections()...) + unions = append(unions, idx.(mergeableLookup).Unions()...) + } + return &mergeableIndexLookup{ + l.ID(), + append(l.unions, unions...), + append(l.intersections, intersections...), + } +} + +type ascendIndexLookup struct { + id string + gte []interface{} + lt []interface{} +} + +func (ascendIndexLookup) Values(sql.Partition) (sql.IndexValueIter, error) { + panic("ascendIndexLookup.Values is a placeholder") +} + +func (l *ascendIndexLookup) Indexes() []string { + return []string{l.id} +} + +func (l *ascendIndexLookup) IsMergeable(sql.IndexLookup) bool { + return true +} + +func (l *ascendIndexLookup) Union(lookups ...sql.IndexLookup) sql.IndexLookup { + return &mergedIndexLookup{append([]sql.IndexLookup{l}, lookups...)} +} + +func (ascendIndexLookup) Difference(...sql.IndexLookup) sql.IndexLookup { + panic("ascendIndexLookup.Difference is not implemented") +} + +func (ascendIndexLookup) Intersection(...sql.IndexLookup) sql.IndexLookup { + panic("ascendIndexLookup.Intersection is not implemented") +} + +type descendIndexLookup struct { + id string + gt []interface{} + lte []interface{} +} + +func (descendIndexLookup) Values(sql.Partition) (sql.IndexValueIter, error) { + panic("descendIndexLookup.Values is a placeholder") +} + +func (l *descendIndexLookup) Indexes() []string { + return []string{l.id} +} + +func (l *descendIndexLookup) IsMergeable(sql.IndexLookup) bool { + return true +} + +func (l *descendIndexLookup) Union(lookups ...sql.IndexLookup) sql.IndexLookup { + return &mergedIndexLookup{append([]sql.IndexLookup{l}, lookups...)} +} + +func (descendIndexLookup) Difference(...sql.IndexLookup) sql.IndexLookup { + panic("descendIndexLookup.Difference is not implemented") +} + +func (descendIndexLookup) Intersection(...sql.IndexLookup) sql.IndexLookup { + panic("descendIndexLookup.Intersection is not implemented") +} + +func TestIndexesIntersection(t *testing.T) { + require := require.New(t) + + idx1, idx2 := &dummyIndex{table: "bar"}, &dummyIndex{table: "foo"} + + left := map[string]*indexLookup{ + "a": &indexLookup{&mergeableIndexLookup{id: "a"}, nil}, + "b": &indexLookup{&mergeableIndexLookup{id: "b"}, []sql.Index{idx1}}, + "c": &indexLookup{new(dummyIndexLookup), nil}, + } + + right := map[string]*indexLookup{ + "b": &indexLookup{&mergeableIndexLookup{id: "b2"}, []sql.Index{idx2}}, + "c": &indexLookup{&mergeableIndexLookup{id: "c"}, nil}, + "d": &indexLookup{&mergeableIndexLookup{id: "d"}, nil}, + } + + require.Equal( + map[string]*indexLookup{ + "a": &indexLookup{&mergeableIndexLookup{id: "a"}, nil}, + "b": &indexLookup{ + &mergeableIndexLookup{ + id: "b", + intersections: []string{"b2"}, + }, + []sql.Index{idx1, idx2}, + }, + "c": &indexLookup{new(dummyIndexLookup), nil}, + "d": &indexLookup{&mergeableIndexLookup{id: "d"}, nil}, + }, + indexesIntersection(NewDefault(sql.NewCatalog()), left, right), + ) +} + +func TestCanMergeIndexes(t *testing.T) { + require := require.New(t) + + require.False(canMergeIndexes(new(mergeableIndexLookup), new(dummyIndexLookup))) + require.True(canMergeIndexes(new(mergeableIndexLookup), new(mergeableIndexLookup))) +} + +type mergeableLookup interface { + ID() string + Unions() []string + Intersections() []string +} + +type mergeableIndexLookup struct { + id string + unions []string + intersections []string +} + +var _ sql.Mergeable = (*mergeableIndexLookup)(nil) +var _ sql.SetOperations = (*mergeableIndexLookup)(nil) + +func (i *mergeableIndexLookup) ID() string { return i.id } +func (i *mergeableIndexLookup) Unions() []string { return i.unions } +func (i *mergeableIndexLookup) Intersections() []string { return i.intersections } + +func (i *mergeableIndexLookup) IsMergeable(lookup sql.IndexLookup) bool { + _, ok := lookup.(mergeableLookup) + return ok +} + +func (i *mergeableIndexLookup) Values(sql.Partition) (sql.IndexValueIter, error) { + panic("not implemented") +} + +func (i *mergeableIndexLookup) Indexes() []string { + return []string{i.ID()} +} + +func (i *mergeableIndexLookup) Difference(indexes ...sql.IndexLookup) sql.IndexLookup { + panic("not implemented") +} + +func (i *mergeableIndexLookup) Intersection(indexes ...sql.IndexLookup) sql.IndexLookup { + var intersections, unions []string + for _, idx := range indexes { + intersections = append(intersections, idx.(mergeableLookup).ID()) + intersections = append(intersections, idx.(mergeableLookup).Intersections()...) + unions = append(unions, idx.(mergeableLookup).Unions()...) + } + return &mergeableIndexLookup{ + i.id, + append(i.unions, unions...), + append(i.intersections, intersections...), + } +} + +func (i *mergeableIndexLookup) Union(indexes ...sql.IndexLookup) sql.IndexLookup { + var intersections, unions []string + for _, idx := range indexes { + unions = append(unions, idx.(*mergeableIndexLookup).id) + unions = append(unions, idx.(*mergeableIndexLookup).unions...) + intersections = append(intersections, idx.(*mergeableIndexLookup).intersections...) + } + return &mergeableIndexLookup{ + i.id, + append(i.unions, unions...), + append(i.intersections, intersections...), + } +} diff --git a/sql/analyzer/batch.go b/sql/analyzer/batch.go new file mode 100644 index 000000000..8153e32c7 --- /dev/null +++ b/sql/analyzer/batch.go @@ -0,0 +1,86 @@ +package analyzer + +import ( + "reflect" + + "github.com/src-d/go-mysql-server/sql" +) + +// RuleFunc is the function to be applied in a rule. +type RuleFunc func(*sql.Context, *Analyzer, sql.Node) (sql.Node, error) + +// Rule to transform nodes. +type Rule struct { + // Name of the rule. + Name string + // Apply transforms a node. + Apply RuleFunc +} + +// Batch executes a set of rules a specific number of times. +// When this number of times is reached, the actual node +// and ErrMaxAnalysisIters is returned. +type Batch struct { + Desc string + Iterations int + Rules []Rule +} + +// Eval executes the actual rules the specified number of times on the Batch. +// If max number of iterations is reached, this method will return the actual +// processed Node and ErrMaxAnalysisIters error. +func (b *Batch) Eval(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + if b.Iterations == 0 { + return n, nil + } + + prev := n + cur, err := b.evalOnce(ctx, a, n) + if err != nil { + return nil, err + } + + if b.Iterations == 1 { + return cur, nil + } + + for i := 1; !nodesEqual(prev, cur); { + prev = cur + cur, err = b.evalOnce(ctx, a, cur) + if err != nil { + return nil, err + } + + i++ + if i >= b.Iterations { + return cur, ErrMaxAnalysisIters.New(b.Iterations) + } + } + + return cur, nil +} + +func (b *Batch) evalOnce(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + result := n + for _, rule := range b.Rules { + var err error + result, err = rule.Apply(ctx, a, result) + if err != nil { + return nil, err + } + } + + return result, nil +} + +func nodesEqual(a, b sql.Node) bool { + if e, ok := a.(equaler); ok { + return e.Equal(b) + } + + if e, ok := b.(equaler); ok { + return e.Equal(a) + } + + return reflect.DeepEqual(a, b) +} diff --git a/sql/analyzer/common_test.go b/sql/analyzer/common_test.go new file mode 100644 index 000000000..a5beb26c9 --- /dev/null +++ b/sql/analyzer/common_test.go @@ -0,0 +1,73 @@ +package analyzer + +import ( + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" +) + +func not(e sql.Expression) sql.Expression { + return expression.NewNot(e) +} + +func gt(left, right sql.Expression) sql.Expression { + return expression.NewGreaterThan(left, right) +} + +func gte(left, right sql.Expression) sql.Expression { + return expression.NewGreaterThanOrEqual(left, right) +} + +func lt(left, right sql.Expression) sql.Expression { + return expression.NewLessThan(left, right) +} + +func lte(left, right sql.Expression) sql.Expression { + return expression.NewLessThanOrEqual(left, right) +} + +func or(left, right sql.Expression) sql.Expression { + return expression.NewOr(left, right) +} + +func and(left, right sql.Expression) sql.Expression { + return expression.NewAnd(left, right) +} + +func col(idx int, table, col string) sql.Expression { + return expression.NewGetFieldWithTable(idx, sql.Int64, table, col, false) +} + +func eq(left, right sql.Expression) sql.Expression { + return expression.NewEquals(left, right) +} + +func lit(n int64) sql.Expression { + return expression.NewLiteral(n, sql.Int64) +} + +var analyzeRules = [][]Rule{ + OnceBeforeDefault, + DefaultRules, + OnceAfterDefault, +} + +func getRule(name string) Rule { + for _, rules := range analyzeRules { + rule := getRuleFrom(rules, name) + if rule != nil { + return *rule + } + } + + panic("missing rule") +} + +func getRuleFrom(rules []Rule, name string) *Rule { + for _, rule := range rules { + if rule.Name == name { + return &rule + } + } + + return nil +} diff --git a/sql/analyzer/convert_dates.go b/sql/analyzer/convert_dates.go new file mode 100644 index 000000000..c7af55431 --- /dev/null +++ b/sql/analyzer/convert_dates.go @@ -0,0 +1,186 @@ +package analyzer + +import ( + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/expression/function" + "github.com/src-d/go-mysql-server/sql/expression/function/aggregation" + "github.com/src-d/go-mysql-server/sql/plan" +) + +// convertDates wraps all expressions of date and datetime type with converts +// to ensure the date range is validated. +func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + if !n.Resolved() { + return n, nil + } + + // Replacements contains a mapping from columns to the alias they will be + // replaced by. + var replacements = make(map[tableCol]string) + + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { + exp, ok := n.(sql.Expressioner) + if !ok { + return n, nil + } + + // nodeReplacements are all the replacements found in the current node. + // These replacements are not applied to the current node, only to + // parent nodes. + var nodeReplacements = make(map[tableCol]string) + + var expressions = make(map[string]bool) + switch exp := exp.(type) { + case *plan.Project: + for _, e := range exp.Projections { + expressions[e.String()] = true + } + case *plan.GroupBy: + for _, e := range exp.Aggregate { + expressions[e.String()] = true + } + } + + var result sql.Node + var err error + switch exp := exp.(type) { + case *plan.GroupBy: + var aggregate = make([]sql.Expression, len(exp.Aggregate)) + for i, a := range exp.Aggregate { + agg, err := expression.TransformUp(a, func(e sql.Expression) (sql.Expression, error) { + return addDateConvert(e, exp, replacements, nodeReplacements, expressions, true) + }) + if err != nil { + return nil, err + } + + aggregate[i] = agg + + if _, ok := agg.(*expression.Alias); !ok && agg.String() != a.String() { + nodeReplacements[tableCol{"", a.String()}] = agg.String() + } + } + + var grouping = make([]sql.Expression, len(exp.Grouping)) + for i, g := range exp.Grouping { + gr, err := expression.TransformUp(g, func(e sql.Expression) (sql.Expression, error) { + return addDateConvert(e, exp, replacements, nodeReplacements, expressions, false) + }) + if err != nil { + return nil, err + } + grouping[i] = gr + } + + result = plan.NewGroupBy(aggregate, grouping, exp.Child) + case *plan.Project: + var projections = make([]sql.Expression, len(exp.Projections)) + for i, e := range exp.Projections { + expr, err := expression.TransformUp(e, func(e sql.Expression) (sql.Expression, error) { + return addDateConvert(e, exp, replacements, nodeReplacements, expressions, true) + }) + if err != nil { + return nil, err + } + + projections[i] = expr + + if _, ok := expr.(*expression.Alias); !ok && expr.String() != e.String() { + nodeReplacements[tableCol{"", e.String()}] = expr.String() + } + } + + result = plan.NewProject(projections, exp.Child) + default: + result, err = plan.TransformExpressions(n, func(e sql.Expression) (sql.Expression, error) { + return addDateConvert(e, n, replacements, nodeReplacements, expressions, false) + }) + } + + if err != nil { + return nil, err + } + + // We're done with this node, so copy all the replacements found in + // this node to the global replacements in order to make the necesssary + // changes in parent nodes. + for tc, n := range nodeReplacements { + replacements[tc] = n + } + + return result, err + }) +} + +func addDateConvert( + e sql.Expression, + node sql.Node, + replacements, nodeReplacements map[tableCol]string, + expressions map[string]bool, + aliasRootProjections bool, +) (sql.Expression, error) { + var result sql.Expression + + // No need to wrap expressions that already validate times, such as + // convert, date_add, etc and those expressions whose Type method + // cannot be called because they are placeholders. + switch e := e.(type) { + case *aggregation.Max: + child, err := addDateConvert(e.Child, node, replacements, nodeReplacements, expressions, false) + if err != nil { + return nil, err + } + + return aggregation.NewMax(child), nil + case *aggregation.Min: + child, err := addDateConvert(e.Child, node, replacements, nodeReplacements, expressions, false) + if err != nil { + return nil, err + } + + return aggregation.NewMin(child), nil + case *expression.Convert, + *expression.Arithmetic, + *function.DateAdd, + *function.DateSub, + *expression.Star, + *expression.DefaultColumn, + *expression.Alias: + return e, nil + default: + // If it's a replacement, just replace it with the correct GetField + // because we know that it's already converted to a correct date + // and there is no point to do so again. + if gf, ok := e.(*expression.GetField); ok { + if name, ok := replacements[tableCol{gf.Table(), gf.Name()}]; ok { + return expression.NewGetField(gf.Index(), gf.Type(), name, gf.IsNullable()), nil + } + } + + switch e.Type() { + case sql.Date: + result = expression.NewConvert(e, expression.ConvertToDate) + case sql.Timestamp: + result = expression.NewConvert(e, expression.ConvertToDatetime) + default: + result = e + } + } + + // Only do this if it's a root expression in a project or group by. + switch node.(type) { + case *plan.Project, *plan.GroupBy: + // If it was originally a GetField, and it's not anymore it's + // because we wrapped it in a convert. We need to make it an alias + // and propagate the changes up the chain. + if gf, ok := e.(*expression.GetField); ok && expressions[e.String()] && aliasRootProjections { + if _, ok := result.(*expression.GetField); !ok { + result = expression.NewAlias(result, gf.Name()) + nodeReplacements[tableCol{gf.Table(), gf.Name()}] = gf.Name() + } + } + } + + return result, nil +} diff --git a/sql/analyzer/convert_dates_test.go b/sql/analyzer/convert_dates_test.go new file mode 100644 index 000000000..72ba53868 --- /dev/null +++ b/sql/analyzer/convert_dates_test.go @@ -0,0 +1,293 @@ +package analyzer + +import ( + "testing" + + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/expression/function" + "github.com/src-d/go-mysql-server/sql/expression/function/aggregation" + "github.com/src-d/go-mysql-server/sql/plan" + "github.com/stretchr/testify/require" +) + +func TestConvertDates(t *testing.T) { + testCases := []struct { + name string + in sql.Expression + out sql.Expression + }{ + { + "arithmetic with dates", + expression.NewPlus(expression.NewLiteral("", sql.Timestamp), expression.NewLiteral("", sql.Timestamp)), + expression.NewPlus( + expression.NewConvert( + expression.NewLiteral("", sql.Timestamp), + expression.ConvertToDatetime, + ), + expression.NewConvert( + expression.NewLiteral("", sql.Timestamp), + expression.ConvertToDatetime, + ), + ), + }, + { + "star", + expression.NewStar(), + expression.NewStar(), + }, + { + "default column", + expression.NewDefaultColumn("foo"), + expression.NewDefaultColumn("foo"), + }, + { + "convert to date", + expression.NewConvert( + expression.NewPlus( + expression.NewLiteral("", sql.Timestamp), + expression.NewLiteral("", sql.Timestamp), + ), + expression.ConvertToDatetime, + ), + expression.NewConvert( + expression.NewPlus( + expression.NewConvert( + expression.NewLiteral("", sql.Timestamp), + expression.ConvertToDatetime, + ), + expression.NewConvert( + expression.NewLiteral("", sql.Timestamp), + expression.ConvertToDatetime, + ), + ), + expression.ConvertToDatetime, + ), + }, + { + "min aggregation", + aggregation.NewMin( + expression.NewGetField(0, sql.Timestamp, "foo", false), + ), + aggregation.NewMin( + expression.NewConvert( + expression.NewGetField(0, sql.Timestamp, "foo", false), + expression.ConvertToDatetime, + ), + ), + }, + { + "max aggregation", + aggregation.NewMax( + expression.NewGetField(0, sql.Timestamp, "foo", false), + ), + aggregation.NewMax( + expression.NewConvert( + expression.NewGetField(0, sql.Timestamp, "foo", false), + expression.ConvertToDatetime, + ), + ), + }, + { + "convert to other type", + expression.NewConvert( + expression.NewLiteral("", sql.Text), + expression.ConvertToBinary, + ), + expression.NewConvert( + expression.NewLiteral("", sql.Text), + expression.ConvertToBinary, + ), + }, + { + "datetime col in alias", + expression.NewAlias( + expression.NewLiteral("", sql.Timestamp), + "foo", + ), + expression.NewAlias( + expression.NewConvert( + expression.NewLiteral("", sql.Timestamp), + expression.ConvertToDatetime, + ), + "foo", + ), + }, + { + "date col in alias", + expression.NewAlias( + expression.NewLiteral("", sql.Date), + "foo", + ), + expression.NewAlias( + expression.NewConvert( + expression.NewLiteral("", sql.Date), + expression.ConvertToDate, + ), + "foo", + ), + }, + { + "date add", + newDateAdd( + expression.NewLiteral("", sql.Timestamp), + expression.NewInterval(expression.NewLiteral(int64(1), sql.Int64), "DAY"), + ), + newDateAdd( + expression.NewConvert( + expression.NewLiteral("", sql.Timestamp), + expression.ConvertToDatetime, + ), + expression.NewInterval(expression.NewLiteral(int64(1), sql.Int64), "DAY"), + ), + }, + { + "date sub", + newDateSub( + expression.NewLiteral("", sql.Timestamp), + expression.NewInterval(expression.NewLiteral(int64(1), sql.Int64), "DAY"), + ), + newDateSub( + expression.NewConvert( + expression.NewLiteral("", sql.Timestamp), + expression.ConvertToDatetime, + ), + expression.NewInterval(expression.NewLiteral(int64(1), sql.Int64), "DAY"), + ), + }, + } + + table := plan.NewResolvedTable(memory.NewTable("t", nil)) + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + input := plan.NewProject([]sql.Expression{tt.in}, table) + expected := plan.NewProject([]sql.Expression{tt.out}, table) + result, err := convertDates(sql.NewEmptyContext(), nil, input) + require.NoError(t, err) + require.Equal(t, expected, result) + }) + } +} + +func TestConvertDatesProject(t *testing.T) { + table := plan.NewResolvedTable(memory.NewTable("t", nil)) + input := plan.NewFilter( + expression.NewEquals( + expression.NewGetField(0, sql.Int64, "foo", false), + expression.NewLiteral("2019-06-06 00:00:00", sql.Text), + ), + plan.NewProject([]sql.Expression{ + expression.NewGetField(0, sql.Timestamp, "foo", false), + }, table), + ) + expected := plan.NewFilter( + expression.NewEquals( + expression.NewGetField(0, sql.Int64, "foo", false), + expression.NewLiteral("2019-06-06 00:00:00", sql.Text), + ), + plan.NewProject([]sql.Expression{ + expression.NewAlias( + expression.NewConvert( + expression.NewGetField(0, sql.Timestamp, "foo", false), + expression.ConvertToDatetime, + ), + "foo", + ), + }, table), + ) + + result, err := convertDates(sql.NewEmptyContext(), nil, input) + require.NoError(t, err) + require.Equal(t, expected, result) +} + +func TestConvertDatesGroupBy(t *testing.T) { + table := plan.NewResolvedTable(memory.NewTable("t", nil)) + input := plan.NewFilter( + expression.NewEquals( + expression.NewGetField(0, sql.Int64, "foo", false), + expression.NewLiteral("2019-06-06 00:00:00", sql.Text), + ), + plan.NewGroupBy( + []sql.Expression{ + expression.NewGetField(0, sql.Timestamp, "foo", false), + }, + []sql.Expression{ + expression.NewGetField(0, sql.Timestamp, "foo", false), + }, table, + ), + ) + expected := plan.NewFilter( + expression.NewEquals( + expression.NewGetField(0, sql.Int64, "foo", false), + expression.NewLiteral("2019-06-06 00:00:00", sql.Text), + ), + plan.NewGroupBy( + []sql.Expression{ + expression.NewAlias( + expression.NewConvert( + expression.NewGetField(0, sql.Timestamp, "foo", false), + expression.ConvertToDatetime, + ), + "foo", + ), + }, + []sql.Expression{ + expression.NewConvert( + expression.NewGetField(0, sql.Timestamp, "foo", false), + expression.ConvertToDatetime, + ), + }, + table, + ), + ) + + result, err := convertDates(sql.NewEmptyContext(), nil, input) + require.NoError(t, err) + require.Equal(t, expected, result) +} + +func TestConvertDatesFieldReference(t *testing.T) { + table := plan.NewResolvedTable(memory.NewTable("t", nil)) + input := plan.NewFilter( + expression.NewEquals( + expression.NewGetField(0, sql.Int64, "DAYOFWEEK(foo)", false), + expression.NewLiteral("2019-06-06 00:00:00", sql.Text), + ), + plan.NewProject([]sql.Expression{ + function.NewDayOfWeek( + expression.NewGetField(0, sql.Timestamp, "foo", false), + ), + }, table), + ) + expected := plan.NewFilter( + expression.NewEquals( + expression.NewGetField(0, sql.Int64, "DAYOFWEEK(convert(foo, datetime))", false), + expression.NewLiteral("2019-06-06 00:00:00", sql.Text), + ), + plan.NewProject([]sql.Expression{ + function.NewDayOfWeek( + expression.NewConvert( + expression.NewGetField(0, sql.Timestamp, "foo", false), + expression.ConvertToDatetime, + ), + ), + }, table), + ) + + result, err := convertDates(sql.NewEmptyContext(), nil, input) + require.NoError(t, err) + require.Equal(t, expected, result) +} + +func newDateAdd(l, r sql.Expression) sql.Expression { + e, _ := function.NewDateAdd(l, r) + return e +} + +func newDateSub(l, r sql.Expression) sql.Expression { + e, _ := function.NewDateSub(l, r) + return e +} diff --git a/sql/analyzer/filters.go b/sql/analyzer/filters.go index 946cd2ba9..bbe54adc7 100644 --- a/sql/analyzer/filters.go +++ b/sql/analyzer/filters.go @@ -3,8 +3,8 @@ package analyzer import ( "reflect" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" ) type filters map[string][]sql.Expression @@ -18,18 +18,22 @@ func (f filters) merge(f2 filters) { func exprToTableFilters(expr sql.Expression) filters { filtersByTable := make(filters) for _, expr := range splitExpression(expr) { - var tables []string - _, _ = expr.TransformUp(func(e sql.Expression) (sql.Expression, error) { + var seenTables = make(map[string]struct{}) + var lastTable string + expression.Inspect(expr, func(e sql.Expression) bool { f, ok := e.(*expression.GetField) if ok { - tables = append(tables, f.Table()) + if _, ok := seenTables[f.Table()]; !ok { + seenTables[f.Table()] = struct{}{} + lastTable = f.Table() + } } - return e, nil + return true }) - if len(tables) == 1 { - filtersByTable[tables[0]] = append(filtersByTable[tables[0]], expr) + if len(seenTables) == 1 { + filtersByTable[lastTable] = append(filtersByTable[lastTable], expr) } } diff --git a/sql/analyzer/filters_test.go b/sql/analyzer/filters_test.go index 54d5cb360..999b95a6a 100644 --- a/sql/analyzer/filters_test.go +++ b/sql/analyzer/filters_test.go @@ -3,9 +3,9 @@ package analyzer import ( "testing" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" ) func TestFiltersMerge(t *testing.T) { @@ -108,6 +108,21 @@ func TestExprToTableFilters(t *testing.T) { require := require.New(t) expr := expression.NewAnd( expression.NewAnd( + expression.NewAnd( + expression.NewEquals( + expression.NewGetFieldWithTable(0, sql.Int64, "mytable", "f", false), + expression.NewLiteral(3.14, sql.Float64), + ), + expression.NewGreaterThan( + expression.NewGetFieldWithTable(0, sql.Int64, "mytable", "f", false), + expression.NewLiteral(3., sql.Float64), + ), + ), + expression.NewIsNull( + expression.NewGetFieldWithTable(0, sql.Int64, "mytable2", "i2", false), + ), + ), + expression.NewOr( expression.NewEquals( expression.NewGetFieldWithTable(0, sql.Int64, "mytable", "f", false), expression.NewLiteral(3.14, sql.Float64), @@ -117,9 +132,6 @@ func TestExprToTableFilters(t *testing.T) { expression.NewLiteral(3., sql.Float64), ), ), - expression.NewIsNull( - expression.NewGetFieldWithTable(0, sql.Int64, "mytable2", "i2", false), - ), ) expected := filters{ @@ -132,6 +144,16 @@ func TestExprToTableFilters(t *testing.T) { expression.NewGetFieldWithTable(0, sql.Int64, "mytable", "f", false), expression.NewLiteral(3., sql.Float64), ), + expression.NewOr( + expression.NewEquals( + expression.NewGetFieldWithTable(0, sql.Int64, "mytable", "f", false), + expression.NewLiteral(3.14, sql.Float64), + ), + expression.NewGreaterThan( + expression.NewGetFieldWithTable(0, sql.Int64, "mytable", "f", false), + expression.NewLiteral(3., sql.Float64), + ), + ), }, "mytable2": []sql.Expression{ expression.NewIsNull( diff --git a/sql/analyzer/optimization_rules.go b/sql/analyzer/optimization_rules.go new file mode 100644 index 000000000..88283cf9f --- /dev/null +++ b/sql/analyzer/optimization_rules.go @@ -0,0 +1,443 @@ +package analyzer + +import ( + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/plan" + "gopkg.in/src-d/go-errors.v1" +) + +func eraseProjection(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, error) { + span, _ := ctx.Span("erase_projection") + defer span.Finish() + + if !node.Resolved() { + return node, nil + } + + a.Log("erase projection, node of type: %T", node) + + return plan.TransformUp(node, func(node sql.Node) (sql.Node, error) { + project, ok := node.(*plan.Project) + if ok && project.Schema().Equals(project.Child.Schema()) { + a.Log("project erased") + return project.Child, nil + } + + return node, nil + }) +} + +func optimizeDistinct(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, error) { + span, _ := ctx.Span("optimize_distinct") + defer span.Finish() + + a.Log("optimize distinct, node of type: %T", node) + if n, ok := node.(*plan.Distinct); ok { + var sortField *expression.GetField + plan.Inspect(n, func(node sql.Node) bool { + a.Log("checking for optimization in node of type: %T", node) + if sort, ok := node.(*plan.Sort); ok && sortField == nil { + if col, ok := sort.SortFields[0].Column.(*expression.GetField); ok { + sortField = col + } + return false + } + return true + }) + + if sortField != nil && n.Schema().Contains(sortField.Name(), sortField.Table()) { + a.Log("distinct optimized for ordered output") + return plan.NewOrderedDistinct(n.Child), nil + } + } + + return node, nil +} + +var errInvalidNodeType = errors.NewKind("reorder projection: invalid node of type: %T") + +func reorderProjection(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + span, ctx := ctx.Span("reorder_projection") + defer span.Finish() + + if n.Resolved() { + return n, nil + } + + a.Log("reorder projection, node of type: %T", n) + + // Then we transform the projection + return plan.TransformUp(n, func(node sql.Node) (sql.Node, error) { + project, ok := node.(*plan.Project) + // When we transform the projection, the children will always be + // unresolved in the case we want to fix, as the reorder happens just + // so some columns can be resolved. + // For that, we need to account for NaturalJoin, whose schema can't be + // obtained until it's resolved and ignore the projection for the + // moment until the resolve_natural_joins has finished resolving the + // node and we can tackle it in the next iteration. + // Without this check, it would cause a panic, because NaturalJoin's + // schema method is just a placeholder that should not be called. + if !ok || hasNaturalJoin(project.Child) { + return node, nil + } + + // We must find all columns that may need to be moved inside the + // projection. + var newColumns = make(map[string]sql.Expression) + for _, col := range project.Projections { + alias, ok := col.(*expression.Alias) + if ok { + newColumns[alias.Name()] = col + } + } + + // And add projection nodes where needed in the child tree. + var didNeedReorder bool + child, err := plan.TransformUp(project.Child, func(node sql.Node) (sql.Node, error) { + var requiredColumns []string + switch node := node.(type) { + case *plan.Sort, *plan.Filter: + for _, expr := range node.(sql.Expressioner).Expressions() { + expression.Inspect(expr, func(e sql.Expression) bool { + if e != nil && e.Resolved() { + return true + } + + uc, ok := e.(column) + if ok && uc.Table() == "" { + if _, ok := newColumns[uc.Name()]; ok { + requiredColumns = append(requiredColumns, uc.Name()) + } + } + + return true + }) + } + default: + return node, nil + } + + if len(requiredColumns) == 0 { + return node, nil + } + + didNeedReorder = true + + // Only add the required columns for that node in the projection. + child := node.Children()[0] + schema := child.Schema() + var projections = make([]sql.Expression, 0, len(schema)+len(requiredColumns)) + for i, col := range schema { + projections = append(projections, expression.NewGetFieldWithTable( + i, col.Type, col.Source, col.Name, col.Nullable, + )) + } + + for _, col := range requiredColumns { + if c, ok := newColumns[col]; ok { + projections = append(projections, c) + delete(newColumns, col) + } + } + + child = plan.NewProject(projections, child) + switch node := node.(type) { + case *plan.Filter: + return plan.NewFilter(node.Expression, child), nil + case *plan.Sort: + return plan.NewSort(node.SortFields, child), nil + default: + return nil, errInvalidNodeType.New(node) + } + }) + + if err != nil { + return nil, err + } + + if !didNeedReorder { + return project, nil + } + + child, err = resolveColumns(ctx, a, child) + if err != nil { + return nil, err + } + + childSchema := child.Schema() + // Finally, replace the columns we moved with GetFields since they + // have already been projected. + var projections = make([]sql.Expression, len(project.Projections)) + for i, p := range project.Projections { + if alias, ok := p.(*expression.Alias); ok { + var found bool + for idx, col := range childSchema { + if col.Name == alias.Name() { + projections[i] = expression.NewGetField( + idx, col.Type, col.Name, col.Nullable, + ) + found = true + break + } + } + + if !found { + projections[i] = p + } + } else { + projections[i] = p + } + } + + return plan.NewProject(projections, child), nil + }) +} + +func moveJoinConditionsToFilter(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + if !n.Resolved() { + a.Log("node is not resolved, skip moving join conditions to filter") + return n, nil + } + + a.Log("moving join conditions to filter, node of type: %T", n) + + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { + join, ok := n.(*plan.InnerJoin) + if !ok { + return n, nil + } + + leftSources := nodeSources(join.Left) + rightSources := nodeSources(join.Right) + var leftFilters, rightFilters, condFilters []sql.Expression + for _, e := range splitExpression(join.Cond) { + sources := expressionSources(e) + + canMoveLeft := containsSources(leftSources, sources) + if canMoveLeft { + leftFilters = append(leftFilters, e) + } + + canMoveRight := containsSources(rightSources, sources) + if canMoveRight { + rightFilters = append(rightFilters, e) + } + + if !canMoveLeft && !canMoveRight { + condFilters = append(condFilters, e) + } + } + + var left, right sql.Node = join.Left, join.Right + if len(leftFilters) > 0 { + leftFilters, err := fixFieldIndexes(left.Schema(), expression.JoinAnd(leftFilters...)) + if err != nil { + return nil, err + } + + left = plan.NewFilter(leftFilters, left) + } + + if len(rightFilters) > 0 { + rightFilters, err := fixFieldIndexes(right.Schema(), expression.JoinAnd(rightFilters...)) + if err != nil { + return nil, err + } + + right = plan.NewFilter(rightFilters, right) + } + + if len(condFilters) > 0 { + return plan.NewInnerJoin( + left, right, + expression.JoinAnd(condFilters...), + ), nil + } + + // if there are no cond filters left we can just convert it to a cross join + return plan.NewCrossJoin(left, right), nil + }) +} + +func removeUnnecessaryConverts(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + span, _ := ctx.Span("remove_unnecessary_converts") + defer span.Finish() + + if !n.Resolved() { + return n, nil + } + + a.Log("removing unnecessary converts, node of type: %T", n) + + return plan.TransformExpressionsUp(n, func(e sql.Expression) (sql.Expression, error) { + if c, ok := e.(*expression.Convert); ok && c.Child.Type() == c.Type() { + return c.Child, nil + } + + return e, nil + }) +} + +// containsSources checks that all `needle` sources are contained inside `haystack`. +func containsSources(haystack, needle []string) bool { + for _, s := range needle { + var found bool + for _, s2 := range haystack { + if s2 == s { + found = true + break + } + } + + if !found { + return false + } + } + + return true +} + +func nodeSources(node sql.Node) []string { + var sources = make(map[string]struct{}) + var result []string + + for _, col := range node.Schema() { + if _, ok := sources[col.Source]; !ok { + sources[col.Source] = struct{}{} + result = append(result, col.Source) + } + } + + return result +} + +func expressionSources(expr sql.Expression) []string { + var sources = make(map[string]struct{}) + var result []string + + expression.Inspect(expr, func(expr sql.Expression) bool { + f, ok := expr.(*expression.GetField) + if ok { + if _, ok := sources[f.Table()]; !ok { + sources[f.Table()] = struct{}{} + result = append(result, f.Table()) + } + } + + return true + }) + + return result +} + +func evalFilter(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, error) { + if !node.Resolved() { + return node, nil + } + + a.Log("evaluating filters, node of type: %T", node) + + return plan.TransformUp(node, func(node sql.Node) (sql.Node, error) { + filter, ok := node.(*plan.Filter) + if !ok { + return node, nil + } + + e, err := expression.TransformUp(filter.Expression, func(e sql.Expression) (sql.Expression, error) { + switch e := e.(type) { + case *expression.Or: + if isTrue(e.Left) { + return e.Left, nil + } + + if isTrue(e.Right) { + return e.Right, nil + } + + if isFalse(e.Left) { + return e.Right, nil + } + + if isFalse(e.Right) { + return e.Left, nil + } + + return e, nil + case *expression.And: + if isFalse(e.Left) { + return e.Left, nil + } + + if isFalse(e.Right) { + return e.Right, nil + } + + if isTrue(e.Left) { + return e.Right, nil + } + + if isTrue(e.Right) { + return e.Left, nil + } + + return e, nil + case *expression.Literal, expression.Tuple: + return e, nil + default: + if !isEvaluable(e) { + return e, nil + } + + // All other expressions types can be evaluated once and turned into literals for the rest of query execution + val, err := e.Eval(ctx, nil) + if err != nil { + return e, nil + } + return expression.NewLiteral(val, e.Type()), nil + } + }) + if err != nil { + return nil, err + } + + if isFalse(e) { + return plan.EmptyTable, nil + } + + if isTrue(e) { + return filter.Child, nil + } + + return plan.NewFilter(e, filter.Child), nil + }) +} + +func isFalse(e sql.Expression) bool { + lit, ok := e.(*expression.Literal) + return ok && + lit.Type() == sql.Boolean && + !lit.Value().(bool) +} + +func isTrue(e sql.Expression) bool { + lit, ok := e.(*expression.Literal) + return ok && + lit.Type() == sql.Boolean && + lit.Value().(bool) +} + +// hasNaturalJoin checks whether there is a natural join at some point in the +// given node and its children. +func hasNaturalJoin(node sql.Node) bool { + var found bool + plan.Inspect(node, func(node sql.Node) bool { + if _, ok := node.(*plan.NaturalJoin); ok { + found = true + return false + } + return true + }) + return found +} diff --git a/sql/analyzer/optimization_rules_test.go b/sql/analyzer/optimization_rules_test.go new file mode 100644 index 000000000..60dee3b0f --- /dev/null +++ b/sql/analyzer/optimization_rules_test.go @@ -0,0 +1,475 @@ +package analyzer + +import ( + "testing" + + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/plan" + "github.com/stretchr/testify/require" +) + +func TestReorderProjection(t *testing.T) { + f := getRule("reorder_projection") + + table := memory.NewTable("mytable", sql.Schema{{ + Name: "i", Source: "mytable", Type: sql.Int64, + }}) + + testCases := []struct { + name string + project sql.Node + expected sql.Node + }{ + { + "sort", + plan.NewProject( + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "mytable", "i", false), + expression.NewAlias(expression.NewLiteral(1, sql.Int64), "foo"), + expression.NewAlias(expression.NewLiteral(2, sql.Int64), "bar"), + }, + plan.NewSort( + []plan.SortField{ + {Column: expression.NewUnresolvedColumn("foo")}, + }, + plan.NewFilter( + expression.NewEquals( + expression.NewLiteral(1, sql.Int64), + expression.NewUnresolvedColumn("bar"), + ), + plan.NewResolvedTable(table), + ), + ), + ), + plan.NewProject( + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "mytable", "i", false), + expression.NewGetField(2, sql.Int64, "foo", false), + expression.NewGetField(1, sql.Int64, "bar", false), + }, + plan.NewSort( + []plan.SortField{{Column: expression.NewGetField(2, sql.Int64, "foo", false)}}, + plan.NewProject( + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "mytable", "i", false), + expression.NewGetField(1, sql.Int64, "bar", false), + expression.NewAlias(expression.NewLiteral(1, sql.Int64), "foo"), + }, + plan.NewFilter( + expression.NewEquals( + expression.NewLiteral(1, sql.Int64), + expression.NewGetField(1, sql.Int64, "bar", false), + ), + plan.NewProject( + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "mytable", "i", false), + expression.NewAlias(expression.NewLiteral(2, sql.Int64), "bar"), + }, + plan.NewResolvedTable(table), + ), + ), + ), + ), + ), + }, + { + "use alias twice", + plan.NewProject( + []sql.Expression{ + expression.NewAlias(expression.NewLiteral(1, sql.Int64), "foo"), + }, + plan.NewFilter( + expression.NewOr( + expression.NewEquals( + expression.NewLiteral(1, sql.Int64), + expression.NewUnresolvedColumn("foo"), + ), + expression.NewEquals( + expression.NewLiteral(1, sql.Int64), + expression.NewUnresolvedColumn("foo"), + ), + ), + plan.NewResolvedTable(table), + ), + ), + plan.NewProject( + []sql.Expression{ + expression.NewGetField(1, sql.Int64, "foo", false), + }, + plan.NewFilter( + expression.NewOr( + expression.NewEquals( + expression.NewLiteral(1, sql.Int64), + expression.NewGetField(1, sql.Int64, "foo", false), + ), + expression.NewEquals( + expression.NewLiteral(1, sql.Int64), + expression.NewGetField(1, sql.Int64, "foo", false), + ), + ), + plan.NewProject( + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "mytable", "i", false), + expression.NewAlias(expression.NewLiteral(1, sql.Int64), "foo"), + }, + plan.NewResolvedTable(table), + ), + ), + ), + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + result, err := f.Apply(sql.NewEmptyContext(), NewDefault(nil), tt.project) + require.NoError(err) + + require.Equal(tt.expected, result) + }) + } +} + +func TestEraseProjection(t *testing.T) { + require := require.New(t) + f := getRule("erase_projection") + + table := memory.NewTable("mytable", sql.Schema{{ + Name: "i", Source: "mytable", Type: sql.Int64, + }}) + + expected := plan.NewSort( + []plan.SortField{{Column: expression.NewGetField(2, sql.Int64, "foo", false)}}, + plan.NewProject( + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "mytable", "i", false), + expression.NewGetField(1, sql.Int64, "bar", false), + expression.NewAlias(expression.NewLiteral(1, sql.Int64), "foo"), + }, + plan.NewFilter( + expression.NewEquals( + expression.NewLiteral(1, sql.Int64), + expression.NewGetField(1, sql.Int64, "bar", false), + ), + plan.NewProject( + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "mytable", "i", false), + expression.NewAlias(expression.NewLiteral(2, sql.Int64), "bar"), + }, + plan.NewResolvedTable(table), + ), + ), + ), + ) + + node := plan.NewProject( + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "mytable", "i", false), + expression.NewGetField(1, sql.Int64, "bar", false), + expression.NewGetField(2, sql.Int64, "foo", false), + }, + expected, + ) + + result, err := f.Apply(sql.NewEmptyContext(), NewDefault(nil), node) + require.NoError(err) + + require.Equal(expected, result) + + result, err = f.Apply(sql.NewEmptyContext(), NewDefault(nil), expected) + require.NoError(err) + + require.Equal(expected, result) +} + +func TestOptimizeDistinct(t *testing.T) { + t1 := memory.NewTable("foo", sql.Schema{ + {Name: "a", Source: "foo"}, + {Name: "b", Source: "foo"}, + }) + + testCases := []struct { + name string + child sql.Node + optimized bool + }{ + { + "without sort", + plan.NewResolvedTable(t1), + false, + }, + { + "sort but column not projected", + plan.NewSort( + []plan.SortField{ + {Column: gf(0, "foo", "c")}, + }, + plan.NewResolvedTable(t1), + ), + false, + }, + { + "sort and column projected", + plan.NewSort( + []plan.SortField{ + {Column: gf(0, "foo", "a")}, + }, + plan.NewResolvedTable(t1), + ), + true, + }, + } + + rule := getRule("optimize_distinct") + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + node, err := rule.Apply(sql.NewEmptyContext(), nil, plan.NewDistinct(tt.child)) + require.NoError(t, err) + + _, ok := node.(*plan.OrderedDistinct) + require.Equal(t, tt.optimized, ok) + }) + } +} + +func TestMoveJoinConditionsToFilter(t *testing.T) { + t1 := memory.NewTable("t1", sql.Schema{ + {Name: "a", Source: "t1", Type: sql.Int64}, + {Name: "b", Source: "t1", Type: sql.Int64}, + }) + + t2 := memory.NewTable("t2", sql.Schema{ + {Name: "c", Source: "t2", Type: sql.Int64}, + {Name: "d", Source: "t2", Type: sql.Int64}, + }) + + t3 := memory.NewTable("t3", sql.Schema{ + {Name: "e", Source: "t3", Type: sql.Int64}, + {Name: "f", Source: "t3", Type: sql.Int64}, + }) + + rule := getRule("move_join_conds_to_filter") + require := require.New(t) + + node := plan.NewInnerJoin( + plan.NewResolvedTable(t1), + plan.NewCrossJoin( + plan.NewResolvedTable(t2), + plan.NewResolvedTable(t3), + ), + expression.JoinAnd( + eq(col(0, "t1", "a"), col(2, "t2", "c")), + eq(col(0, "t1", "a"), col(4, "t3", "e")), + eq(col(2, "t2", "c"), col(4, "t3", "e")), + eq(col(0, "t1", "a"), lit(5)), + ), + ) + + result, err := rule.Apply(sql.NewEmptyContext(), NewDefault(nil), node) + require.NoError(err) + + var expected sql.Node = plan.NewInnerJoin( + plan.NewFilter( + eq(col(0, "t1", "a"), lit(5)), + plan.NewResolvedTable(t1), + ), + plan.NewFilter( + eq(col(0, "t2", "c"), col(2, "t3", "e")), + plan.NewCrossJoin( + plan.NewResolvedTable(t2), + plan.NewResolvedTable(t3), + ), + ), + and( + eq(col(0, "t1", "a"), col(2, "t2", "c")), + eq(col(0, "t1", "a"), col(4, "t3", "e")), + ), + ) + + require.Equal(expected, result) + + node = plan.NewInnerJoin( + plan.NewResolvedTable(t1), + plan.NewCrossJoin( + plan.NewResolvedTable(t2), + plan.NewResolvedTable(t3), + ), + expression.JoinAnd( + eq(col(0, "t2", "c"), col(0, "t3", "e")), + eq(col(0, "t1", "a"), lit(5)), + ), + ) + + result, err = rule.Apply(sql.NewEmptyContext(), NewDefault(nil), node) + require.NoError(err) + + expected = plan.NewCrossJoin( + plan.NewFilter( + eq(col(0, "t1", "a"), lit(5)), + plan.NewResolvedTable(t1), + ), + plan.NewFilter( + eq(col(0, "t2", "c"), col(2, "t3", "e")), + plan.NewCrossJoin( + plan.NewResolvedTable(t2), + plan.NewResolvedTable(t3), + ), + ), + ) + + require.Equal(result, expected) +} + +func TestEvalFilter(t *testing.T) { + inner := memory.NewTable("foo", nil) + rule := getRule("eval_filter") + + testCases := []struct { + filter sql.Expression + expected sql.Node + }{ + { + and( + eq(lit(5), lit(5)), + eq(col(0, "foo", "bar"), lit(5)), + ), + plan.NewFilter( + eq(col(0, "foo", "bar"), lit(5)), + plan.NewResolvedTable(inner), + ), + }, + { + and( + eq(col(0, "foo", "bar"), lit(5)), + eq(lit(5), lit(5)), + ), + plan.NewFilter( + eq(col(0, "foo", "bar"), lit(5)), + plan.NewResolvedTable(inner), + ), + }, + { + and( + eq(lit(5), lit(4)), + eq(col(0, "foo", "bar"), lit(5)), + ), + plan.EmptyTable, + }, + { + and( + eq(col(0, "foo", "bar"), lit(5)), + eq(lit(5), lit(4)), + ), + plan.EmptyTable, + }, + { + and( + eq(lit(4), lit(4)), + eq(lit(5), lit(5)), + ), + plan.NewResolvedTable(inner), + }, + { + or( + eq(lit(5), lit(4)), + eq(col(0, "foo", "bar"), lit(5)), + ), + plan.NewFilter( + eq(col(0, "foo", "bar"), lit(5)), + plan.NewResolvedTable(inner), + ), + }, + { + or( + eq(col(0, "foo", "bar"), lit(5)), + eq(lit(5), lit(4)), + ), + plan.NewFilter( + eq(col(0, "foo", "bar"), lit(5)), + plan.NewResolvedTable(inner), + ), + }, + { + or( + eq(lit(5), lit(5)), + eq(col(0, "foo", "bar"), lit(5)), + ), + plan.NewResolvedTable(inner), + }, + { + or( + eq(col(0, "foo", "bar"), lit(5)), + eq(lit(5), lit(5)), + ), + plan.NewResolvedTable(inner), + }, + { + or( + eq(lit(5), lit(4)), + eq(lit(5), lit(4)), + ), + plan.EmptyTable, + }, + } + + for _, tt := range testCases { + t.Run(tt.filter.String(), func(t *testing.T) { + require := require.New(t) + node := plan.NewFilter(tt.filter, plan.NewResolvedTable(inner)) + result, err := rule.Apply(sql.NewEmptyContext(), NewDefault(nil), node) + require.NoError(err) + require.Equal(tt.expected, result) + }) + } +} + +func TestRemoveUnnecessaryConverts(t *testing.T) { + testCases := []struct { + name string + childExpr sql.Expression + castType string + expected sql.Expression + }{ + { + "unnecessary cast", + expression.NewLiteral([]byte{}, sql.Blob), + "binary", + expression.NewLiteral([]byte{}, sql.Blob), + }, + { + "necessary cast", + expression.NewLiteral("foo", sql.Text), + "signed", + expression.NewConvert( + expression.NewLiteral("foo", sql.Text), + "signed", + ), + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + node := plan.NewProject([]sql.Expression{ + expression.NewConvert(tt.childExpr, tt.castType), + }, + plan.NewResolvedTable(memory.NewTable("foo", nil)), + ) + + result, err := removeUnnecessaryConverts( + sql.NewEmptyContext(), + NewDefault(nil), + node, + ) + require.NoError(err) + + resultExpr := result.(*plan.Project).Projections[0] + require.Equal(tt.expected, resultExpr) + }) + } +} diff --git a/sql/analyzer/parallelize.go b/sql/analyzer/parallelize.go new file mode 100644 index 000000000..a56b9479e --- /dev/null +++ b/sql/analyzer/parallelize.go @@ -0,0 +1,109 @@ +package analyzer + +import ( + "strconv" + + "github.com/go-kit/kit/metrics/discard" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/plan" +) + +var ( + // ParallelQueryCounter describes a metric that accumulates + // number of parallel queries monotonically. + ParallelQueryCounter = discard.NewCounter() +) + +func shouldParallelize(node sql.Node) bool { + // Do not try to parallelize index operations. + switch node.(type) { + case *plan.CreateIndex, *plan.DropIndex, *plan.Describe: + return false + default: + return true + } +} + +func parallelize(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, error) { + if a.Parallelism <= 1 || !node.Resolved() { + return node, nil + } + + proc, ok := node.(*plan.QueryProcess) + if (ok && !shouldParallelize(proc.Child)) || !shouldParallelize(node) { + return node, nil + } + + node, err := plan.TransformUp(node, func(node sql.Node) (sql.Node, error) { + if !isParallelizable(node) { + return node, nil + } + ParallelQueryCounter.With("parallelism", strconv.Itoa(a.Parallelism)).Add(1) + + return plan.NewExchange(a.Parallelism, node), nil + }) + + if err != nil { + return nil, err + } + + return plan.TransformUp(node, removeRedundantExchanges) +} + +// removeRedundantExchanges removes all the exchanges except for the topmost +// of all. +func removeRedundantExchanges(node sql.Node) (sql.Node, error) { + exchange, ok := node.(*plan.Exchange) + if !ok { + return node, nil + } + + child, err := plan.TransformUp(exchange.Child, func(node sql.Node) (sql.Node, error) { + if exchange, ok := node.(*plan.Exchange); ok { + return exchange.Child, nil + } + return node, nil + }) + if err != nil { + return nil, err + } + + return exchange.WithChildren(child) +} + +func isParallelizable(node sql.Node) bool { + var ok = true + var tableSeen bool + var lastWasTable bool + + plan.Inspect(node, func(node sql.Node) bool { + if node == nil { + return true + } + + lastWasTable = false + if plan.IsBinary(node) { + ok = false + return false + } + + switch node.(type) { + // These are the only unary nodes that can be parallelized. Any other + // unary nodes will not. + case *plan.Filter, + *plan.Project, + *plan.TableAlias, + *plan.Exchange: + case sql.Table: + lastWasTable = true + tableSeen = true + default: + ok = false + return false + } + + return true + }) + + return ok && tableSeen && lastWasTable +} diff --git a/sql/analyzer/parallelize_test.go b/sql/analyzer/parallelize_test.go new file mode 100644 index 000000000..5f554a442 --- /dev/null +++ b/sql/analyzer/parallelize_test.go @@ -0,0 +1,228 @@ +package analyzer + +import ( + "testing" + + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/plan" + "github.com/stretchr/testify/require" +) + +func TestParallelize(t *testing.T) { + require := require.New(t) + table := memory.NewTable("t", nil) + rule := getRuleFrom(OnceAfterAll, "parallelize") + node := plan.NewProject( + nil, + plan.NewInnerJoin( + plan.NewFilter( + expression.NewLiteral(1, sql.Int64), + plan.NewResolvedTable(table), + ), + plan.NewFilter( + expression.NewLiteral(1, sql.Int64), + plan.NewResolvedTable(table), + ), + expression.NewLiteral(1, sql.Int64), + ), + ) + + expected := plan.NewProject( + nil, + plan.NewInnerJoin( + plan.NewExchange( + 2, + plan.NewFilter( + expression.NewLiteral(1, sql.Int64), + plan.NewResolvedTable(table), + ), + ), + plan.NewExchange( + 2, + plan.NewFilter( + expression.NewLiteral(1, sql.Int64), + plan.NewResolvedTable(table), + ), + ), + expression.NewLiteral(1, sql.Int64), + ), + ) + + result, err := rule.Apply(sql.NewEmptyContext(), &Analyzer{Parallelism: 2}, node) + require.NoError(err) + require.Equal(expected, result) +} + +func TestParallelizeCreateIndex(t *testing.T) { + require := require.New(t) + table := memory.NewTable("t", nil) + rule := getRuleFrom(OnceAfterAll, "parallelize") + node := plan.NewCreateIndex( + "", + plan.NewResolvedTable(table), + nil, + "", + nil, + ) + + result, err := rule.Apply(sql.NewEmptyContext(), &Analyzer{Parallelism: 1}, node) + require.NoError(err) + require.Equal(node, result) +} + +func TestIsParallelizable(t *testing.T) { + table := memory.NewTable("t", nil) + + testCases := []struct { + name string + node sql.Node + ok bool + }{ + { + "just table", + plan.NewResolvedTable(table), + true, + }, + { + "filter", + plan.NewFilter( + expression.NewLiteral(1, sql.Int64), + plan.NewResolvedTable(table), + ), + true, + }, + { + "project", + plan.NewProject( + nil, + plan.NewFilter( + expression.NewLiteral(1, sql.Int64), + plan.NewResolvedTable(table), + ), + ), + true, + }, + { + "join", + plan.NewInnerJoin( + plan.NewResolvedTable(table), + plan.NewResolvedTable(table), + expression.NewLiteral(1, sql.Int64), + ), + false, + }, + { + "group by", + plan.NewGroupBy( + nil, + nil, + plan.NewResolvedTable(nil), + ), + false, + }, + { + "limit", + plan.NewLimit( + 5, + plan.NewResolvedTable(nil), + ), + false, + }, + { + "offset", + plan.NewOffset( + 5, + plan.NewResolvedTable(nil), + ), + false, + }, + { + "sort", + plan.NewSort( + nil, + plan.NewResolvedTable(nil), + ), + false, + }, + { + "distinct", + plan.NewDistinct( + plan.NewResolvedTable(nil), + ), + false, + }, + { + "ordered distinct", + plan.NewOrderedDistinct( + plan.NewResolvedTable(nil), + ), + false, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.ok, isParallelizable(tt.node)) + }) + } +} + +func TestRemoveRedundantExchanges(t *testing.T) { + require := require.New(t) + + table := memory.NewTable("t", nil) + + node := plan.NewProject( + nil, + plan.NewInnerJoin( + plan.NewExchange( + 1, + plan.NewFilter( + expression.NewLiteral(1, sql.Int64), + plan.NewExchange( + 1, + plan.NewResolvedTable(table), + ), + ), + ), + plan.NewExchange( + 1, + plan.NewFilter( + expression.NewLiteral(1, sql.Int64), + plan.NewExchange( + 1, + plan.NewResolvedTable(table), + ), + ), + ), + expression.NewLiteral(1, sql.Int64), + ), + ) + + expected := plan.NewProject( + nil, + plan.NewInnerJoin( + plan.NewExchange( + 1, + plan.NewFilter( + expression.NewLiteral(1, sql.Int64), + plan.NewResolvedTable(table), + ), + ), + plan.NewExchange( + 1, + plan.NewFilter( + expression.NewLiteral(1, sql.Int64), + plan.NewResolvedTable(table), + ), + ), + expression.NewLiteral(1, sql.Int64), + ), + ) + + result, err := plan.TransformUp(node, removeRedundantExchanges) + require.NoError(err) + require.Equal(expected, result) +} diff --git a/sql/analyzer/process.go b/sql/analyzer/process.go new file mode 100644 index 000000000..aa0d9d0c7 --- /dev/null +++ b/sql/analyzer/process.go @@ -0,0 +1,103 @@ +package analyzer + +import ( + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/plan" +) + +// trackProcess will wrap the query in a process node and add progress items +// to the already existing process. +func trackProcess(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + if !n.Resolved() { + return n, nil + } + + if _, ok := n.(*plan.QueryProcess); ok { + return n, nil + } + + processList := a.Catalog.ProcessList + + var seen = make(map[string]struct{}) + n, err := plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { + switch n := n.(type) { + case *plan.ResolvedTable: + switch n.Table.(type) { + case *plan.ProcessTable, *plan.ProcessIndexableTable: + return n, nil + } + + name := n.Table.Name() + if _, ok := seen[name]; ok { + return n, nil + } + + var total int64 = -1 + if counter, ok := n.Table.(sql.PartitionCounter); ok { + count, err := counter.PartitionCount(ctx) + if err != nil { + return nil, err + } + total = count + } + processList.AddTableProgress(ctx.Pid(), name, total) + + seen[name] = struct{}{} + + onPartitionDone := func(partitionName string) { + processList.UpdateTableProgress(ctx.Pid(), name, 1) + processList.RemovePartitionProgress(ctx.Pid(), name, partitionName) + } + + onPartitionStart := func(partitionName string) { + processList.AddPartitionProgress(ctx.Pid(), name, partitionName, -1) + } + + onRowNext := func(partitionName string) { + processList.UpdatePartitionProgress(ctx.Pid(), name, partitionName, 1) + } + + var t sql.Table + switch table := n.Table.(type) { + case sql.IndexableTable: + t = plan.NewProcessIndexableTable(table, onPartitionDone, onPartitionStart, onRowNext) + default: + t = plan.NewProcessTable(table, onPartitionDone, onPartitionStart, onRowNext) + } + + return plan.NewResolvedTable(t), nil + default: + return n, nil + } + }) + if err != nil { + return nil, err + } + + // Don't wrap CreateIndex in a QueryProcess, as it is a CreateIndexProcess. + // CreateIndex will take care of marking the process as done on its own. + if _, ok := n.(*plan.CreateIndex); ok { + return n, nil + } + + // Remove QueryProcess nodes from the subqueries. Otherwise, the process + // will be marked as done as soon as a subquery finishes. + node, err := plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { + if sq, ok := n.(*plan.SubqueryAlias); ok { + if qp, ok := sq.Child.(*plan.QueryProcess); ok { + return plan.NewSubqueryAlias(sq.Name(), qp.Child), nil + } + } + return n, nil + }) + if err != nil { + return nil, err + } + + return plan.NewQueryProcess(node, func() { + processList.Done(ctx.Pid()) + if span := ctx.RootSpan(); span != nil { + span.Finish() + } + }), nil +} diff --git a/sql/analyzer/process_test.go b/sql/analyzer/process_test.go new file mode 100644 index 000000000..91271af4c --- /dev/null +++ b/sql/analyzer/process_test.go @@ -0,0 +1,125 @@ +package analyzer + +import ( + "context" + "testing" + "time" + + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/plan" + "github.com/stretchr/testify/require" +) + +func TestTrackProcess(t *testing.T) { + require := require.New(t) + rule := getRuleFrom(OnceAfterAll, "track_process") + catalog := sql.NewCatalog() + a := NewDefault(catalog) + + node := plan.NewInnerJoin( + plan.NewResolvedTable(&table{memory.NewPartitionedTable("foo", nil, 2)}), + plan.NewResolvedTable(memory.NewPartitionedTable("bar", nil, 4)), + expression.NewLiteral(int64(1), sql.Int64), + ) + + ctx := sql.NewContext(context.Background(), sql.WithPid(1)) + ctx, err := catalog.AddProcess(ctx, sql.QueryProcess, "SELECT foo") + require.NoError(err) + + result, err := rule.Apply(ctx, a, node) + require.NoError(err) + + processes := catalog.Processes() + require.Len(processes, 1) + require.Equal("SELECT foo", processes[0].Query) + require.Equal(sql.QueryProcess, processes[0].Type) + require.Equal( + map[string]sql.TableProgress{ + "foo": sql.TableProgress{ + Progress: sql.Progress{Name: "foo", Done: 0, Total: 2}, + PartitionsProgress: map[string]sql.PartitionProgress{}}, + "bar": sql.TableProgress{ + Progress: sql.Progress{Name: "bar", Done: 0, Total: 4}, + PartitionsProgress: map[string]sql.PartitionProgress{}}, + }, + processes[0].Progress) + + proc, ok := result.(*plan.QueryProcess) + require.True(ok) + + join, ok := proc.Child.(*plan.InnerJoin) + require.True(ok) + + lhs, ok := join.Left.(*plan.ResolvedTable) + require.True(ok) + _, ok = lhs.Table.(*plan.ProcessTable) + require.True(ok) + + rhs, ok := join.Right.(*plan.ResolvedTable) + require.True(ok) + _, ok = rhs.Table.(*plan.ProcessIndexableTable) + require.True(ok) + + iter, err := proc.RowIter(ctx) + require.NoError(err) + _, err = sql.RowIterToRows(iter) + require.NoError(err) + + require.Len(catalog.Processes(), 0) + + select { + case <-ctx.Done(): + case <-time.After(5 * time.Millisecond): + t.Errorf("expecting context to be cancelled") + } +} + +func TestTrackProcessSubquery(t *testing.T) { + require := require.New(t) + rule := getRuleFrom(OnceAfterAll, "track_process") + catalog := sql.NewCatalog() + a := NewDefault(catalog) + + node := plan.NewProject( + nil, + plan.NewSubqueryAlias("f", + plan.NewQueryProcess( + plan.NewResolvedTable(memory.NewTable("foo", nil)), + nil, + ), + ), + ) + + result, err := rule.Apply(sql.NewEmptyContext(), a, node) + require.NoError(err) + + expectedChild := plan.NewProject( + nil, + plan.NewSubqueryAlias("f", + plan.NewResolvedTable(memory.NewTable("foo", nil)), + ), + ) + + proc, ok := result.(*plan.QueryProcess) + require.True(ok) + require.Equal(expectedChild, proc.Child) +} + +func withoutProcessTracking(a *Analyzer) *Analyzer { + afterAll := a.Batches[len(a.Batches)-1] + afterAll.Rules = afterAll.Rules[1:] + return a +} + +// wrapper around sql.Table to make it not indexable +type table struct { + sql.Table +} + +var _ sql.PartitionCounter = (*table)(nil) + +func (t *table) PartitionCount(ctx *sql.Context) (int64, error) { + return t.Table.(sql.PartitionCounter).PartitionCount(ctx) +} diff --git a/sql/analyzer/prune_columns.go b/sql/analyzer/prune_columns.go new file mode 100644 index 000000000..0b4121ebe --- /dev/null +++ b/sql/analyzer/prune_columns.go @@ -0,0 +1,288 @@ +package analyzer + +import ( + "fmt" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/plan" +) + +type usedColumns map[string]map[string]struct{} + +func pruneColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + a.Log("pruning columns, node of type %T", n) + if !n.Resolved() { + return n, nil + } + + if describe, ok := n.(*plan.DescribeQuery); ok { + pruned, err := pruneColumns(ctx, a, describe.Child) + if err != nil { + return nil, err + } + + return plan.NewDescribeQuery(describe.Format, pruned), nil + } + + columns := make(usedColumns) + + // All the columns required for the output of the query must be mark as + // used, otherwise the schema would change. + for _, col := range n.Schema() { + if _, ok := columns[col.Source]; !ok { + columns[col.Source] = make(map[string]struct{}) + } + columns[col.Source][col.Name] = struct{}{} + } + + findUsedColumns(columns, n) + + n, err := pruneUnusedColumns(n, columns) + if err != nil { + return nil, err + } + + n, err = pruneSubqueries(ctx, a, n, columns) + if err != nil { + return nil, err + } + + return fixRemainingFieldsIndexes(n) +} + +func pruneSubqueryColumns( + ctx *sql.Context, + a *Analyzer, + n *plan.SubqueryAlias, + parentColumns usedColumns, +) (sql.Node, error) { + a.Log("pruning columns of subquery with alias %q", n.Name()) + + columns := make(usedColumns) + + // The columns coming from the parent have the subquery alias name as the + // source. We need to find the real table in order to prune the subquery + // correctly. + tableByCol := make(map[string]string) + for _, col := range n.Child.Schema() { + tableByCol[col.Name] = col.Source + } + + for col := range parentColumns[n.Name()] { + table, ok := tableByCol[col] + if !ok { + // This should never happen, but better be safe than sorry. + return nil, fmt.Errorf("this is likely a bug: missing projected column %q on subquery %q", col, n.Name()) + } + + if _, ok := columns[table]; !ok { + columns[table] = make(map[string]struct{}) + } + + columns[table][col] = struct{}{} + } + + findUsedColumns(columns, n.Child) + + node, err := pruneUnusedColumns(n.Child, columns) + if err != nil { + return nil, err + } + + node, err = pruneSubqueries(ctx, a, node, columns) + if err != nil { + return nil, err + } + + // There is no need to fix the field indexes after pruning here + // because the main query will take care of fixing the indexes of all the + // nodes in the tree. + + return plan.NewSubqueryAlias(n.Name(), node), nil +} + +func findUsedColumns(columns usedColumns, n sql.Node) { + plan.Inspect(n, func(n sql.Node) bool { + switch n := n.(type) { + case *plan.Project: + addUsedProjectColumns(columns, n.Projections) + return true + case *plan.GroupBy: + addUsedProjectColumns(columns, n.Aggregate) + addUsedColumns(columns, n.Grouping) + return true + case *plan.SubqueryAlias: + return false + } + + exp, ok := n.(sql.Expressioner) + if ok { + addUsedColumns(columns, exp.Expressions()) + } + + return true + }) +} + +func pruneSubqueries( + ctx *sql.Context, + a *Analyzer, + n sql.Node, + parentColumns usedColumns, +) (sql.Node, error) { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { + subq, ok := n.(*plan.SubqueryAlias) + if !ok { + return n, nil + } + + return pruneSubqueryColumns(ctx, a, subq, parentColumns) + }) +} + +func pruneUnusedColumns(n sql.Node, columns usedColumns) (sql.Node, error) { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { + switch n := n.(type) { + case *plan.Project: + return pruneProject(n, columns), nil + case *plan.GroupBy: + return pruneGroupBy(n, columns), nil + default: + return n, nil + } + }) +} + +func fixRemainingFieldsIndexes(n sql.Node) (sql.Node, error) { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { + switch n := n.(type) { + case *plan.SubqueryAlias: + child, err := fixRemainingFieldsIndexes(n.Child) + if err != nil { + return nil, err + } + + return plan.NewSubqueryAlias(n.Name(), child), nil + default: + if _, ok := n.(sql.Expressioner); !ok { + return n, nil + } + + var schema sql.Schema + for _, c := range n.Children() { + schema = append(schema, c.Schema()...) + } + + if len(schema) == 0 { + return n, nil + } + + indexes := make(map[tableCol]int) + for i, col := range schema { + indexes[tableCol{col.Source, col.Name}] = i + } + + return plan.TransformExpressions(n, func(e sql.Expression) (sql.Expression, error) { + gf, ok := e.(*expression.GetField) + if !ok { + return e, nil + } + + idx, ok := indexes[tableCol{gf.Table(), gf.Name()}] + if !ok { + return nil, ErrColumnTableNotFound.New(gf.Table(), gf.Name()) + } + + if idx == gf.Index() { + return gf, nil + } + + ngf := *gf + return ngf.WithIndex(idx), nil + }) + } + }) +} + +func addUsedProjectColumns( + columns usedColumns, + projection []sql.Expression, +) { + var candidates []sql.Expression + for _, e := range projection { + // Only check for expressions that are not directly a GetField. This + // is because in a projection we only care about those that were used + // to compute new columns, such as aliases and so on. The fields that + // are just passed up in the tree will already be in some other part + // if they are really used. + if _, ok := e.(*expression.GetField); !ok { + candidates = append(candidates, e) + } + } + + addUsedColumns(columns, candidates) +} + +func addUsedColumns(columns usedColumns, exprs []sql.Expression) { + for _, e := range exprs { + expression.Inspect(e, func(e sql.Expression) bool { + if gf, ok := e.(*expression.GetField); ok { + if _, ok := columns[gf.Table()]; !ok { + columns[gf.Table()] = make(map[string]struct{}) + } + columns[gf.Table()][gf.Name()] = struct{}{} + } + return true + }) + } +} + +func pruneProject(n *plan.Project, columns usedColumns) sql.Node { + var remaining []sql.Expression + for _, e := range n.Projections { + if !shouldPruneExpr(e, columns) { + remaining = append(remaining, e) + } + } + + if len(remaining) == 0 { + return n.Child + } + + return plan.NewProject(remaining, n.Child) +} + +func pruneGroupBy(n *plan.GroupBy, columns usedColumns) sql.Node { + var remaining []sql.Expression + for _, e := range n.Aggregate { + if !shouldPruneExpr(e, columns) { + remaining = append(remaining, e) + } + } + + if len(remaining) == 0 { + return n.Child + } + + return plan.NewGroupBy(remaining, n.Grouping, n.Child) +} + +func shouldPruneExpr(e sql.Expression, cols usedColumns) bool { + gf, ok := e.(*expression.GetField) + if !ok { + return false + } + + if gf.Table() == "" { + return false + } + + if c, ok := cols[gf.Table()]; ok { + if _, ok := c[gf.Name()]; ok { + return false + } + } + + return true +} diff --git a/sql/analyzer/prune_columns_test.go b/sql/analyzer/prune_columns_test.go new file mode 100644 index 000000000..e2a1c2abd --- /dev/null +++ b/sql/analyzer/prune_columns_test.go @@ -0,0 +1,302 @@ +package analyzer + +import ( + "testing" + + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/plan" + "github.com/stretchr/testify/require" +) + +func TestPruneColumns(t *testing.T) { + rule := getRuleFrom(OnceAfterDefault, "prune_columns") + a := NewDefault(nil) + + t1 := plan.NewResolvedTable(memory.NewTable("t1", sql.Schema{ + {Name: "foo", Type: sql.Int64, Source: "t1"}, + {Name: "bar", Type: sql.Int64, Source: "t1"}, + {Name: "bax", Type: sql.Int64, Source: "t1"}, + })) + + t2 := plan.NewResolvedTable(memory.NewTable("t2", sql.Schema{ + {Name: "foo", Type: sql.Int64, Source: "t2"}, + {Name: "baz", Type: sql.Int64, Source: "t2"}, + {Name: "bux", Type: sql.Int64, Source: "t2"}, + })) + + testCases := []struct { + name string + input sql.Node + expected sql.Node + }{ + { + "natural join", + plan.NewProject( + []sql.Expression{ + gf(0, "t1", "foo"), + gf(1, "", "some_alias"), + }, + plan.NewProject( + []sql.Expression{ + gf(0, "t1", "foo"), + expression.NewAlias(gf(1, "t1", "bar"), "some_alias"), + }, + plan.NewFilter( + eq(gf(0, "t1", "foo"), gf(4, "t2", "baz")), + plan.NewProject( + []sql.Expression{ + gf(0, "t1", "foo"), + gf(1, "t1", "bar"), + gf(2, "t1", "bax"), + gf(4, "t2", "baz"), + gf(5, "t2", "bux"), + }, + plan.NewCrossJoin(t1, t2), + ), + ), + ), + ), + + plan.NewProject( + []sql.Expression{ + gf(0, "t1", "foo"), + gf(1, "", "some_alias"), + }, + plan.NewProject( + []sql.Expression{ + gf(0, "t1", "foo"), + expression.NewAlias(gf(1, "t1", "bar"), "some_alias"), + }, + plan.NewFilter( + eq(gf(0, "t1", "foo"), gf(2, "t2", "baz")), + plan.NewProject( + []sql.Expression{ + gf(0, "t1", "foo"), + gf(1, "t1", "bar"), + gf(4, "t2", "baz"), + }, + plan.NewCrossJoin(t1, t2), + ), + ), + ), + ), + }, + + { + "subquery", + plan.NewProject( + []sql.Expression{ + gf(0, "t", "foo"), + gf(1, "", "some_alias"), + }, + plan.NewProject( + []sql.Expression{ + gf(0, "t", "foo"), + expression.NewAlias(gf(1, "t", "bar"), "some_alias"), + }, + plan.NewFilter( + eq(gf(0, "t", "foo"), gf(4, "t", "baz")), + plan.NewSubqueryAlias("t", + plan.NewProject( + []sql.Expression{ + gf(0, "t1", "foo"), + gf(1, "t1", "bar"), + gf(2, "t1", "bax"), + gf(4, "t2", "baz"), + gf(5, "t2", "bux"), + }, + plan.NewCrossJoin(t1, t2), + ), + ), + ), + ), + ), + + plan.NewProject( + []sql.Expression{ + gf(0, "t", "foo"), + gf(1, "", "some_alias"), + }, + plan.NewProject( + []sql.Expression{ + gf(0, "t", "foo"), + expression.NewAlias(gf(1, "t", "bar"), "some_alias"), + }, + plan.NewFilter( + eq(gf(0, "t", "foo"), gf(2, "t", "baz")), + plan.NewSubqueryAlias("t", + plan.NewProject( + []sql.Expression{ + gf(0, "t1", "foo"), + gf(1, "t1", "bar"), + gf(4, "t2", "baz"), + }, + plan.NewCrossJoin(t1, t2), + ), + ), + ), + ), + ), + }, + + { + "group by", + plan.NewGroupBy( + []sql.Expression{ + gf(0, "t1", "foo"), + gf(1, "", "some_alias"), + }, + []sql.Expression{ + gf(0, "t1", "foo"), + gf(5, "t2", "bux"), + gf(1, "", "some_alias"), + }, + plan.NewProject( + []sql.Expression{ + gf(0, "t1", "foo"), + expression.NewAlias(gf(1, "t1", "bar"), "some_alias"), + gf(5, "t2", "bux"), + }, + plan.NewFilter( + eq(gf(0, "t1", "foo"), gf(4, "t2", "baz")), + plan.NewProject( + []sql.Expression{ + gf(0, "t1", "foo"), + gf(1, "t1", "bar"), + gf(2, "t1", "bax"), + gf(4, "t2", "baz"), + gf(5, "t2", "bux"), + }, + plan.NewCrossJoin(t1, t2), + ), + ), + ), + ), + + plan.NewGroupBy( + []sql.Expression{ + gf(0, "t1", "foo"), + gf(1, "", "some_alias"), + }, + []sql.Expression{ + gf(0, "t1", "foo"), + gf(2, "t2", "bux"), + gf(1, "", "some_alias"), + }, + plan.NewProject( + []sql.Expression{ + gf(0, "t1", "foo"), + expression.NewAlias(gf(1, "t1", "bar"), "some_alias"), + gf(3, "t2", "bux"), + }, + plan.NewFilter( + eq(gf(0, "t1", "foo"), gf(2, "t2", "baz")), + plan.NewProject( + []sql.Expression{ + gf(0, "t1", "foo"), + gf(1, "t1", "bar"), + gf(4, "t2", "baz"), + gf(5, "t2", "bux"), + }, + plan.NewCrossJoin(t1, t2), + ), + ), + ), + ), + }, + + { + "used inside subquery and not outside", + plan.NewProject( + []sql.Expression{ + gf(0, "sq", "foo"), + }, + plan.NewSubqueryAlias("sq", + plan.NewProject( + []sql.Expression{gf(0, "t1", "foo")}, + plan.NewInnerJoin( + plan.NewProject( + []sql.Expression{ + gf(0, "t1", "foo"), + gf(1, "t1", "bar"), + gf(2, "t1", "bax"), + }, + t1, + ), + plan.NewProject( + []sql.Expression{ + gf(0, "t2", "foo"), + gf(1, "t2", "baz"), + gf(2, "t2", "bux"), + }, + t2, + ), + expression.NewEquals( + gf(0, "t1", "foo"), + gf(3, "t2", "foo"), + ), + ), + ), + ), + ), + plan.NewProject( + []sql.Expression{ + gf(0, "sq", "foo"), + }, + plan.NewSubqueryAlias("sq", + plan.NewProject( + []sql.Expression{gf(0, "t1", "foo")}, + plan.NewInnerJoin( + plan.NewProject( + []sql.Expression{ + gf(0, "t1", "foo"), + }, + t1, + ), + plan.NewProject( + []sql.Expression{ + gf(0, "t2", "foo"), + }, + t2, + ), + expression.NewEquals( + gf(0, "t1", "foo"), + gf(1, "t2", "foo"), + ), + ), + ), + ), + ), + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + // Hack so the result and expected schema can be compared. + // SubqueryAlias keeps the schema cached, which will not be + // in the case of `expected` because it was constructed by hand. + ensureSubquerySchema(tt.expected) + + require := require.New(t) + result, err := rule.Apply(sql.NewEmptyContext(), a, tt.input) + require.NoError(err) + require.Equal(tt.expected.Schema(), result.Schema()) + require.Equal(tt.expected, result) + }) + } +} + +func ensureSubquerySchema(n sql.Node) { + plan.Inspect(n, func(n sql.Node) bool { + if _, ok := n.(*plan.SubqueryAlias); ok { + _ = n.Schema() + } + return true + }) +} + +func gf(idx int, table, name string) *expression.GetField { + return expression.NewGetFieldWithTable(idx, sql.Int64, table, name, false) +} diff --git a/sql/analyzer/pushdown.go b/sql/analyzer/pushdown.go new file mode 100644 index 000000000..07049c413 --- /dev/null +++ b/sql/analyzer/pushdown.go @@ -0,0 +1,379 @@ +package analyzer + +import ( + "reflect" + "sync" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/plan" +) + +func pushdown(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + span, ctx := ctx.Span("pushdown") + defer span.Finish() + + a.Log("pushdown, node of type: %T", n) + if !n.Resolved() { + return n, nil + } + + // don't do pushdown on certain queries + switch n.(type) { + case *plan.InsertInto, *plan.DeleteFrom, *plan.Update, *plan.CreateIndex: + return n, nil + } + + a.Log("finding used columns in node") + + colSpan, _ := ctx.Span("find_pushdown_columns") + + // First step is to find all col exprs and group them by the table they mention. + // Even if they appear multiple times, only the first one will be used. + fieldsByTable := findFieldsByTable(n) + + colSpan.Finish() + + a.Log("finding filters in node") + filters := findFilters(ctx, n) + + indexSpan, _ := ctx.Span("assign_indexes") + indexes, err := assignIndexes(a, n) + if err != nil { + return nil, err + } + indexSpan.Finish() + + a.Log("transforming nodes with pushdown of filters, projections and indexes") + + return transformPushdown(a, n, filters, indexes, fieldsByTable) +} + +// fixFieldIndexesOnExpressions executes fixFieldIndexes on a list of exprs. +func fixFieldIndexesOnExpressions(schema sql.Schema, expressions ...sql.Expression) ([]sql.Expression, error) { + var result = make([]sql.Expression, len(expressions)) + for i, e := range expressions { + var err error + result[i], err = fixFieldIndexes(schema, e) + if err != nil { + return nil, err + } + } + return result, nil +} + +// fixFieldIndexes transforms the given expression setting correct indexes +// for GetField expressions according to the schema of the row in the table +// and not the one where the filter came from. +func fixFieldIndexes(schema sql.Schema, exp sql.Expression) (sql.Expression, error) { + return expression.TransformUp(exp, func(e sql.Expression) (sql.Expression, error) { + switch e := e.(type) { + case *expression.GetField: + // we need to rewrite the indexes for the table row + for i, col := range schema { + if e.Name() == col.Name && e.Table() == col.Source { + return expression.NewGetFieldWithTable( + i, + e.Type(), + e.Table(), + e.Name(), + e.IsNullable(), + ), nil + } + } + + return nil, ErrFieldMissing.New(e.Name()) + } + + return e, nil + }) +} + +func findFieldsByTable(n sql.Node) map[string][]string { + var fieldsByTable = make(map[string][]string) + plan.InspectExpressions(n, func(e sql.Expression) bool { + if gf, ok := e.(*expression.GetField); ok { + if !stringContains(fieldsByTable[gf.Table()], gf.Name()) { + fieldsByTable[gf.Table()] = append(fieldsByTable[gf.Table()], gf.Name()) + } + } + return true + }) + return fieldsByTable +} + +func findFilters(ctx *sql.Context, n sql.Node) filters { + span, _ := ctx.Span("find_pushdown_filters") + defer span.Finish() + + // Find all filters, also by table. Note that filters that mention + // more than one table will not be passed to neither. + filters := make(filters) + plan.Inspect(n, func(node sql.Node) bool { + switch node := node.(type) { + case *plan.Filter: + fs := exprToTableFilters(node.Expression) + filters.merge(fs) + } + return true + }) + + return filters +} + +func transformPushdown( + a *Analyzer, + n sql.Node, + filters filters, + indexes map[string]*indexLookup, + fieldsByTable map[string][]string, +) (sql.Node, error) { + // Now all nodes can be transformed. Since traversal of the tree is done + // from inner to outer the filters have to be processed first so they get + // to the tables. + var handledFilters []sql.Expression + var queryIndexes []sql.Index + + node, err := plan.TransformUp(n, func(node sql.Node) (sql.Node, error) { + a.Log("transforming node of type: %T", node) + switch node := node.(type) { + case *plan.Filter: + return pushdownFilter(a, node, handledFilters) + case *plan.ResolvedTable: + return pushdownTable( + a, + node, + filters, + &handledFilters, + &queryIndexes, + fieldsByTable, + indexes, + ) + default: + return transformExpressioners(node) + } + }) + + release := func() { + for _, idx := range queryIndexes { + a.Catalog.ReleaseIndex(idx) + } + } + + if err != nil { + release() + return nil, err + } + + if len(queryIndexes) > 0 { + return &releaser{node, release}, nil + } + + return node, nil +} + +func transformExpressioners(node sql.Node) (sql.Node, error) { + if _, ok := node.(sql.Expressioner); !ok { + return node, nil + } + + var schemas []sql.Schema + for _, child := range node.Children() { + schemas = append(schemas, child.Schema()) + } + + if len(schemas) < 1 { + return node, nil + } + + n, err := plan.TransformExpressions(node, func(e sql.Expression) (sql.Expression, error) { + for _, schema := range schemas { + fixed, err := fixFieldIndexes(schema, e) + if err == nil { + return fixed, nil + } + + if ErrFieldMissing.Is(err) { + continue + } + + return nil, err + } + + return e, nil + }) + + if err != nil { + return nil, err + } + + switch j := n.(type) { + case *plan.InnerJoin: + cond, err := fixFieldIndexes(j.Schema(), j.Cond) + if err != nil { + return nil, err + } + + n = plan.NewInnerJoin(j.Left, j.Right, cond) + case *plan.RightJoin: + cond, err := fixFieldIndexes(j.Schema(), j.Cond) + if err != nil { + return nil, err + } + + n = plan.NewRightJoin(j.Left, j.Right, cond) + case *plan.LeftJoin: + cond, err := fixFieldIndexes(j.Schema(), j.Cond) + if err != nil { + return nil, err + } + + n = plan.NewLeftJoin(j.Left, j.Right, cond) + } + + return n, nil +} + +func pushdownTable( + a *Analyzer, + node *plan.ResolvedTable, + filters filters, + handledFilters *[]sql.Expression, + queryIndexes *[]sql.Index, + fieldsByTable map[string][]string, + indexes map[string]*indexLookup, +) (sql.Node, error) { + var table = node.Table + + if ft, ok := table.(sql.FilteredTable); ok { + tableFilters := filters[node.Name()] + handled := ft.HandledFilters(tableFilters) + *handledFilters = append(*handledFilters, handled...) + schema := node.Schema() + handled, err := fixFieldIndexesOnExpressions(schema, handled...) + if err != nil { + return nil, err + } + + table = ft.WithFilters(handled) + a.Log( + "table %q transformed with pushdown of filters, %d filters handled of %d", + node.Name(), + len(handled), + len(tableFilters), + ) + } + + if pt, ok := table.(sql.ProjectedTable); ok { + table = pt.WithProjection(fieldsByTable[node.Name()]) + a.Log("table %q transformed with pushdown of projection", node.Name()) + } + + if it, ok := table.(sql.IndexableTable); ok { + indexLookup, ok := indexes[node.Name()] + if ok { + *queryIndexes = append(*queryIndexes, indexLookup.indexes...) + table = it.WithIndexLookup(indexLookup.lookup) + a.Log("table %q transformed with pushdown of index", node.Name()) + } + } + + return plan.NewResolvedTable(table), nil +} + +func pushdownFilter( + a *Analyzer, + node *plan.Filter, + handledFilters []sql.Expression, +) (sql.Node, error) { + if len(handledFilters) == 0 { + a.Log("no handled filters, leaving filter untouched") + return node, nil + } + + unhandled := getUnhandledFilters( + splitExpression(node.Expression), + handledFilters, + ) + + if len(unhandled) == 0 { + a.Log("filter node has no unhandled filters, so it will be removed") + return node.Child, nil + } + + a.Log( + "%d handled filters removed from filter node, filter has now %d filters", + len(handledFilters), + len(unhandled), + ) + + return plan.NewFilter(expression.JoinAnd(unhandled...), node.Child), nil +} + +type releaser struct { + Child sql.Node + Release func() +} + +func (r *releaser) Resolved() bool { + return r.Child.Resolved() +} + +func (r *releaser) Children() []sql.Node { + return []sql.Node{r.Child} +} + +func (r *releaser) RowIter(ctx *sql.Context) (sql.RowIter, error) { + iter, err := r.Child.RowIter(ctx) + if err != nil { + r.Release() + return nil, err + } + + return &releaseIter{child: iter, release: r.Release}, nil +} + +func (r *releaser) Schema() sql.Schema { + return r.Child.Schema() +} + +func (r *releaser) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), 1) + } + return &releaser{children[0], r.Release}, nil +} + +func (r *releaser) String() string { + return r.Child.String() +} + +func (r *releaser) Equal(n sql.Node) bool { + if r2, ok := n.(*releaser); ok { + return reflect.DeepEqual(r.Child, r2.Child) + } + return false +} + +type releaseIter struct { + child sql.RowIter + release func() + once sync.Once +} + +func (i *releaseIter) Next() (sql.Row, error) { + row, err := i.child.Next() + if err != nil { + _ = i.Close() + return nil, err + } + return row, nil +} + +func (i *releaseIter) Close() (err error) { + i.once.Do(i.release) + if i.child != nil { + err = i.child.Close() + } + return err +} diff --git a/sql/analyzer/pushdown_test.go b/sql/analyzer/pushdown_test.go new file mode 100644 index 000000000..818eb40a0 --- /dev/null +++ b/sql/analyzer/pushdown_test.go @@ -0,0 +1,226 @@ +package analyzer + +import ( + "testing" + + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/plan" + "github.com/stretchr/testify/require" +) + +func TestPushdownProjectionAndFilters(t *testing.T) { + require := require.New(t) + f := getRule("pushdown") + + table := memory.NewTable("mytable", sql.Schema{ + {Name: "i", Type: sql.Int32, Source: "mytable"}, + {Name: "f", Type: sql.Float64, Source: "mytable"}, + {Name: "t", Type: sql.Text, Source: "mytable"}, + }) + + table2 := memory.NewTable("mytable2", sql.Schema{ + {Name: "i2", Type: sql.Int32, Source: "mytable2"}, + {Name: "f2", Type: sql.Float64, Source: "mytable2"}, + {Name: "t2", Type: sql.Text, Source: "mytable2"}, + }) + + db := memory.NewDatabase("mydb") + db.AddTable("mytable", table) + db.AddTable("mytable2", table2) + + catalog := sql.NewCatalog() + catalog.AddDatabase(db) + a := NewDefault(catalog) + + node := plan.NewProject( + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int32, "mytable", "i", false), + }, + plan.NewFilter( + expression.NewAnd( + expression.NewEquals( + expression.NewGetFieldWithTable(1, sql.Float64, "mytable", "f", false), + expression.NewLiteral(3.14, sql.Float64), + ), + expression.NewIsNull( + expression.NewGetFieldWithTable(0, sql.Int32, "mytable2", "i2", false), + ), + ), + plan.NewCrossJoin( + plan.NewResolvedTable(table), + plan.NewResolvedTable(table2), + ), + ), + ) + + expected := plan.NewProject( + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int32, "mytable", "i", false), + }, + plan.NewCrossJoin( + plan.NewResolvedTable( + table.WithFilters([]sql.Expression{ + expression.NewEquals( + expression.NewGetFieldWithTable(1, sql.Float64, "mytable", "f", false), + expression.NewLiteral(3.14, sql.Float64), + ), + }).(*memory.Table).WithProjection([]string{"i", "f"}), + ), + plan.NewResolvedTable( + table2.WithFilters([]sql.Expression{ + expression.NewIsNull( + expression.NewGetFieldWithTable(0, sql.Int32, "mytable2", "i2", false), + ), + }).(*memory.Table).WithProjection([]string{"i2"}), + ), + ), + ) + + result, err := f.Apply(sql.NewEmptyContext(), a, node) + require.NoError(err) + require.Equal(expected, result) +} + +func TestPushdownIndexable(t *testing.T) { + require := require.New(t) + + table := memory.NewTable("mytable", sql.Schema{ + {Name: "i", Type: sql.Int32, Source: "mytable"}, + {Name: "f", Type: sql.Float64, Source: "mytable"}, + {Name: "t", Type: sql.Text, Source: "mytable"}, + }) + + table2 := memory.NewTable("mytable2", sql.Schema{ + {Name: "i2", Type: sql.Int32, Source: "mytable2"}, + {Name: "f2", Type: sql.Float64, Source: "mytable2"}, + {Name: "t2", Type: sql.Text, Source: "mytable2"}, + }) + + db := memory.NewDatabase("") + db.AddTable("mytable", table) + db.AddTable("mytable2", table2) + + catalog := sql.NewCatalog() + catalog.AddDatabase(db) + + idx1 := &dummyIndex{ + "mytable", + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int32, "mytable", "i", false), + }, + } + done, ready, err := catalog.AddIndex(idx1) + require.NoError(err) + close(done) + <-ready + + idx2 := &dummyIndex{ + "mytable", + []sql.Expression{ + expression.NewGetFieldWithTable(1, sql.Float64, "mytable", "f", false), + }, + } + done, ready, err = catalog.AddIndex(idx2) + require.NoError(err) + close(done) + <-ready + + idx3 := &dummyIndex{ + "mytable2", + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int32, "mytable2", "i2", false), + }, + } + done, ready, err = catalog.AddIndex(idx3) + + require.NoError(err) + close(done) + <-ready + + a := withoutProcessTracking(NewDefault(catalog)) + + node := plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedQualifiedColumn("mytable", "i"), + }, + plan.NewFilter( + expression.NewAnd( + expression.NewAnd( + expression.NewEquals( + expression.NewUnresolvedQualifiedColumn("mytable", "f"), + expression.NewLiteral(3.14, sql.Float64), + ), + expression.NewGreaterThan( + expression.NewUnresolvedQualifiedColumn("mytable", "i"), + expression.NewLiteral(1, sql.Int32), + ), + ), + expression.NewNot( + expression.NewEquals( + expression.NewUnresolvedQualifiedColumn("mytable2", "i2"), + expression.NewLiteral(2, sql.Int32), + ), + ), + ), + plan.NewCrossJoin( + plan.NewResolvedTable(table), + plan.NewResolvedTable(table2), + ), + ), + ) + + expected := &releaser{ + plan.NewProject( + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int32, "mytable", "i", false), + }, + plan.NewCrossJoin( + plan.NewResolvedTable( + table.WithFilters([]sql.Expression{ + expression.NewEquals( + expression.NewGetFieldWithTable(1, sql.Float64, "mytable", "f", false), + expression.NewLiteral(3.14, sql.Float64), + ), + expression.NewGreaterThan( + expression.NewGetFieldWithTable(0, sql.Int32, "mytable", "i", false), + expression.NewLiteral(1, sql.Int32), + ), + }).(*memory.Table). + WithProjection([]string{"i", "f"}).(*memory.Table). + WithIndexLookup(&mergeableIndexLookup{id: "3.14"}), + ), + plan.NewResolvedTable( + table2.WithFilters([]sql.Expression{ + expression.NewNot( + expression.NewEquals( + expression.NewGetFieldWithTable(0, sql.Int32, "mytable2", "i2", false), + expression.NewLiteral(2, sql.Int32), + ), + ), + }).(*memory.Table). + WithProjection([]string{"i2"}).(*memory.Table). + WithIndexLookup(&negateIndexLookup{value: "2"}), + ), + ), + ), + nil, + } + + result, err := a.Analyze(sql.NewEmptyContext(), node) + require.NoError(err) + + // we need to remove the release function to compare, otherwise it will fail + result, err = plan.TransformUp(result, func(node sql.Node) (sql.Node, error) { + switch node := node.(type) { + case *releaser: + return &releaser{Child: node.Child}, nil + default: + return node, nil + } + }) + require.NoError(err) + + require.Equal(expected, result) +} diff --git a/sql/analyzer/resolve_columns.go b/sql/analyzer/resolve_columns.go new file mode 100644 index 000000000..79e976739 --- /dev/null +++ b/sql/analyzer/resolve_columns.go @@ -0,0 +1,569 @@ +package analyzer + +import ( + "fmt" + "sort" + "strings" + + "github.com/src-d/go-mysql-server/internal/similartext" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/plan" + "gopkg.in/src-d/go-errors.v1" + "vitess.io/vitess/go/vt/sqlparser" +) + +func checkAliases(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + span, _ := ctx.Span("check_aliases") + defer span.Finish() + + a.Log("check aliases") + + var err error + plan.Inspect(n, func(node sql.Node) bool { + p, ok := node.(*plan.Project) + if !ok { + return true + } + + aliases := lookForAliasDeclarations(p) + for alias := range aliases { + if isAliasUsed(p, alias) { + err = ErrMisusedAlias.New(alias) + } + } + + return true + }) + + return n, err +} + +func lookForAliasDeclarations(node sql.Expressioner) map[string]struct{} { + var ( + aliases = map[string]struct{}{} + in = struct{}{} + ) + + for _, e := range node.Expressions() { + expression.Inspect(e, func(expr sql.Expression) bool { + if alias, ok := expr.(*expression.Alias); ok { + aliases[alias.Name()] = in + } + + return true + }) + } + + return aliases +} + +func isAliasUsed(node sql.Expressioner, alias string) bool { + var found bool + for _, e := range node.Expressions() { + expression.Inspect(e, func(expr sql.Expression) bool { + if a, ok := expr.(*expression.Alias); ok { + if a.Name() == alias { + return false + } + + return true + } + + if n, ok := expr.(sql.Nameable); ok && n.Name() == alias { + found = true + return false + } + + return true + }) + + if found { + break + } + } + + return found +} + +// deferredColumn is a wrapper on UnresolvedColumn used only to defer the +// resolution of the column because it may require some work done by +// other analyzer phases. +type deferredColumn struct { + *expression.UnresolvedColumn +} + +// IsNullable implements the Expression interface. +func (deferredColumn) IsNullable() bool { + return true +} + +// Children implements the Expression interface. +func (deferredColumn) Children() []sql.Expression { return nil } + +// WithChildren implements the Expression interface. +func (e deferredColumn) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(e, len(children), 0) + } + return e, nil +} + +type tableCol struct { + table string + col string +} + +type indexedCol struct { + *sql.Column + index int +} + +// column is the common interface that groups UnresolvedColumn and deferredColumn. +type column interface { + sql.Nameable + sql.Tableable + sql.Expression +} + +func qualifyColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { + if _, ok := n.(sql.Expressioner); !ok || n.Resolved() { + return n, nil + } + + columns := getNodeAvailableColumns(n) + tables := getNodeAvailableTables(n) + + return plan.TransformExpressions(n, func(e sql.Expression) (sql.Expression, error) { + return qualifyExpression(e, columns, tables) + }) + }) +} + +func qualifyExpression( + e sql.Expression, + columns map[string][]string, + tables map[string]string, +) (sql.Expression, error) { + switch col := e.(type) { + case column: + // Skip this step for global and session variables + if isGlobalOrSessionColumn(col) { + return col, nil + } + + name, table := strings.ToLower(col.Name()), strings.ToLower(col.Table()) + availableTables := dedupStrings(columns[name]) + if table != "" { + table, ok := tables[table] + if !ok { + if len(tables) == 0 { + return nil, sql.ErrTableNotFound.New(col.Table()) + } + + similar := similartext.FindFromMap(tables, col.Table()) + return nil, sql.ErrTableNotFound.New(col.Table() + similar) + } + + // If the table exists but it's not available for this node it + // means some work is still needed, so just return the column + // and let it be resolved in the next pass. + if !stringContains(availableTables, table) { + return col, nil + } + + return expression.NewUnresolvedQualifiedColumn(table, col.Name()), nil + } + + switch len(availableTables) { + case 0: + // If there are no tables that have any column with the column + // name let's just return it as it is. This may be an alias, so + // we'll wait for the reorder of the projection. + return col, nil + case 1: + return expression.NewUnresolvedQualifiedColumn( + availableTables[0], + col.Name(), + ), nil + default: + return nil, ErrAmbiguousColumnName.New(col.Name(), strings.Join(availableTables, ", ")) + } + case *expression.Star: + if col.Table != "" { + if real, ok := tables[strings.ToLower(col.Table)]; ok { + col = expression.NewQualifiedStar(real) + } + + if _, ok := tables[strings.ToLower(col.Table)]; !ok { + return nil, sql.ErrTableNotFound.New(col.Table) + } + } + return col, nil + default: + // If any other kind of expression has a star, just replace it + // with an unqualified star because it cannot be expanded. + return expression.TransformUp(e, func(e sql.Expression) (sql.Expression, error) { + if _, ok := e.(*expression.Star); ok { + return expression.NewStar(), nil + } + return e, nil + }) + } +} + +func getNodeAvailableColumns(n sql.Node) map[string][]string { + var columns = make(map[string][]string) + getColumnsInNodes(n.Children(), columns) + return columns +} + +func getColumnsInNodes(nodes []sql.Node, columns map[string][]string) { + indexCol := func(table, col string) { + col = strings.ToLower(col) + columns[col] = append(columns[col], strings.ToLower(table)) + } + + indexExpressions := func(exprs []sql.Expression) { + for _, e := range exprs { + switch e := e.(type) { + case *expression.Alias: + indexCol("", e.Name()) + case *expression.GetField: + indexCol(e.Table(), e.Name()) + case *expression.UnresolvedColumn: + indexCol(e.Table(), e.Name()) + } + } + } + + for _, node := range nodes { + switch n := node.(type) { + case *plan.ResolvedTable, *plan.SubqueryAlias: + for _, col := range n.Schema() { + indexCol(col.Source, col.Name) + } + case *plan.Project: + indexExpressions(n.Projections) + case *plan.GroupBy: + indexExpressions(n.Aggregate) + default: + getColumnsInNodes(n.Children(), columns) + } + } +} + +func getNodeAvailableTables(n sql.Node) map[string]string { + tables := make(map[string]string) + getNodesAvailableTables(tables, n.Children()...) + return tables +} + +func getNodesAvailableTables(tables map[string]string, nodes ...sql.Node) { + for _, n := range nodes { + switch n := n.(type) { + case *plan.SubqueryAlias, *plan.ResolvedTable: + name := strings.ToLower(n.(sql.Nameable).Name()) + tables[name] = name + case *plan.TableAlias: + switch t := n.Child.(type) { + case *plan.ResolvedTable, *plan.UnresolvedTable: + name := strings.ToLower(t.(sql.Nameable).Name()) + alias := strings.ToLower(n.Name()) + tables[alias] = name + // Also add the name of the table because you can refer to a + // table with either the alias or the name. + tables[name] = name + } + default: + getNodesAvailableTables(tables, n.Children()...) + } + } +} + +var errGlobalVariablesNotSupported = errors.NewKind("can't resolve global variable, %s was requested") + +const ( + sessionTable = "@@" + sqlparser.SessionStr + sessionPrefix = sqlparser.SessionStr + "." + globalPrefix = sqlparser.GlobalStr + "." +) + +func resolveColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + span, ctx := ctx.Span("resolve_columns") + defer span.Finish() + + a.Log("resolve columns, node of type: %T", n) + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { + a.Log("transforming node of type: %T", n) + if n.Resolved() { + return n, nil + } + + if _, ok := n.(sql.Expressioner); !ok { + return n, nil + } + + // We need to use the schema, so all children must be resolved. + for _, c := range n.Children() { + if !c.Resolved() { + return n, nil + } + } + + columns := findChildIndexedColumns(n) + return plan.TransformExpressions(n, func(e sql.Expression) (sql.Expression, error) { + a.Log("transforming expression of type: %T", e) + + uc, ok := e.(column) + if !ok || e.Resolved() { + return e, nil + } + + if isGlobalOrSessionColumn(uc) { + return resolveGlobalOrSessionColumn(ctx, uc) + } + + return resolveColumnExpression(ctx, uc, columns) + }) + }) +} + +func findChildIndexedColumns(n sql.Node) map[tableCol]indexedCol { + var idx int + var columns = make(map[tableCol]indexedCol) + + for _, child := range n.Children() { + for _, col := range child.Schema() { + columns[tableCol{ + table: strings.ToLower(col.Source), + col: strings.ToLower(col.Name), + }] = indexedCol{col, idx} + idx++ + } + } + + return columns +} + +func resolveGlobalOrSessionColumn(ctx *sql.Context, col column) (sql.Expression, error) { + if col.Table() != "" && strings.ToLower(col.Table()) != sessionTable { + return nil, errGlobalVariablesNotSupported.New(col) + } + + name := strings.TrimLeft(col.Name(), "@") + name = strings.TrimPrefix(strings.TrimPrefix(name, globalPrefix), sessionPrefix) + typ, value := ctx.Get(name) + return expression.NewGetSessionField(name, typ, value), nil +} + +func resolveColumnExpression( + ctx *sql.Context, + e column, + columns map[tableCol]indexedCol, +) (sql.Expression, error) { + name := strings.ToLower(e.Name()) + table := strings.ToLower(e.Table()) + col, ok := columns[tableCol{table, name}] + if !ok { + switch uc := e.(type) { + case *expression.UnresolvedColumn: + // Defer the resolution of the column to give the analyzer more + // time to resolve other parts so this can be resolved. + return &deferredColumn{uc}, nil + default: + if table != "" { + return nil, ErrColumnTableNotFound.New(e.Table(), e.Name()) + } + + return nil, ErrColumnNotFound.New(e.Name()) + } + } + + return expression.NewGetFieldWithTable( + col.index, + col.Type, + col.Source, + col.Name, + col.Nullable, + ), nil +} + +// resolveGroupingColumns reorders the aggregation in a groupby so aliases +// defined in it can be resolved in the grouping of the groupby. To do so, +// all aliases are pushed down to a projection node under the group by. +func resolveGroupingColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + a.Log("resoving group columns") + if n.Resolved() { + return n, nil + } + + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { + g, ok := n.(*plan.GroupBy) + if n.Resolved() || !ok || len(g.Grouping) == 0 { + return n, nil + } + + // The reason we have two sets of columns, one for grouping and + // one for aggregate is because an alias can redefine a column name + // of the child schema. In the grouping, if that column is referenced + // it refers to the alias, and not the one in the child. However, + // in the aggregate, aliases in that same aggregate cannot be used, + // so it refers to the column in the child node. + var groupingColumns = make(map[string]struct{}) + for _, g := range g.Grouping { + for _, n := range findAllColumns(g) { + groupingColumns[strings.ToLower(n)] = struct{}{} + } + } + + var aggregateColumns = make(map[string]struct{}) + for _, agg := range g.Aggregate { + // This alias is going to be pushed down, so don't bother gathering + // its requirements. + if alias, ok := agg.(*expression.Alias); ok { + if _, ok := groupingColumns[strings.ToLower(alias.Name())]; ok { + continue + } + } + + for _, n := range findAllColumns(agg) { + aggregateColumns[strings.ToLower(n)] = struct{}{} + } + } + + var newAggregate []sql.Expression + var projection []sql.Expression + // Aliases will keep the aliases that have been pushed down and their + // index in the new aggregate. + var aliases = make(map[string]int) + + var needsReorder bool + for _, a := range g.Aggregate { + alias, ok := a.(*expression.Alias) + // Note that aliases of aggregations cannot be used in the grouping + // because the grouping is needed before computing the aggregation. + if !ok || containsAggregation(alias) { + newAggregate = append(newAggregate, a) + continue + } + + name := strings.ToLower(alias.Name()) + // Only if the alias is required in the grouping set needsReorder + // to true. If it's not required, there's no need for a reorder if + // no other alias is required. + _, ok = groupingColumns[name] + if ok { + aliases[name] = len(newAggregate) + needsReorder = true + delete(groupingColumns, name) + + projection = append(projection, a) + newAggregate = append(newAggregate, expression.NewUnresolvedColumn(alias.Name())) + } else { + newAggregate = append(newAggregate, a) + } + } + + if !needsReorder { + return n, nil + } + + // Instead of iterating columns directly, we want them sorted so the + // executions of the rule are consistent. + var missingCols = make([]string, 0, len(aggregateColumns)+len(groupingColumns)) + for col := range aggregateColumns { + missingCols = append(missingCols, col) + } + for col := range groupingColumns { + missingCols = append(missingCols, col) + } + sort.Strings(missingCols) + + var renames = make(map[string]string) + // All columns required by expressions in both grouping and aggregation + // must also be projected in the new projection node or they will not + // be able to resolve. + for _, col := range missingCols { + name := col + // If an alias has been pushed down with the same name as a missing + // column, there will be a conflict of names. We must find an unique name + // for the missing column. + if _, ok := aliases[col]; ok { + for i := 1; ; i++ { + name = fmt.Sprintf("%s_%02d", col, i) + if !stringContains(missingCols, name) { + break + } + } + } + + if name == col { + projection = append(projection, expression.NewUnresolvedColumn(col)) + } else { + renames[col] = name + projection = append(projection, expression.NewAlias( + expression.NewUnresolvedColumn(col), + name, + )) + } + } + + // If there is any name conflict between columns we need to rename every + // usage inside the aggregate. + if len(renames) > 0 { + for i, expr := range newAggregate { + var err error + newAggregate[i], err = expression.TransformUp(expr, func(e sql.Expression) (sql.Expression, error) { + col, ok := e.(*expression.UnresolvedColumn) + if ok { + // We need to make sure we don't rename the reference to the + // pushed down alias. + if to, ok := renames[col.Name()]; ok && aliases[col.Name()] != i { + return expression.NewUnresolvedColumn(to), nil + } + } + + return e, nil + }) + if err != nil { + return nil, err + } + } + } + + return plan.NewGroupBy( + newAggregate, g.Grouping, + plan.NewProject(projection, g.Child), + ), nil + }) +} + +func findAllColumns(e sql.Expression) []string { + var cols []string + expression.Inspect(e, func(e sql.Expression) bool { + col, ok := e.(*expression.UnresolvedColumn) + if ok { + cols = append(cols, col.Name()) + } + return true + }) + return cols +} + +func dedupStrings(in []string) []string { + var seen = make(map[string]struct{}) + var result []string + for _, s := range in { + if _, ok := seen[s]; !ok { + seen[s] = struct{}{} + result = append(result, s) + } + } + return result +} + +func isGlobalOrSessionColumn(col column) bool { + return strings.HasPrefix(col.Name(), "@@") || strings.HasPrefix(col.Table(), "@@") +} diff --git a/sql/analyzer/resolve_columns_test.go b/sql/analyzer/resolve_columns_test.go new file mode 100644 index 000000000..4b111e943 --- /dev/null +++ b/sql/analyzer/resolve_columns_test.go @@ -0,0 +1,378 @@ +package analyzer + +import ( + "context" + "testing" + + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/plan" + "github.com/stretchr/testify/require" +) + +func TestQualifyColumnsProject(t *testing.T) { + require := require.New(t) + + table := memory.NewTable("foo", sql.Schema{ + {Name: "a", Type: sql.Text, Source: "foo"}, + {Name: "b", Type: sql.Text, Source: "foo"}, + }) + + node := plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedColumn("a"), + expression.NewUnresolvedColumn("b"), + }, + plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedQualifiedColumn("foo", "a"), + }, + plan.NewResolvedTable(table), + ), + ) + + result, err := qualifyColumns(sql.NewEmptyContext(), NewDefault(nil), node) + require.NoError(err) + + expected := plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedQualifiedColumn("foo", "a"), + // b is not qualified because it's not projected + expression.NewUnresolvedColumn("b"), + }, + plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedQualifiedColumn("foo", "a"), + }, + plan.NewResolvedTable(table), + ), + ) + + require.Equal(expected, result) +} + +func TestMisusedAlias(t *testing.T) { + require := require.New(t) + f := getRule("check_aliases") + + table := memory.NewTable("mytable", sql.Schema{ + {Name: "i", Type: sql.Int32}, + }) + + node := plan.NewProject( + []sql.Expression{ + expression.NewAlias( + expression.NewUnresolvedColumn("i"), + "alias_i", + ), + expression.NewUnresolvedColumn("alias_i"), + }, + plan.NewResolvedTable(table), + ) + + _, err := f.Apply(sql.NewEmptyContext(), nil, node) + require.EqualError(err, ErrMisusedAlias.New("alias_i").Error()) +} + +func TestQualifyColumns(t *testing.T) { + require := require.New(t) + f := getRule("qualify_columns") + + table := memory.NewTable("mytable", sql.Schema{{Name: "i", Type: sql.Int32, Source: "mytable"}}) + table2 := memory.NewTable("mytable2", sql.Schema{{Name: "i", Type: sql.Int32, Source: "mytable2"}}) + sessionTable := memory.NewTable("@@session", sql.Schema{{Name: "autocommit", Type: sql.Int64, Source: "@@session"}}) + globalTable := memory.NewTable("@@global", sql.Schema{{Name: "max_allowed_packet", Type: sql.Int64, Source: "@@global"}}) + + node := plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedColumn("@@max_allowed_packet"), + }, + plan.NewResolvedTable(globalTable), + ) + col, ok := node.Projections[0].(*expression.UnresolvedColumn) + require.True(ok) + require.Truef(isGlobalOrSessionColumn(col), "@@max_allowed_packet is not global or session column") + + expected := plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedQualifiedColumn("", "@@max_allowed_packet"), + }, + plan.NewResolvedTable(globalTable), + ) + + result, err := f.Apply(sql.NewEmptyContext(), nil, node) + require.NoError(err) + require.Equal(expected, result) + + node = plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedColumn("@@autocommit"), + }, + plan.NewResolvedTable(sessionTable), + ) + col, ok = node.Projections[0].(*expression.UnresolvedColumn) + require.True(ok) + require.Truef(isGlobalOrSessionColumn(col), "@@autocommit is not global or session column") + + expected = plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedQualifiedColumn("", "@@autocommit"), + }, + plan.NewResolvedTable(sessionTable), + ) + + result, err = f.Apply(sql.NewEmptyContext(), nil, node) + require.NoError(err) + require.Equal(expected, result) + + node = plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedColumn("i"), + }, + plan.NewResolvedTable(table), + ) + + expected = plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedQualifiedColumn("mytable", "i"), + }, + plan.NewResolvedTable(table), + ) + + result, err = f.Apply(sql.NewEmptyContext(), nil, node) + require.NoError(err) + require.Equal(expected, result) + + node = plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedQualifiedColumn("mytable", "i"), + }, + plan.NewResolvedTable(table), + ) + + result, err = f.Apply(sql.NewEmptyContext(), nil, node) + require.NoError(err) + require.Equal(expected, result) + + node = plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedQualifiedColumn("a", "i"), + }, + plan.NewTableAlias("a", plan.NewResolvedTable(table)), + ) + + expected = plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedQualifiedColumn("mytable", "i"), + }, + plan.NewTableAlias("a", plan.NewResolvedTable(table)), + ) + + result, err = f.Apply(sql.NewEmptyContext(), nil, node) + require.NoError(err) + require.Equal(expected, result) + + node = plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedColumn("z"), + }, + plan.NewTableAlias("a", plan.NewResolvedTable(table)), + ) + + result, err = f.Apply(sql.NewEmptyContext(), nil, node) + require.NoError(err) + require.Equal(node, result) + + node = plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedQualifiedColumn("foo", "i"), + }, + plan.NewTableAlias("a", plan.NewResolvedTable(table)), + ) + + _, err = f.Apply(sql.NewEmptyContext(), nil, node) + require.Error(err) + require.True(sql.ErrTableNotFound.Is(err)) + + node = plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedColumn("i"), + }, + plan.NewCrossJoin( + plan.NewResolvedTable(table), + plan.NewResolvedTable(table2), + ), + ) + + _, err = f.Apply(sql.NewEmptyContext(), nil, node) + require.Error(err) + require.True(ErrAmbiguousColumnName.Is(err)) + + subquery := plan.NewSubqueryAlias( + "b", + plan.NewProject( + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "mytable", "i", false), + }, + plan.NewResolvedTable(table), + ), + ) + // preload schema + _ = subquery.Schema() + + node = plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedQualifiedColumn("a", "i"), + }, + plan.NewCrossJoin( + plan.NewTableAlias("a", plan.NewResolvedTable(table)), + subquery, + ), + ) + + expected = plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedQualifiedColumn("mytable", "i"), + }, + plan.NewCrossJoin( + plan.NewTableAlias("a", plan.NewResolvedTable(table)), + subquery, + ), + ) + + result, err = f.Apply(sql.NewEmptyContext(), nil, node) + require.NoError(err) + require.Equal(expected, result) +} + +func TestQualifyColumnsQualifiedStar(t *testing.T) { + require := require.New(t) + f := getRule("qualify_columns") + + table := memory.NewTable("mytable", sql.Schema{{Name: "i", Type: sql.Int32}}) + + node := plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedFunction( + "count", + true, + expression.NewQualifiedStar("mytable"), + ), + }, + plan.NewResolvedTable(table), + ) + + expected := plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedFunction( + "count", + true, + expression.NewStar(), + ), + }, + plan.NewResolvedTable(table), + ) + + result, err := f.Apply(sql.NewEmptyContext(), nil, node) + require.NoError(err) + require.Equal(expected, result) +} + +func TestResolveColumnsSession(t *testing.T) { + require := require.New(t) + + ctx := sql.NewContext(context.Background(), sql.WithSession(sql.NewBaseSession())) + ctx.Set("foo_bar", sql.Int64, int64(42)) + ctx.Set("autocommit", sql.Boolean, true) + + node := plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedColumn("@@foo_bar"), + expression.NewUnresolvedColumn("@@bar_baz"), + expression.NewUnresolvedColumn("@@autocommit"), + }, + plan.NewResolvedTable(dualTable), + ) + + result, err := resolveColumns(ctx, NewDefault(nil), node) + require.NoError(err) + + expected := plan.NewProject( + []sql.Expression{ + expression.NewGetSessionField("foo_bar", sql.Int64, int64(42)), + expression.NewGetSessionField("bar_baz", sql.Null, nil), + expression.NewGetSessionField("autocommit", sql.Boolean, true), + }, + plan.NewResolvedTable(dualTable), + ) + + require.Equal(expected, result) +} + +func TestResolveGroupingColumns(t *testing.T) { + require := require.New(t) + + a := NewDefault(nil) + node := plan.NewGroupBy( + []sql.Expression{ + expression.NewAlias( + expression.NewUnresolvedFunction("foo", true, + expression.NewUnresolvedColumn("c"), + ), + "c", + ), + expression.NewAlias( + expression.NewUnresolvedColumn("d"), + "b", + ), + expression.NewUnresolvedFunction("bar", false, + expression.NewUnresolvedColumn("b"), + ), + }, + []sql.Expression{ + expression.NewUnresolvedColumn("a"), + expression.NewUnresolvedColumn("b"), + }, + plan.NewResolvedTable(memory.NewTable("table", nil)), + ) + + expected := plan.NewGroupBy( + []sql.Expression{ + expression.NewAlias( + expression.NewUnresolvedFunction("foo", true, + expression.NewUnresolvedColumn("c"), + ), + "c", + ), + expression.NewUnresolvedColumn("b"), + expression.NewUnresolvedFunction("bar", false, + expression.NewUnresolvedColumn("b_01"), + ), + }, + []sql.Expression{ + expression.NewUnresolvedColumn("a"), + expression.NewUnresolvedColumn("b"), + }, + plan.NewProject( + []sql.Expression{ + expression.NewAlias( + expression.NewUnresolvedColumn("d"), + "b", + ), + expression.NewUnresolvedColumn("a"), + expression.NewAlias( + expression.NewUnresolvedColumn("b"), + "b_01", + ), + expression.NewUnresolvedColumn("c"), + }, + plan.NewResolvedTable(memory.NewTable("table", nil)), + ), + ) + + result, err := resolveGroupingColumns(sql.NewEmptyContext(), a, node) + require.NoError(err) + + require.Equal(expected, result) +} diff --git a/sql/analyzer/resolve_database.go b/sql/analyzer/resolve_database.go new file mode 100644 index 000000000..2c5d0f628 --- /dev/null +++ b/sql/analyzer/resolve_database.go @@ -0,0 +1,38 @@ +package analyzer + +import ( + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/plan" +) + +func resolveDatabase(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + span, _ := ctx.Span("resolve_database") + defer span.Finish() + + a.Log("resolve database, node of type: %T", n) + + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { + d, ok := n.(sql.Databaser) + if !ok { + return n, nil + } + + var dbName = a.Catalog.CurrentDatabase() + if db := d.Database(); db != nil { + if _, ok := db.(sql.UnresolvedDatabase); !ok { + return n, nil + } + + if db.Name() != "" { + dbName = db.Name() + } + } + + db, err := a.Catalog.Database(dbName) + if err != nil { + return nil, err + } + + return d.WithDatabase(db) + }) +} diff --git a/sql/analyzer/resolve_functions.go b/sql/analyzer/resolve_functions.go new file mode 100644 index 000000000..34bac11b0 --- /dev/null +++ b/sql/analyzer/resolve_functions.go @@ -0,0 +1,47 @@ +package analyzer + +import ( + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/plan" +) + +func resolveFunctions(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + span, _ := ctx.Span("resolve_functions") + defer span.Finish() + + a.Log("resolve functions, node of type %T", n) + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { + a.Log("transforming node of type: %T", n) + if n.Resolved() { + return n, nil + } + + return plan.TransformExpressionsUp(n, func(e sql.Expression) (sql.Expression, error) { + a.Log("transforming expression of type: %T", e) + if e.Resolved() { + return e, nil + } + + uf, ok := e.(*expression.UnresolvedFunction) + if !ok { + return e, nil + } + + n := uf.Name() + f, err := a.Catalog.Function(n) + if err != nil { + return nil, err + } + + rf, err := f.Call(uf.Arguments...) + if err != nil { + return nil, err + } + + a.Log("resolved function %q", n) + + return rf, nil + }) + }) +} diff --git a/sql/analyzer/resolve_generators.go b/sql/analyzer/resolve_generators.go new file mode 100644 index 000000000..4635e24d6 --- /dev/null +++ b/sql/analyzer/resolve_generators.go @@ -0,0 +1,97 @@ +package analyzer + +import ( + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/expression/function" + "github.com/src-d/go-mysql-server/sql/plan" + "gopkg.in/src-d/go-errors.v1" +) + +var ( + errMultipleGenerators = errors.NewKind("there can't be more than 1 instance of EXPLODE in a SELECT") + errExplodeNotArray = errors.NewKind("argument of type %q given to EXPLODE, expecting array") +) + +func resolveGenerators(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { + p, ok := n.(*plan.Project) + if !ok { + return n, nil + } + + projection := p.Projections + + g, err := findGenerator(projection) + if err != nil { + return nil, err + } + + // There might be no generator in the project, in that case we don't + // have to do anything. + if g == nil { + return n, nil + } + + projection[g.idx] = g.expr + + var name string + if n, ok := g.expr.(sql.Nameable); ok { + name = n.Name() + } else { + name = g.expr.String() + } + + return plan.NewGenerate( + plan.NewProject(projection, p.Child), + expression.NewGetField(g.idx, g.expr.Type(), name, g.expr.IsNullable()), + ), nil + }) +} + +type generator struct { + idx int + expr sql.Expression +} + +// findGenerator will find in the given projection a generator column. If there +// is no generator, it will return nil. +// If there are is than one generator or the argument to explode is not an +// array it will fail. +// All occurrences of Explode will be replaced with Generate. +func findGenerator(exprs []sql.Expression) (*generator, error) { + var g = &generator{idx: -1} + for i, e := range exprs { + var found bool + switch e := e.(type) { + case *function.Explode: + found = true + g.expr = function.NewGenerate(e.Child) + case *expression.Alias: + if exp, ok := e.Child.(*function.Explode); ok { + found = true + g.expr = expression.NewAlias( + function.NewGenerate(exp.Child), + e.Name(), + ) + } + } + + if found { + if g.idx >= 0 { + return nil, errMultipleGenerators.New() + } + g.idx = i + + if !sql.IsArray(g.expr.Type()) { + return nil, errExplodeNotArray.New(g.expr.Type()) + } + } + } + + if g.expr == nil { + return nil, nil + } + + return g, nil +} diff --git a/sql/analyzer/resolve_generators_test.go b/sql/analyzer/resolve_generators_test.go new file mode 100644 index 000000000..8a7106df8 --- /dev/null +++ b/sql/analyzer/resolve_generators_test.go @@ -0,0 +1,117 @@ +package analyzer + +import ( + "testing" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/expression/function" + "github.com/src-d/go-mysql-server/sql/plan" + "github.com/stretchr/testify/require" + "gopkg.in/src-d/go-errors.v1" +) + +func TestResolveGenerators(t *testing.T) { + testCases := []struct { + name string + node sql.Node + expected sql.Node + err *errors.Kind + }{ + { + name: "regular explode", + node: plan.NewProject( + []sql.Expression{ + expression.NewGetField(0, sql.Int64, "a", false), + function.NewExplode(expression.NewGetField(1, sql.Array(sql.Int64), "b", false)), + expression.NewGetField(2, sql.Int64, "c", false), + }, + plan.NewUnresolvedTable("foo", ""), + ), + expected: plan.NewGenerate( + plan.NewProject( + []sql.Expression{ + expression.NewGetField(0, sql.Int64, "a", false), + function.NewGenerate(expression.NewGetField(1, sql.Array(sql.Int64), "b", false)), + expression.NewGetField(2, sql.Int64, "c", false), + }, + plan.NewUnresolvedTable("foo", ""), + ), + expression.NewGetField(1, sql.Array(sql.Int64), "EXPLODE(b)", false), + ), + err: nil, + }, + { + name: "explode with alias", + node: plan.NewProject( + []sql.Expression{ + expression.NewGetField(0, sql.Int64, "a", false), + expression.NewAlias( + function.NewExplode( + expression.NewGetField(1, sql.Array(sql.Int64), "b", false), + ), + "x", + ), + expression.NewGetField(2, sql.Int64, "c", false), + }, + plan.NewUnresolvedTable("foo", ""), + ), + expected: plan.NewGenerate( + plan.NewProject( + []sql.Expression{ + expression.NewGetField(0, sql.Int64, "a", false), + expression.NewAlias( + function.NewGenerate( + expression.NewGetField(1, sql.Array(sql.Int64), "b", false), + ), + "x", + ), + expression.NewGetField(2, sql.Int64, "c", false), + }, + plan.NewUnresolvedTable("foo", ""), + ), + expression.NewGetField(1, sql.Array(sql.Int64), "x", false), + ), + err: nil, + }, + { + name: "non array type on explode", + node: plan.NewProject( + []sql.Expression{ + expression.NewGetField(0, sql.Int64, "a", false), + function.NewExplode(expression.NewGetField(1, sql.Int64, "b", false)), + }, + plan.NewUnresolvedTable("foo", ""), + ), + expected: nil, + err: errExplodeNotArray, + }, + { + name: "more than one generator", + node: plan.NewProject( + []sql.Expression{ + expression.NewGetField(0, sql.Int64, "a", false), + function.NewExplode(expression.NewGetField(1, sql.Array(sql.Int64), "b", false)), + function.NewExplode(expression.NewGetField(2, sql.Array(sql.Int64), "c", false)), + }, + plan.NewUnresolvedTable("foo", ""), + ), + expected: nil, + err: errMultipleGenerators, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + result, err := resolveGenerators(sql.NewEmptyContext(), nil, tt.node) + if tt.err != nil { + require.Error(err) + require.True(tt.err.Is(err)) + } else { + require.NoError(err) + require.Equal(tt.expected, result) + } + }) + } +} diff --git a/sql/analyzer/resolve_having.go b/sql/analyzer/resolve_having.go new file mode 100644 index 000000000..fdac4f67c --- /dev/null +++ b/sql/analyzer/resolve_having.go @@ -0,0 +1,502 @@ +package analyzer + +import ( + "reflect" + "strings" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/expression/function/aggregation" + "github.com/src-d/go-mysql-server/sql/plan" + "gopkg.in/src-d/go-errors.v1" +) + +func resolveHaving(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, error) { + return plan.TransformUp(node, func(node sql.Node) (sql.Node, error) { + having, ok := node.(*plan.Having) + if !ok { + return node, nil + } + + if !having.Child.Resolved() { + return node, nil + } + + originalSchema := having.Schema() + + var requiresProjection bool + if containsAggregation(having.Cond) { + var err error + having, requiresProjection, err = replaceAggregations(having) + if err != nil { + return nil, err + } + } + + missingCols := findMissingColumns(having, having.Cond) + // If all the columns required by the having are available, do nothing about it. + if len(missingCols) > 0 { + var err error + having, err = pushMissingColumnsUp(having, missingCols) + if err != nil { + return nil, err + } + requiresProjection = true + } + + if !requiresProjection { + return having, nil + } + + return projectOriginalAggregation(having, originalSchema), nil + }) +} + +func findMissingColumns(node sql.Node, expr sql.Expression) []string { + var schemaCols []string + for _, col := range node.Schema() { + schemaCols = append(schemaCols, strings.ToLower(col.Name)) + } + + var missingCols []string + for _, n := range findExprNameables(expr) { + name := strings.ToLower(n.Name()) + if !stringContains(schemaCols, name) { + missingCols = append(missingCols, n.Name()) + } + } + + return missingCols +} + +func projectOriginalAggregation(having *plan.Having, schema sql.Schema) *plan.Project { + var projection []sql.Expression + for i, col := range schema { + projection = append( + projection, + expression.NewGetFieldWithTable(i, col.Type, col.Source, col.Name, col.Nullable), + ) + } + + return plan.NewProject(projection, having) +} + +var errHavingChildMissingRef = errors.NewKind("cannot find column %s referenced in HAVING clause in either GROUP BY or its child") + +func pushMissingColumnsUp( + having *plan.Having, + missingCols []string, +) (*plan.Having, error) { + groupBy, err := findGroupBy(having) + if err != nil { + return nil, err + } + + schema := groupBy.Child.Schema() + var newAggregate []sql.Expression + for _, c := range missingCols { + idx := -1 + for i, col := range schema { + if strings.ToLower(c) == strings.ToLower(col.Name) { + idx = i + break + } + } + if idx < 0 { + return nil, errHavingChildMissingRef.New(c) + } + col := schema[idx] + newAggregate = append( + newAggregate, + expression.NewGetFieldWithTable(idx, col.Type, col.Source, col.Name, col.Nullable), + ) + } + + node, err := addColumnsToGroupBy(having, newAggregate) + if err != nil { + return nil, err + } + + return node.(*plan.Having), nil +} + +func findGroupBy(n sql.Node) (*plan.GroupBy, error) { + children := n.Children() + if len(children) != 1 { + return nil, errHavingNeedsGroupBy.New() + } + + if g, ok := children[0].(*plan.GroupBy); ok { + return g, nil + } + + return findGroupBy(children[0]) +} + +func addColumnsToGroupBy(node sql.Node, columns []sql.Expression) (sql.Node, error) { + switch node := node.(type) { + case *plan.Project: + child, err := addColumnsToGroupBy(node.Child, columns) + if err != nil { + return nil, err + } + + var newProjections = make([]sql.Expression, len(columns)) + for i, col := range columns { + var name = col.String() + var table string + if n, ok := col.(sql.Nameable); ok { + name = n.Name() + } + + if t, ok := col.(sql.Tableable); ok { + table = t.Table() + } + + newProjections[i] = expression.NewGetFieldWithTable( + len(child.Schema())-len(columns)+i, + col.Type(), + table, + name, + col.IsNullable(), + ) + } + + return plan.NewProject(append(node.Projections, newProjections...), child), nil + case *plan.Filter, + *plan.Sort, + *plan.Limit, + *plan.Offset, + *plan.Distinct, + *plan.Having: + child, err := addColumnsToGroupBy(node.Children()[0], columns) + if err != nil { + return nil, err + } + return node.WithChildren(child) + case *plan.GroupBy: + return plan.NewGroupBy(append(node.Aggregate, columns...), node.Grouping, node.Child), nil + default: + return nil, errHavingNeedsGroupBy.New() + } +} + +// pushColumnsUp pushes up the group by columns with the given indexes. +// It returns the resultant node, the indexes of those pushed up columns in the +// resultant node and an error, if any. +func pushColumnsUp(node sql.Node, columns []int) (sql.Node, []int, error) { + switch node := node.(type) { + case *plan.Project: + child, columns, err := pushColumnsUp(node.Child, columns) + if err != nil { + return nil, nil, err + } + + var seen = make(map[int]int) + for i, col := range node.Projections { + switch col := col.(type) { + case *expression.Alias: + if f, ok := col.Child.(*expression.GetField); ok { + seen[f.Index()] = i + } + case *expression.GetField: + seen[col.Index()] = i + } + } + + var newProjections = make([]sql.Expression, len(node.Projections)) + copy(newProjections, node.Projections) + schema := child.Schema() + var newColumns []int + + for _, idx := range columns { + if newIdx, ok := seen[idx]; ok { + newColumns = append(newColumns, newIdx) + continue + } + + col := schema[idx] + newIdx := len(newProjections) + newProjections = append(newProjections, expression.NewGetFieldWithTable( + newIdx, + col.Type, + col.Source, + col.Name, + col.Nullable, + )) + newColumns = append(newColumns, newIdx) + } + + return plan.NewProject(newProjections, child), newColumns, nil + case *plan.Filter: + child, columns, err := pushColumnsUp(node.Child, columns) + if err != nil { + return nil, nil, err + } + return plan.NewFilter(node.Expression, child), columns, nil + case *plan.Sort: + child, columns, err := pushColumnsUp(node.Child, columns) + if err != nil { + return nil, nil, err + } + return plan.NewSort(node.SortFields, child), columns, nil + case *plan.Limit: + child, columns, err := pushColumnsUp(node.Child, columns) + if err != nil { + return nil, nil, err + } + return plan.NewLimit(node.Limit, child), columns, nil + case *plan.Offset: + child, columns, err := pushColumnsUp(node.Child, columns) + if err != nil { + return nil, nil, err + } + return plan.NewOffset(node.Offset, child), columns, nil + case *plan.Distinct: + child, columns, err := pushColumnsUp(node.Child, columns) + if err != nil { + return nil, nil, err + } + return plan.NewDistinct(child), columns, nil + case *plan.GroupBy: + return node, columns, nil + case *plan.Having: + child, columns, err := pushColumnsUp(node.Child, columns) + if err != nil { + return nil, nil, err + } + return plan.NewHaving(node.Cond, child), columns, nil + default: + return nil, nil, errHavingNeedsGroupBy.New() + } +} + +func replaceAggregations(having *plan.Having) (*plan.Having, bool, error) { + groupBy, err := findGroupBy(having) + if err != nil { + return nil, false, err + } + + var newAggregate []sql.Expression + + var pushUp []int + var tokenToIdx = make(map[int]int) + var pushUpToken = -1 + + // We need to find all aggregations inside the having condition. The ones + // that are already present in the group by will be pushed up and the ones + // that are not, will be added to the group by and pushed up. + // + // To push up already existing aggregations we need to change all possible + // projections between the having and the group by, so we will need to + // assign some fake token indexes to replace later with the actual column + // indexes after they have been pushed up. This is because some of these + // may have already been projected in some projection and we cannot ensure + // from here what the final index will be. + cond, err := expression.TransformUp(having.Cond, func(e sql.Expression) (sql.Expression, error) { + agg, ok := e.(sql.Aggregation) + if !ok { + return e, nil + } + + for i, expr := range groupBy.Aggregate { + if aggregationEquals(agg, expr) { + token := pushUpToken + pushUpToken-- + pushUp = append(pushUp, i) + tokenToIdx[token] = len(pushUp) - 1 + return expression.NewGetField( + token, + expr.Type(), + expr.String(), + expr.IsNullable(), + ), nil + } + } + + newAggregate = append(newAggregate, agg) + return expression.NewGetField( + len(having.Child.Schema())+len(newAggregate)-1, + agg.Type(), + agg.String(), + agg.IsNullable(), + ), nil + }) + if err != nil { + return nil, false, err + } + + // The new aggregations will be added to the group by and pushed up until + // the topmost node. + having = plan.NewHaving(cond, having.Child) + node, err := addColumnsToGroupBy(having, newAggregate) + if err != nil { + return nil, false, err + } + + // Then, the ones that already existed are pushed up and we get the final + // indexes at the topmost node (the having) in the same order. + node, pushedUpColumns, err := pushColumnsUp(node, pushUp) + if err != nil { + return nil, false, err + } + + newSchema := node.Schema() + requiresProjection := len(newSchema) != len(having.Schema()) + having = node.(*plan.Having) + + // Now, the tokens are replaced with the actual columns, now that we know + // what the indexes are. + cond, err = expression.TransformUp(having.Cond, func(e sql.Expression) (sql.Expression, error) { + f, ok := e.(*expression.GetField) + if !ok { + return e, nil + } + + idx, ok := tokenToIdx[f.Index()] + if !ok { + return e, nil + } + + idx = pushedUpColumns[idx] + col := newSchema[idx] + return expression.NewGetFieldWithTable(idx, col.Type, col.Source, col.Name, col.Nullable), nil + }) + if err != nil { + return nil, false, err + } + + return plan.NewHaving(cond, having.Child), requiresProjection, nil +} + +func aggregationEquals(a, b sql.Expression) bool { + // First unwrap aliases + if alias, ok := b.(*expression.Alias); ok { + b = alias.Child + } else if alias, ok := a.(*expression.Alias); ok { + a = alias.Child + } + + switch a := a.(type) { + case *aggregation.Count: + // it doesn't matter what's inside a Count, the result will be + // the same. + _, ok := b.(*aggregation.Count) + return ok + case *aggregation.CountDistinct: + // it doesn't matter what's inside a Count, the result will be + // the same. + _, ok := b.(*aggregation.CountDistinct) + return ok + case *aggregation.Sum: + b, ok := b.(*aggregation.Sum) + if !ok { + return false + } + + return aggregationChildEquals(a.Child, b.Child) + case *aggregation.Avg: + b, ok := b.(*aggregation.Avg) + if !ok { + return false + } + + return aggregationChildEquals(a.Child, b.Child) + case *aggregation.Min: + b, ok := b.(*aggregation.Min) + if !ok { + return false + } + + return aggregationChildEquals(a.Child, b.Child) + case *aggregation.Max: + b, ok := b.(*aggregation.Max) + if !ok { + return false + } + + return aggregationChildEquals(a.Child, b.Child) + case *aggregation.First: + b, ok := b.(*aggregation.First) + if !ok { + return false + } + + return aggregationChildEquals(a.Child, b.Child) + case *aggregation.Last: + b, ok := b.(*aggregation.Last) + if !ok { + return false + } + + return aggregationChildEquals(a.Child, b.Child) + default: + return false + } +} + +// aggregationChildEquals checks if expression a coming from the having +// matches expression b coming from the group by. To do that, columns in +// a need to be replaced to match the ones in b if their name or table and +// name match. +func aggregationChildEquals(a, b sql.Expression) bool { + var fieldsByName = make(map[string]sql.Expression) + var fieldsByTableCol = make(map[tableCol]sql.Expression) + expression.Inspect(b, func(e sql.Expression) bool { + gf, ok := e.(*expression.GetField) + if ok { + fieldsByTableCol[tableCol{ + strings.ToLower(gf.Table()), + strings.ToLower(gf.Name()), + }] = e + fieldsByName[strings.ToLower(gf.Name())] = e + } + return true + }) + + a, err := expression.TransformUp(a, func(e sql.Expression) (sql.Expression, error) { + var table, name string + switch e := e.(type) { + case *expression.UnresolvedColumn: + table = strings.ToLower(e.Table()) + name = strings.ToLower(e.Name()) + case *expression.GetField: + table = strings.ToLower(e.Table()) + name = strings.ToLower(e.Name()) + } + + if table == "" { + f, ok := fieldsByName[name] + if !ok { + return e, nil + } + return f, nil + } + + f, ok := fieldsByTableCol[tableCol{table, name}] + if !ok { + return e, nil + } + return f, nil + }) + if err != nil { + return false + } + + return reflect.DeepEqual(a, b) +} + +var errHavingNeedsGroupBy = errors.NewKind("found HAVING clause with no GROUP BY") + +func hasAggregations(expr sql.Expression) bool { + var has bool + expression.Inspect(expr, func(e sql.Expression) bool { + _, ok := e.(sql.Aggregation) + if ok { + has = true + return false + } + return true + }) + return has +} diff --git a/sql/analyzer/resolve_having_test.go b/sql/analyzer/resolve_having_test.go new file mode 100644 index 000000000..cc3707671 --- /dev/null +++ b/sql/analyzer/resolve_having_test.go @@ -0,0 +1,311 @@ +package analyzer + +import ( + "testing" + + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/expression/function/aggregation" + "github.com/src-d/go-mysql-server/sql/plan" + "github.com/stretchr/testify/require" + "gopkg.in/src-d/go-errors.v1" +) + +func TestResolveHaving(t *testing.T) { + testCases := []struct { + name string + input sql.Node + expected sql.Node + err *errors.Kind + }{ + { + "replace existing aggregation in group by", + plan.NewHaving( + expression.NewGreaterThan( + aggregation.NewAvg(expression.NewUnresolvedColumn("foo")), + expression.NewLiteral(int64(5), sql.Int64), + ), + plan.NewGroupBy( + []sql.Expression{ + expression.NewAlias(aggregation.NewAvg(expression.NewGetFieldWithTable(0, sql.Int64, "t", "foo", false)), "x"), + expression.NewGetField(0, sql.Int64, "foo", false), + }, + []sql.Expression{expression.NewGetField(0, sql.Int64, "foo", false)}, + plan.NewResolvedTable(memory.NewTable("t", nil)), + ), + ), + plan.NewHaving( + expression.NewGreaterThan( + expression.NewGetField(0, sql.Float64, "x", true), + expression.NewLiteral(int64(5), sql.Int64), + ), + plan.NewGroupBy( + []sql.Expression{ + expression.NewAlias(aggregation.NewAvg(expression.NewGetFieldWithTable(0, sql.Int64, "t", "foo", false)), "x"), + expression.NewGetField(0, sql.Int64, "foo", false), + }, + []sql.Expression{expression.NewGetField(0, sql.Int64, "foo", false)}, + plan.NewResolvedTable(memory.NewTable("t", nil)), + ), + ), + nil, + }, + { + "push down aggregation to group by", + plan.NewHaving( + expression.NewGreaterThan( + aggregation.NewCount(expression.NewStar()), + expression.NewLiteral(int64(5), sql.Int64), + ), + plan.NewGroupBy( + []sql.Expression{ + expression.NewAlias(aggregation.NewAvg(expression.NewGetField(0, sql.Int64, "foo", false)), "x"), + expression.NewGetFieldWithTable(0, sql.Int64, "t", "foo", false), + }, + []sql.Expression{expression.NewGetField(0, sql.Int64, "foo", false)}, + plan.NewResolvedTable(memory.NewTable("t", nil)), + ), + ), + plan.NewProject( + []sql.Expression{ + expression.NewGetField(0, sql.Float64, "x", true), + expression.NewGetFieldWithTable(1, sql.Int64, "t", "foo", false), + }, + plan.NewHaving( + expression.NewGreaterThan( + expression.NewGetField(2, sql.Int64, "COUNT(*)", false), + expression.NewLiteral(int64(5), sql.Int64), + ), + plan.NewGroupBy( + []sql.Expression{ + expression.NewAlias(aggregation.NewAvg(expression.NewGetField(0, sql.Int64, "foo", false)), "x"), + expression.NewGetFieldWithTable(0, sql.Int64, "t", "foo", false), + aggregation.NewCount(expression.NewStar()), + }, + []sql.Expression{expression.NewGetField(0, sql.Int64, "foo", false)}, + plan.NewResolvedTable(memory.NewTable("t", nil)), + ), + ), + ), + nil, + }, + { + "push up missing column", + plan.NewHaving( + expression.NewGreaterThan( + expression.NewUnresolvedColumn("i"), + expression.NewLiteral(int64(5), sql.Int64), + ), + plan.NewGroupBy( + []sql.Expression{ + expression.NewGetFieldWithTable(1, sql.Int64, "t", "foo", false), + }, + []sql.Expression{expression.NewGetFieldWithTable(1, sql.Int64, "t", "foo", false)}, + plan.NewResolvedTable(memory.NewTable("t", sql.Schema{ + {Type: sql.Int64, Name: "i", Source: "t"}, + {Type: sql.Int64, Name: "i", Source: "foo"}, + })), + ), + ), + plan.NewProject( + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "t", "foo", false), + }, + plan.NewHaving( + expression.NewGreaterThan( + expression.NewUnresolvedColumn("i"), + expression.NewLiteral(int64(5), sql.Int64), + ), + plan.NewGroupBy( + []sql.Expression{ + expression.NewGetFieldWithTable(1, sql.Int64, "t", "foo", false), + expression.NewGetFieldWithTable(0, sql.Int64, "t", "i", false), + }, + []sql.Expression{expression.NewGetFieldWithTable(1, sql.Int64, "t", "foo", false)}, + plan.NewResolvedTable(memory.NewTable("t", sql.Schema{ + {Type: sql.Int64, Name: "i", Source: "t"}, + {Type: sql.Int64, Name: "i", Source: "foo"}, + })), + ), + ), + ), + nil, + }, + { + "push up missing column with nodes in between", + plan.NewHaving( + expression.NewGreaterThan( + expression.NewUnresolvedColumn("i"), + expression.NewLiteral(int64(5), sql.Int64), + ), + plan.NewProject( + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "t", "foo", false), + }, + plan.NewGroupBy( + []sql.Expression{ + expression.NewGetFieldWithTable(1, sql.Int64, "t", "foo", false), + }, + []sql.Expression{expression.NewGetFieldWithTable(1, sql.Int64, "t", "foo", false)}, + plan.NewResolvedTable(memory.NewTable("t", sql.Schema{ + {Type: sql.Int64, Name: "i", Source: "t"}, + {Type: sql.Int64, Name: "i", Source: "foo"}, + })), + ), + ), + ), + plan.NewProject( + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "t", "foo", false), + }, + plan.NewHaving( + expression.NewGreaterThan( + expression.NewUnresolvedColumn("i"), + expression.NewLiteral(int64(5), sql.Int64), + ), + plan.NewProject( + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "t", "foo", false), + expression.NewGetFieldWithTable(1, sql.Int64, "t", "i", false), + }, + plan.NewGroupBy( + []sql.Expression{ + expression.NewGetFieldWithTable(1, sql.Int64, "t", "foo", false), + expression.NewGetFieldWithTable(0, sql.Int64, "t", "i", false), + }, + []sql.Expression{expression.NewGetFieldWithTable(1, sql.Int64, "t", "foo", false)}, + plan.NewResolvedTable(memory.NewTable("t", sql.Schema{ + {Type: sql.Int64, Name: "i", Source: "t"}, + {Type: sql.Int64, Name: "i", Source: "foo"}, + })), + ), + ), + ), + ), + nil, + }, + { + "push down aggregations with nodes in between", + plan.NewHaving( + expression.NewGreaterThan( + aggregation.NewCount(expression.NewStar()), + expression.NewLiteral(int64(5), sql.Int64), + ), + plan.NewProject( + []sql.Expression{ + expression.NewAlias(expression.NewGetField(0, sql.Float64, "avg(foo)", false), "x"), + expression.NewGetFieldWithTable(1, sql.Int64, "t", "foo", false), + }, + plan.NewGroupBy( + []sql.Expression{ + aggregation.NewAvg(expression.NewGetField(0, sql.Int64, "foo", false)), + expression.NewGetFieldWithTable(0, sql.Int64, "t", "foo", false), + }, + []sql.Expression{expression.NewGetField(0, sql.Int64, "foo", false)}, + plan.NewResolvedTable(memory.NewTable("t", nil)), + ), + ), + ), + plan.NewProject( + []sql.Expression{ + expression.NewGetField(0, sql.Float64, "x", false), + expression.NewGetFieldWithTable(1, sql.Int64, "t", "foo", false), + }, + plan.NewHaving( + expression.NewGreaterThan( + expression.NewGetField(2, sql.Int64, "COUNT(*)", false), + expression.NewLiteral(int64(5), sql.Int64), + ), + plan.NewProject( + []sql.Expression{ + expression.NewAlias(expression.NewGetField(0, sql.Float64, "avg(foo)", false), "x"), + expression.NewGetFieldWithTable(1, sql.Int64, "t", "foo", false), + expression.NewGetField(2, sql.Int64, "COUNT(*)", false), + }, + plan.NewGroupBy( + []sql.Expression{ + aggregation.NewAvg(expression.NewGetField(0, sql.Int64, "foo", false)), + expression.NewGetFieldWithTable(0, sql.Int64, "t", "foo", false), + aggregation.NewCount(expression.NewStar()), + }, + []sql.Expression{expression.NewGetField(0, sql.Int64, "foo", false)}, + plan.NewResolvedTable(memory.NewTable("t", nil)), + ), + ), + ), + ), + nil, + }, + { + "replace existing aggregation in group by with nodes in between", + plan.NewHaving( + expression.NewGreaterThan( + aggregation.NewAvg(expression.NewUnresolvedColumn("foo")), + expression.NewLiteral(int64(5), sql.Int64), + ), + plan.NewProject( + []sql.Expression{ + expression.NewGetField(0, sql.Float64, "x", false), + expression.NewGetField(1, sql.Int64, "foo", false), + }, + plan.NewGroupBy( + []sql.Expression{ + expression.NewAlias(aggregation.NewAvg(expression.NewGetFieldWithTable(0, sql.Int64, "t", "foo", false)), "x"), + expression.NewGetField(0, sql.Int64, "foo", false), + }, + []sql.Expression{expression.NewGetField(0, sql.Int64, "foo", false)}, + plan.NewResolvedTable(memory.NewTable("t", nil)), + ), + ), + ), + plan.NewHaving( + expression.NewGreaterThan( + expression.NewGetField(0, sql.Float64, "x", false), + expression.NewLiteral(int64(5), sql.Int64), + ), + plan.NewProject( + []sql.Expression{ + expression.NewGetField(0, sql.Float64, "x", false), + expression.NewGetField(1, sql.Int64, "foo", false), + }, + plan.NewGroupBy( + []sql.Expression{ + expression.NewAlias(aggregation.NewAvg(expression.NewGetFieldWithTable(0, sql.Int64, "t", "foo", false)), "x"), + expression.NewGetField(0, sql.Int64, "foo", false), + }, + []sql.Expression{expression.NewGetField(0, sql.Int64, "foo", false)}, + plan.NewResolvedTable(memory.NewTable("t", nil)), + ), + ), + ), + nil, + }, + { + "missing groupby", + plan.NewHaving( + expression.NewGreaterThan( + aggregation.NewCount(expression.NewStar()), + expression.NewLiteral(int64(5), sql.Int64), + ), + plan.NewResolvedTable(memory.NewTable("t", nil)), + ), + nil, + errHavingNeedsGroupBy, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + result, err := resolveHaving(sql.NewEmptyContext(), nil, tt.input) + if tt.err != nil { + require.Error(err) + require.True(tt.err.Is(err)) + } else { + require.NoError(err) + require.Equal(tt.expected, result) + } + }) + } +} diff --git a/sql/analyzer/resolve_natural_joins.go b/sql/analyzer/resolve_natural_joins.go new file mode 100644 index 000000000..a6cf1fdb9 --- /dev/null +++ b/sql/analyzer/resolve_natural_joins.go @@ -0,0 +1,137 @@ +package analyzer + +import ( + "strings" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/plan" +) + +func resolveNaturalJoins(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + span, _ := ctx.Span("resolve_natural_joins") + defer span.Finish() + + var replacements = make(map[tableCol]tableCol) + var tableAliases = make(map[string]string) + + return plan.TransformUp(n, func(node sql.Node) (sql.Node, error) { + switch n := node.(type) { + case *plan.TableAlias: + alias := n.Name() + table := n.Child.(*plan.ResolvedTable).Name() + tableAliases[strings.ToLower(alias)] = table + return n, nil + case *plan.NaturalJoin: + return resolveNaturalJoin(n, replacements) + case sql.Expressioner: + return replaceExpressions(node, replacements, tableAliases) + default: + return n, nil + } + }) +} + +func resolveNaturalJoin( + n *plan.NaturalJoin, + replacements map[tableCol]tableCol, +) (sql.Node, error) { + // Both sides of the natural join need to be resolved in order to resolve + // the natural join itself. + if !n.Left.Resolved() || !n.Right.Resolved() { + return n, nil + } + + leftSchema := n.Left.Schema() + rightSchema := n.Right.Schema() + + var conditions, common, left, right []sql.Expression + for i, lcol := range leftSchema { + leftCol := expression.NewGetFieldWithTable( + i, + lcol.Type, + lcol.Source, + lcol.Name, + lcol.Nullable, + ) + if idx, rcol := findCol(rightSchema, lcol.Name); rcol != nil { + common = append(common, leftCol) + replacements[tableCol{strings.ToLower(rcol.Source), strings.ToLower(rcol.Name)}] = tableCol{ + strings.ToLower(lcol.Source), strings.ToLower(lcol.Name), + } + + conditions = append( + conditions, + expression.NewEquals( + leftCol, + expression.NewGetFieldWithTable( + len(leftSchema)+idx, + rcol.Type, + rcol.Source, + rcol.Name, + rcol.Nullable, + ), + ), + ) + } else { + left = append(left, leftCol) + } + } + + if len(conditions) == 0 { + return plan.NewCrossJoin(n.Left, n.Right), nil + } + + for i, col := range rightSchema { + source := strings.ToLower(col.Source) + name := strings.ToLower(col.Name) + if _, ok := replacements[tableCol{source, name}]; !ok { + right = append( + right, + expression.NewGetFieldWithTable( + len(leftSchema)+i, + col.Type, + col.Source, + col.Name, + col.Nullable, + ), + ) + } + } + + return plan.NewProject( + append(append(common, left...), right...), + plan.NewInnerJoin(n.Left, n.Right, expression.JoinAnd(conditions...)), + ), nil +} + +func findCol(s sql.Schema, name string) (int, *sql.Column) { + for i, c := range s { + if strings.ToLower(c.Name) == strings.ToLower(name) { + return i, c + } + } + return -1, nil +} + +func replaceExpressions( + n sql.Node, + replacements map[tableCol]tableCol, + tableAliases map[string]string, +) (sql.Node, error) { + return plan.TransformExpressions(n, func(e sql.Expression) (sql.Expression, error) { + switch e := e.(type) { + case *expression.GetField, *expression.UnresolvedColumn: + var tableName = e.(sql.Tableable).Table() + if t, ok := tableAliases[strings.ToLower(tableName)]; ok { + tableName = t + } + + name := e.(sql.Nameable).Name() + if col, ok := replacements[tableCol{strings.ToLower(tableName), strings.ToLower(name)}]; ok { + return expression.NewUnresolvedQualifiedColumn(col.table, col.col), nil + } + } + return e, nil + }) +} diff --git a/sql/analyzer/resolve_natural_joins_test.go b/sql/analyzer/resolve_natural_joins_test.go new file mode 100644 index 000000000..4030f8f93 --- /dev/null +++ b/sql/analyzer/resolve_natural_joins_test.go @@ -0,0 +1,378 @@ +package analyzer + +import ( + "testing" + + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/plan" + "github.com/stretchr/testify/require" +) + +func TestResolveNaturalJoins(t *testing.T) { + require := require.New(t) + + left := memory.NewTable("t1", sql.Schema{ + {Name: "a", Type: sql.Int64, Source: "t1"}, + {Name: "b", Type: sql.Int64, Source: "t1"}, + {Name: "c", Type: sql.Int64, Source: "t1"}, + }) + + right := memory.NewTable("t2", sql.Schema{ + {Name: "d", Type: sql.Int64, Source: "t2"}, + {Name: "c", Type: sql.Int64, Source: "t2"}, + {Name: "b", Type: sql.Int64, Source: "t2"}, + {Name: "e", Type: sql.Int64, Source: "t2"}, + }) + + node := plan.NewNaturalJoin( + plan.NewResolvedTable(left), + plan.NewResolvedTable(right), + ) + rule := getRule("resolve_natural_joins") + + result, err := rule.Apply(sql.NewEmptyContext(), NewDefault(nil), node) + require.NoError(err) + + expected := plan.NewProject( + []sql.Expression{ + expression.NewGetFieldWithTable(1, sql.Int64, "t1", "b", false), + expression.NewGetFieldWithTable(2, sql.Int64, "t1", "c", false), + expression.NewGetFieldWithTable(0, sql.Int64, "t1", "a", false), + expression.NewGetFieldWithTable(3, sql.Int64, "t2", "d", false), + expression.NewGetFieldWithTable(6, sql.Int64, "t2", "e", false), + }, + plan.NewInnerJoin( + plan.NewResolvedTable(left), + plan.NewResolvedTable(right), + expression.JoinAnd( + expression.NewEquals( + expression.NewGetFieldWithTable(1, sql.Int64, "t1", "b", false), + expression.NewGetFieldWithTable(5, sql.Int64, "t2", "b", false), + ), + expression.NewEquals( + expression.NewGetFieldWithTable(2, sql.Int64, "t1", "c", false), + expression.NewGetFieldWithTable(4, sql.Int64, "t2", "c", false), + ), + ), + ), + ) + + require.Equal(expected, result) +} + +func TestResolveNaturalJoinsColumns(t *testing.T) { + rule := getRule("resolve_natural_joins") + require := require.New(t) + + left := memory.NewTable("t1", sql.Schema{ + {Name: "a", Type: sql.Int64, Source: "t1"}, + {Name: "b", Type: sql.Int64, Source: "t1"}, + {Name: "c", Type: sql.Int64, Source: "t1"}, + }) + + right := memory.NewTable("t2", sql.Schema{ + {Name: "d", Type: sql.Int64, Source: "t2"}, + {Name: "c", Type: sql.Int64, Source: "t2"}, + {Name: "b", Type: sql.Int64, Source: "t2"}, + {Name: "e", Type: sql.Int64, Source: "t2"}, + }) + + node := plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedQualifiedColumn("t2", "b"), + }, + plan.NewNaturalJoin( + plan.NewResolvedTable(left), + plan.NewResolvedTable(right), + ), + ) + + result, err := rule.Apply(sql.NewEmptyContext(), NewDefault(nil), node) + require.NoError(err) + + expected := plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedQualifiedColumn("t1", "b"), + }, + plan.NewProject( + []sql.Expression{ + expression.NewGetFieldWithTable(1, sql.Int64, "t1", "b", false), + expression.NewGetFieldWithTable(2, sql.Int64, "t1", "c", false), + expression.NewGetFieldWithTable(0, sql.Int64, "t1", "a", false), + expression.NewGetFieldWithTable(3, sql.Int64, "t2", "d", false), + expression.NewGetFieldWithTable(6, sql.Int64, "t2", "e", false), + }, + plan.NewInnerJoin( + plan.NewResolvedTable(left), + plan.NewResolvedTable(right), + expression.JoinAnd( + expression.NewEquals( + expression.NewGetFieldWithTable(1, sql.Int64, "t1", "b", false), + expression.NewGetFieldWithTable(5, sql.Int64, "t2", "b", false), + ), + expression.NewEquals( + expression.NewGetFieldWithTable(2, sql.Int64, "t1", "c", false), + expression.NewGetFieldWithTable(4, sql.Int64, "t2", "c", false), + ), + ), + ), + ), + ) + + require.Equal(expected, result) +} + +func TestResolveNaturalJoinsTableAlias(t *testing.T) { + rule := getRule("resolve_natural_joins") + require := require.New(t) + + left := memory.NewTable("t1", sql.Schema{ + {Name: "a", Type: sql.Int64, Source: "t1"}, + {Name: "b", Type: sql.Int64, Source: "t1"}, + {Name: "c", Type: sql.Int64, Source: "t1"}, + }) + + right := memory.NewTable("t2", sql.Schema{ + {Name: "d", Type: sql.Int64, Source: "t2"}, + {Name: "c", Type: sql.Int64, Source: "t2"}, + {Name: "b", Type: sql.Int64, Source: "t2"}, + {Name: "e", Type: sql.Int64, Source: "t2"}, + }) + + node := plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedQualifiedColumn("t2", "b"), + expression.NewUnresolvedQualifiedColumn("t2-alias", "c"), + }, + plan.NewNaturalJoin( + plan.NewResolvedTable(left), + plan.NewTableAlias("t2-alias", plan.NewResolvedTable(right)), + ), + ) + + result, err := rule.Apply(sql.NewEmptyContext(), NewDefault(nil), node) + require.NoError(err) + + expected := plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedQualifiedColumn("t1", "b"), + expression.NewUnresolvedQualifiedColumn("t1", "c"), + }, + plan.NewProject( + []sql.Expression{ + expression.NewGetFieldWithTable(1, sql.Int64, "t1", "b", false), + expression.NewGetFieldWithTable(2, sql.Int64, "t1", "c", false), + expression.NewGetFieldWithTable(0, sql.Int64, "t1", "a", false), + expression.NewGetFieldWithTable(3, sql.Int64, "t2", "d", false), + expression.NewGetFieldWithTable(6, sql.Int64, "t2", "e", false), + }, + plan.NewInnerJoin( + plan.NewResolvedTable(left), + plan.NewTableAlias("t2-alias", plan.NewResolvedTable(right)), + expression.JoinAnd( + expression.NewEquals( + expression.NewGetFieldWithTable(1, sql.Int64, "t1", "b", false), + expression.NewGetFieldWithTable(5, sql.Int64, "t2", "b", false), + ), + expression.NewEquals( + expression.NewGetFieldWithTable(2, sql.Int64, "t1", "c", false), + expression.NewGetFieldWithTable(4, sql.Int64, "t2", "c", false), + ), + ), + ), + ), + ) + + require.Equal(expected, result) +} + +func TestResolveNaturalJoinsChained(t *testing.T) { + rule := getRule("resolve_natural_joins") + require := require.New(t) + + left := memory.NewTable("t1", sql.Schema{ + {Name: "a", Type: sql.Int64, Source: "t1"}, + {Name: "b", Type: sql.Int64, Source: "t1"}, + {Name: "c", Type: sql.Int64, Source: "t1"}, + {Name: "f", Type: sql.Int64, Source: "t1"}, + }) + + right := memory.NewTable("t2", sql.Schema{ + {Name: "d", Type: sql.Int64, Source: "t2"}, + {Name: "c", Type: sql.Int64, Source: "t2"}, + {Name: "b", Type: sql.Int64, Source: "t2"}, + {Name: "e", Type: sql.Int64, Source: "t2"}, + }) + + upperRight := memory.NewTable("t3", sql.Schema{ + {Name: "a", Type: sql.Int64, Source: "t3"}, + {Name: "b", Type: sql.Int64, Source: "t3"}, + {Name: "f", Type: sql.Int64, Source: "t3"}, + {Name: "g", Type: sql.Int64, Source: "t3"}, + }) + + node := plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedQualifiedColumn("t2", "b"), + expression.NewUnresolvedQualifiedColumn("t2-alias", "c"), + expression.NewUnresolvedQualifiedColumn("t3-alias", "f"), + }, + plan.NewNaturalJoin( + plan.NewNaturalJoin( + plan.NewResolvedTable(left), + plan.NewTableAlias("t2-alias", plan.NewResolvedTable(right)), + ), + plan.NewTableAlias("t3-alias", plan.NewResolvedTable(upperRight)), + ), + ) + + result, err := rule.Apply(sql.NewEmptyContext(), NewDefault(nil), node) + require.NoError(err) + + expected := plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedQualifiedColumn("t1", "b"), + expression.NewUnresolvedQualifiedColumn("t1", "c"), + expression.NewUnresolvedQualifiedColumn("t1", "f"), + }, + plan.NewProject( + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "t1", "b", false), + expression.NewGetFieldWithTable(2, sql.Int64, "t1", "a", false), + expression.NewGetFieldWithTable(3, sql.Int64, "t1", "f", false), + expression.NewGetFieldWithTable(1, sql.Int64, "t1", "c", false), + expression.NewGetFieldWithTable(4, sql.Int64, "t2", "d", false), + expression.NewGetFieldWithTable(5, sql.Int64, "t2", "e", false), + expression.NewGetFieldWithTable(9, sql.Int64, "t3", "g", false), + }, + plan.NewInnerJoin( + plan.NewProject( + []sql.Expression{ + expression.NewGetFieldWithTable(1, sql.Int64, "t1", "b", false), + expression.NewGetFieldWithTable(2, sql.Int64, "t1", "c", false), + expression.NewGetFieldWithTable(0, sql.Int64, "t1", "a", false), + expression.NewGetFieldWithTable(3, sql.Int64, "t1", "f", false), + expression.NewGetFieldWithTable(4, sql.Int64, "t2", "d", false), + expression.NewGetFieldWithTable(7, sql.Int64, "t2", "e", false), + }, + plan.NewInnerJoin( + plan.NewResolvedTable(left), + plan.NewTableAlias("t2-alias", plan.NewResolvedTable(right)), + expression.JoinAnd( + expression.NewEquals( + expression.NewGetFieldWithTable(1, sql.Int64, "t1", "b", false), + expression.NewGetFieldWithTable(6, sql.Int64, "t2", "b", false), + ), + expression.NewEquals( + expression.NewGetFieldWithTable(2, sql.Int64, "t1", "c", false), + expression.NewGetFieldWithTable(5, sql.Int64, "t2", "c", false), + ), + ), + ), + ), + plan.NewTableAlias("t3-alias", plan.NewResolvedTable(upperRight)), + expression.JoinAnd( + expression.NewEquals( + expression.NewGetFieldWithTable(0, sql.Int64, "t1", "b", false), + expression.NewGetFieldWithTable(7, sql.Int64, "t3", "b", false), + ), + expression.NewEquals( + expression.NewGetFieldWithTable(2, sql.Int64, "t1", "a", false), + expression.NewGetFieldWithTable(6, sql.Int64, "t3", "a", false), + ), + expression.NewEquals( + expression.NewGetFieldWithTable(3, sql.Int64, "t1", "f", false), + expression.NewGetFieldWithTable(8, sql.Int64, "t3", "f", false), + ), + ), + ), + ), + ) + + require.Equal(expected, result) +} + +func TestResolveNaturalJoinsEqual(t *testing.T) { + require := require.New(t) + + left := memory.NewTable("t1", sql.Schema{ + {Name: "a", Type: sql.Int64, Source: "t1"}, + {Name: "b", Type: sql.Int64, Source: "t1"}, + {Name: "c", Type: sql.Int64, Source: "t1"}, + }) + + right := memory.NewTable("t2", sql.Schema{ + {Name: "a", Type: sql.Int64, Source: "t2"}, + {Name: "b", Type: sql.Int64, Source: "t2"}, + {Name: "c", Type: sql.Int64, Source: "t2"}, + }) + + node := plan.NewNaturalJoin( + plan.NewResolvedTable(left), + plan.NewResolvedTable(right), + ) + rule := getRule("resolve_natural_joins") + + result, err := rule.Apply(sql.NewEmptyContext(), NewDefault(nil), node) + require.NoError(err) + + expected := plan.NewProject( + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "t1", "a", false), + expression.NewGetFieldWithTable(1, sql.Int64, "t1", "b", false), + expression.NewGetFieldWithTable(2, sql.Int64, "t1", "c", false), + }, + plan.NewInnerJoin( + plan.NewResolvedTable(left), + plan.NewResolvedTable(right), + expression.JoinAnd( + expression.NewEquals( + expression.NewGetFieldWithTable(0, sql.Int64, "t1", "a", false), + expression.NewGetFieldWithTable(3, sql.Int64, "t2", "a", false), + ), + expression.NewEquals( + expression.NewGetFieldWithTable(1, sql.Int64, "t1", "b", false), + expression.NewGetFieldWithTable(4, sql.Int64, "t2", "b", false), + ), + expression.NewEquals( + expression.NewGetFieldWithTable(2, sql.Int64, "t1", "c", false), + expression.NewGetFieldWithTable(5, sql.Int64, "t2", "c", false), + ), + ), + ), + ) + + require.Equal(expected, result) +} + +func TestResolveNaturalJoinsDisjoint(t *testing.T) { + require := require.New(t) + + left := memory.NewTable("t1", sql.Schema{ + {Name: "a", Type: sql.Int64, Source: "t1"}, + {Name: "b", Type: sql.Int64, Source: "t1"}, + {Name: "c", Type: sql.Int64, Source: "t1"}, + }) + + right := memory.NewTable("t2", sql.Schema{ + {Name: "d", Type: sql.Int64, Source: "t2"}, + {Name: "e", Type: sql.Int64, Source: "t2"}, + }) + + node := plan.NewNaturalJoin( + plan.NewResolvedTable(left), + plan.NewResolvedTable(right), + ) + rule := getRule("resolve_natural_joins") + + result, err := rule.Apply(sql.NewEmptyContext(), NewDefault(nil), node) + require.NoError(err) + + expected := plan.NewCrossJoin( + plan.NewResolvedTable(left), + plan.NewResolvedTable(right), + ) + require.Equal(expected, result) +} diff --git a/sql/analyzer/resolve_orderby.go b/sql/analyzer/resolve_orderby.go new file mode 100644 index 000000000..b7e7dfc65 --- /dev/null +++ b/sql/analyzer/resolve_orderby.go @@ -0,0 +1,266 @@ +package analyzer + +import ( + "strings" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/plan" + errors "gopkg.in/src-d/go-errors.v1" +) + +func resolveOrderBy(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + span, _ := ctx.Span("resolve_orderby") + defer span.Finish() + + a.Log("resolving order bys, node of type: %T", n) + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { + a.Log("transforming node of type: %T", n) + sort, ok := n.(*plan.Sort) + if !ok { + return n, nil + } + + if !sort.Child.Resolved() { + a.Log("child of type %T is not resolved yet, skipping", sort.Child) + return n, nil + } + + childNewCols := columnsDefinedInNode(sort.Child) + var schemaCols []string + for _, col := range sort.Child.Schema() { + schemaCols = append(schemaCols, strings.ToLower(col.Name)) + } + + var colsFromChild []string + var missingCols []string + for _, f := range sort.SortFields { + ns := findExprNameables(f.Column) + + for _, n := range ns { + name := strings.ToLower(n.Name()) + if stringContains(childNewCols, name) { + colsFromChild = append(colsFromChild, n.Name()) + } else if !stringContains(schemaCols, name) { + missingCols = append(missingCols, n.Name()) + } + } + } + + // If all the columns required by the order by are available, do nothing about it. + if len(missingCols) == 0 { + a.Log("no missing columns, skipping") + return n, nil + } + + // If there are no columns required by the order by available, then move the order by + // below its child. + if len(colsFromChild) == 0 && len(missingCols) > 0 { + a.Log("pushing down sort, missing columns: %s", strings.Join(missingCols, ", ")) + return pushSortDown(sort) + } + + a.Log("fixing sort dependencies, missing columns: %s", strings.Join(missingCols, ", ")) + + // If there are some columns required by the order by on the child but some are missing + // we have to do some more complex logic and split the projection in two. + return fixSortDependencies(sort, missingCols) + }) +} + +// fixSortDependencies replaces the sort node by a node with the child projection +// followed by the sort, an intermediate projection or group by with all the missing +// columns required for the sort and then the child of the child projection or group by. +func fixSortDependencies(sort *plan.Sort, missingCols []string) (sql.Node, error) { + var expressions []sql.Expression + switch child := sort.Child.(type) { + case *plan.Project: + expressions = child.Projections + case *plan.GroupBy: + expressions = child.Aggregate + default: + return nil, errSortPushdown.New(child) + } + + var newExpressions = append([]sql.Expression{}, expressions...) + for _, col := range missingCols { + newExpressions = append(newExpressions, expression.NewUnresolvedColumn(col)) + } + + for i, e := range expressions { + var name string + if n, ok := e.(sql.Nameable); ok { + name = n.Name() + } else { + name = e.String() + } + + var table string + if t, ok := e.(sql.Tableable); ok { + table = t.Table() + } + expressions[i] = expression.NewGetFieldWithTable( + i, e.Type(), table, name, e.IsNullable(), + ) + } + + switch child := sort.Child.(type) { + case *plan.Project: + return plan.NewProject( + expressions, + plan.NewSort( + sort.SortFields, + plan.NewProject(newExpressions, child.Child), + ), + ), nil + case *plan.GroupBy: + return plan.NewProject( + expressions, + plan.NewSort( + sort.SortFields, + plan.NewGroupBy(newExpressions, child.Grouping, child.Child), + ), + ), nil + default: + return nil, errSortPushdown.New(child) + } +} + +// columnsDefinedInNode returns the columns that were defined in this node, +// which, by definition, can only be plan.Project or plan.GroupBy. +func columnsDefinedInNode(n sql.Node) []string { + var exprs []sql.Expression + switch n := n.(type) { + case *plan.Project: + exprs = n.Projections + case *plan.GroupBy: + exprs = n.Aggregate + } + + var cols []string + for _, e := range exprs { + alias, ok := e.(*expression.Alias) + if ok { + cols = append(cols, strings.ToLower(alias.Name())) + } + } + + return cols +} + +var errSortPushdown = errors.NewKind("unable to push plan.Sort node below %T") + +func pushSortDown(sort *plan.Sort) (sql.Node, error) { + switch child := sort.Child.(type) { + case *plan.Project: + return plan.NewProject( + child.Projections, + plan.NewSort(sort.SortFields, child.Child), + ), nil + case *plan.GroupBy: + return plan.NewGroupBy( + child.Aggregate, + child.Grouping, + plan.NewSort(sort.SortFields, child.Child), + ), nil + case *plan.ResolvedTable: + return sort, nil + default: + children := child.Children() + if len(children) == 1 { + newChild, err := pushSortDown(plan.NewSort(sort.SortFields, children[0])) + if err != nil { + return nil, err + } + + return child.WithChildren(newChild) + } + + // If the child has more than one children we don't know to which side + // the sort must be pushed down. + return nil, errSortPushdown.New(child) + } +} + +func resolveOrderByLiterals(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + a.Log("resolve order by literals") + + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { + sort, ok := n.(*plan.Sort) + if !ok { + return n, nil + } + + // wait for the child to be resolved + if !sort.Child.Resolved() { + return n, nil + } + + schema := sort.Child.Schema() + var ( + fields = make([]plan.SortField, len(sort.SortFields)) + schemaCols = make([]string, len(schema)) + ) + for i, col := range sort.Child.Schema() { + schemaCols[i] = col.Name + } + for i, f := range sort.SortFields { + if lit, ok := f.Column.(*expression.Literal); ok && sql.IsNumber(f.Column.Type()) { + // it is safe to eval literals with no context and/or row + v, err := lit.Eval(nil, nil) + if err != nil { + return nil, err + } + + v, err = sql.Int64.Convert(v) + if err != nil { + return nil, err + } + + // column access is 1-indexed + idx := int(v.(int64)) - 1 + if idx >= len(schema) || idx < 0 { + return nil, ErrOrderByColumnIndex.New(idx + 1) + } + + fields[i] = plan.SortField{ + Column: expression.NewUnresolvedColumn(schemaCols[idx]), + Order: f.Order, + NullOrdering: f.NullOrdering, + } + + a.Log("replaced order by column %d with %s", idx+1, schemaCols[idx]) + } else { + if agg, ok := f.Column.(sql.Aggregation); ok { + name := agg.String() + if nameable, ok := f.Column.(sql.Nameable); ok { + name = nameable.Name() + } + + fields[i] = plan.SortField{ + Column: expression.NewUnresolvedColumn(name), + Order: f.Order, + NullOrdering: f.NullOrdering, + } + } else { + fields[i] = f + } + } + } + + return plan.NewSort(fields, sort.Child), nil + }) +} + +func findExprNameables(e sql.Expression) []sql.Nameable { + var result []sql.Nameable + expression.Inspect(e, func(e sql.Expression) bool { + n, ok := e.(sql.Nameable) + if ok { + result = append(result, n) + return false + } + return true + }) + return result +} diff --git a/sql/analyzer/resolve_orderby_test.go b/sql/analyzer/resolve_orderby_test.go new file mode 100644 index 000000000..119c197a4 --- /dev/null +++ b/sql/analyzer/resolve_orderby_test.go @@ -0,0 +1,283 @@ +package analyzer + +import ( + "testing" + + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/plan" + "github.com/stretchr/testify/require" +) + +func TestResolveOrderBy(t *testing.T) { + rule := getRule("resolve_orderby") + a := NewDefault(nil) + ctx := sql.NewEmptyContext() + + table := memory.NewTable("foo", sql.Schema{ + {Name: "a", Type: sql.Int64, Source: "foo"}, + {Name: "b", Type: sql.Int64, Source: "foo"}, + }) + + t.Run("with project", func(t *testing.T) { + require := require.New(t) + node := plan.NewSort( + []plan.SortField{ + {Column: expression.NewUnresolvedColumn("x")}, + }, + plan.NewProject( + []sql.Expression{ + expression.NewAlias( + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), + "x", + ), + }, + plan.NewResolvedTable(table), + ), + ) + + result, err := rule.Apply(ctx, a, node) + require.NoError(err) + + require.Equal(node, result) + + node = plan.NewSort( + []plan.SortField{ + {Column: expression.NewUnresolvedColumn("a")}, + }, + plan.NewProject( + []sql.Expression{ + expression.NewAlias( + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), + "x", + ), + }, + plan.NewResolvedTable(table), + ), + ) + + expected := plan.NewProject( + []sql.Expression{ + expression.NewAlias( + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), + "x", + ), + }, + plan.NewSort( + []plan.SortField{ + {Column: expression.NewUnresolvedColumn("a")}, + }, + plan.NewResolvedTable(table), + ), + ) + + result, err = rule.Apply(ctx, a, node) + require.NoError(err) + + require.Equal(expected, result) + + node = plan.NewSort( + []plan.SortField{ + {Column: expression.NewUnresolvedColumn("a")}, + {Column: expression.NewUnresolvedColumn("x")}, + }, + plan.NewProject( + []sql.Expression{ + expression.NewAlias( + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), + "x", + ), + }, + plan.NewResolvedTable(table), + ), + ) + + expected = plan.NewProject( + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "", "x", false), + }, + plan.NewSort( + []plan.SortField{ + {Column: expression.NewUnresolvedColumn("a")}, + {Column: expression.NewUnresolvedColumn("x")}, + }, + plan.NewProject( + []sql.Expression{ + expression.NewAlias( + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), + "x", + ), + expression.NewUnresolvedColumn("a"), + }, + plan.NewResolvedTable(table), + ), + ), + ) + + result, err = rule.Apply(ctx, a, node) + require.NoError(err) + + require.Equal(expected, result) + }) + + t.Run("with group by", func(t *testing.T) { + require := require.New(t) + node := plan.NewSort( + []plan.SortField{ + {Column: expression.NewUnresolvedColumn("x")}, + }, + plan.NewGroupBy( + []sql.Expression{ + expression.NewAlias( + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), + "x", + ), + }, + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), + }, + plan.NewResolvedTable(table), + ), + ) + + result, err := rule.Apply(ctx, a, node) + require.NoError(err) + + require.Equal(node, result) + + node = plan.NewSort( + []plan.SortField{ + {Column: expression.NewUnresolvedColumn("a")}, + }, + plan.NewGroupBy( + []sql.Expression{ + expression.NewAlias( + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), + "x", + ), + }, + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), + }, + plan.NewResolvedTable(table), + ), + ) + + var expected sql.Node = plan.NewGroupBy( + []sql.Expression{ + expression.NewAlias( + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), + "x", + ), + }, + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), + }, + plan.NewSort( + []plan.SortField{ + {Column: expression.NewUnresolvedColumn("a")}, + }, + plan.NewResolvedTable(table), + ), + ) + + result, err = rule.Apply(ctx, a, node) + require.NoError(err) + + require.Equal(expected, result) + + node = plan.NewSort( + []plan.SortField{ + {Column: expression.NewUnresolvedColumn("a")}, + {Column: expression.NewUnresolvedColumn("x")}, + }, + plan.NewGroupBy( + []sql.Expression{ + expression.NewAlias( + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), + "x", + ), + }, + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), + }, + plan.NewResolvedTable(table), + ), + ) + + expected = plan.NewProject( + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "", "x", false), + }, + plan.NewSort( + []plan.SortField{ + {Column: expression.NewUnresolvedColumn("a")}, + {Column: expression.NewUnresolvedColumn("x")}, + }, + plan.NewGroupBy( + []sql.Expression{ + expression.NewAlias( + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), + "x", + ), + expression.NewUnresolvedColumn("a"), + }, + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), + }, + plan.NewResolvedTable(table), + ), + ), + ) + + result, err = rule.Apply(ctx, a, node) + require.NoError(err) + + require.Equal(expected, result) + }) +} + +func TestResolveOrderByLiterals(t *testing.T) { + require := require.New(t) + f := getRule("resolve_orderby_literals") + + table := memory.NewTable("t", sql.Schema{ + {Name: "a", Type: sql.Int64, Source: "t"}, + {Name: "b", Type: sql.Int64, Source: "t"}, + }) + + node := plan.NewSort( + []plan.SortField{ + {Column: expression.NewLiteral(int64(2), sql.Int64)}, + {Column: expression.NewLiteral(int64(1), sql.Int64)}, + }, + plan.NewResolvedTable(table), + ) + + result, err := f.Apply(sql.NewEmptyContext(), NewDefault(nil), node) + require.NoError(err) + + require.Equal( + plan.NewSort( + []plan.SortField{ + {Column: expression.NewUnresolvedColumn("b")}, + {Column: expression.NewUnresolvedColumn("a")}, + }, + plan.NewResolvedTable(table), + ), + result, + ) + + node = plan.NewSort( + []plan.SortField{ + {Column: expression.NewLiteral(int64(3), sql.Int64)}, + {Column: expression.NewLiteral(int64(1), sql.Int64)}, + }, + plan.NewResolvedTable(table), + ) + + _, err = f.Apply(sql.NewEmptyContext(), NewDefault(nil), node) + require.Error(err) + require.True(ErrOrderByColumnIndex.Is(err)) +} diff --git a/sql/analyzer/resolve_stars.go b/sql/analyzer/resolve_stars.go new file mode 100644 index 000000000..2261752f3 --- /dev/null +++ b/sql/analyzer/resolve_stars.go @@ -0,0 +1,73 @@ +package analyzer + +import ( + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/plan" +) + +func resolveStar(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + span, _ := ctx.Span("resolve_star") + defer span.Finish() + + a.Log("resolving star, node of type: %T", n) + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { + a.Log("transforming node of type: %T", n) + if n.Resolved() { + return n, nil + } + + switch n := n.(type) { + case *plan.Project: + if !n.Child.Resolved() { + return n, nil + } + + expressions, err := expandStars(n.Projections, n.Child.Schema()) + if err != nil { + return nil, err + } + + return plan.NewProject(expressions, n.Child), nil + case *plan.GroupBy: + if !n.Child.Resolved() { + return n, nil + } + + aggregate, err := expandStars(n.Aggregate, n.Child.Schema()) + if err != nil { + return nil, err + } + + return plan.NewGroupBy(aggregate, n.Grouping, n.Child), nil + default: + return n, nil + } + }) +} + +func expandStars(exprs []sql.Expression, schema sql.Schema) ([]sql.Expression, error) { + var expressions []sql.Expression + for _, e := range exprs { + if s, ok := e.(*expression.Star); ok { + var exprs []sql.Expression + for i, col := range schema { + if s.Table == "" || s.Table == col.Source { + exprs = append(exprs, expression.NewGetFieldWithTable( + i, col.Type, col.Source, col.Name, col.Nullable, + )) + } + } + + if len(exprs) == 0 && s.Table != "" { + return nil, sql.ErrTableNotFound.New(s.Table) + } + + expressions = append(expressions, exprs...) + } else { + expressions = append(expressions, e) + } + } + + return expressions, nil +} diff --git a/sql/analyzer/resolve_stars_test.go b/sql/analyzer/resolve_stars_test.go new file mode 100644 index 000000000..b6006271f --- /dev/null +++ b/sql/analyzer/resolve_stars_test.go @@ -0,0 +1,185 @@ +package analyzer + +import ( + "testing" + + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/plan" + "github.com/stretchr/testify/require" +) + +func TestResolveStar(t *testing.T) { + f := getRule("resolve_star") + + table := memory.NewTable("mytable", sql.Schema{ + {Name: "a", Type: sql.Int32, Source: "mytable"}, + {Name: "b", Type: sql.Int32, Source: "mytable"}, + }) + + table2 := memory.NewTable("mytable2", sql.Schema{ + {Name: "c", Type: sql.Int32, Source: "mytable2"}, + {Name: "d", Type: sql.Int32, Source: "mytable2"}, + }) + + testCases := []struct { + name string + node sql.Node + expected sql.Node + }{ + { + "unqualified star", + plan.NewProject( + []sql.Expression{expression.NewStar()}, + plan.NewResolvedTable(table), + ), + plan.NewProject( + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int32, "mytable", "a", false), + expression.NewGetFieldWithTable(1, sql.Int32, "mytable", "b", false), + }, + plan.NewResolvedTable(table), + ), + }, + { + "qualified star", + plan.NewProject( + []sql.Expression{expression.NewQualifiedStar("mytable2")}, + plan.NewCrossJoin( + plan.NewResolvedTable(table), + plan.NewResolvedTable(table2), + ), + ), + plan.NewProject( + []sql.Expression{ + expression.NewGetFieldWithTable(2, sql.Int32, "mytable2", "c", false), + expression.NewGetFieldWithTable(3, sql.Int32, "mytable2", "d", false), + }, + plan.NewCrossJoin( + plan.NewResolvedTable(table), + plan.NewResolvedTable(table2), + ), + ), + }, + { + "qualified star and unqualified star", + plan.NewProject( + []sql.Expression{ + expression.NewStar(), + expression.NewQualifiedStar("mytable2"), + }, + plan.NewCrossJoin( + plan.NewResolvedTable(table), + plan.NewResolvedTable(table2), + ), + ), + plan.NewProject( + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int32, "mytable", "a", false), + expression.NewGetFieldWithTable(1, sql.Int32, "mytable", "b", false), + expression.NewGetFieldWithTable(2, sql.Int32, "mytable2", "c", false), + expression.NewGetFieldWithTable(3, sql.Int32, "mytable2", "d", false), + expression.NewGetFieldWithTable(2, sql.Int32, "mytable2", "c", false), + expression.NewGetFieldWithTable(3, sql.Int32, "mytable2", "d", false), + }, + plan.NewCrossJoin( + plan.NewResolvedTable(table), + plan.NewResolvedTable(table2), + ), + ), + }, + { + "stars mixed with other expressions", + plan.NewProject( + []sql.Expression{ + expression.NewStar(), + expression.NewUnresolvedColumn("foo"), + expression.NewQualifiedStar("mytable2"), + }, + plan.NewCrossJoin( + plan.NewResolvedTable(table), + plan.NewResolvedTable(table2), + ), + ), + plan.NewProject( + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int32, "mytable", "a", false), + expression.NewGetFieldWithTable(1, sql.Int32, "mytable", "b", false), + expression.NewGetFieldWithTable(2, sql.Int32, "mytable2", "c", false), + expression.NewGetFieldWithTable(3, sql.Int32, "mytable2", "d", false), + expression.NewUnresolvedColumn("foo"), + expression.NewGetFieldWithTable(2, sql.Int32, "mytable2", "c", false), + expression.NewGetFieldWithTable(3, sql.Int32, "mytable2", "d", false), + }, + plan.NewCrossJoin( + plan.NewResolvedTable(table), + plan.NewResolvedTable(table2), + ), + ), + }, + { + "star in groupby", + plan.NewGroupBy( + []sql.Expression{ + expression.NewStar(), + }, + nil, + plan.NewResolvedTable(table), + ), + plan.NewGroupBy( + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int32, "mytable", "a", false), + expression.NewGetFieldWithTable(1, sql.Int32, "mytable", "b", false), + }, + nil, + plan.NewResolvedTable(table), + ), + }, + { // note that this behaviour deviates from MySQL + "star after some expressions", + plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedColumn("foo"), + expression.NewStar(), + }, + plan.NewResolvedTable(table), + ), + plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedColumn("foo"), + expression.NewGetFieldWithTable(0, sql.Int32, "mytable", "a", false), + expression.NewGetFieldWithTable(1, sql.Int32, "mytable", "b", false), + }, + plan.NewResolvedTable(table), + ), + }, + { // note that this behaviour deviates from MySQL + "unqualified star used multiple times", + plan.NewProject( + []sql.Expression{ + expression.NewStar(), + expression.NewStar(), + }, + plan.NewResolvedTable(table), + ), + plan.NewProject( + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int32, "mytable", "a", false), + expression.NewGetFieldWithTable(1, sql.Int32, "mytable", "b", false), + expression.NewGetFieldWithTable(0, sql.Int32, "mytable", "a", false), + expression.NewGetFieldWithTable(1, sql.Int32, "mytable", "b", false), + }, + plan.NewResolvedTable(table), + ), + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + result, err := f.Apply(sql.NewEmptyContext(), nil, tt.node) + require.NoError(t, err) + require.Equal(t, tt.expected, result) + }) + } +} diff --git a/sql/analyzer/resolve_subqueries.go b/sql/analyzer/resolve_subqueries.go new file mode 100644 index 000000000..20b97df3f --- /dev/null +++ b/sql/analyzer/resolve_subqueries.go @@ -0,0 +1,49 @@ +package analyzer + +import ( + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/plan" +) + +func resolveSubqueries(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + span, ctx := ctx.Span("resolve_subqueries") + defer span.Finish() + + a.Log("resolving subqueries") + n, err := plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { + switch n := n.(type) { + case *plan.SubqueryAlias: + a.Log("found subquery %q with child of type %T", n.Name(), n.Child) + child, err := a.Analyze(ctx, n.Child) + if err != nil { + return nil, err + } + + return plan.NewSubqueryAlias(n.Name(), child), nil + default: + return n, nil + } + }) + if err != nil { + return nil, err + } + + return plan.TransformExpressionsUp(n, func(e sql.Expression) (sql.Expression, error) { + s, ok := e.(*expression.Subquery) + if !ok || s.Resolved() { + return e, nil + } + + q, err := a.Analyze(ctx, s.Query) + if err != nil { + return nil, err + } + + if qp, ok := q.(*plan.QueryProcess); ok { + q = qp.Child + } + + return s.WithQuery(q), nil + }) +} diff --git a/sql/analyzer/resolve_subqueries_test.go b/sql/analyzer/resolve_subqueries_test.go new file mode 100644 index 000000000..04e97832a --- /dev/null +++ b/sql/analyzer/resolve_subqueries_test.go @@ -0,0 +1,90 @@ +package analyzer + +import ( + "testing" + + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/plan" + "github.com/stretchr/testify/require" +) + +func TestResolveSubqueries(t *testing.T) { + require := require.New(t) + + table1 := memory.NewTable("foo", sql.Schema{{Name: "a", Type: sql.Int64, Source: "foo"}}) + table2 := memory.NewTable("bar", sql.Schema{ + {Name: "b", Type: sql.Int64, Source: "bar"}, + {Name: "k", Type: sql.Int64, Source: "bar"}, + }) + table3 := memory.NewTable("baz", sql.Schema{{Name: "c", Type: sql.Int64, Source: "baz"}}) + db := memory.NewDatabase("mydb") + db.AddTable("foo", table1) + db.AddTable("bar", table2) + db.AddTable("baz", table3) + + catalog := sql.NewCatalog() + catalog.AddDatabase(db) + a := withoutProcessTracking(NewDefault(catalog)) + + // SELECT * FROM + // (SELECT a FROM foo) t1, + // (SELECT b FROM (SELECT b FROM bar) t2alias) t2, + // baz + node := plan.NewProject( + []sql.Expression{expression.NewStar()}, + plan.NewCrossJoin( + plan.NewCrossJoin( + plan.NewSubqueryAlias( + "t1", + plan.NewProject( + []sql.Expression{expression.NewUnresolvedColumn("a")}, + plan.NewUnresolvedTable("foo", ""), + ), + ), + plan.NewSubqueryAlias( + "t2", + plan.NewProject( + []sql.Expression{expression.NewUnresolvedColumn("b")}, + plan.NewSubqueryAlias( + "t2alias", + plan.NewProject( + []sql.Expression{expression.NewUnresolvedColumn("b")}, + plan.NewUnresolvedTable("bar", ""), + ), + ), + ), + ), + ), + plan.NewUnresolvedTable("baz", ""), + ), + ) + + subquery := plan.NewSubqueryAlias( + "t2alias", + plan.NewResolvedTable(table2.WithProjection([]string{"b"})), + ) + _ = subquery.Schema() + + expected := plan.NewProject( + []sql.Expression{expression.NewStar()}, + plan.NewCrossJoin( + plan.NewCrossJoin( + plan.NewSubqueryAlias( + "t1", + plan.NewResolvedTable(table1.WithProjection([]string{"a"})), + ), + plan.NewSubqueryAlias( + "t2", + subquery, + ), + ), + plan.NewUnresolvedTable("baz", ""), + ), + ) + + result, err := resolveSubqueries(sql.NewEmptyContext(), a, node) + require.NoError(err) + require.Equal(expected, result) +} diff --git a/sql/analyzer/resolve_tables.go b/sql/analyzer/resolve_tables.go new file mode 100644 index 000000000..ec495cd36 --- /dev/null +++ b/sql/analyzer/resolve_tables.go @@ -0,0 +1,55 @@ +package analyzer + +import ( + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/plan" +) + +const dualTableName = "dual" + +var dualTable = func() sql.Table { + t := memory.NewTable(dualTableName, sql.Schema{ + {Name: "dummy", Source: dualTableName, Type: sql.Text, Nullable: false}, + }) + _ = t.Insert(sql.NewEmptyContext(), sql.NewRow("x")) + return t +}() + +func resolveTables(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + span, _ := ctx.Span("resolve_tables") + defer span.Finish() + + a.Log("resolve table, node of type: %T", n) + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { + a.Log("transforming node of type: %T", n) + if n.Resolved() { + return n, nil + } + + t, ok := n.(*plan.UnresolvedTable) + if !ok { + return n, nil + } + + name := t.Name() + db := t.Database + if db == "" { + db = a.Catalog.CurrentDatabase() + } + + rt, err := a.Catalog.Table(db, name) + if err != nil { + if sql.ErrTableNotFound.Is(err) && name == dualTableName { + rt = dualTable + name = dualTableName + } else { + return nil, err + } + } + + a.Log("table resolved: %q", t.Name()) + + return plan.NewResolvedTable(rt), nil + }) +} diff --git a/sql/analyzer/resolve_tables_test.go b/sql/analyzer/resolve_tables_test.go new file mode 100644 index 000000000..fbc4fcca0 --- /dev/null +++ b/sql/analyzer/resolve_tables_test.go @@ -0,0 +1,93 @@ +package analyzer + +import ( + "testing" + + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/plan" + "github.com/stretchr/testify/require" +) + +func TestResolveTables(t *testing.T) { + require := require.New(t) + + f := getRule("resolve_tables") + + table := memory.NewTable("mytable", sql.Schema{{Name: "i", Type: sql.Int32}}) + db := memory.NewDatabase("mydb") + db.AddTable("mytable", table) + + catalog := sql.NewCatalog() + catalog.AddDatabase(db) + + a := NewBuilder(catalog).AddPostAnalyzeRule(f.Name, f.Apply).Build() + + var notAnalyzed sql.Node = plan.NewUnresolvedTable("mytable", "") + analyzed, err := f.Apply(sql.NewEmptyContext(), a, notAnalyzed) + require.NoError(err) + require.Equal(plan.NewResolvedTable(table), analyzed) + + notAnalyzed = plan.NewUnresolvedTable("MyTable", "") + analyzed, err = f.Apply(sql.NewEmptyContext(), a, notAnalyzed) + require.NoError(err) + require.Equal(plan.NewResolvedTable(table), analyzed) + + notAnalyzed = plan.NewUnresolvedTable("nonexistant", "") + analyzed, err = f.Apply(sql.NewEmptyContext(), a, notAnalyzed) + require.Error(err) + require.Nil(analyzed) + + analyzed, err = f.Apply(sql.NewEmptyContext(), a, plan.NewResolvedTable(table)) + require.NoError(err) + require.Equal(plan.NewResolvedTable(table), analyzed) + + notAnalyzed = plan.NewUnresolvedTable("dual", "") + analyzed, err = f.Apply(sql.NewEmptyContext(), a, notAnalyzed) + require.NoError(err) + require.Equal(plan.NewResolvedTable(dualTable), analyzed) +} + +func TestResolveTablesNested(t *testing.T) { + require := require.New(t) + + f := getRule("resolve_tables") + + table := memory.NewTable("mytable", sql.Schema{{Name: "i", Type: sql.Int32}}) + table2 := memory.NewTable("my_other_table", sql.Schema{{Name: "i", Type: sql.Int32}}) + db := memory.NewDatabase("mydb") + db.AddTable("mytable", table) + catalog := sql.NewCatalog() + catalog.AddDatabase(db) + + db2 := memory.NewDatabase("my_other_db") + db2.AddTable("my_other_table", table2) + catalog.AddDatabase(db2) + + a := NewBuilder(catalog).AddPostAnalyzeRule(f.Name, f.Apply).Build() + + notAnalyzed := plan.NewProject( + []sql.Expression{expression.NewGetField(0, sql.Int32, "i", true)}, + plan.NewUnresolvedTable("mytable", ""), + ) + analyzed, err := f.Apply(sql.NewEmptyContext(), a, notAnalyzed) + require.NoError(err) + expected := plan.NewProject( + []sql.Expression{expression.NewGetField(0, sql.Int32, "i", true)}, + plan.NewResolvedTable(table), + ) + require.Equal(expected, analyzed) + + notAnalyzed = plan.NewProject( + []sql.Expression{expression.NewGetField(0, sql.Int32, "i", true)}, + plan.NewUnresolvedTable("my_other_table", "my_other_db"), + ) + analyzed, err = f.Apply(sql.NewEmptyContext(), a, notAnalyzed) + require.NoError(err) + expected = plan.NewProject( + []sql.Expression{expression.NewGetField(0, sql.Int32, "i", true)}, + plan.NewResolvedTable(table2), + ) + require.Equal(expected, analyzed) +} diff --git a/sql/analyzer/rules.go b/sql/analyzer/rules.go index 55e5df124..c2b8daf00 100644 --- a/sql/analyzer/rules.go +++ b/sql/analyzer/rules.go @@ -1,30 +1,54 @@ package analyzer import ( - "strings" - errors "gopkg.in/src-d/go-errors.v1" - "gopkg.in/src-d/go-mysql-server.v0/mem" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" - "gopkg.in/src-d/go-mysql-server.v0/sql/plan" ) // DefaultRules to apply when analyzing nodes. var DefaultRules = []Rule{ - {"resolve_subqueries", resolveSubqueries}, - {"resolve_tables", resolveTables}, + {"resolve_natural_joins", resolveNaturalJoins}, {"resolve_orderby_literals", resolveOrderByLiterals}, + {"resolve_orderby", resolveOrderBy}, + {"resolve_grouping_columns", resolveGroupingColumns}, {"qualify_columns", qualifyColumns}, {"resolve_columns", resolveColumns}, {"resolve_database", resolveDatabase}, {"resolve_star", resolveStar}, {"resolve_functions", resolveFunctions}, + {"resolve_having", resolveHaving}, + {"reorder_aggregations", reorderAggregations}, {"reorder_projection", reorderProjection}, - {"pushdown", pushdown}, + {"move_join_conds_to_filter", moveJoinConditionsToFilter}, + {"eval_filter", evalFilter}, {"optimize_distinct", optimizeDistinct}, +} + +// OnceBeforeDefault contains the rules to be applied just once before the +// DefaultRules. +var OnceBeforeDefault = []Rule{ + {"resolve_subqueries", resolveSubqueries}, + {"resolve_tables", resolveTables}, + {"check_aliases", checkAliases}, +} + +// OnceAfterDefault contains the rules to be applied just once after the +// DefaultRules. +var OnceAfterDefault = []Rule{ + {"resolve_generators", resolveGenerators}, + {"remove_unnecessary_converts", removeUnnecessaryConverts}, + {"assign_catalog", assignCatalog}, + {"prune_columns", pruneColumns}, + {"convert_dates", convertDates}, + {"pushdown", pushdown}, {"erase_projection", eraseProjection}, - {"index_catalog", indexCatalog}, +} + +// OnceAfterAll contains the rules to be applied just once after all other +// rules have been applied. +var OnceAfterAll = []Rule{ + {"track_process", trackProcess}, + {"parallelize", parallelize}, + {"clear_warnings", clearWarnings}, } var ( @@ -42,792 +66,7 @@ var ( // ErrOrderByColumnIndex is returned when in an order clause there is a // column that is unknown. ErrOrderByColumnIndex = errors.NewKind("unknown column %d in order by clause") + // ErrMisusedAlias is returned when a alias is defined and used in the same projection. + ErrMisusedAlias = errors.NewKind("column %q does not exist in scope, but there is an alias defined in" + + " this projection with that name. Aliases cannot be used in the same projection they're defined in") ) - -func resolveSubqueries(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { - span, ctx := ctx.Span("resolve_subqueries") - defer span.Finish() - - a.Log("resolving subqueries") - return n.TransformUp(func(n sql.Node) (sql.Node, error) { - switch n := n.(type) { - case *plan.SubqueryAlias: - a.Log("found subquery %q with child of type %T", n.Name(), n.Child) - child, err := a.Analyze(ctx, n.Child) - if err != nil { - return nil, err - } - return plan.NewSubqueryAlias(n.Name(), child), nil - default: - return n, nil - } - }) -} - -func resolveOrderByLiterals(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { - a.Log("resolve order by literals") - - return n.TransformUp(func(n sql.Node) (sql.Node, error) { - sort, ok := n.(*plan.Sort) - if !ok { - return n, nil - } - - var fields = make([]plan.SortField, len(sort.SortFields)) - for i, f := range sort.SortFields { - if lit, ok := f.Column.(*expression.Literal); ok && sql.IsNumber(f.Column.Type()) { - // it is safe to eval literals with no context and/or row - v, err := lit.Eval(nil, nil) - if err != nil { - return nil, err - } - - v, err = sql.Int64.Convert(v) - if err != nil { - return nil, err - } - - // column access is 1-indexed - idx := int(v.(int64)) - 1 - - schema := sort.Child.Schema() - if idx >= len(schema) || idx < 0 { - return nil, ErrOrderByColumnIndex.New(idx + 1) - } - - fields[i] = plan.SortField{ - Column: expression.NewUnresolvedColumn(schema[idx].Name), - Order: f.Order, - NullOrdering: f.NullOrdering, - } - - a.Log("replaced order by column %d with %s", idx+1, schema[idx].Name) - } else { - fields[i] = f - } - } - - return plan.NewSort(fields, sort.Child), nil - }) -} - -func qualifyColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { - span, ctx := ctx.Span("qualify_columns") - defer span.Finish() - - a.Log("qualify columns") - tables := make(map[string]sql.Node) - tableAliases := make(map[string]string) - colIndex := make(map[string][]string) - - indexCols := func(table string, schema sql.Schema) { - for _, col := range schema { - colIndex[col.Name] = append(colIndex[col.Name], table) - } - } - - return n.TransformUp(func(n sql.Node) (sql.Node, error) { - a.Log("transforming node of type: %T", n) - switch n := n.(type) { - case *plan.TableAlias: - switch t := n.Child.(type) { - case sql.Table: - tableAliases[n.Name()] = t.Name() - default: - tables[n.Name()] = n.Child - indexCols(n.Name(), n.Schema()) - } - case sql.Table: - tables[n.Name()] = n - indexCols(n.Name(), n.Schema()) - } - - return n.TransformExpressionsUp(func(e sql.Expression) (sql.Expression, error) { - a.Log("transforming expression of type: %T", e) - switch col := e.(type) { - case *expression.UnresolvedColumn: - col = expression.NewUnresolvedQualifiedColumn(col.Table(), col.Name()) - - if col.Table() == "" { - tables := dedupStrings(colIndex[col.Name()]) - switch len(tables) { - case 0: - // If there are no tables that have any column with the column - // name let's just return it as it is. This may be an alias, so - // we'll wait for the reorder of the - return col, nil - case 1: - col = expression.NewUnresolvedQualifiedColumn( - tables[0], - col.Name(), - ) - default: - return nil, ErrAmbiguousColumnName.New(col.Name(), strings.Join(tables, ", ")) - } - } else { - if real, ok := tableAliases[col.Table()]; ok { - col = expression.NewUnresolvedQualifiedColumn( - real, - col.Name(), - ) - } - - if _, ok := tables[col.Table()]; !ok { - return nil, sql.ErrTableNotFound.New(col.Table()) - } - } - - a.Log("column %q was qualified with table %q", col.Name(), col.Table()) - return col, nil - case *expression.Star: - if col.Table != "" { - if real, ok := tableAliases[col.Table]; ok { - col = expression.NewQualifiedStar(real) - } - - if _, ok := tables[col.Table]; !ok { - return nil, sql.ErrTableNotFound.New(col.Table) - } - - return col, nil - } - } - return e, nil - }) - }) -} - -func resolveDatabase(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { - span, ctx := ctx.Span("resolve_database") - defer span.Finish() - - a.Log("resolve database, node of type: %T", n) - - // TODO Database should implement node, - // and ShowTables and CreateTable nodes should be binaryNodes - switch v := n.(type) { - case *plan.ShowTables: - db, err := a.Catalog.Database(a.CurrentDatabase) - if err != nil { - return n, err - } - - v.Database = db - case *plan.CreateTable: - db, err := a.Catalog.Database(a.CurrentDatabase) - if err != nil { - return n, err - } - - v.Database = db - } - - return n, nil -} - -var dualTable = func() sql.Table { - t := mem.NewTable("dual", sql.Schema{ - {Name: "dummy", Source: "dual", Type: sql.Text, Nullable: false}, - }) - _ = t.Insert(sql.NewRow("x")) - return t -}() - -func resolveTables(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { - span, ctx := ctx.Span("resolve_tables") - defer span.Finish() - - a.Log("resolve table, node of type: %T", n) - return n.TransformUp(func(n sql.Node) (sql.Node, error) { - a.Log("transforming node of type: %T", n) - if n.Resolved() { - return n, nil - } - - t, ok := n.(*plan.UnresolvedTable) - if !ok { - return n, nil - } - - rt, err := a.Catalog.Table(a.CurrentDatabase, t.Name) - if err != nil { - if sql.ErrTableNotFound.Is(err) && t.Name == dualTable.Name() { - rt = dualTable - } else { - return nil, err - } - } - - a.Log("table resolved: %q", rt.Name()) - - return rt, nil - }) -} - -func resolveStar(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { - span, ctx := ctx.Span("resolve_star") - defer span.Finish() - - a.Log("resolving star, node of type: %T", n) - return n.TransformUp(func(n sql.Node) (sql.Node, error) { - a.Log("transforming node of type: %T", n) - if n.Resolved() { - return n, nil - } - - p, ok := n.(*plan.Project) - if !ok { - return n, nil - } - - var expressions []sql.Expression - schema := p.Child.Schema() - for _, e := range p.Projections { - if s, ok := e.(*expression.Star); ok { - var exprs []sql.Expression - for i, col := range schema { - if s.Table == "" || s.Table == col.Source { - exprs = append(exprs, expression.NewGetFieldWithTable( - i, col.Type, col.Source, col.Name, col.Nullable, - )) - } - } - - if len(exprs) == 0 && s.Table != "" { - return nil, sql.ErrTableNotFound.New(s.Table) - } - - a.Log("%s replaced with %d fields", e, len(exprs)) - expressions = append(expressions, exprs...) - } else { - expressions = append(expressions, e) - } - } - - return plan.NewProject(expressions, p.Child), nil - }) -} - -type columnInfo struct { - idx int - col *sql.Column -} - -// maybeAlias is a wrapper on UnresolvedColumn used only to defer the -// resolution of the column because it could be an alias and that -// phase of the analyzer has not run yet. -type maybeAlias struct { - *expression.UnresolvedColumn -} - -func (e maybeAlias) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { - return fn(e) -} - -// column is the common interface that groups UnresolvedColumn and maybeAlias. -type column interface { - sql.Nameable - sql.Tableable - sql.Expression -} - -func resolveColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { - span, ctx := ctx.Span("resolve_columns") - defer span.Finish() - - a.Log("resolve columns, node of type: %T", n) - return n.TransformUp(func(n sql.Node) (sql.Node, error) { - a.Log("transforming node of type: %T", n) - if n.Resolved() { - return n, nil - } - - colMap := make(map[string][]columnInfo) - idx := 0 - for _, child := range n.Children() { - if !child.Resolved() { - return n, nil - } - - for _, col := range child.Schema() { - colMap[col.Name] = append(colMap[col.Name], columnInfo{idx, col}) - idx++ - } - } - - return n.TransformExpressionsUp(func(e sql.Expression) (sql.Expression, error) { - a.Log("transforming expression of type: %T", e) - if n.Resolved() { - return e, nil - } - - uc, ok := e.(column) - if !ok { - return e, nil - } - - columnsInfo, ok := colMap[uc.Name()] - if !ok { - if uc.Table() != "" { - return nil, ErrColumnTableNotFound.New(uc.Table(), uc.Name()) - } - - switch uc := uc.(type) { - case *expression.UnresolvedColumn: - return &maybeAlias{uc}, nil - default: - return nil, ErrColumnNotFound.New(uc.Name()) - } - } - - var ci columnInfo - var found bool - for _, c := range columnsInfo { - if c.col.Source == uc.Table() { - ci = c - found = true - break - } - } - - if !found { - if uc.Table() != "" { - return nil, ErrColumnTableNotFound.New(uc.Table(), uc.Name()) - } - - switch uc := uc.(type) { - case *expression.UnresolvedColumn: - return &maybeAlias{uc}, nil - default: - return nil, ErrColumnNotFound.New(uc.Name()) - } - } - - a.Log("column resolved to %q.%q", ci.col.Source, ci.col.Name) - - return expression.NewGetFieldWithTable( - ci.idx, - ci.col.Type, - ci.col.Source, - ci.col.Name, - ci.col.Nullable, - ), nil - }) - }) -} - -func resolveFunctions(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { - span, ctx := ctx.Span("resolve_functions") - defer span.Finish() - - a.Log("resolve functions, node of type %T", n) - return n.TransformUp(func(n sql.Node) (sql.Node, error) { - a.Log("transforming node of type: %T", n) - if n.Resolved() { - return n, nil - } - - return n.TransformExpressionsUp(func(e sql.Expression) (sql.Expression, error) { - a.Log("transforming expression of type: %T", e) - if e.Resolved() { - return e, nil - } - - uf, ok := e.(*expression.UnresolvedFunction) - if !ok { - return e, nil - } - - n := uf.Name() - f, err := a.Catalog.Function(n) - if err != nil { - return nil, err - } - - rf, err := f.Call(uf.Arguments...) - if err != nil { - return nil, err - } - - a.Log("resolved function %q", n) - - return rf, nil - }) - }) -} - -func optimizeDistinct(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, error) { - span, ctx := ctx.Span("optimize_distinct") - defer span.Finish() - - a.Log("optimize distinct, node of type: %T", node) - if node, ok := node.(*plan.Distinct); ok { - var isSorted bool - _, _ = node.TransformUp(func(node sql.Node) (sql.Node, error) { - a.Log("checking for optimization in node of type: %T", node) - if _, ok := node.(*plan.Sort); ok { - isSorted = true - } - return node, nil - }) - - if isSorted { - a.Log("distinct optimized for ordered output") - return plan.NewOrderedDistinct(node.Child), nil - } - } - - return node, nil -} - -var errInvalidNodeType = errors.NewKind("reorder projection: invalid node of type: %T") - -func reorderProjection(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { - span, ctx := ctx.Span("reorder_projection") - defer span.Finish() - - if n.Resolved() { - return n, nil - } - - a.Log("reorder projection, node of type: %T", n) - - // Then we transform the projection - return n.TransformUp(func(node sql.Node) (sql.Node, error) { - project, ok := node.(*plan.Project) - if !ok { - return node, nil - } - - // We must find all columns that may need to be moved inside the - // projection. - //var movedColumns = make(map[string]sql.Expression) - var newColumns = make(map[string]sql.Expression) - for _, col := range project.Projections { - alias, ok := col.(*expression.Alias) - if ok { - newColumns[alias.Name()] = col - } - } - - // And add projection nodes where needed in the child tree. - var didNeedReorder bool - child, err := project.Child.TransformUp(func(node sql.Node) (sql.Node, error) { - var requiredColumns []string - switch node := node.(type) { - case *plan.Sort, *plan.Filter: - for _, expr := range node.(sql.Expressioner).Expressions() { - expression.Inspect(expr, func(e sql.Expression) bool { - if e != nil && e.Resolved() { - return true - } - - uc, ok := e.(column) - if ok && uc.Table() == "" { - if _, ok := newColumns[uc.Name()]; ok { - requiredColumns = append(requiredColumns, uc.Name()) - } - } - - return true - }) - } - default: - return node, nil - } - - didNeedReorder = true - - // Only add the required columns for that node in the projection. - child := node.Children()[0] - schema := child.Schema() - var projections = make([]sql.Expression, 0, len(schema)+len(requiredColumns)) - for i, col := range schema { - projections = append(projections, expression.NewGetFieldWithTable( - i, col.Type, col.Source, col.Name, col.Nullable, - )) - } - - for _, col := range requiredColumns { - projections = append(projections, newColumns[col]) - delete(newColumns, col) - } - - child = plan.NewProject(projections, child) - switch node := node.(type) { - case *plan.Filter: - return plan.NewFilter(node.Expression, child), nil - case *plan.Sort: - return plan.NewSort(node.SortFields, child), nil - default: - return nil, errInvalidNodeType.New(node) - } - }) - - if err != nil { - return nil, err - } - - if !didNeedReorder { - return project, nil - } - - child, err = resolveColumns(ctx, a, child) - if err != nil { - return nil, err - } - - childSchema := child.Schema() - // Finally, replace the columns we moved with GetFields since they - // have already been projected. - var projections = make([]sql.Expression, len(project.Projections)) - for i, p := range project.Projections { - if alias, ok := p.(*expression.Alias); ok { - var found bool - for idx, col := range childSchema { - if col.Name == alias.Name() { - projections[i] = expression.NewGetField( - idx, col.Type, col.Name, col.Nullable, - ) - found = true - break - } - } - - if !found { - projections[i] = p - } - } else { - projections[i] = p - } - } - - return plan.NewProject(projections, child), nil - }) -} - -func eraseProjection(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, error) { - span, ctx := ctx.Span("erase_projection") - defer span.Finish() - - if !node.Resolved() { - return node, nil - } - - a.Log("erase projection, node of type: %T", node) - - return node.TransformUp(func(node sql.Node) (sql.Node, error) { - project, ok := node.(*plan.Project) - if ok && project.Schema().Equals(project.Child.Schema()) { - a.Log("project erased") - return project.Child, nil - } - - return node, nil - }) -} - -func dedupStrings(in []string) []string { - var seen = make(map[string]struct{}) - var result []string - for _, s := range in { - if _, ok := seen[s]; !ok { - seen[s] = struct{}{} - result = append(result, s) - } - } - return result -} - -// indexCatalog sets the catalog in the CreateIndex nodes. -func indexCatalog(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { - if !n.Resolved() { - return n, nil - } - - ci, ok := n.(*plan.CreateIndex) - if !ok { - return n, nil - } - - span, ctx := ctx.Span("index_catalog") - defer span.Finish() - - nc := *ci - ci.Catalog = a.Catalog - ci.CurrentDatabase = a.CurrentDatabase - - return &nc, nil -} - -func pushdown(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { - span, ctx := ctx.Span("pushdown") - defer span.Finish() - - a.Log("pushdown, node of type: %T", n) - if !n.Resolved() { - return n, nil - } - - var fieldsByTable = make(map[string][]string) - var exprsByTable = make(map[string][]sql.Expression) - type tableField struct { - table string - field string - } - var tableFields = make(map[tableField]struct{}) - - a.Log("finding used columns in node") - - colSpan, _ := ctx.Span("find_pushdown_columns") - - // First step is to find all col exprs and group them by the table they mention. - // Even if they appear multiple times, only the first one will be used. - plan.InspectExpressions(n, func(e sql.Expression) bool { - if e, ok := e.(*expression.GetField); ok { - tf := tableField{e.Table(), e.Name()} - if _, ok := tableFields[tf]; !ok { - a.Log("found used column %s.%s", e.Table(), e.Name()) - tableFields[tf] = struct{}{} - fieldsByTable[e.Table()] = append(fieldsByTable[e.Table()], e.Name()) - exprsByTable[e.Table()] = append(exprsByTable[e.Table()], e) - } - } - return true - }) - - colSpan.Finish() - - a.Log("finding filters in node") - - filterSpan, _ := ctx.Span("find_pushdown_filters") - - // then find all filters, also by table. Note that filters that mention - // more than one table will not be passed to neither. - filters := make(filters) - plan.Inspect(n, func(node sql.Node) bool { - a.Log("inspecting node of type: %T", node) - switch node := node.(type) { - case *plan.Filter: - fs := exprToTableFilters(node.Expression) - a.Log("found filters for %d tables %s", len(fs), node.Expression) - filters.merge(fs) - } - return true - }) - - filterSpan.Finish() - - a.Log("transforming nodes with pushdown of filters and projections") - - // Now all nodes can be transformed. Since traversal of the tree is done - // from inner to outer the filters have to be processed first so they get - // to the tables. - var handledFilters []sql.Expression - return n.TransformUp(func(node sql.Node) (sql.Node, error) { - a.Log("transforming node of type: %T", node) - switch node := node.(type) { - case *plan.Filter: - if len(handledFilters) == 0 { - a.Log("no handled filters, leaving filter untouched") - return node, nil - } - - unhandled := getUnhandledFilters( - splitExpression(node.Expression), - handledFilters, - ) - - if len(unhandled) == 0 { - a.Log("filter node has no unhandled filters, so it will be removed") - return node.Child, nil - } - - a.Log( - "%d handled filters removed from filter node, filter has now %d filters", - len(handledFilters), - len(unhandled), - ) - - return plan.NewFilter(expression.JoinAnd(unhandled...), node.Child), nil - case *plan.PushdownProjectionAndFiltersTable, *plan.PushdownProjectionTable: - // they also implement the interfaces for pushdown, so we better return - // or there will be a very nice infinite loop - return node, nil - case sql.PushdownProjectionAndFiltersTable: - cols := exprsByTable[node.Name()] - tableFilters := filters[node.Name()] - handled := node.HandledFilters(tableFilters) - handledFilters = append(handledFilters, handled...) - - a.Log( - "table %q transformed with pushdown of projection and filters, %d filters handled of %d", - node.Name(), - len(handled), - len(tableFilters), - ) - - schema := node.Schema() - cols, err := fixFieldIndexesOnExpressions(schema, cols...) - if err != nil { - return nil, err - } - - handled, err = fixFieldIndexesOnExpressions(schema, handled...) - if err != nil { - return nil, err - } - - return plan.NewPushdownProjectionAndFiltersTable( - cols, - handled, - node, - ), nil - case sql.PushdownProjectionTable: - cols := fieldsByTable[node.Name()] - a.Log("table %q transformed with pushdown of projection", node.Name()) - return plan.NewPushdownProjectionTable(cols, node), nil - } - return node, nil - }) -} - -// fixFieldIndexesOnExpressions executes fixFieldIndexes on a list of exprs. -func fixFieldIndexesOnExpressions(schema sql.Schema, expressions ...sql.Expression) ([]sql.Expression, error) { - var result = make([]sql.Expression, len(expressions)) - for i, e := range expressions { - var err error - result[i], err = fixFieldIndexes(schema, e) - if err != nil { - return nil, err - } - } - return result, nil -} - -// fixFieldIndexes transforms the given expression setting correct indexes -// for GetField expressions according to the schema of the row in the table -// and not the one where the filter came from. -func fixFieldIndexes(schema sql.Schema, exp sql.Expression) (sql.Expression, error) { - return exp.TransformUp(func(e sql.Expression) (sql.Expression, error) { - switch e := e.(type) { - case *expression.GetField: - // we need to rewrite the indexes for the table row - for i, col := range schema { - if e.Name() == col.Name { - return expression.NewGetFieldWithTable( - i, - e.Type(), - e.Table(), - e.Name(), - e.IsNullable(), - ), nil - } - } - - return nil, ErrFieldMissing.New(e.Name()) - } - - return e, nil - }) -} diff --git a/sql/analyzer/rules_test.go b/sql/analyzer/rules_test.go deleted file mode 100644 index d4748c9d3..000000000 --- a/sql/analyzer/rules_test.go +++ /dev/null @@ -1,762 +0,0 @@ -package analyzer - -import ( - "testing" - - "github.com/stretchr/testify/require" - - "gopkg.in/src-d/go-mysql-server.v0/mem" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" - "gopkg.in/src-d/go-mysql-server.v0/sql/plan" -) - -func TestResolveSubqueries(t *testing.T) { - require := require.New(t) - - table1 := mem.NewTable("foo", sql.Schema{{Name: "a", Type: sql.Int64, Source: "foo"}}) - table2 := mem.NewTable("bar", sql.Schema{ - {Name: "b", Type: sql.Int64, Source: "bar"}, - {Name: "k", Type: sql.Int64, Source: "bar"}, - }) - table3 := mem.NewTable("baz", sql.Schema{{Name: "c", Type: sql.Int64, Source: "baz"}}) - db := mem.NewDatabase("mydb") - db.AddTable("foo", table1) - db.AddTable("bar", table2) - db.AddTable("baz", table3) - - catalog := &sql.Catalog{Databases: []sql.Database{db}} - a := New(catalog) - a.CurrentDatabase = "mydb" - - // SELECT * FROM - // (SELECT a FROM foo) t1, - // (SELECT b FROM (SELECT b FROM bar) t2alias) t2, - // baz - node := plan.NewProject( - []sql.Expression{expression.NewStar()}, - plan.NewCrossJoin( - plan.NewCrossJoin( - plan.NewSubqueryAlias( - "t1", - plan.NewProject( - []sql.Expression{expression.NewUnresolvedColumn("a")}, - plan.NewUnresolvedTable("foo"), - ), - ), - plan.NewSubqueryAlias( - "t2", - plan.NewProject( - []sql.Expression{expression.NewUnresolvedColumn("b")}, - plan.NewSubqueryAlias( - "t2alias", - plan.NewProject( - []sql.Expression{expression.NewUnresolvedColumn("b")}, - plan.NewUnresolvedTable("bar"), - ), - ), - ), - ), - ), - plan.NewUnresolvedTable("baz"), - ), - ) - - subquery := plan.NewSubqueryAlias( - "t2alias", - plan.NewProject( - []sql.Expression{ - expression.NewGetFieldWithTable(0, sql.Int64, "bar", "b", false), - }, - table2, - ), - ) - _ = subquery.Schema() - - expected := plan.NewProject( - []sql.Expression{expression.NewStar()}, - plan.NewCrossJoin( - plan.NewCrossJoin( - plan.NewSubqueryAlias( - "t1", - table1, - ), - plan.NewSubqueryAlias( - "t2", - subquery, - ), - ), - plan.NewUnresolvedTable("baz"), - ), - ) - - result, err := resolveSubqueries(sql.NewEmptyContext(), a, node) - require.NoError(err) - - require.Equal(expected, result) -} - -func TestResolveTables(t *testing.T) { - require := require.New(t) - - f := getRule("resolve_tables") - - table := mem.NewTable("mytable", sql.Schema{{Name: "i", Type: sql.Int32}}) - db := mem.NewDatabase("mydb") - db.AddTable("mytable", table) - - catalog := &sql.Catalog{Databases: []sql.Database{db}} - - a := New(catalog) - a.Rules = []Rule{f} - - a.CurrentDatabase = "mydb" - var notAnalyzed sql.Node = plan.NewUnresolvedTable("mytable") - analyzed, err := f.Apply(sql.NewEmptyContext(), a, notAnalyzed) - require.NoError(err) - require.Equal(table, analyzed) - - notAnalyzed = plan.NewUnresolvedTable("nonexistant") - analyzed, err = f.Apply(sql.NewEmptyContext(), a, notAnalyzed) - require.Error(err) - require.Nil(analyzed) - - analyzed, err = f.Apply(sql.NewEmptyContext(), a, table) - require.NoError(err) - require.Equal(table, analyzed) - - notAnalyzed = plan.NewUnresolvedTable("dual") - analyzed, err = f.Apply(sql.NewEmptyContext(), a, notAnalyzed) - require.NoError(err) - require.Equal(dualTable, analyzed) -} - -func TestResolveTablesNested(t *testing.T) { - require := require.New(t) - - f := getRule("resolve_tables") - - table := mem.NewTable("mytable", sql.Schema{{Name: "i", Type: sql.Int32}}) - db := mem.NewDatabase("mydb") - db.AddTable("mytable", table) - - catalog := &sql.Catalog{Databases: []sql.Database{db}} - - a := New(catalog) - a.Rules = []Rule{f} - a.CurrentDatabase = "mydb" - - notAnalyzed := plan.NewProject( - []sql.Expression{expression.NewGetField(0, sql.Int32, "i", true)}, - plan.NewUnresolvedTable("mytable"), - ) - analyzed, err := f.Apply(sql.NewEmptyContext(), a, notAnalyzed) - require.NoError(err) - expected := plan.NewProject( - []sql.Expression{expression.NewGetField(0, sql.Int32, "i", true)}, - table, - ) - require.Equal(expected, analyzed) -} - -func TestResolveOrderByLiterals(t *testing.T) { - require := require.New(t) - f := getRule("resolve_orderby_literals") - - table := mem.NewTable("t", sql.Schema{ - {Name: "a", Type: sql.Int64, Source: "t"}, - {Name: "b", Type: sql.Int64, Source: "t"}, - }) - - node := plan.NewSort( - []plan.SortField{ - {Column: expression.NewLiteral(int64(2), sql.Int64)}, - {Column: expression.NewLiteral(int64(1), sql.Int64)}, - }, - table, - ) - - result, err := f.Apply(sql.NewEmptyContext(), New(nil), node) - require.NoError(err) - - require.Equal( - plan.NewSort( - []plan.SortField{ - {Column: expression.NewUnresolvedColumn("b")}, - {Column: expression.NewUnresolvedColumn("a")}, - }, - table, - ), - result, - ) - - node = plan.NewSort( - []plan.SortField{ - {Column: expression.NewLiteral(int64(3), sql.Int64)}, - {Column: expression.NewLiteral(int64(1), sql.Int64)}, - }, - table, - ) - - _, err = f.Apply(sql.NewEmptyContext(), New(nil), node) - require.Error(err) - require.True(ErrOrderByColumnIndex.Is(err)) -} - -func TestResolveStar(t *testing.T) { - f := getRule("resolve_star") - - table := mem.NewTable("mytable", sql.Schema{ - {Name: "a", Type: sql.Int32, Source: "mytable"}, - {Name: "b", Type: sql.Int32, Source: "mytable"}, - }) - - table2 := mem.NewTable("mytable2", sql.Schema{ - {Name: "c", Type: sql.Int32, Source: "mytable2"}, - {Name: "d", Type: sql.Int32, Source: "mytable2"}, - }) - - testCases := []struct { - name string - node sql.Node - expected sql.Node - }{ - { - "unqualified star", - plan.NewProject( - []sql.Expression{expression.NewStar()}, - table, - ), - plan.NewProject( - []sql.Expression{ - expression.NewGetFieldWithTable(0, sql.Int32, "mytable", "a", false), - expression.NewGetFieldWithTable(1, sql.Int32, "mytable", "b", false), - }, - table, - ), - }, - { - "qualified star", - plan.NewProject( - []sql.Expression{expression.NewQualifiedStar("mytable2")}, - plan.NewCrossJoin(table, table2), - ), - plan.NewProject( - []sql.Expression{ - expression.NewGetFieldWithTable(2, sql.Int32, "mytable2", "c", false), - expression.NewGetFieldWithTable(3, sql.Int32, "mytable2", "d", false), - }, - plan.NewCrossJoin(table, table2), - ), - }, - { - "qualified star and unqualified star", - plan.NewProject( - []sql.Expression{ - expression.NewStar(), - expression.NewQualifiedStar("mytable2"), - }, - plan.NewCrossJoin(table, table2), - ), - plan.NewProject( - []sql.Expression{ - expression.NewGetFieldWithTable(0, sql.Int32, "mytable", "a", false), - expression.NewGetFieldWithTable(1, sql.Int32, "mytable", "b", false), - expression.NewGetFieldWithTable(2, sql.Int32, "mytable2", "c", false), - expression.NewGetFieldWithTable(3, sql.Int32, "mytable2", "d", false), - expression.NewGetFieldWithTable(2, sql.Int32, "mytable2", "c", false), - expression.NewGetFieldWithTable(3, sql.Int32, "mytable2", "d", false), - }, - plan.NewCrossJoin(table, table2), - ), - }, - { - "stars mixed with other expressions", - plan.NewProject( - []sql.Expression{ - expression.NewStar(), - expression.NewUnresolvedColumn("foo"), - expression.NewQualifiedStar("mytable2"), - }, - plan.NewCrossJoin(table, table2), - ), - plan.NewProject( - []sql.Expression{ - expression.NewGetFieldWithTable(0, sql.Int32, "mytable", "a", false), - expression.NewGetFieldWithTable(1, sql.Int32, "mytable", "b", false), - expression.NewGetFieldWithTable(2, sql.Int32, "mytable2", "c", false), - expression.NewGetFieldWithTable(3, sql.Int32, "mytable2", "d", false), - expression.NewUnresolvedColumn("foo"), - expression.NewGetFieldWithTable(2, sql.Int32, "mytable2", "c", false), - expression.NewGetFieldWithTable(3, sql.Int32, "mytable2", "d", false), - }, - plan.NewCrossJoin(table, table2), - ), - }, - } - - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - result, err := f.Apply(sql.NewEmptyContext(), nil, tt.node) - require.NoError(t, err) - require.Equal(t, tt.expected, result) - }) - } -} - -func TestQualifyColumns(t *testing.T) { - require := require.New(t) - f := getRule("qualify_columns") - - table := mem.NewTable("mytable", sql.Schema{{Name: "i", Type: sql.Int32}}) - table2 := mem.NewTable("mytable2", sql.Schema{{Name: "i", Type: sql.Int32}}) - - node := plan.NewProject( - []sql.Expression{ - expression.NewUnresolvedColumn("i"), - }, - table, - ) - - expected := plan.NewProject( - []sql.Expression{ - expression.NewUnresolvedQualifiedColumn("mytable", "i"), - }, - table, - ) - - result, err := f.Apply(sql.NewEmptyContext(), nil, node) - require.NoError(err) - require.Equal(expected, result) - - node = plan.NewProject( - []sql.Expression{ - expression.NewUnresolvedQualifiedColumn("mytable", "i"), - }, - table, - ) - - result, err = f.Apply(sql.NewEmptyContext(), nil, node) - require.NoError(err) - require.Equal(expected, result) - - node = plan.NewProject( - []sql.Expression{ - expression.NewUnresolvedQualifiedColumn("a", "i"), - }, - plan.NewTableAlias("a", table), - ) - - expected = plan.NewProject( - []sql.Expression{ - expression.NewUnresolvedQualifiedColumn("mytable", "i"), - }, - plan.NewTableAlias("a", table), - ) - - result, err = f.Apply(sql.NewEmptyContext(), nil, node) - require.NoError(err) - require.Equal(expected, result) - - node = plan.NewProject( - []sql.Expression{ - expression.NewUnresolvedColumn("z"), - }, - plan.NewTableAlias("a", table), - ) - - result, err = f.Apply(sql.NewEmptyContext(), nil, node) - require.NoError(err) - require.Equal(node, result) - - node = plan.NewProject( - []sql.Expression{ - expression.NewUnresolvedQualifiedColumn("foo", "i"), - }, - plan.NewTableAlias("a", table), - ) - - result, err = f.Apply(sql.NewEmptyContext(), nil, node) - require.Error(err) - require.True(sql.ErrTableNotFound.Is(err)) - - node = plan.NewProject( - []sql.Expression{ - expression.NewUnresolvedColumn("i"), - }, - plan.NewCrossJoin(table, table2), - ) - - _, err = f.Apply(sql.NewEmptyContext(), nil, node) - require.Error(err) - require.True(ErrAmbiguousColumnName.Is(err)) - - subquery := plan.NewSubqueryAlias( - "b", - plan.NewProject( - []sql.Expression{ - expression.NewGetFieldWithTable(0, sql.Int64, "mytable", "i", false), - }, - table, - ), - ) - // preload schema - _ = subquery.Schema() - - node = plan.NewProject( - []sql.Expression{ - expression.NewUnresolvedQualifiedColumn("a", "i"), - }, - plan.NewCrossJoin( - plan.NewTableAlias("a", table), - subquery, - ), - ) - - expected = plan.NewProject( - []sql.Expression{ - expression.NewUnresolvedQualifiedColumn("mytable", "i"), - }, - plan.NewCrossJoin( - plan.NewTableAlias("a", table), - subquery, - ), - ) - - result, err = f.Apply(sql.NewEmptyContext(), nil, node) - require.NoError(err) - require.Equal(expected, result) -} - -func TestReorderProjection(t *testing.T) { - require := require.New(t) - f := getRule("reorder_projection") - - table := mem.NewTable("mytable", sql.Schema{{ - Name: "i", Source: "mytable", Type: sql.Int64, - }}) - - node := plan.NewProject( - []sql.Expression{ - expression.NewGetFieldWithTable(0, sql.Int64, "mytable", "i", false), - expression.NewAlias(expression.NewLiteral(1, sql.Int64), "foo"), - expression.NewAlias(expression.NewLiteral(2, sql.Int64), "bar"), - }, - plan.NewSort( - []plan.SortField{ - {Column: expression.NewUnresolvedColumn("foo")}, - }, - plan.NewFilter( - expression.NewEquals( - expression.NewLiteral(1, sql.Int64), - expression.NewUnresolvedColumn("bar"), - ), - table, - ), - ), - ) - - expected := plan.NewProject( - []sql.Expression{ - expression.NewGetFieldWithTable(0, sql.Int64, "mytable", "i", false), - expression.NewGetField(2, sql.Int64, "foo", false), - expression.NewGetField(1, sql.Int64, "bar", false), - }, - plan.NewSort( - []plan.SortField{{Column: expression.NewGetField(2, sql.Int64, "foo", false)}}, - plan.NewProject( - []sql.Expression{ - expression.NewGetFieldWithTable(0, sql.Int64, "mytable", "i", false), - expression.NewGetField(1, sql.Int64, "bar", false), - expression.NewAlias(expression.NewLiteral(1, sql.Int64), "foo"), - }, - plan.NewFilter( - expression.NewEquals( - expression.NewLiteral(1, sql.Int64), - expression.NewGetField(1, sql.Int64, "bar", false), - ), - plan.NewProject( - []sql.Expression{ - expression.NewGetFieldWithTable(0, sql.Int64, "mytable", "i", false), - expression.NewAlias(expression.NewLiteral(2, sql.Int64), "bar"), - }, - table, - ), - ), - ), - ), - ) - - result, err := f.Apply(sql.NewEmptyContext(), New(nil), node) - require.NoError(err) - - require.Equal(expected, result) -} - -func TestEraseProjection(t *testing.T) { - require := require.New(t) - f := getRule("erase_projection") - - table := mem.NewTable("mytable", sql.Schema{{ - Name: "i", Source: "mytable", Type: sql.Int64, - }}) - - expected := plan.NewSort( - []plan.SortField{{Column: expression.NewGetField(2, sql.Int64, "foo", false)}}, - plan.NewProject( - []sql.Expression{ - expression.NewGetFieldWithTable(0, sql.Int64, "mytable", "i", false), - expression.NewGetField(1, sql.Int64, "bar", false), - expression.NewAlias(expression.NewLiteral(1, sql.Int64), "foo"), - }, - plan.NewFilter( - expression.NewEquals( - expression.NewLiteral(1, sql.Int64), - expression.NewGetField(1, sql.Int64, "bar", false), - ), - plan.NewProject( - []sql.Expression{ - expression.NewGetFieldWithTable(0, sql.Int64, "mytable", "i", false), - expression.NewAlias(expression.NewLiteral(2, sql.Int64), "bar"), - }, - table, - ), - ), - ), - ) - - node := plan.NewProject( - []sql.Expression{ - expression.NewGetFieldWithTable(0, sql.Int64, "mytable", "i", false), - expression.NewGetField(1, sql.Int64, "bar", false), - expression.NewGetField(2, sql.Int64, "foo", false), - }, - expected, - ) - - result, err := f.Apply(sql.NewEmptyContext(), New(nil), node) - require.NoError(err) - - require.Equal(expected, result) - - result, err = f.Apply(sql.NewEmptyContext(), New(nil), expected) - require.NoError(err) - - require.Equal(expected, result) -} - -func TestOptimizeDistinct(t *testing.T) { - require := require.New(t) - notSorted := plan.NewDistinct(mem.NewTable("foo", nil)) - sorted := plan.NewDistinct(plan.NewSort(nil, mem.NewTable("foo", nil))) - - rule := getRule("optimize_distinct") - - analyzedNotSorted, err := rule.Apply(sql.NewEmptyContext(), nil, notSorted) - require.NoError(err) - - analyzedSorted, err := rule.Apply(sql.NewEmptyContext(), nil, sorted) - require.NoError(err) - - require.Equal(notSorted, analyzedNotSorted) - require.Equal(plan.NewOrderedDistinct(sorted.Child), analyzedSorted) -} - -func TestPushdownProjection(t *testing.T) { - require := require.New(t) - f := getRule("pushdown") - - table := &pushdownProjectionTable{mem.NewTable("mytable", sql.Schema{ - {Name: "i", Type: sql.Int32}, - {Name: "f", Type: sql.Float64}, - {Name: "t", Type: sql.Text}, - })} - - table2 := &pushdownProjectionTable{mem.NewTable("mytable2", sql.Schema{ - {Name: "i2", Type: sql.Int32}, - {Name: "f2", Type: sql.Float64}, - {Name: "t2", Type: sql.Text}, - })} - - node := plan.NewProject( - []sql.Expression{ - expression.NewGetFieldWithTable(0, sql.Int32, "mytable", "i", false), - }, - plan.NewFilter( - expression.NewAnd( - expression.NewEquals( - expression.NewGetFieldWithTable(1, sql.Float64, "mytable", "f", false), - expression.NewLiteral(3.14, sql.Float64), - ), - expression.NewIsNull( - expression.NewGetFieldWithTable(0, sql.Int32, "mytable2", "i2", false), - ), - ), - plan.NewCrossJoin(table, table2), - ), - ) - - expected := plan.NewProject( - []sql.Expression{ - expression.NewGetFieldWithTable(0, sql.Int32, "mytable", "i", false), - }, - plan.NewFilter( - expression.NewAnd( - expression.NewEquals( - expression.NewGetFieldWithTable(1, sql.Float64, "mytable", "f", false), - expression.NewLiteral(3.14, sql.Float64), - ), - expression.NewIsNull( - expression.NewGetFieldWithTable(0, sql.Int32, "mytable2", "i2", false), - ), - ), - plan.NewCrossJoin( - plan.NewPushdownProjectionTable([]string{"i", "f"}, table), - plan.NewPushdownProjectionTable([]string{"i2"}, table2), - ), - ), - ) - - result, err := f.Apply(sql.NewEmptyContext(), nil, node) - require.NoError(err) - require.Equal(expected, result) -} - -func TestPushdownProjectionAndFilters(t *testing.T) { - require := require.New(t) - a := New(nil) - - table := &pushdownProjectionAndFiltersTable{mem.NewTable("mytable", sql.Schema{ - {Name: "i", Type: sql.Int32, Source: "mytable"}, - {Name: "f", Type: sql.Float64, Source: "mytable"}, - {Name: "t", Type: sql.Text, Source: "mytable"}, - })} - - table2 := &pushdownProjectionAndFiltersTable{mem.NewTable("mytable2", sql.Schema{ - {Name: "i2", Type: sql.Int32, Source: "mytable2"}, - {Name: "f2", Type: sql.Float64, Source: "mytable2"}, - {Name: "t2", Type: sql.Text, Source: "mytable2"}, - })} - - node := plan.NewProject( - []sql.Expression{ - expression.NewUnresolvedQualifiedColumn("mytable", "i"), - }, - plan.NewFilter( - expression.NewAnd( - expression.NewAnd( - expression.NewEquals( - expression.NewUnresolvedQualifiedColumn("mytable", "f"), - expression.NewLiteral(3.14, sql.Float64), - ), - expression.NewGreaterThan( - expression.NewUnresolvedQualifiedColumn("mytable", "f"), - expression.NewLiteral(3., sql.Float64), - ), - ), - expression.NewIsNull( - expression.NewUnresolvedQualifiedColumn("mytable2", "i2"), - ), - ), - plan.NewCrossJoin(table, table2), - ), - ) - - expected := plan.NewProject( - []sql.Expression{ - expression.NewGetFieldWithTable(0, sql.Int32, "mytable", "i", false), - }, - plan.NewFilter( - expression.NewAnd( - expression.NewGreaterThan( - expression.NewGetFieldWithTable(1, sql.Float64, "mytable", "f", false), - expression.NewLiteral(3., sql.Float64), - ), - expression.NewIsNull( - expression.NewGetFieldWithTable(3, sql.Int32, "mytable2", "i2", false), - ), - ), - plan.NewCrossJoin( - plan.NewPushdownProjectionAndFiltersTable( - []sql.Expression{ - expression.NewGetFieldWithTable(0, sql.Int32, "mytable", "i", false), - expression.NewGetFieldWithTable(1, sql.Float64, "mytable", "f", false), - }, - []sql.Expression{ - expression.NewEquals( - expression.NewGetFieldWithTable(1, sql.Float64, "mytable", "f", false), - expression.NewLiteral(3.14, sql.Float64), - ), - }, - table, - ), - plan.NewPushdownProjectionAndFiltersTable( - []sql.Expression{ - expression.NewGetFieldWithTable(0, sql.Int32, "mytable2", "i2", false), - }, - nil, - table2, - ), - ), - ), - ) - - result, err := a.Analyze(sql.NewEmptyContext(), node) - require.NoError(err) - require.Equal(expected, result) -} - -type pushdownProjectionTable struct { - sql.Table -} - -var _ sql.PushdownProjectionTable = (*pushdownProjectionTable)(nil) - -func (pushdownProjectionTable) WithProject(*sql.Context, []string) (sql.RowIter, error) { - panic("not implemented") -} - -func (t *pushdownProjectionTable) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(t) -} - -func (t *pushdownProjectionTable) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - return t, nil -} - -type pushdownProjectionAndFiltersTable struct { - sql.Table -} - -var _ sql.PushdownProjectionAndFiltersTable = (*pushdownProjectionAndFiltersTable)(nil) - -func (pushdownProjectionAndFiltersTable) HandledFilters(filters []sql.Expression) []sql.Expression { - var handled []sql.Expression - for _, f := range filters { - if eq, ok := f.(*expression.Equals); ok { - handled = append(handled, eq) - } - } - return handled -} - -func (pushdownProjectionAndFiltersTable) WithProjectAndFilters(_ *sql.Context, cols, filters []sql.Expression) (sql.RowIter, error) { - panic("not implemented") -} - -func (t *pushdownProjectionAndFiltersTable) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(t) -} - -func (t *pushdownProjectionAndFiltersTable) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - return t, nil -} - -func getRule(name string) Rule { - for _, rule := range DefaultRules { - if rule.Name == name { - return rule - } - } - panic("missing rule") -} diff --git a/sql/analyzer/validation_rules.go b/sql/analyzer/validation_rules.go index c1815a94a..7d35ee6be 100644 --- a/sql/analyzer/validation_rules.go +++ b/sql/analyzer/validation_rules.go @@ -3,19 +3,24 @@ package analyzer import ( "strings" - errors "gopkg.in/src-d/go-errors.v1" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" - "gopkg.in/src-d/go-mysql-server.v0/sql/plan" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/expression/function" + "github.com/src-d/go-mysql-server/sql/plan" + "gopkg.in/src-d/go-errors.v1" ) const ( - validateResolvedRule = "validate_resolved" - validateOrderByRule = "validate_order_by" - validateGroupByRule = "validate_group_by" - validateSchemaSourceRule = "validate_schema_source" - validateProjectTuplesRule = "validate_project_tuples" - validateIndexCreationRule = "validate_index_creation" + validateResolvedRule = "validate_resolved" + validateOrderByRule = "validate_order_by" + validateGroupByRule = "validate_group_by" + validateSchemaSourceRule = "validate_schema_source" + validateProjectTuplesRule = "validate_project_tuples" + validateIndexCreationRule = "validate_index_creation" + validateCaseResultTypesRule = "validate_case_result_types" + validateIntervalUsageRule = "validate_interval_usage" + validateExplodeUsageRule = "validate_explode_usage" + validateSubqueryColumnsRule = "validate_subquery_columns" ) var ( @@ -29,38 +34,65 @@ var ( ErrValidationGroupBy = errors.NewKind("GroupBy aggregate expression '%v' doesn't appear in the grouping columns") // ErrValidationSchemaSource is returned when there is any column source // that does not match the table name. - ErrValidationSchemaSource = errors.NewKind("all schema column sources don't match table name, expecting %q, but found: %s") + ErrValidationSchemaSource = errors.NewKind("one or more schema sources are empty") // ErrProjectTuple is returned when there is a tuple of more than 1 column // inside a projection. ErrProjectTuple = errors.NewKind("selected field %d should have 1 column, but has %d") // ErrUnknownIndexColumns is returned when there are columns in the expr // to index that are unknown in the table. ErrUnknownIndexColumns = errors.NewKind("unknown columns to index for table %q: %s") + // ErrCaseResultType is returned when one or more of the types of the values in + // a case expression don't match. + ErrCaseResultType = errors.NewKind( + "expecting all case branches to return values of type %s, " + + "but found value %q of type %s on %s", + ) + // ErrIntervalInvalidUse is returned when an interval expression is not + // correctly used. + ErrIntervalInvalidUse = errors.NewKind( + "invalid use of an interval, which can only be used with DATE_ADD, " + + "DATE_SUB and +/- operators to subtract from or add to a date", + ) + // ErrExplodeInvalidUse is returned when an EXPLODE function is used + // outside a Project node. + ErrExplodeInvalidUse = errors.NewKind( + "using EXPLODE is not supported outside a Project node", + ) + + // ErrSubqueryColumns is returned when an expression subquery returns + // more than a single column. + ErrSubqueryColumns = errors.NewKind( + "subquery expressions can only return a single column", + ) ) // DefaultValidationRules to apply while analyzing nodes. -var DefaultValidationRules = []ValidationRule{ +var DefaultValidationRules = []Rule{ {validateResolvedRule, validateIsResolved}, {validateOrderByRule, validateOrderBy}, {validateGroupByRule, validateGroupBy}, {validateSchemaSourceRule, validateSchemaSource}, {validateProjectTuplesRule, validateProjectTuples}, {validateIndexCreationRule, validateIndexCreation}, + {validateCaseResultTypesRule, validateCaseResultTypes}, + {validateIntervalUsageRule, validateIntervalUsage}, + {validateExplodeUsageRule, validateExplodeUsage}, + {validateSubqueryColumnsRule, validateSubqueryColumns}, } -func validateIsResolved(ctx *sql.Context, n sql.Node) error { - span, ctx := ctx.Span("validate_is_resolved") +func validateIsResolved(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + span, _ := ctx.Span("validate_is_resolved") defer span.Finish() if !n.Resolved() { - return ErrValidationResolved.New(n) + return nil, ErrValidationResolved.New(n) } - return nil + return n, nil } -func validateOrderBy(ctx *sql.Context, n sql.Node) error { - span, ctx := ctx.Span("validate_order_by") +func validateOrderBy(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + span, _ := ctx.Span("validate_order_by") defer span.Finish() switch n := n.(type) { @@ -68,16 +100,16 @@ func validateOrderBy(ctx *sql.Context, n sql.Node) error { for _, field := range n.SortFields { switch field.Column.(type) { case sql.Aggregation: - return ErrValidationOrderBy.New() + return nil, ErrValidationOrderBy.New() } } } - return nil + return n, nil } -func validateGroupBy(ctx *sql.Context, n sql.Node) error { - span, ctx := ctx.Span("validate_order_by") +func validateGroupBy(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + span, _ := ctx.Span("validate_group_by") defer span.Finish() switch n := n.(type) { @@ -85,7 +117,7 @@ func validateGroupBy(ctx *sql.Context, n sql.Node) error { // Allow the parser use the GroupBy node to eval the aggregation functions // for sql statementes that don't make use of the GROUP BY expression. if len(n.Grouping) == 0 { - return nil + return n, nil } var validAggs []string @@ -99,15 +131,15 @@ func validateGroupBy(ctx *sql.Context, n sql.Node) error { for _, expr := range n.Aggregate { if _, ok := expr.(sql.Aggregation); !ok { if !isValidAgg(validAggs, expr) { - return ErrValidationGroupBy.New(expr.String()) + return nil, ErrValidationGroupBy.New(expr.String()) } } } - return nil + return n, nil } - return nil + return n, nil } func isValidAgg(validAggs []string, expr sql.Expression) bool { @@ -121,29 +153,29 @@ func isValidAgg(validAggs []string, expr sql.Expression) bool { } } -func validateSchemaSource(ctx *sql.Context, n sql.Node) error { - span, ctx := ctx.Span("validate_schema_source") +func validateSchemaSource(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + span, _ := ctx.Span("validate_schema_source") defer span.Finish() switch n := n.(type) { case *plan.TableAlias: // table aliases should not be validated - if child, ok := n.Child.(sql.Table); ok { - return validateSchema(child) + if child, ok := n.Child.(*plan.ResolvedTable); ok { + return n, validateSchema(child) } - case sql.Table: - return validateSchema(n) + case *plan.ResolvedTable: + return n, validateSchema(n) } - return nil + return n, nil } -func validateIndexCreation(ctx *sql.Context, n sql.Node) error { - span, ctx := ctx.Span("validate_index_creation") +func validateIndexCreation(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + span, _ := ctx.Span("validate_index_creation") defer span.Finish() ci, ok := n.(*plan.CreateIndex) if !ok { - return nil + return n, nil } schema := ci.Table.Schema() @@ -154,7 +186,7 @@ func validateIndexCreation(ctx *sql.Context, n sql.Node) error { expression.Inspect(expr, func(e sql.Expression) bool { gf, ok := e.(*expression.GetField) if ok { - if gf.Table() != table || !schema.Contains(gf.Name()) { + if gf.Table() != table || !schema.Contains(gf.Name(), gf.Table()) { unknownColumns = append(unknownColumns, gf.Name()) } } @@ -163,41 +195,158 @@ func validateIndexCreation(ctx *sql.Context, n sql.Node) error { } if len(unknownColumns) > 0 { - return ErrUnknownIndexColumns.New(table, strings.Join(unknownColumns, ", ")) + return nil, ErrUnknownIndexColumns.New(table, strings.Join(unknownColumns, ", ")) } - return nil + return n, nil } -func validateSchema(t sql.Table) error { - name := t.Name() +func validateSchema(t *plan.ResolvedTable) error { for _, col := range t.Schema() { - if col.Source != name { - return ErrValidationSchemaSource.New(name, col.Source) + if col.Source == "" { + return ErrValidationSchemaSource.New() } } return nil } -func validateProjectTuples(ctx *sql.Context, n sql.Node) error { - span, ctx := ctx.Span("validate_project_tuples") - defer span.Finish() +func findProjectTuples(n sql.Node) (sql.Node, error) { + if n == nil { + return n, nil + } switch n := n.(type) { - case *plan.Project: - for i, e := range n.Projections { + case *plan.Project, *plan.GroupBy: + for i, e := range n.(sql.Expressioner).Expressions() { if sql.IsTuple(e.Type()) { - return ErrProjectTuple.New(i+1, sql.NumColumns(e.Type())) + return nil, ErrProjectTuple.New(i+1, sql.NumColumns(e.Type())) } } - case *plan.GroupBy: - for i, e := range n.Aggregate { - if sql.IsTuple(e.Type()) { - return ErrProjectTuple.New(i+1, sql.NumColumns(e.Type())) + default: + for _, ch := range n.Children() { + _, err := findProjectTuples(ch) + if err != nil { + return nil, err } } } - return nil + + return n, nil +} + +func validateProjectTuples(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + span, _ := ctx.Span("validate_project_tuples") + defer span.Finish() + return findProjectTuples(n) +} + +func validateCaseResultTypes(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + span, ctx := ctx.Span("validate_case_result_types") + defer span.Finish() + + var err error + plan.InspectExpressions(n, func(e sql.Expression) bool { + switch e := e.(type) { + case *expression.Case: + typ := e.Type() + for _, b := range e.Branches { + if b.Value.Type() != typ && b.Value.Type() != sql.Null { + err = ErrCaseResultType.New(typ, b.Value, b.Value.Type(), e) + return false + } + } + + if e.Else != nil { + if e.Else.Type() != typ && e.Else.Type() != sql.Null { + err = ErrCaseResultType.New(typ, e.Else, e.Else.Type(), e) + return false + } + } + + return false + default: + return true + } + }) + + if err != nil { + return nil, err + } + + return n, nil +} + +func validateIntervalUsage(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + var invalid bool + plan.InspectExpressions(n, func(e sql.Expression) bool { + // If it's already invalid just skip everything else. + if invalid { + return false + } + + switch e := e.(type) { + case *function.DateAdd, *function.DateSub: + return false + case *expression.Arithmetic: + if e.Op == "+" || e.Op == "-" { + return false + } + case *expression.Interval: + invalid = true + } + + return true + }) + + if invalid { + return nil, ErrIntervalInvalidUse.New() + } + + return n, nil +} + +func validateExplodeUsage(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + var invalid bool + plan.InspectExpressions(n, func(e sql.Expression) bool { + // If it's already invalid just skip everything else. + if invalid { + return false + } + + // All usage of Explode will be incorrect because the ones in projects + // would have already been converted to Generate, so we only have to + // look for those. + if _, ok := e.(*function.Explode); ok { + invalid = true + } + + return true + }) + + if invalid { + return nil, ErrExplodeInvalidUse.New() + } + + return n, nil +} + +func validateSubqueryColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + valid := true + plan.InspectExpressions(n, func(e sql.Expression) bool { + s, ok := e.(*expression.Subquery) + if ok && len(s.Query.Schema()) != 1 { + valid = false + return false + } + + return true + }) + + if !valid { + return nil, ErrSubqueryColumns.New() + } + + return n, nil } func stringContains(strs []string, target string) bool { diff --git a/sql/analyzer/validation_rules_test.go b/sql/analyzer/validation_rules_test.go index 054e27e98..543fbc688 100644 --- a/sql/analyzer/validation_rules_test.go +++ b/sql/analyzer/validation_rules_test.go @@ -3,11 +3,12 @@ package analyzer import ( "testing" - "gopkg.in/src-d/go-mysql-server.v0/mem" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression/function/aggregation" - "gopkg.in/src-d/go-mysql-server.v0/sql/plan" + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/expression/function" + "github.com/src-d/go-mysql-server/sql/expression/function/aggregation" + "github.com/src-d/go-mysql-server/sql/plan" "github.com/stretchr/testify/require" ) @@ -17,10 +18,10 @@ func TestValidateResolved(t *testing.T) { vr := getValidationRule(validateResolvedRule) - err := vr.Apply(sql.NewEmptyContext(), dummyNode{true}) + _, err := vr.Apply(sql.NewEmptyContext(), nil, dummyNode{true}) require.NoError(err) - err = vr.Apply(sql.NewEmptyContext(), dummyNode{false}) + _, err = vr.Apply(sql.NewEmptyContext(), nil, dummyNode{false}) require.Error(err) } @@ -29,12 +30,12 @@ func TestValidateOrderBy(t *testing.T) { vr := getValidationRule(validateOrderByRule) - err := vr.Apply(sql.NewEmptyContext(), dummyNode{true}) + _, err := vr.Apply(sql.NewEmptyContext(), nil, dummyNode{true}) require.NoError(err) - err = vr.Apply(sql.NewEmptyContext(), dummyNode{false}) + _, err = vr.Apply(sql.NewEmptyContext(), nil, dummyNode{false}) require.NoError(err) - err = vr.Apply(sql.NewEmptyContext(), plan.NewSort( + _, err = vr.Apply(sql.NewEmptyContext(), nil, plan.NewSort( []plan.SortField{{Column: aggregation.NewCount(nil), Order: plan.Descending}}, nil, )) @@ -46,9 +47,9 @@ func TestValidateGroupBy(t *testing.T) { vr := getValidationRule(validateGroupByRule) - err := vr.Apply(sql.NewEmptyContext(), dummyNode{true}) + _, err := vr.Apply(sql.NewEmptyContext(), nil, dummyNode{true}) require.NoError(err) - err = vr.Apply(sql.NewEmptyContext(), dummyNode{false}) + _, err = vr.Apply(sql.NewEmptyContext(), nil, dummyNode{false}) require.NoError(err) childSchema := sql.Schema{ @@ -56,12 +57,19 @@ func TestValidateGroupBy(t *testing.T) { {Name: "col2", Type: sql.Int64}, } - child := mem.NewTable("test", childSchema) - child.Insert(sql.NewRow("col1_1", int64(1111))) - child.Insert(sql.NewRow("col1_1", int64(2222))) - child.Insert(sql.NewRow("col1_2", int64(4444))) - child.Insert(sql.NewRow("col1_1", int64(1111))) - child.Insert(sql.NewRow("col1_2", int64(4444))) + child := memory.NewTable("test", childSchema) + + rows := []sql.Row{ + sql.NewRow("col1_1", int64(1111)), + sql.NewRow("col1_1", int64(2222)), + sql.NewRow("col1_2", int64(4444)), + sql.NewRow("col1_1", int64(1111)), + sql.NewRow("col1_2", int64(4444)), + } + + for _, r := range rows { + require.NoError(child.Insert(sql.NewEmptyContext(), r)) + } p := plan.NewGroupBy( []sql.Expression{ @@ -72,10 +80,10 @@ func TestValidateGroupBy(t *testing.T) { []sql.Expression{ expression.NewGetField(0, sql.Text, "col1", true), }, - child, + plan.NewResolvedTable(child), ) - err = vr.Apply(sql.NewEmptyContext(), p) + _, err = vr.Apply(sql.NewEmptyContext(), nil, p) require.NoError(err) } @@ -84,9 +92,9 @@ func TestValidateGroupByErr(t *testing.T) { vr := getValidationRule(validateGroupByRule) - err := vr.Apply(sql.NewEmptyContext(), dummyNode{true}) + _, err := vr.Apply(sql.NewEmptyContext(), nil, dummyNode{true}) require.NoError(err) - err = vr.Apply(sql.NewEmptyContext(), dummyNode{false}) + _, err = vr.Apply(sql.NewEmptyContext(), nil, dummyNode{false}) require.NoError(err) childSchema := sql.Schema{ @@ -94,12 +102,19 @@ func TestValidateGroupByErr(t *testing.T) { {Name: "col2", Type: sql.Int64}, } - child := mem.NewTable("test", childSchema) - child.Insert(sql.NewRow("col1_1", int64(1111))) - child.Insert(sql.NewRow("col1_1", int64(2222))) - child.Insert(sql.NewRow("col1_2", int64(4444))) - child.Insert(sql.NewRow("col1_1", int64(1111))) - child.Insert(sql.NewRow("col1_2", int64(4444))) + child := memory.NewTable("test", childSchema) + + rows := []sql.Row{ + sql.NewRow("col1_1", int64(1111)), + sql.NewRow("col1_1", int64(2222)), + sql.NewRow("col1_2", int64(4444)), + sql.NewRow("col1_1", int64(1111)), + sql.NewRow("col1_2", int64(4444)), + } + + for _, r := range rows { + require.NoError(child.Insert(sql.NewEmptyContext(), r)) + } p := plan.NewGroupBy( []sql.Expression{ @@ -109,10 +124,10 @@ func TestValidateGroupByErr(t *testing.T) { []sql.Expression{ expression.NewGetField(0, sql.Text, "col1", true), }, - child, + plan.NewResolvedTable(child), ) - err = vr.Apply(sql.NewEmptyContext(), p) + _, err = vr.Apply(sql.NewEmptyContext(), nil, p) require.Error(err) } @@ -129,25 +144,37 @@ func TestValidateSchemaSource(t *testing.T) { }, { "table with valid schema", - mem.NewTable("mytable", sql.Schema{ - {Name: "foo", Source: "mytable"}, - {Name: "bar", Source: "mytable"}, - }), + plan.NewResolvedTable( + memory.NewTable( + "mytable", + sql.Schema{ + {Name: "foo", Source: "mytable"}, + {Name: "bar", Source: "mytable"}, + }, + ), + ), true, }, { "table with invalid schema", - mem.NewTable("mytable", sql.Schema{ - {Name: "foo", Source: "mytable"}, - {Name: "bar", Source: "something"}, - }), + plan.NewResolvedTable( + memory.NewTable( + "mytable", + sql.Schema{ + {Name: "foo", Source: ""}, + {Name: "bar", Source: "something"}, + }, + ), + ), false, }, { "table alias with table", - plan.NewTableAlias("foo", mem.NewTable("mytable", sql.Schema{ - {Name: "foo", Source: "mytable"}, - })), + plan.NewTableAlias("foo", plan.NewResolvedTable( + memory.NewTable("mytable", sql.Schema{ + {Name: "foo", Source: "mytable"}, + }), + )), true, }, { @@ -170,7 +197,7 @@ func TestValidateSchemaSource(t *testing.T) { for _, tt := range testCases { t.Run(tt.name, func(t *testing.T) { require := require.New(t) - err := rule.Apply(sql.NewEmptyContext(), tt.node) + _, err := rule.Apply(sql.NewEmptyContext(), nil, tt.node) if tt.ok { require.NoError(err) } else { @@ -208,11 +235,38 @@ func TestValidateProjectTuples(t *testing.T) { plan.NewProject([]sql.Expression{ expression.NewTuple( expression.NewLiteral(1, sql.Int64), - expression.NewLiteral(1, sql.Int64), + expression.NewLiteral(2, sql.Int64), ), }, nil), false, }, + { + "distinct with a 2 elem tuple inside the project", + plan.NewDistinct( + plan.NewProject([]sql.Expression{ + expression.NewTuple( + expression.NewLiteral(1, sql.Int64), + expression.NewLiteral(2, sql.Int64), + ), + }, nil)), + false, + }, + { + "alias with a tuple", + plan.NewProject( + []sql.Expression{ + expression.NewAlias( + expression.NewTuple( + expression.NewLiteral(1, sql.Int64), + expression.NewLiteral(2, sql.Int64), + ), + "foo", + ), + }, + plan.NewUnresolvedTable("dual", ""), + ), + false, + }, { "groupby with no tuple", plan.NewGroupBy([]sql.Expression{ @@ -245,7 +299,7 @@ func TestValidateProjectTuples(t *testing.T) { for _, tt := range testCases { t.Run(tt.name, func(t *testing.T) { require := require.New(t) - err := rule.Apply(sql.NewEmptyContext(), tt.node) + _, err := rule.Apply(sql.NewEmptyContext(), nil, tt.node) if tt.ok { require.NoError(err) } else { @@ -257,7 +311,7 @@ func TestValidateProjectTuples(t *testing.T) { } func TestValidateIndexCreation(t *testing.T) { - table := mem.NewTable("foo", sql.Schema{ + table := memory.NewTable("foo", sql.Schema{ {Name: "a", Source: "foo"}, {Name: "b", Source: "foo"}, }) @@ -270,7 +324,7 @@ func TestValidateIndexCreation(t *testing.T) { { "columns from another table", plan.NewCreateIndex( - "idx", table, + "idx", plan.NewResolvedTable(table), []sql.Expression{expression.NewEquals( expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), expression.NewGetFieldWithTable(0, sql.Int64, "bar", "b", false), @@ -283,7 +337,7 @@ func TestValidateIndexCreation(t *testing.T) { { "columns that don't exist", plan.NewCreateIndex( - "idx", table, + "idx", plan.NewResolvedTable(table), []sql.Expression{expression.NewEquals( expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), expression.NewGetFieldWithTable(0, sql.Int64, "foo", "c", false), @@ -296,7 +350,7 @@ func TestValidateIndexCreation(t *testing.T) { { "columns only from table", plan.NewCreateIndex( - "idx", table, + "idx", plan.NewResolvedTable(table), []sql.Expression{expression.NewEquals( expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), expression.NewGetFieldWithTable(0, sql.Int64, "foo", "b", false), @@ -312,7 +366,7 @@ func TestValidateIndexCreation(t *testing.T) { for _, tt := range testCases { t.Run(tt.name, func(t *testing.T) { require := require.New(t) - err := rule.Apply(sql.NewEmptyContext(), tt.node) + _, err := rule.Apply(sql.NewEmptyContext(), nil, tt.node) if tt.ok { require.NoError(err) } else { @@ -323,21 +377,344 @@ func TestValidateIndexCreation(t *testing.T) { } } -type dummyNode struct{ resolved bool } +func TestValidateCaseResultTypes(t *testing.T) { + rule := getValidationRule(validateCaseResultTypesRule) + + testCases := []struct { + name string + expr *expression.Case + ok bool + }{ + { + "one of the branches does not match", + expression.NewCase( + expression.NewGetField(0, sql.Int64, "foo", false), + []expression.CaseBranch{ + { + Cond: expression.NewLiteral(int64(1), sql.Int64), + Value: expression.NewLiteral("foo", sql.Text), + }, + { + Cond: expression.NewLiteral(int64(2), sql.Int64), + Value: expression.NewLiteral(int64(1), sql.Int64), + }, + }, + expression.NewLiteral("foo", sql.Text), + ), + false, + }, + { + "else does not match", + expression.NewCase( + expression.NewGetField(0, sql.Int64, "foo", false), + []expression.CaseBranch{ + { + Cond: expression.NewLiteral(int64(1), sql.Int64), + Value: expression.NewLiteral("foo", sql.Text), + }, + { + Cond: expression.NewLiteral(int64(2), sql.Int64), + Value: expression.NewLiteral("bar", sql.Text), + }, + }, + expression.NewLiteral(int64(1), sql.Int64), + ), + false, + }, + { + "all ok", + expression.NewCase( + expression.NewGetField(0, sql.Int64, "foo", false), + []expression.CaseBranch{ + { + Cond: expression.NewLiteral(int64(1), sql.Int64), + Value: expression.NewLiteral("foo", sql.Text), + }, + { + Cond: expression.NewLiteral(int64(2), sql.Int64), + Value: expression.NewLiteral("bar", sql.Text), + }, + }, + expression.NewLiteral("baz", sql.Text), + ), + true, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + _, err := rule.Apply(sql.NewEmptyContext(), nil, plan.NewProject( + []sql.Expression{tt.expr}, + plan.NewResolvedTable(dualTable), + )) + + if tt.ok { + require.NoError(err) + } else { + require.Error(err) + require.True(ErrCaseResultType.Is(err)) + } + }) + } +} + +func mustFunc(e sql.Expression, err error) sql.Expression { + if err != nil { + panic(err) + } + return e +} + +func TestValidateIntervalUsage(t *testing.T) { + testCases := []struct { + name string + node sql.Node + ok bool + }{ + { + "date add", + plan.NewProject( + []sql.Expression{ + mustFunc(function.NewDateAdd( + expression.NewLiteral("2018-05-01", sql.Text), + expression.NewInterval( + expression.NewLiteral(int64(1), sql.Int64), + "DAY", + ), + )), + }, + plan.NewUnresolvedTable("dual", ""), + ), + true, + }, + { + "date sub", + plan.NewProject( + []sql.Expression{ + mustFunc(function.NewDateSub( + expression.NewLiteral("2018-05-01", sql.Text), + expression.NewInterval( + expression.NewLiteral(int64(1), sql.Int64), + "DAY", + ), + )), + }, + plan.NewUnresolvedTable("dual", ""), + ), + true, + }, + { + "+ op", + plan.NewProject( + []sql.Expression{ + expression.NewPlus( + expression.NewLiteral("2018-05-01", sql.Text), + expression.NewInterval( + expression.NewLiteral(int64(1), sql.Int64), + "DAY", + ), + ), + }, + plan.NewUnresolvedTable("dual", ""), + ), + true, + }, + { + "- op", + plan.NewProject( + []sql.Expression{ + expression.NewMinus( + expression.NewLiteral("2018-05-01", sql.Text), + expression.NewInterval( + expression.NewLiteral(int64(1), sql.Int64), + "DAY", + ), + ), + }, + plan.NewUnresolvedTable("dual", ""), + ), + true, + }, + { + "invalid", + plan.NewProject( + []sql.Expression{ + expression.NewInterval( + expression.NewLiteral(int64(1), sql.Int64), + "DAY", + ), + }, + plan.NewUnresolvedTable("dual", ""), + ), + false, + }, + { + "alias", + plan.NewProject( + []sql.Expression{ + expression.NewAlias( + expression.NewInterval( + expression.NewLiteral(int64(1), sql.Int64), + "DAY", + ), + "foo", + ), + }, + plan.NewUnresolvedTable("dual", ""), + ), + false, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + _, err := validateIntervalUsage(sql.NewEmptyContext(), nil, tt.node) + if tt.ok { + require.NoError(err) + } else { + require.Error(err) + require.True(ErrIntervalInvalidUse.Is(err)) + } + }) + } +} + +func TestValidateExplodeUsage(t *testing.T) { + testCases := []struct { + name string + node sql.Node + ok bool + }{ + { + "valid", + plan.NewGenerate( + plan.NewProject( + []sql.Expression{ + expression.NewAlias( + function.NewGenerate( + expression.NewGetField(0, sql.Array(sql.Int64), "f", false), + ), + "foo", + ), + }, + plan.NewUnresolvedTable("dual", ""), + ), + expression.NewGetField(0, sql.Array(sql.Int64), "foo", false), + ), + true, + }, + { + "where", + plan.NewFilter( + function.NewArrayLength( + function.NewExplode( + expression.NewGetField(0, sql.Array(sql.Int64), "foo", false), + ), + ), + plan.NewGenerate( + plan.NewProject( + []sql.Expression{ + expression.NewAlias( + function.NewGenerate( + expression.NewGetField(0, sql.Array(sql.Int64), "f", false), + ), + "foo", + ), + }, + plan.NewUnresolvedTable("dual", ""), + ), + expression.NewGetField(0, sql.Array(sql.Int64), "foo", false), + ), + ), + false, + }, + { + "group by", + plan.NewGenerate( + plan.NewGroupBy( + []sql.Expression{ + expression.NewAlias( + function.NewExplode( + expression.NewGetField(0, sql.Array(sql.Int64), "f", false), + ), + "foo", + ), + }, + []sql.Expression{ + expression.NewAlias( + function.NewExplode( + expression.NewGetField(0, sql.Array(sql.Int64), "f", false), + ), + "foo", + ), + }, + plan.NewUnresolvedTable("dual", ""), + ), + expression.NewGetField(0, sql.Array(sql.Int64), "foo", false), + ), + false, + }, + } -func (n dummyNode) String() string { return "dummynode" } -func (n dummyNode) Resolved() bool { return n.resolved } -func (dummyNode) Schema() sql.Schema { return sql.Schema{} } -func (dummyNode) Children() []sql.Node { return nil } -func (dummyNode) RowIter(*sql.Context) (sql.RowIter, error) { return nil, nil } -func (dummyNode) TransformUp(sql.TransformNodeFunc) (sql.Node, error) { return nil, nil } -func (dummyNode) TransformExpressionsUp( - sql.TransformExprFunc, -) (sql.Node, error) { - return nil, nil + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + _, err := validateExplodeUsage(sql.NewEmptyContext(), nil, tt.node) + if tt.ok { + require.NoError(err) + } else { + require.Error(err) + require.True(ErrExplodeInvalidUse.Is(err)) + } + }) + } } -func getValidationRule(name string) ValidationRule { +func TestValidateSubqueryColumns(t *testing.T) { + require := require.New(t) + ctx := sql.NewEmptyContext() + + node := plan.NewProject([]sql.Expression{ + expression.NewSubquery(plan.NewProject( + []sql.Expression{ + lit(1), + lit(2), + }, + dummyNode{true}, + )), + }, dummyNode{true}) + + _, err := validateSubqueryColumns(ctx, nil, node) + require.Error(err) + require.True(ErrSubqueryColumns.Is(err)) + + node = plan.NewProject([]sql.Expression{ + expression.NewSubquery(plan.NewProject( + []sql.Expression{ + lit(1), + }, + dummyNode{true}, + )), + }, dummyNode{true}) + + _, err = validateSubqueryColumns(ctx, nil, node) + require.NoError(err) +} + +type dummyNode struct{ resolved bool } + +func (n dummyNode) String() string { return "dummynode" } +func (n dummyNode) Resolved() bool { return n.resolved } +func (dummyNode) Schema() sql.Schema { return nil } +func (dummyNode) Children() []sql.Node { return nil } +func (dummyNode) RowIter(*sql.Context) (sql.RowIter, error) { return nil, nil } +func (dummyNode) WithChildren(...sql.Node) (sql.Node, error) { return nil, nil } + +func getValidationRule(name string) Rule { for _, rule := range DefaultValidationRules { if rule.Name == name { return rule diff --git a/sql/analyzer/warnings.go b/sql/analyzer/warnings.go new file mode 100644 index 000000000..a2b7dfbd7 --- /dev/null +++ b/sql/analyzer/warnings.go @@ -0,0 +1,27 @@ +package analyzer + +import ( + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/plan" +) + +func clearWarnings(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, error) { + children := node.Children() + if len(children) == 0 { + return node, nil + } + + switch ch := children[0].(type) { + case plan.ShowWarnings: + return node, nil + case *plan.Offset: + clearWarnings(ctx, a, ch) + return node, nil + case *plan.Limit: + clearWarnings(ctx, a, ch) + return node, nil + } + + ctx.ClearWarnings() + return node, nil +} diff --git a/sql/cache.go b/sql/cache.go new file mode 100644 index 000000000..01e28b858 --- /dev/null +++ b/sql/cache.go @@ -0,0 +1,131 @@ +package sql + +import ( + "fmt" + "hash/crc64" + "runtime" + + lru "github.com/hashicorp/golang-lru" + errors "gopkg.in/src-d/go-errors.v1" +) + +var table = crc64.MakeTable(crc64.ISO) + +// CacheKey returns a hash of the given value to be used as key in +// a cache. +func CacheKey(v interface{}) uint64 { + return crc64.Checksum([]byte(fmt.Sprintf("%#v", v)), table) +} + +// ErrKeyNotFound is returned when the key could not be found in the cache. +var ErrKeyNotFound = errors.NewKind("memory: key %d not found in cache") + +type lruCache struct { + memory Freeable + reporter Reporter + size int + cache *lru.Cache +} + +func newLRUCache(memory Freeable, r Reporter, size uint) *lruCache { + lru, _ := lru.New(int(size)) + return &lruCache{memory, r, int(size), lru} +} + +func (l *lruCache) Put(k uint64, v interface{}) error { + if releaseMemoryIfNeeded(l.reporter, l.Free, l.memory.Free) { + l.cache.Add(k, v) + } + return nil +} + +func (l *lruCache) Get(k uint64) (interface{}, error) { + v, ok := l.cache.Get(k) + if !ok { + return nil, ErrKeyNotFound.New(k) + } + + return v, nil +} + +func (l *lruCache) Free() { + l.cache, _ = lru.New(l.size) +} + +func (l *lruCache) Dispose() { + l.memory = nil + l.cache = nil +} + +type rowsCache struct { + memory Freeable + reporter Reporter + rows []Row +} + +func newRowsCache(memory Freeable, r Reporter) *rowsCache { + return &rowsCache{memory, r, nil} +} + +func (c *rowsCache) Add(row Row) error { + if !releaseMemoryIfNeeded(c.reporter, c.memory.Free) { + return ErrNoMemoryAvailable.New() + } + + c.rows = append(c.rows, row) + return nil +} + +func (c *rowsCache) Get() []Row { return c.rows } + +func (c *rowsCache) Dispose() { + c.memory = nil + c.rows = nil +} + +type historyCache struct { + memory Freeable + reporter Reporter + cache map[uint64]interface{} +} + +func newHistoryCache(memory Freeable, r Reporter) *historyCache { + return &historyCache{memory, r, make(map[uint64]interface{})} +} + +func (h *historyCache) Put(k uint64, v interface{}) error { + if !releaseMemoryIfNeeded(h.reporter, h.memory.Free) { + return ErrNoMemoryAvailable.New() + } + h.cache[k] = v + return nil +} + +func (h *historyCache) Get(k uint64) (interface{}, error) { + v, ok := h.cache[k] + if !ok { + return nil, ErrKeyNotFound.New(k) + } + return v, nil +} + +func (h *historyCache) Dispose() { + h.memory = nil + h.cache = nil +} + +// releasesMemoryIfNeeded releases memory if needed using the following steps +// until there is available memory. It returns whether or not there was +// available memory after all the steps. +func releaseMemoryIfNeeded(r Reporter, steps ...func()) bool { + for _, s := range steps { + if HasAvailableMemory(r) { + return true + } + + s() + runtime.GC() + } + + return HasAvailableMemory(r) +} diff --git a/sql/cache_test.go b/sql/cache_test.go new file mode 100644 index 000000000..7984f7986 --- /dev/null +++ b/sql/cache_test.go @@ -0,0 +1,169 @@ +package sql + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCacheKey(t *testing.T) { + k := CacheKey(1) + require.Equal(t, uint64(0x4320000000000000), k) +} + +func TestLRUCache(t *testing.T) { + t.Run("basic methods", func(t *testing.T) { + require := require.New(t) + + cache := newLRUCache(mockMemory{}, fixedReporter(5, 50), 10) + + require.NoError(cache.Put(1, "foo")) + v, err := cache.Get(1) + require.NoError(err) + require.Equal("foo", v) + + _, err = cache.Get(2) + require.Error(err) + require.True(ErrKeyNotFound.Is(err)) + + // Free the cache and check previous entry disappeared. + cache.Free() + + _, err = cache.Get(1) + require.Error(err) + require.True(ErrKeyNotFound.Is(err)) + + cache.Dispose() + require.Panics(func() { + _, _ = cache.Get(1) + }) + }) + + t.Run("no memory available", func(t *testing.T) { + require := require.New(t) + cache := newLRUCache(mockMemory{}, fixedReporter(51, 50), 5) + + require.NoError(cache.Put(1, "foo")) + _, err := cache.Get(1) + require.Error(err) + require.True(ErrKeyNotFound.Is(err)) + }) + + t.Run("free required to add entry", func(t *testing.T) { + require := require.New(t) + var freed bool + cache := newLRUCache( + mockMemory{func() { + freed = true + }}, + mockReporter{func() uint64 { + if freed { + return 0 + } + return 51 + }, 50}, + 5, + ) + require.NoError(cache.Put(1, "foo")) + v, err := cache.Get(1) + require.NoError(err) + require.Equal("foo", v) + require.True(freed) + }) +} + +func TestHistoryCache(t *testing.T) { + t.Run("basic methods", func(t *testing.T) { + require := require.New(t) + + cache := newHistoryCache(mockMemory{}, fixedReporter(5, 50)) + + require.NoError(cache.Put(1, "foo")) + v, err := cache.Get(1) + require.NoError(err) + require.Equal("foo", v) + + _, err = cache.Get(2) + require.Error(err) + require.True(ErrKeyNotFound.Is(err)) + + cache.Dispose() + require.Panics(func() { + _ = cache.Put(2, "foo") + }) + }) + + t.Run("no memory available", func(t *testing.T) { + require := require.New(t) + cache := newHistoryCache(mockMemory{}, fixedReporter(51, 50)) + + err := cache.Put(1, "foo") + require.Error(err) + require.True(ErrNoMemoryAvailable.Is(err)) + }) + + t.Run("free required to add entry", func(t *testing.T) { + require := require.New(t) + var freed bool + cache := newHistoryCache( + mockMemory{func() { + freed = true + }}, + mockReporter{func() uint64 { + if freed { + return 0 + } + return 51 + }, 50}, + ) + require.NoError(cache.Put(1, "foo")) + v, err := cache.Get(1) + require.NoError(err) + require.Equal("foo", v) + require.True(freed) + }) +} + +func TestRowsCache(t *testing.T) { + t.Run("basic methods", func(t *testing.T) { + require := require.New(t) + + cache := newRowsCache(mockMemory{}, fixedReporter(5, 50)) + + require.NoError(cache.Add(Row{1})) + require.Len(cache.Get(), 1) + + cache.Dispose() + require.Panics(func() { + _ = cache.Add(Row{2}) + }) + }) + + t.Run("no memory available", func(t *testing.T) { + require := require.New(t) + cache := newRowsCache(mockMemory{}, fixedReporter(51, 50)) + + err := cache.Add(Row{1, "foo"}) + require.Error(err) + require.True(ErrNoMemoryAvailable.Is(err)) + }) + + t.Run("free required to add entry", func(t *testing.T) { + require := require.New(t) + var freed bool + cache := newRowsCache( + mockMemory{func() { + freed = true + }}, + mockReporter{func() uint64 { + if freed { + return 0 + } + return 51 + }, 50}, + ) + require.NoError(cache.Add(Row{1, "foo"})) + require.Len(cache.Get(), 1) + require.True(freed) + }) +} diff --git a/sql/catalog.go b/sql/catalog.go index 7ae07245d..ee3a48ade 100644 --- a/sql/catalog.go +++ b/sql/catalog.go @@ -1,6 +1,12 @@ package sql import ( + "fmt" + "strings" + "sync" + + "github.com/src-d/go-mysql-server/internal/similartext" + "gopkg.in/src-d/go-errors.v1" ) @@ -9,18 +15,81 @@ var ErrDatabaseNotFound = errors.NewKind("database not found: %s") // Catalog holds databases, tables and functions. type Catalog struct { - Databases FunctionRegistry *IndexRegistry + *ProcessList + *MemoryManager + + mu sync.RWMutex + currentDatabase string + dbs Databases + locks sessionLocks } +type ( + sessionLocks map[uint32]dbLocks + dbLocks map[string]tableLocks + tableLocks map[string]struct{} +) + // NewCatalog returns a new empty Catalog. func NewCatalog() *Catalog { return &Catalog{ - Databases: Databases{}, FunctionRegistry: NewFunctionRegistry(), IndexRegistry: NewIndexRegistry(), + MemoryManager: NewMemoryManager(ProcessMemory), + ProcessList: NewProcessList(), + locks: make(sessionLocks), + } +} + +// CurrentDatabase returns the current database. +func (c *Catalog) CurrentDatabase() string { + c.mu.RLock() + defer c.mu.RUnlock() + return c.currentDatabase +} + +// SetCurrentDatabase changes the current database. +func (c *Catalog) SetCurrentDatabase(db string) { + c.mu.Lock() + c.currentDatabase = db + c.mu.Unlock() +} + +// AllDatabases returns all databases in the catalog. +func (c *Catalog) AllDatabases() Databases { + c.mu.RLock() + defer c.mu.RUnlock() + + var result = make(Databases, len(c.dbs)) + copy(result, c.dbs) + return result +} + +// AddDatabase adds a new database to the catalog. +func (c *Catalog) AddDatabase(db Database) { + c.mu.Lock() + if c.currentDatabase == "" { + c.currentDatabase = db.Name() } + + c.dbs.Add(db) + c.mu.Unlock() +} + +// Database returns the database with the given name. +func (c *Catalog) Database(db string) (Database, error) { + c.mu.RLock() + defer c.mu.RUnlock() + return c.dbs.Database(db) +} + +// Table returns the table in the given database with the given name. +func (c *Catalog) Table(db, table string) (Table, error) { + c.mu.RLock() + defer c.mu.RUnlock() + return c.dbs.Table(db, table) } // Databases is a collection of Database. @@ -28,13 +97,26 @@ type Databases []Database // Database returns the Database with the given name if it exists. func (d Databases) Database(name string) (Database, error) { + + if len(d) == 0 { + return nil, ErrDatabaseNotFound.New(name) + } + + name = strings.ToLower(name) + var dbNames []string for _, db := range d { - if db.Name() == name { + if strings.ToLower(db.Name()) == name { return db, nil } + dbNames = append(dbNames, db.Name()) } + similar := similartext.Find(dbNames, name) + return nil, ErrDatabaseNotFound.New(name + similar) +} - return nil, ErrDatabaseNotFound.New(name) +// Add adds a new database. +func (d *Databases) Add(db Database) { + *d = append(*d, db) } // Table returns the Table with the given name if it exists. @@ -44,11 +126,74 @@ func (d Databases) Table(dbName string, tableName string) (Table, error) { return nil, err } + tableName = strings.ToLower(tableName) + tables := db.Tables() - table, found := tables[tableName] - if !found { + if len(tables) == 0 { return nil, ErrTableNotFound.New(tableName) } + // Try to get the table by key, but if the name is not the same, + // then use the slow path and iterate over all tables comparing + // the name. + table, ok := tables[tableName] + if !ok { + for name, table := range tables { + if strings.ToLower(name) == tableName { + return table, nil + } + } + + similar := similartext.FindFromMap(tables, tableName) + return nil, ErrTableNotFound.New(tableName + similar) + } + return table, nil } + +// LockTable adds a lock for the given table and session client. It is assumed +// the database is the current database in use. +func (c *Catalog) LockTable(id uint32, table string) { + db := c.CurrentDatabase() + c.mu.Lock() + defer c.mu.Unlock() + if _, ok := c.locks[id]; !ok { + c.locks[id] = make(dbLocks) + } + + if _, ok := c.locks[id][db]; !ok { + c.locks[id][db] = make(tableLocks) + } + + c.locks[id][db][table] = struct{}{} +} + +// UnlockTables unlocks all tables for which the given session client has a +// lock. +func (c *Catalog) UnlockTables(ctx *Context, id uint32) error { + c.mu.Lock() + defer c.mu.Unlock() + + var errors []string + for db, tables := range c.locks[id] { + for t := range tables { + table, err := c.dbs.Table(db, t) + if err == nil { + if lockable, ok := table.(Lockable); ok { + if e := lockable.Unlock(ctx, id); e != nil { + errors = append(errors, e.Error()) + } + } + } else { + errors = append(errors, err.Error()) + } + } + } + + delete(c.locks, id) + if len(errors) > 0 { + return fmt.Errorf("error unlocking tables for %d: %s", id, strings.Join(errors, ", ")) + } + + return nil +} diff --git a/sql/catalog_locks_test.go b/sql/catalog_locks_test.go new file mode 100644 index 000000000..bdd40cd99 --- /dev/null +++ b/sql/catalog_locks_test.go @@ -0,0 +1,37 @@ +package sql + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCatalogLockTable(t *testing.T) { + require := require.New(t) + c := NewCatalog() + c.SetCurrentDatabase("db1") + c.LockTable(1, "foo") + c.LockTable(2, "bar") + c.LockTable(1, "baz") + c.SetCurrentDatabase("db2") + c.LockTable(1, "qux") + + expected := sessionLocks{ + 1: dbLocks{ + "db1": tableLocks{ + "foo": struct{}{}, + "baz": struct{}{}, + }, + "db2": tableLocks{ + "qux": struct{}{}, + }, + }, + 2: dbLocks{ + "db1": tableLocks{ + "bar": struct{}{}, + }, + }, + } + + require.Equal(expected, c.locks) +} diff --git a/sql/catalog_test.go b/sql/catalog_test.go index 6e770782e..3aef282e4 100644 --- a/sql/catalog_test.go +++ b/sql/catalog_test.go @@ -3,12 +3,42 @@ package sql_test import ( "testing" + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/mem" - "gopkg.in/src-d/go-mysql-server.v0/sql" ) -func TestCatalog_Database(t *testing.T) { +func TestCatalogCurrentDatabase(t *testing.T) { + require := require.New(t) + + c := sql.NewCatalog() + require.Equal("", c.CurrentDatabase()) + + c.AddDatabase(memory.NewDatabase("foo")) + require.Equal("foo", c.CurrentDatabase()) + + c.SetCurrentDatabase("bar") + require.Equal("bar", c.CurrentDatabase()) +} + +func TestAllDatabases(t *testing.T) { + require := require.New(t) + + var dbs = sql.Databases{ + memory.NewDatabase("a"), + memory.NewDatabase("b"), + memory.NewDatabase("c"), + } + + c := sql.NewCatalog() + for _, db := range dbs { + c.AddDatabase(db) + } + + require.Equal(dbs, c.AllDatabases()) +} + +func TestCatalogDatabase(t *testing.T) { require := require.New(t) c := sql.NewCatalog() @@ -16,15 +46,19 @@ func TestCatalog_Database(t *testing.T) { require.EqualError(err, "database not found: foo") require.Nil(db) - mydb := mem.NewDatabase("foo") - c.Databases = append(c.Databases, mydb) + mydb := memory.NewDatabase("foo") + c.AddDatabase(mydb) + + db, err = c.Database("flo") + require.EqualError(err, "database not found: flo, maybe you mean foo?") + require.Nil(db) db, err = c.Database("foo") require.NoError(err) require.Equal(mydb, db) } -func TestCatalog_Table(t *testing.T) { +func TestCatalogTable(t *testing.T) { require := require.New(t) c := sql.NewCatalog() @@ -33,17 +67,66 @@ func TestCatalog_Table(t *testing.T) { require.EqualError(err, "database not found: foo") require.Nil(table) - db := mem.NewDatabase("foo") - c.Databases = append(c.Databases, db) + db := memory.NewDatabase("foo") + c.AddDatabase(db) table, err = c.Table("foo", "bar") require.EqualError(err, "table not found: bar") require.Nil(table) - mytable := mem.NewTable("bar", sql.Schema{}) + mytable := memory.NewTable("bar", nil) db.AddTable("bar", mytable) + table, err = c.Table("foo", "baz") + require.EqualError(err, "table not found: baz, maybe you mean bar?") + require.Nil(table) + table, err = c.Table("foo", "bar") require.NoError(err) require.Equal(mytable, table) + + table, err = c.Table("foo", "BAR") + require.NoError(err) + require.Equal(mytable, table) +} + +func TestCatalogUnlockTables(t *testing.T) { + require := require.New(t) + + db := memory.NewDatabase("db") + t1 := newLockableTable(memory.NewTable("t1", nil)) + t2 := newLockableTable(memory.NewTable("t2", nil)) + db.AddTable("t1", t1) + db.AddTable("t2", t2) + + c := sql.NewCatalog() + c.AddDatabase(db) + + c.LockTable(1, "t1") + c.LockTable(1, "t2") + + require.NoError(c.UnlockTables(nil, 1)) + + require.Equal(1, t1.unlocks) + require.Equal(1, t2.unlocks) +} + +type lockableTable struct { + sql.Table + unlocks int +} + +func newLockableTable(t sql.Table) *lockableTable { + return &lockableTable{Table: t} +} + +var _ sql.Lockable = (*lockableTable)(nil) + +func (l *lockableTable) Lock(ctx *sql.Context, write bool) error { + return nil +} + +func (l *lockableTable) Unlock(ctx *sql.Context, id uint32) error { + l.unlocks++ + return nil } diff --git a/sql/core.go b/sql/core.go index 77615bfdc..16ef6fa8d 100644 --- a/sql/core.go +++ b/sql/core.go @@ -2,6 +2,10 @@ package sql import ( "fmt" + "io" + "math" + "strconv" + "time" "gopkg.in/src-d/go-errors.v1" ) @@ -21,6 +25,13 @@ var ( //ErrUnexpectedRowLength is thrown when the obtained row has more columns than the schema ErrUnexpectedRowLength = errors.NewKind("expected %d values, got %d") + + // ErrInvalidChildrenNumber is returned when the WithChildren method of a + // node or expression is called with an invalid number of arguments. + ErrInvalidChildrenNumber = errors.NewKind("%T: invalid children number, got %d, expected %d") + + // ErrDeleteRowNotFound + ErrDeleteRowNotFound = errors.NewKind("row was not found when attempting to delete").New() ) // Nameable is something that has a name. @@ -41,17 +52,6 @@ type Resolvable interface { Resolved() bool } -// Transformable is a node which can be transformed. -type Transformable interface { - // TransformUp transforms all nodes and returns the result of this transformation. - // Transformation is not propagated to subqueries. - TransformUp(TransformNodeFunc) (Node, error) - // TransformExpressionsUp transforms all expressions inside the node and all its - // children and returns a node with the result of the transformations. - // Transformation is not propagated to subqueries. - TransformExpressionsUp(TransformExprFunc) (Node, error) -} - // TransformNodeFunc is a function that given a node will return that node // as is or transformed along with an error, if any. type TransformNodeFunc func(Node) (Node, error) @@ -70,11 +70,13 @@ type Expression interface { IsNullable() bool // Eval evaluates the given row and returns a result. Eval(*Context, Row) (interface{}, error) - // TransformUp transforms the expression and all its children with the - // given transform function. - TransformUp(TransformExprFunc) (Expression, error) // Children returns the children expressions of this expression. Children() []Expression + // WithChildren returns a copy of the expression with children replaced. + // It will return an error if the number of children is different than + // the current number of children. They must be given in the same order + // as they are returned by Children. + WithChildren(...Expression) (Expression, error) } // Aggregation implements an aggregation expression, where an @@ -96,7 +98,6 @@ type Aggregation interface { // Node is a node in the execution plan tree. type Node interface { Resolvable - Transformable fmt.Stringer // Schema of the node. Schema() Schema @@ -104,60 +105,128 @@ type Node interface { Children() []Node // RowIter produces a row iterator from this node. RowIter(*Context) (RowIter, error) + // WithChildren returns a copy of the node with children replaced. + // It will return an error if the number of children is different than + // the current number of children. They must be given in the same order + // as they are returned by Children. + WithChildren(...Node) (Node, error) +} + +// OpaqueNode is a node that doesn't allow transformations to its children and +// acts a a black box. +type OpaqueNode interface { + Node + // Opaque reports whether the node is opaque or not. + Opaque() bool +} + +// AsyncNode is a node that can be executed asynchronously. +type AsyncNode interface { + // IsAsync reports whether the node is async or not. + IsAsync() bool } // Expressioner is a node that contains expressions. type Expressioner interface { // Expressions returns the list of expressions contained by the node. Expressions() []Expression + // WithExpressions returns a copy of the node with expressions replaced. + // It will return an error if the number of expressions is different than + // the current number of expressions. They must be given in the same order + // as they are returned by Expressions. + WithExpressions(...Expression) (Node, error) } -// Table represents a SQL table. +// Databaser is a node that contains a reference to a database. +type Databaser interface { + // Database the current database. + Database() Database + // WithDatabase returns a new node instance with the database replaced with + // the one given as parameter. + WithDatabase(Database) (Node, error) +} + +// Partition represents a partition from a SQL table. +type Partition interface { + Key() []byte +} + +// PartitionIter is an iterator that retrieves partitions. +type PartitionIter interface { + io.Closer + Next() (Partition, error) +} + +// Table represents the backend of a SQL table. type Table interface { Nameable - Node + String() string + Schema() Schema + Partitions(*Context) (PartitionIter, error) + PartitionRows(*Context, Partition) (RowIter, error) +} + +// TableWrapper is a node that wraps the real table. This is needed because +// wrappers cannot implement some methods the table may implement. +type TableWrapper interface { + // Underlying returns the underlying table. + Underlying() Table } -// Indexable represents a table that supports being indexed and receiving -// indexes to be able to speed up its execution. -type Indexable interface { - // IndexKeyValueIter returns an iterator with the values of each row in - // the table for the given column names. - IndexKeyValueIter(colNames []string) (IndexKeyValueIter, error) - // WithProjectFiltersAndIndex is meant to be called instead of RowIter - // method of the table. Returns a new iterator given the columns, - // filters and the index so the table can improve its speed instead of - // making a full scan. - WithProjectFiltersAndIndex(columns, filters []Expression, index IndexValueIter) (RowIter, error) +// PartitionCounter can return the number of partitions. +type PartitionCounter interface { + // PartitionCount returns the number of partitions. + PartitionCount(*Context) (int64, error) } -// PushdownProjectionTable is a table that can produce a specific RowIter +//FilteredTable is a table that can produce a specific RowIter +// that's more optimized given the filters. +type FilteredTable interface { + Table + HandledFilters(filters []Expression) []Expression + WithFilters(filters []Expression) Table + Filters() []Expression +} + +// ProjectedTable is a table that can produce a specific RowIter // that's more optimized given the columns that are projected. -type PushdownProjectionTable interface { +type ProjectedTable interface { Table - // WithProject replaces the RowIter method of the table and returns a new - // row iterator given the column names that are projected. - WithProject(ctx *Context, colNames []string) (RowIter, error) + WithProjection(colNames []string) Table + Projection() []string } -// PushdownProjectionAndFiltersTable is a table that can produce a specific -// RowIter that's more optimized given the columns that are projected and -// the filters for this table. -type PushdownProjectionAndFiltersTable interface { +// IndexableTable represents a table that supports being indexed and +// receiving indexes to be able to speed up its execution. +type IndexableTable interface { Table - // HandledFilters returns the subset of filters that can be handled by this - // table. - HandledFilters(filters []Expression) []Expression - // WithProjectAndFilters replaces the RowIter method of the table and - // return a new row iterator given the column names that are projected - // and the filters applied to this table. - WithProjectAndFilters(ctx *Context, columns, filters []Expression) (RowIter, error) + WithIndexLookup(IndexLookup) Table + IndexLookup() IndexLookup + IndexKeyValues(*Context, []string) (PartitionIndexKeyValueIter, error) } // Inserter allow rows to be inserted in them. type Inserter interface { // Insert the given row. - Insert(row Row) error + Insert(*Context, Row) error +} + +// Deleter allow rows to be deleted from tables. +type Deleter interface { + // Delete the given row. Returns ErrDeleteRowNotFound if the row was not found. + Delete(*Context, Row) error +} + +// Replacer allows rows to be replaced through a Delete (if applicable) then Insert. +type Replacer interface { + Deleter + Inserter +} + +// Updater allows rows to be updated. +type Updater interface { + // Update the given row. Provides both the old and new rows. + Update(ctx *Context, old Row, new Row) error } // Database represents the database. @@ -167,7 +236,75 @@ type Database interface { Tables() map[string]Table } -// Alterable should be implemented by databases that can handle DDL statements -type Alterable interface { - Create(name string, schema Schema) error +// TableCreator should be implemented by databases that can create new tables. +type TableCreator interface { + CreateTable(ctx *Context, name string, schema Schema) error +} + +// TableDropper should be implemented by databases that can drop tables. +type TableDropper interface { + DropTable(ctx *Context, name string) error +} + +// Lockable should be implemented by tables that can be locked and unlocked. +type Lockable interface { + Nameable + // Lock locks the table either for reads or writes. Any session clients can + // read while the table is locked for read, but not write. + // When the table is locked for write, nobody can write except for the + // session client that requested the lock. + Lock(ctx *Context, write bool) error + // Unlock releases the lock for the current session client. It blocks until + // all reads or writes started during the lock are finished. + // Context may be nil if the unlock it's because the connection was closed. + // The id will always be provided, since in some cases context is not + // available. + Unlock(ctx *Context, id uint32) error +} + +// EvaluateCondition evaluates a condition, which is an expression whose value +// will be coerced to boolean. +func EvaluateCondition(ctx *Context, cond Expression, row Row) (bool, error) { + v, err := cond.Eval(ctx, row) + if err != nil { + return false, err + } + + switch b := v.(type) { + case bool: + return b, nil + case int: + return b != int(0), nil + case int64: + return b != int64(0), nil + case int32: + return b != int32(0), nil + case int16: + return b != int16(0), nil + case int8: + return b != int8(0), nil + case uint: + return b != uint(0), nil + case uint64: + return b != uint64(0), nil + case uint32: + return b != uint32(0), nil + case uint16: + return b != uint16(0), nil + case uint8: + return b != uint8(0), nil + case time.Duration: + return int64(b) != 0, nil + case time.Time: + return b.UnixNano() != 0, nil + case float64: + return int(math.Round(v.(float64))) != 0, nil + case float32: + return int(math.Round(float64(v.(float32)))) != 0, nil + case string: + parsed, err := strconv.ParseFloat(v.(string), 64) + return err == nil && int(parsed) != 0, nil + default: + return false, nil + } } diff --git a/sql/core_test.go b/sql/core_test.go new file mode 100644 index 000000000..cf3e23acd --- /dev/null +++ b/sql/core_test.go @@ -0,0 +1,49 @@ +package sql_test + +import ( + "fmt" + "testing" + "time" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" +) + +var conditions = []struct { + evaluated bool + value interface{} + t sql.Type +}{ + {true, int16(1), sql.Int16}, + {false, int16(0), sql.Int16}, + {true, int32(1), sql.Int32}, + {false, int32(0), sql.Int32}, + {true, int(1), sql.Int64}, + {false, int(0), sql.Int64}, + {true, float32(1), sql.Float32}, + {true, float64(1), sql.Float64}, + {false, float32(0), sql.Float32}, + {false, float64(0), sql.Float64}, + {true, float32(0.5), sql.Float32}, + {true, float64(0.5), sql.Float64}, + {true, "1", sql.Text}, + {false, "0", sql.Text}, + {false, "foo", sql.Text}, + {false, "0.5", sql.Text}, + {false, time.Duration(0), sql.Timestamp}, + {true, time.Duration(1), sql.Timestamp}, + {false, false, sql.Boolean}, + {true, true, sql.Boolean}, +} + +func TestEvaluateCondition(t *testing.T) { + for _, v := range conditions { + t.Run(fmt.Sprint(v.value, " evaluated to ", v.evaluated, " type ", v.t), func(t *testing.T) { + require := require.New(t) + b, err := sql.EvaluateCondition(sql.NewEmptyContext(), expression.NewLiteral(v.value, v.t), sql.NewRow()) + require.NoError(err) + require.Equal(v.evaluated, b) + }) + } +} diff --git a/sql/expression/alias.go b/sql/expression/alias.go index ee382b6d3..c7485dfd9 100644 --- a/sql/expression/alias.go +++ b/sql/expression/alias.go @@ -3,7 +3,7 @@ package expression import ( "fmt" - "gopkg.in/src-d/go-mysql-server.v0/sql" + "github.com/src-d/go-mysql-server/sql" ) // Alias is a node that gives a name to an expression. @@ -31,13 +31,12 @@ func (e *Alias) String() string { return fmt.Sprintf("%s as %s", e.Child, e.name) } -// TransformUp implements the Expression interface. -func (e *Alias) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := e.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (e *Alias) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(e, len(children), 1) } - return f(NewAlias(child, e.name)) + return NewAlias(children[0], e.name), nil } // Name implements the Nameable interface. diff --git a/sql/expression/arithmetic.go b/sql/expression/arithmetic.go index 38214f755..d7044e06e 100644 --- a/sql/expression/arithmetic.go +++ b/sql/expression/arithmetic.go @@ -2,11 +2,13 @@ package expression import ( "fmt" + "reflect" + "time" errors "gopkg.in/src-d/go-errors.v1" - "gopkg.in/src-d/go-vitess.v0/vt/sqlparser" + "vitess.io/vitess/go/vt/sqlparser" - "gopkg.in/src-d/go-mysql-server.v0/sql" + "github.com/src-d/go-mysql-server/sql" ) var ( @@ -20,7 +22,7 @@ var ( // Arithmetic expressions (+, -, *, /, ...) type Arithmetic struct { BinaryExpression - op string + Op string } // NewArithmetic creates a new Arithmetic sql.Expression. @@ -84,13 +86,30 @@ func NewMod(left, right sql.Expression) *Arithmetic { } func (a *Arithmetic) String() string { - return fmt.Sprintf("%s %s %s", a.Left, a.op, a.Right) + return fmt.Sprintf("%s %s %s", a.Left, a.Op, a.Right) +} + +// IsNullable implements the sql.Expression interface. +func (a *Arithmetic) IsNullable() bool { + if a.Type() == sql.Timestamp { + return true + } + + return a.BinaryExpression.IsNullable() } // Type returns the greatest type for given operation. func (a *Arithmetic) Type() sql.Type { - switch a.op { + switch a.Op { case sqlparser.PlusStr, sqlparser.MinusStr, sqlparser.MultStr, sqlparser.DivStr: + if isInterval(a.Left) || isInterval(a.Right) { + return sql.Timestamp + } + + if sql.IsTime(a.Left.Type()) && sql.IsTime(a.Right.Type()) { + return sql.Int64 + } + if sql.IsInteger(a.Left.Type()) && sql.IsInteger(a.Right.Type()) { if sql.IsUnsigned(a.Left.Type()) && sql.IsUnsigned(a.Right.Type()) { return sql.Uint64 @@ -113,37 +132,36 @@ func (a *Arithmetic) Type() sql.Type { return sql.Float64 } -// TransformUp implements the Expression interface. -func (a *Arithmetic) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - l, err := a.Left.TransformUp(f) - if err != nil { - return nil, err - } +func isInterval(expr sql.Expression) bool { + _, ok := expr.(*Interval) + return ok +} - r, err := a.Right.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (a *Arithmetic) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(a, len(children), 2) } - - return f(NewArithmetic(l, r, a.op)) + return NewArithmetic(children[0], children[1], a.Op), nil } // Eval implements the Expression interface. func (a *Arithmetic) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - span, ctx := ctx.Span("expression.(" + a.op + ")") - defer span.Finish() - lval, rval, err := a.evalLeftRight(ctx, row) if err != nil { return nil, err } + if lval == nil || rval == nil { + return nil, nil + } + lval, rval, err = a.convertLeftRight(lval, rval) if err != nil { return nil, err } - switch a.op { + switch a.Op { case sqlparser.PlusStr: return plus(lval, rval) case sqlparser.MinusStr: @@ -168,37 +186,63 @@ func (a *Arithmetic) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return mod(lval, rval) } - return nil, errUnableToEval.New(lval, a.op, rval) + return nil, errUnableToEval.New(lval, a.Op, rval) } func (a *Arithmetic) evalLeftRight(ctx *sql.Context, row sql.Row) (interface{}, interface{}, error) { - lval, err := a.Left.Eval(ctx, row) - if err != nil { - return nil, nil, err + var lval, rval interface{} + var err error + + if i, ok := a.Left.(*Interval); ok { + lval, err = i.EvalDelta(ctx, row) + if err != nil { + return nil, nil, err + } + } else { + lval, err = a.Left.Eval(ctx, row) + if err != nil { + return nil, nil, err + } } - rval, err := a.Right.Eval(ctx, row) - if err != nil { - return nil, nil, err + if i, ok := a.Right.(*Interval); ok { + rval, err = i.EvalDelta(ctx, row) + if err != nil { + return nil, nil, err + } + } else { + rval, err = a.Right.Eval(ctx, row) + if err != nil { + return nil, nil, err + } } return lval, rval, nil } -func (a *Arithmetic) convertLeftRight(lval interface{}, rval interface{}) (interface{}, interface{}, error) { +func (a *Arithmetic) convertLeftRight(left interface{}, right interface{}) (interface{}, interface{}, error) { + var err error typ := a.Type() - lval64, err := typ.Convert(lval) - if err != nil { - return nil, nil, err + if i, ok := left.(*TimeDelta); ok { + left = i + } else { + left, err = typ.Convert(left) + if err != nil { + return nil, nil, err + } } - rval64, err := typ.Convert(rval) - if err != nil { - return nil, nil, err + if i, ok := right.(*TimeDelta); ok { + right = i + } else { + right, err = typ.Convert(right) + if err != nil { + return nil, nil, err + } } - return lval64, rval64, nil + return left, right, nil } func plus(lval, rval interface{}) (interface{}, error) { @@ -220,6 +264,18 @@ func plus(lval, rval interface{}) (interface{}, error) { case float64: return l + r, nil } + case time.Time: + switch r := rval.(type) { + case *TimeDelta: + return sql.ValidateTime(r.Add(l)), nil + case time.Time: + return l.Unix() + r.Unix(), nil + } + case *TimeDelta: + switch r := rval.(type) { + case time.Time: + return sql.ValidateTime(l.Add(r)), nil + } } return nil, errUnableToCast.New(lval, rval) @@ -244,6 +300,13 @@ func minus(lval, rval interface{}) (interface{}, error) { case float64: return l - r, nil } + case time.Time: + switch r := rval.(type) { + case *TimeDelta: + return sql.ValidateTime(r.Sub(l)), nil + case time.Time: + return l.Unix() - r.Unix(), nil + } } return nil, errUnableToCast.New(lval, rval) @@ -410,3 +473,79 @@ func mod(lval, rval interface{}) (interface{}, error) { return nil, errUnableToCast.New(lval, rval) } + +// UnaryMinus is an unary minus operator. +type UnaryMinus struct { + UnaryExpression +} + +// NewUnaryMinus creates a new UnaryMinus expression node. +func NewUnaryMinus(child sql.Expression) *UnaryMinus { + return &UnaryMinus{UnaryExpression{Child: child}} +} + +// Eval implements the sql.Expression interface. +func (e *UnaryMinus) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + child, err := e.Child.Eval(ctx, row) + if err != nil { + return nil, err + } + + if child == nil { + return nil, nil + } + + if !sql.IsNumber(e.Child.Type()) { + child, err = sql.Float64.Convert(child) + if err != nil { + child = 0.0 + } + } + + switch n := child.(type) { + case float64: + return -n, nil + case float32: + return -n, nil + case int64: + return -n, nil + case uint64: + return -int64(n), nil + case int32: + return -n, nil + case uint32: + return -int32(n), nil + default: + return nil, sql.ErrInvalidType.New(reflect.TypeOf(n)) + } +} + +// Type implements the sql.Expression interface. +func (e *UnaryMinus) Type() sql.Type { + typ := e.Child.Type() + if !sql.IsNumber(typ) { + return sql.Float64 + } + + if typ == sql.Uint32 { + return sql.Int32 + } + + if typ == sql.Uint64 { + return sql.Int64 + } + + return e.Child.Type() +} + +func (e *UnaryMinus) String() string { + return fmt.Sprintf("-%s", e.Child) +} + +// WithChildren implements the Expression interface. +func (e *UnaryMinus) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(e, len(children), 1) + } + return NewUnaryMinus(children[0]), nil +} diff --git a/sql/expression/arithmetic_test.go b/sql/expression/arithmetic_test.go index 02809687d..811039734 100644 --- a/sql/expression/arithmetic_test.go +++ b/sql/expression/arithmetic_test.go @@ -2,9 +2,10 @@ package expression import ( "testing" + "time" + "github.com/src-d/go-mysql-server/sql" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/sql" ) func TestPlus(t *testing.T) { @@ -38,6 +39,29 @@ func TestPlus(t *testing.T) { require.Equal(float64(5), result) } +func TestPlusInterval(t *testing.T) { + require := require.New(t) + + expected := time.Date(2018, time.May, 2, 0, 0, 0, 0, time.UTC) + op := NewPlus( + NewLiteral("2018-05-01", sql.Text), + NewInterval(NewLiteral(int64(1), sql.Int64), "DAY"), + ) + + result, err := op.Eval(sql.NewEmptyContext(), nil) + require.NoError(err) + require.Equal(expected, result) + + op = NewPlus( + NewInterval(NewLiteral(int64(1), sql.Int64), "DAY"), + NewLiteral("2018-05-01", sql.Text), + ) + + result, err = op.Eval(sql.NewEmptyContext(), nil) + require.NoError(err) + require.Equal(expected, result) +} + func TestMinus(t *testing.T) { var testCases = []struct { name string @@ -69,6 +93,20 @@ func TestMinus(t *testing.T) { require.Equal(float64(0), result) } +func TestMinusInterval(t *testing.T) { + require := require.New(t) + + expected := time.Date(2018, time.May, 1, 0, 0, 0, 0, time.UTC) + op := NewMinus( + NewLiteral("2018-05-02", sql.Text), + NewInterval(NewLiteral(int64(1), sql.Int64), "DAY"), + ) + + result, err := op.Eval(sql.NewEmptyContext(), nil) + require.NoError(err) + require.Equal(expected, result) +} + func TestMult(t *testing.T) { var testCases = []struct { name string @@ -360,3 +398,31 @@ func TestAllInt64(t *testing.T) { }) } } + +func TestUnaryMinus(t *testing.T) { + testCases := []struct { + name string + input interface{} + typ sql.Type + expected interface{} + }{ + {"int32", int32(1), sql.Int32, int32(-1)}, + {"uint32", uint32(1), sql.Uint32, int32(-1)}, + {"int64", int64(1), sql.Int64, int64(-1)}, + {"uint64", uint64(1), sql.Uint64, int64(-1)}, + {"float32", float32(1), sql.Float32, float32(-1)}, + {"float64", float64(1), sql.Float64, float64(-1)}, + {"int text", "1", sql.Text, float64(-1)}, + {"float text", "1.2", sql.Text, float64(-1.2)}, + {"nil", nil, sql.Text, nil}, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + f := NewUnaryMinus(NewLiteral(tt.input, tt.typ)) + result, err := f.Eval(sql.NewEmptyContext(), nil) + require.NoError(t, err) + require.Equal(t, tt.expected, result) + }) + } +} diff --git a/sql/expression/between.go b/sql/expression/between.go index aac2c1284..15114890b 100644 --- a/sql/expression/between.go +++ b/sql/expression/between.go @@ -3,7 +3,7 @@ package expression import ( "fmt" - "gopkg.in/src-d/go-mysql-server.v0/sql" + "github.com/src-d/go-mysql-server/sql" ) // Between checks a value is between two given values. @@ -19,7 +19,7 @@ func NewBetween(val, lower, upper sql.Expression) *Between { } func (b *Between) String() string { - return fmt.Sprintf("BETWEEN(%s, %s, %s)", b.Val, b.Lower, b.Upper) + return fmt.Sprintf("%s BETWEEN %s AND %s", b.Val, b.Lower, b.Upper) } // Children implements the Expression interface. @@ -42,9 +42,6 @@ func (b *Between) Resolved() bool { // Eval implements the Expression interface. func (b *Between) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - span, ctx := ctx.Span("expression.Between") - defer span.Finish() - typ := b.Val.Type() val, err := b.Val.Eval(ctx, row) if err != nil { @@ -101,22 +98,10 @@ func (b *Between) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return cmpLower >= 0 && cmpUpper <= 0, nil } -// TransformUp implements the Expression interface. -func (b *Between) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - val, err := b.Val.TransformUp(f) - if err != nil { - return nil, err - } - - lower, err := b.Lower.TransformUp(f) - if err != nil { - return nil, err - } - - upper, err := b.Upper.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (b *Between) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 3 { + return nil, sql.ErrInvalidChildrenNumber.New(b, len(children), 3) } - - return f(NewBetween(val, lower, upper)) + return NewBetween(children[0], children[1], children[2]), nil } diff --git a/sql/expression/between_test.go b/sql/expression/between_test.go index eb0ced088..58fd7bb1a 100644 --- a/sql/expression/between_test.go +++ b/sql/expression/between_test.go @@ -3,8 +3,8 @@ package expression import ( "testing" + "github.com/src-d/go-mysql-server/sql" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/sql" ) func TestBetween(t *testing.T) { diff --git a/sql/expression/boolean.go b/sql/expression/boolean.go index 726eb6c7b..73815fb7d 100644 --- a/sql/expression/boolean.go +++ b/sql/expression/boolean.go @@ -2,8 +2,9 @@ package expression import ( "fmt" + "reflect" - "gopkg.in/src-d/go-mysql-server.v0/sql" + "github.com/src-d/go-mysql-server/sql" ) // Not is a node that negates an expression. @@ -23,9 +24,6 @@ func (e *Not) Type() sql.Type { // Eval implements the Expression interface. func (e *Not) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - span, ctx := ctx.Span("expression.Not") - defer span.Finish() - v, err := e.Child.Eval(ctx, row) if err != nil { return nil, err @@ -35,18 +33,29 @@ func (e *Not) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, nil } - return !v.(bool), nil + b, ok := v.(bool) + if !ok { + v, _ = e.Type().Convert(v) + if v == nil { + return nil, nil + } + + if b, ok = v.(bool); !ok { + return nil, sql.ErrInvalidType.New(reflect.TypeOf(v).String()) + } + } + + return !b, nil } func (e *Not) String() string { return fmt.Sprintf("NOT(%s)", e.Child) } -// TransformUp implements the Expression interface. -func (e *Not) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := e.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (e *Not) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(e, len(children), 1) } - return f(NewNot(child)) + return NewNot(children[0]), nil } diff --git a/sql/expression/boolean_test.go b/sql/expression/boolean_test.go index cb69b0d2c..aff618485 100644 --- a/sql/expression/boolean_test.go +++ b/sql/expression/boolean_test.go @@ -2,9 +2,10 @@ package expression import ( "testing" + "time" + "github.com/src-d/go-mysql-server/sql" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/sql" ) func TestNot(t *testing.T) { @@ -14,4 +15,9 @@ func TestNot(t *testing.T) { require.False(eval(t, e, sql.NewRow(true)).(bool)) require.True(eval(t, e, sql.NewRow(false)).(bool)) require.Nil(eval(t, e, sql.NewRow(nil))) + require.False(eval(t, e, sql.NewRow(1)).(bool)) + require.True(eval(t, e, sql.NewRow(0)).(bool)) + require.False(eval(t, e, sql.NewRow(time.Now())).(bool)) + require.False(eval(t, e, sql.NewRow(time.Second)).(bool)) + require.True(eval(t, e, sql.NewRow("any string always false")).(bool)) } diff --git a/sql/expression/case.go b/sql/expression/case.go new file mode 100644 index 000000000..feb5f15f3 --- /dev/null +++ b/sql/expression/case.go @@ -0,0 +1,187 @@ +package expression + +import ( + "bytes" + + "github.com/src-d/go-mysql-server/sql" +) + +// CaseBranch is a single branch of a case expression. +type CaseBranch struct { + Cond sql.Expression + Value sql.Expression +} + +// Case is an expression that returns the value of one of its branches when a +// condition is met. +type Case struct { + Expr sql.Expression + Branches []CaseBranch + Else sql.Expression +} + +// NewCase returns an new Case expression. +func NewCase(expr sql.Expression, branches []CaseBranch, elseExpr sql.Expression) *Case { + return &Case{expr, branches, elseExpr} +} + +// Type implements the sql.Expression interface. +func (c *Case) Type() sql.Type { + for _, b := range c.Branches { + if b.Value.Type() != sql.Null { + return b.Value.Type() + } + } + + if c.Else.Type() != sql.Null { + return c.Else.Type() + } + + return sql.Null +} + +// IsNullable implements the sql.Expression interface. +func (c *Case) IsNullable() bool { + for _, b := range c.Branches { + if b.Value.IsNullable() { + return true + } + } + + return c.Else == nil || c.Else.IsNullable() +} + +// Resolved implements the sql.Expression interface. +func (c *Case) Resolved() bool { + if (c.Expr != nil && !c.Expr.Resolved()) || + (c.Else != nil && !c.Else.Resolved()) { + return false + } + + for _, b := range c.Branches { + if !b.Cond.Resolved() || !b.Value.Resolved() { + return false + } + } + + return true +} + +// Children implements the sql.Expression interface. +func (c *Case) Children() []sql.Expression { + var children []sql.Expression + + if c.Expr != nil { + children = append(children, c.Expr) + } + + for _, b := range c.Branches { + children = append(children, b.Cond, b.Value) + } + + if c.Else != nil { + children = append(children, c.Else) + } + + return children +} + +// Eval implements the sql.Expression interface. +func (c *Case) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + span, ctx := ctx.Span("expression.Case") + defer span.Finish() + + var expr interface{} + var err error + if c.Expr != nil { + expr, err = c.Expr.Eval(ctx, row) + if err != nil { + return nil, err + } + } + + for _, b := range c.Branches { + var cond sql.Expression + if expr != nil { + cond = NewEquals(NewLiteral(expr, c.Expr.Type()), b.Cond) + } else { + cond = b.Cond + } + + ok, err := sql.EvaluateCondition(ctx, cond, row) + if err != nil { + return nil, err + } + + if ok { + return b.Value.Eval(ctx, row) + } + } + + if c.Else != nil { + return c.Else.Eval(ctx, row) + } + + return nil, nil +} + +// WithChildren implements the Expression interface. +func (c *Case) WithChildren(children ...sql.Expression) (sql.Expression, error) { + var expected = len(c.Branches) * 2 + if c.Expr != nil { + expected++ + } + + if c.Else != nil { + expected++ + } + + if len(children) != expected { + return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), expected) + } + + var expr, elseExpr sql.Expression + if c.Expr != nil { + expr = children[0] + children = children[1:] + } + + if c.Else != nil { + elseExpr = children[len(children)-1] + children = children[:len(children)-1] + } + + var branches []CaseBranch + for i := 0; i < len(children); i += 2 { + branches = append(branches, CaseBranch{ + Cond: children[i], + Value: children[i+1], + }) + } + + return NewCase(expr, branches, elseExpr), nil +} + +func (c *Case) String() string { + var buf bytes.Buffer + + buf.WriteString("CASE ") + if c.Expr != nil { + buf.WriteString(c.Expr.String()) + } + + for _, b := range c.Branches { + buf.WriteString(" WHEN ") + buf.WriteString(b.Cond.String()) + buf.WriteString(" THEN ") + buf.WriteString(b.Value.String()) + } + + if c.Else != nil { + buf.WriteString(" ELSE ") + buf.WriteString(c.Else.String()) + } + + buf.WriteString(" END") + return buf.String() +} diff --git a/sql/expression/case_test.go b/sql/expression/case_test.go new file mode 100644 index 000000000..80a24aa47 --- /dev/null +++ b/sql/expression/case_test.go @@ -0,0 +1,146 @@ +package expression + +import ( + "testing" + + "github.com/src-d/go-mysql-server/sql" + "github.com/stretchr/testify/require" +) + +func TestCase(t *testing.T) { + f1 := NewCase( + NewGetField(0, sql.Int64, "foo", false), + []CaseBranch{ + {Cond: NewLiteral(int64(1), sql.Int64), Value: NewLiteral(int64(2), sql.Int64)}, + {Cond: NewLiteral(int64(3), sql.Int64), Value: NewLiteral(int64(4), sql.Int64)}, + {Cond: NewLiteral(int64(5), sql.Int64), Value: NewLiteral(int64(6), sql.Int64)}, + }, + NewLiteral(int64(7), sql.Int64), + ) + + f2 := NewCase( + nil, + []CaseBranch{ + { + Cond: NewEquals( + NewGetField(0, sql.Int64, "foo", false), + NewLiteral(int64(1), sql.Int64), + ), + Value: NewLiteral(int64(2), sql.Int64), + }, + { + Cond: NewEquals( + NewGetField(0, sql.Int64, "foo", false), + NewLiteral(int64(3), sql.Int64), + ), + Value: NewLiteral(int64(4), sql.Int64), + }, + { + Cond: NewEquals( + NewGetField(0, sql.Int64, "foo", false), + NewLiteral(int64(5), sql.Int64), + ), + Value: NewLiteral(int64(6), sql.Int64), + }, + }, + NewLiteral(int64(7), sql.Int64), + ) + + f3 := NewCase( + NewGetField(0, sql.Int64, "foo", false), + []CaseBranch{ + {Cond: NewLiteral(int64(1), sql.Int64), Value: NewLiteral(int64(2), sql.Int64)}, + {Cond: NewLiteral(int64(3), sql.Int64), Value: NewLiteral(int64(4), sql.Int64)}, + {Cond: NewLiteral(int64(5), sql.Int64), Value: NewLiteral(int64(6), sql.Int64)}, + }, + nil, + ) + + testCases := []struct { + name string + f *Case + row sql.Row + expected interface{} + }{ + { + "with expr and else branch 1", + f1, + sql.Row{int64(1)}, + int64(2), + }, + { + "with expr and else branch 2", + f1, + sql.Row{int64(3)}, + int64(4), + }, + { + "with expr and else branch 3", + f1, + sql.Row{int64(5)}, + int64(6), + }, + { + "with expr and else, else branch", + f1, + sql.Row{int64(9)}, + int64(7), + }, + { + "without expr and else branch 1", + f2, + sql.Row{int64(1)}, + int64(2), + }, + { + "without expr and else branch 2", + f2, + sql.Row{int64(3)}, + int64(4), + }, + { + "without expr and else branch 3", + f2, + sql.Row{int64(5)}, + int64(6), + }, + { + "without expr and else, else branch", + f2, + sql.Row{int64(9)}, + int64(7), + }, + { + "without else, else branch", + f3, + sql.Row{int64(9)}, + nil, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + result, err := tt.f.Eval(sql.NewEmptyContext(), tt.row) + require.NoError(err) + require.Equal(tt.expected, result) + }) + } +} + +func TestCaseNullBranch(t *testing.T) { + require := require.New(t) + f := NewCase( + NewGetField(0, sql.Int64, "x", false), + []CaseBranch{ + { + Cond: NewLiteral(int64(1), sql.Int64), + Value: NewLiteral(nil, sql.Null), + }, + }, + nil, + ) + result, err := f.Eval(sql.NewEmptyContext(), sql.Row{int64(1)}) + require.NoError(err) + require.Nil(result) +} diff --git a/sql/expression/common.go b/sql/expression/common.go index 62ea4c714..b124f269e 100644 --- a/sql/expression/common.go +++ b/sql/expression/common.go @@ -1,9 +1,19 @@ package expression import ( - "gopkg.in/src-d/go-mysql-server.v0/sql" + "github.com/src-d/go-mysql-server/sql" ) +// IsUnary returns whether the expression is unary or not. +func IsUnary(e sql.Expression) bool { + return len(e.Children()) == 1 +} + +// IsBinary returns whether the expression is binary or not. +func IsBinary(e sql.Expression) bool { + return len(e.Children()) == 2 +} + // UnaryExpression is an expression that has only one children. type UnaryExpression struct { Child sql.Expression diff --git a/sql/expression/common_test.go b/sql/expression/common_test.go index 71edd8de8..e00d3270f 100644 --- a/sql/expression/common_test.go +++ b/sql/expression/common_test.go @@ -3,8 +3,8 @@ package expression import ( "testing" + "github.com/src-d/go-mysql-server/sql" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/sql" ) func eval(t *testing.T, e sql.Expression, row sql.Row) interface{} { @@ -13,3 +13,15 @@ func eval(t *testing.T, e sql.Expression, row sql.Row) interface{} { require.NoError(t, err) return v } + +func TestIsUnary(t *testing.T) { + require := require.New(t) + require.True(IsUnary(NewNot(nil))) + require.False(IsUnary(NewAnd(nil, nil))) +} + +func TestIsBinary(t *testing.T) { + require := require.New(t) + require.False(IsBinary(NewNot(nil))) + require.True(IsBinary(NewAnd(nil, nil))) +} diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index c41bf3a25..87601579b 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -2,10 +2,11 @@ package expression import ( "fmt" - "regexp" + "sync" + "github.com/src-d/go-mysql-server/internal/regex" + "github.com/src-d/go-mysql-server/sql" errors "gopkg.in/src-d/go-errors.v1" - "gopkg.in/src-d/go-mysql-server.v0/sql" ) // Comparer implements a comparison expression. @@ -70,32 +71,32 @@ func (c *comparison) evalLeftAndRight(ctx *sql.Context, row sql.Row) (interface{ func (c *comparison) castLeftAndRight(left, right interface{}) (interface{}, interface{}, error) { if sql.IsNumber(c.Left().Type()) || sql.IsNumber(c.Right().Type()) { if sql.IsDecimal(c.Left().Type()) || sql.IsDecimal(c.Right().Type()) { - left, right, err := convertLeftAndRight(left, right, ConvertToDecimal) + l, r, err := convertLeftAndRight(left, right, ConvertToDecimal) if err != nil { return nil, nil, err } c.compareType = sql.Float64 - return left, right, nil + return l, r, nil } if sql.IsSigned(c.Left().Type()) || sql.IsSigned(c.Right().Type()) { - left, right, err := convertLeftAndRight(left, right, ConvertToSigned) + l, r, err := convertLeftAndRight(left, right, ConvertToSigned) if err != nil { return nil, nil, err } c.compareType = sql.Int64 - return left, right, nil + return l, r, nil } - left, right, err := convertLeftAndRight(left, right, ConvertToUnsigned) + l, r, err := convertLeftAndRight(left, right, ConvertToUnsigned) if err != nil { return nil, nil, err } c.compareType = sql.Uint64 - return left, right, nil + return l, r, nil } left, right, err := convertLeftAndRight(left, right, ConvertToChar) @@ -144,9 +145,6 @@ func NewEquals(left sql.Expression, right sql.Expression) *Equals { // Eval implements the Expression interface. func (e *Equals) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - span, ctx := ctx.Span("expression.Equals") - defer span.Finish() - result, err := e.Compare(ctx, row) if err != nil { if ErrNilOperand.Is(err) { @@ -159,19 +157,12 @@ func (e *Equals) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return result == 0, nil } -// TransformUp implements the Expression interface. -func (e *Equals) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - left, err := e.Left().TransformUp(f) - if err != nil { - return nil, err - } - - right, err := e.Right().TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (e *Equals) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(e, len(children), 2) } - - return f(NewEquals(left, right)) + return NewEquals(children[0], children[1]), nil } func (e *Equals) String() string { @@ -181,18 +172,29 @@ func (e *Equals) String() string { // Regexp is a comparison that checks an expression matches a regexp. type Regexp struct { comparison + pool *sync.Pool + cached bool } // NewRegexp creates a new Regexp expression. func NewRegexp(left sql.Expression, right sql.Expression) *Regexp { - return &Regexp{newComparison(left, right)} + var cached = true + Inspect(right, func(e sql.Expression) bool { + if _, ok := e.(*GetField); ok { + cached = false + } + return true + }) + + return &Regexp{ + comparison: newComparison(left, right), + pool: nil, + cached: cached, + } } // Eval implements the Expression interface. func (re *Regexp) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - span, ctx := ctx.Span("expression.Regexp") - defer span.Finish() - if sql.IsText(re.Left().Type()) && sql.IsText(re.Right().Type()) { return re.compareRegexp(ctx, row) } @@ -210,46 +212,71 @@ func (re *Regexp) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } func (re *Regexp) compareRegexp(ctx *sql.Context, row sql.Row) (interface{}, error) { - left, right, err := re.evalLeftAndRight(ctx, row) - if err != nil { + left, err := re.Left().Eval(ctx, row) + if err != nil || left == nil { return nil, err } - - if left == nil || right == nil { - return nil, nil - } - left, err = sql.Text.Convert(left) if err != nil { return nil, err } - right, err = sql.Text.Convert(right) - if err != nil { - return nil, err + var ( + matcher regex.Matcher + disposer regex.Disposer + right interface{} + ) + // eval right and convert to text + if !re.cached || re.pool == nil { + right, err = re.Right().Eval(ctx, row) + if err != nil || right == nil { + return nil, err + } + right, err = sql.Text.Convert(right) + if err != nil { + return nil, err + } } - - reg, err := regexp.Compile(right.(string)) - if err != nil { - return false, err + // for non-cached regex every time create a new matcher + if !re.cached { + matcher, disposer, err = regex.New(regex.Default(), right.(string)) + } else { + if re.pool == nil { + re.pool = &sync.Pool{ + New: func() interface{} { + r, _, e := regex.New(regex.Default(), right.(string)) + if e != nil { + err = e + return nil + } + return r + }, + } + } + if obj := re.pool.Get(); obj != nil { + matcher = obj.(regex.Matcher) + } } - - return reg.MatchString(left.(string)), nil -} - -// TransformUp implements the Expression interface. -func (re *Regexp) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - left, err := re.Left().TransformUp(f) - if err != nil { + if matcher == nil { return nil, err } - right, err := re.Right().TransformUp(f) - if err != nil { - return nil, err + ok := matcher.Match(left.(string)) + + if !re.cached { + disposer.Dispose() + } else if re.pool != nil { + re.pool.Put(matcher) } + return ok, nil +} - return f(NewRegexp(left, right)) +// WithChildren implements the Expression interface. +func (re *Regexp) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(re, len(children), 2) + } + return NewRegexp(children[0], children[1]), nil } func (re *Regexp) String() string { @@ -268,9 +295,6 @@ func NewGreaterThan(left sql.Expression, right sql.Expression) *GreaterThan { // Eval implements the Expression interface. func (gt *GreaterThan) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - span, ctx := ctx.Span("expression.GreaterThan") - defer span.Finish() - result, err := gt.Compare(ctx, row) if err != nil { if ErrNilOperand.Is(err) { @@ -283,19 +307,12 @@ func (gt *GreaterThan) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) return result == 1, nil } -// TransformUp implements the Expression interface. -func (gt *GreaterThan) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - left, err := gt.Left().TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (gt *GreaterThan) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(gt, len(children), 2) } - - right, err := gt.Right().TransformUp(f) - if err != nil { - return nil, err - } - - return f(NewGreaterThan(left, right)) + return NewGreaterThan(children[0], children[1]), nil } func (gt *GreaterThan) String() string { @@ -314,9 +331,6 @@ func NewLessThan(left sql.Expression, right sql.Expression) *LessThan { // Eval implements the expression interface. func (lt *LessThan) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - span, ctx := ctx.Span("expression.LessThan") - defer span.Finish() - result, err := lt.Compare(ctx, row) if err != nil { if ErrNilOperand.Is(err) { @@ -329,19 +343,12 @@ func (lt *LessThan) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return result == -1, nil } -// TransformUp implements the Expression interface. -func (lt *LessThan) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - left, err := lt.Left().TransformUp(f) - if err != nil { - return nil, err - } - - right, err := lt.Right().TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (lt *LessThan) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(lt, len(children), 2) } - - return f(NewLessThan(left, right)) + return NewLessThan(children[0], children[1]), nil } func (lt *LessThan) String() string { @@ -361,9 +368,6 @@ func NewGreaterThanOrEqual(left sql.Expression, right sql.Expression) *GreaterTh // Eval implements the Expression interface. func (gte *GreaterThanOrEqual) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - span, ctx := ctx.Span("expression.GreaterThanOrEqual") - defer span.Finish() - result, err := gte.Compare(ctx, row) if err != nil { if ErrNilOperand.Is(err) { @@ -376,19 +380,12 @@ func (gte *GreaterThanOrEqual) Eval(ctx *sql.Context, row sql.Row) (interface{}, return result > -1, nil } -// TransformUp implements the Expression interface. -func (gte *GreaterThanOrEqual) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - left, err := gte.Left().TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (gte *GreaterThanOrEqual) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(gte, len(children), 2) } - - right, err := gte.Right().TransformUp(f) - if err != nil { - return nil, err - } - - return f(NewGreaterThanOrEqual(left, right)) + return NewGreaterThanOrEqual(children[0], children[1]), nil } func (gte *GreaterThanOrEqual) String() string { @@ -408,9 +405,6 @@ func NewLessThanOrEqual(left sql.Expression, right sql.Expression) *LessThanOrEq // Eval implements the Expression interface. func (lte *LessThanOrEqual) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - span, ctx := ctx.Span("expression.LessThanOrEqual") - defer span.Finish() - result, err := lte.Compare(ctx, row) if err != nil { if ErrNilOperand.Is(err) { @@ -423,19 +417,12 @@ func (lte *LessThanOrEqual) Eval(ctx *sql.Context, row sql.Row) (interface{}, er return result < 1, nil } -// TransformUp implements the Expression interface. -func (lte *LessThanOrEqual) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - left, err := lte.Left().TransformUp(f) - if err != nil { - return nil, err - } - - right, err := lte.Right().TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (lte *LessThanOrEqual) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(lte, len(children), 2) } - - return f(NewLessThanOrEqual(left, right)) + return NewLessThanOrEqual(children[0], children[1]), nil } func (lte *LessThanOrEqual) String() string { @@ -463,9 +450,6 @@ func NewIn(left sql.Expression, right sql.Expression) *In { // Eval implements the Expression interface. func (in *In) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - span, ctx := ctx.Span("expression.In") - defer span.Finish() - typ := in.Left().Type() leftElems := sql.NumColumns(typ) left, err := in.Left().Eval(ctx, row) @@ -482,7 +466,6 @@ func (in *In) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, err } - // TODO: support subqueries switch right := in.Right().(type) { case Tuple: for _, el := range right { @@ -512,25 +495,46 @@ func (in *In) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } } + return false, nil + case *Subquery: + if leftElems > 1 { + return nil, ErrInvalidOperandColumns.New(leftElems, 1) + } + + typ := right.Type() + values, err := right.EvalMultiple(ctx) + if err != nil { + return nil, err + } + + for _, val := range values { + val, err = typ.Convert(val) + if err != nil { + return nil, err + } + + cmp, err := typ.Compare(left, val) + if err != nil { + return nil, err + } + + if cmp == 0 { + return true, nil + } + } + return false, nil default: return nil, ErrUnsupportedInOperand.New(right) } } -// TransformUp implements the Expression interface. -func (in *In) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - left, err := in.Left().TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (in *In) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(in, len(children), 2) } - - right, err := in.Right().TransformUp(f) - if err != nil { - return nil, err - } - - return f(NewIn(left, right)) + return NewIn(children[0], children[1]), nil } func (in *In) String() string { @@ -554,9 +558,6 @@ func NewNotIn(left sql.Expression, right sql.Expression) *NotIn { // Eval implements the Expression interface. func (in *NotIn) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - span, ctx := ctx.Span("expression.NotIn") - defer span.Finish() - typ := in.Left().Type() leftElems := sql.NumColumns(typ) left, err := in.Left().Eval(ctx, row) @@ -573,7 +574,6 @@ func (in *NotIn) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, err } - // TODO: support subqueries switch right := in.Right().(type) { case Tuple: for _, el := range right { @@ -603,25 +603,46 @@ func (in *NotIn) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } } + return true, nil + case *Subquery: + if leftElems > 1 { + return nil, ErrInvalidOperandColumns.New(leftElems, 1) + } + + typ := right.Type() + values, err := right.EvalMultiple(ctx) + if err != nil { + return nil, err + } + + for _, val := range values { + val, err = typ.Convert(val) + if err != nil { + return nil, err + } + + cmp, err := typ.Compare(left, val) + if err != nil { + return nil, err + } + + if cmp == 0 { + return false, nil + } + } + return true, nil default: return nil, ErrUnsupportedInOperand.New(right) } } -// TransformUp implements the Expression interface. -func (in *NotIn) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - left, err := in.Left().TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (in *NotIn) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(in, len(children), 2) } - - right, err := in.Right().TransformUp(f) - if err != nil { - return nil, err - } - - return f(NewNotIn(left, right)) + return NewNotIn(children[0], children[1]), nil } func (in *NotIn) String() string { diff --git a/sql/expression/comparison_test.go b/sql/expression/comparison_test.go index aee436446..802c3676a 100644 --- a/sql/expression/comparison_test.go +++ b/sql/expression/comparison_test.go @@ -1,10 +1,14 @@ -package expression +package expression_test import ( "testing" + "github.com/src-d/go-mysql-server/internal/regex" + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/plan" errors "gopkg.in/src-d/go-errors.v1" - "gopkg.in/src-d/go-mysql-server.v0/sql" "github.com/stretchr/testify/require" ) @@ -94,11 +98,11 @@ var likeComparisonCases = map[sql.Type]map[int][][]interface{}{ func TestEquals(t *testing.T) { require := require.New(t) for resultType, cmpCase := range comparisonCases { - get0 := NewGetField(0, resultType, "col1", true) + get0 := expression.NewGetField(0, resultType, "col1", true) require.NotNil(get0) - get1 := NewGetField(1, resultType, "col2", true) + get1 := expression.NewGetField(1, resultType, "col2", true) require.NotNil(get1) - eq := NewEquals(get0, get1) + eq := expression.NewEquals(get0, get1) require.NotNil(eq) require.Equal(sql.Boolean, eq.Type()) for cmpResult, cases := range cmpCase { @@ -121,11 +125,11 @@ func TestEquals(t *testing.T) { func TestLessThan(t *testing.T) { require := require.New(t) for resultType, cmpCase := range comparisonCases { - get0 := NewGetField(0, resultType, "col1", true) + get0 := expression.NewGetField(0, resultType, "col1", true) require.NotNil(get0) - get1 := NewGetField(1, resultType, "col2", true) + get1 := expression.NewGetField(1, resultType, "col2", true) require.NotNil(get1) - eq := NewLessThan(get0, get1) + eq := expression.NewLessThan(get0, get1) require.NotNil(eq) require.Equal(sql.Boolean, eq.Type()) for cmpResult, cases := range cmpCase { @@ -148,11 +152,11 @@ func TestLessThan(t *testing.T) { func TestGreaterThan(t *testing.T) { require := require.New(t) for resultType, cmpCase := range comparisonCases { - get0 := NewGetField(0, resultType, "col1", true) + get0 := expression.NewGetField(0, resultType, "col1", true) require.NotNil(get0) - get1 := NewGetField(1, resultType, "col2", true) + get1 := expression.NewGetField(1, resultType, "col2", true) require.NotNil(get1) - eq := NewGreaterThan(get0, get1) + eq := expression.NewGreaterThan(get0, get1) require.NotNil(eq) require.Equal(sql.Boolean, eq.Type()) for cmpResult, cases := range cmpCase { @@ -173,17 +177,27 @@ func TestGreaterThan(t *testing.T) { } func TestRegexp(t *testing.T) { + for _, engine := range regex.Engines() { + regex.SetDefault(engine) + t.Run(engine, testRegexpCases) + } +} + +func testRegexpCases(t *testing.T) { + t.Helper() require := require.New(t) + for resultType, cmpCase := range likeComparisonCases { - get0 := NewGetField(0, resultType, "col1", true) + get0 := expression.NewGetField(0, resultType, "col1", true) require.NotNil(get0) - get1 := NewGetField(1, resultType, "col2", true) + get1 := expression.NewGetField(1, resultType, "col2", true) require.NotNil(get1) - eq := NewRegexp(get0, get1) - require.NotNil(eq) - require.Equal(sql.Boolean, eq.Type()) for cmpResult, cases := range cmpCase { for _, pair := range cases { + eq := expression.NewRegexp(get0, get1) + require.NotNil(eq) + require.Equal(sql.Boolean, eq.Type()) + row := sql.NewRow(pair[0], pair[1]) require.NotNil(row) cmp := eval(t, eq, row) @@ -199,6 +213,19 @@ func TestRegexp(t *testing.T) { } } +func TestInvalidRegexp(t *testing.T) { + t.Helper() + require := require.New(t) + + col1 := expression.NewGetField(0, sql.Text, "col1", true) + invalid := expression.NewLiteral("*col1", sql.Text) + r := expression.NewRegexp(col1, invalid) + row := sql.NewRow("col1") + + _, err := r.Eval(sql.NewEmptyContext(), row) + require.Error(err) +} + func TestIn(t *testing.T) { testCases := []struct { name string @@ -210,10 +237,10 @@ func TestIn(t *testing.T) { }{ { "left is nil", - NewLiteral(nil, sql.Null), - NewTuple( - NewLiteral(int64(1), sql.Int64), - NewLiteral(int64(2), sql.Int64), + expression.NewLiteral(nil, sql.Null), + expression.NewTuple( + expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int64(2), sql.Int64), ), nil, nil, @@ -221,32 +248,32 @@ func TestIn(t *testing.T) { }, { "left and right don't have the same cols", - NewLiteral(1, sql.Int64), - NewTuple( - NewTuple( - NewLiteral(int64(1), sql.Int64), - NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(1, sql.Int64), + expression.NewTuple( + expression.NewTuple( + expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int64(1), sql.Int64), ), - NewLiteral(int64(2), sql.Int64), + expression.NewLiteral(int64(2), sql.Int64), ), nil, nil, - ErrInvalidOperandColumns, + expression.ErrInvalidOperandColumns, }, { "right is an unsupported operand", - NewLiteral(1, sql.Int64), - NewLiteral(int64(2), sql.Int64), + expression.NewLiteral(1, sql.Int64), + expression.NewLiteral(int64(2), sql.Int64), nil, nil, - ErrUnsupportedInOperand, + expression.ErrUnsupportedInOperand, }, { "left is in right", - NewGetField(0, sql.Int64, "foo", false), - NewTuple( - NewGetField(0, sql.Int64, "foo", false), - NewLiteral(int64(2), sql.Int64), + expression.NewGetField(0, sql.Int64, "foo", false), + expression.NewTuple( + expression.NewGetField(0, sql.Int64, "foo", false), + expression.NewLiteral(int64(2), sql.Int64), ), sql.NewRow(int64(1)), true, @@ -254,10 +281,10 @@ func TestIn(t *testing.T) { }, { "left is not in right", - NewGetField(0, sql.Int64, "foo", false), - NewTuple( - NewGetField(1, sql.Int64, "bar", false), - NewLiteral(int64(2), sql.Int64), + expression.NewGetField(0, sql.Int64, "foo", false), + expression.NewTuple( + expression.NewGetField(1, sql.Int64, "bar", false), + expression.NewLiteral(int64(2), sql.Int64), ), sql.NewRow(int64(1), int64(3)), false, @@ -269,7 +296,96 @@ func TestIn(t *testing.T) { t.Run(tt.name, func(t *testing.T) { require := require.New(t) - result, err := NewIn(tt.left, tt.right).Eval(sql.NewEmptyContext(), tt.row) + result, err := expression.NewIn(tt.left, tt.right). + Eval(sql.NewEmptyContext(), tt.row) + if tt.err != nil { + require.Error(err) + require.True(tt.err.Is(err)) + } else { + require.NoError(err) + require.Equal(tt.result, result) + } + }) + } +} + +func TestInSubquery(t *testing.T) { + ctx := sql.NewEmptyContext() + table := memory.NewTable("foo", sql.Schema{ + {Name: "t", Source: "foo", Type: sql.Text}, + }) + + require.NoError(t, table.Insert(ctx, sql.Row{"one"})) + require.NoError(t, table.Insert(ctx, sql.Row{"two"})) + require.NoError(t, table.Insert(ctx, sql.Row{"three"})) + + project := func(expr sql.Expression) sql.Node { + return plan.NewProject([]sql.Expression{ + expr, + }, plan.NewResolvedTable(table)) + } + + testCases := []struct { + name string + left sql.Expression + right sql.Node + row sql.Row + result interface{} + err *errors.Kind + }{ + { + "left is nil", + expression.NewLiteral(nil, sql.Null), + project( + expression.NewLiteral(int64(1), sql.Int64), + ), + nil, + nil, + nil, + }, + { + "left and right don't have the same cols", + expression.NewTuple( + expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int64(1), sql.Int64), + ), + project( + expression.NewLiteral(int64(2), sql.Int64), + ), + nil, + nil, + expression.ErrInvalidOperandColumns, + }, + { + "left is in right", + expression.NewGetField(0, sql.Text, "foo", false), + project( + expression.NewGetField(0, sql.Text, "foo", false), + ), + sql.NewRow("two"), + true, + nil, + }, + { + "left is not in right", + expression.NewGetField(0, sql.Text, "foo", false), + project( + expression.NewGetField(0, sql.Text, "foo", false), + ), + sql.NewRow("four"), + false, + nil, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + result, err := expression.NewIn( + tt.left, + expression.NewSubquery(tt.right), + ).Eval(sql.NewEmptyContext(), tt.row) if tt.err != nil { require.Error(err) require.True(tt.err.Is(err)) @@ -292,10 +408,10 @@ func TestNotIn(t *testing.T) { }{ { "left is nil", - NewLiteral(nil, sql.Null), - NewTuple( - NewLiteral(int64(1), sql.Int64), - NewLiteral(int64(2), sql.Int64), + expression.NewLiteral(nil, sql.Null), + expression.NewTuple( + expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int64(2), sql.Int64), ), nil, nil, @@ -303,32 +419,32 @@ func TestNotIn(t *testing.T) { }, { "left and right don't have the same cols", - NewLiteral(1, sql.Int64), - NewTuple( - NewTuple( - NewLiteral(int64(1), sql.Int64), - NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(1, sql.Int64), + expression.NewTuple( + expression.NewTuple( + expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int64(1), sql.Int64), ), - NewLiteral(int64(2), sql.Int64), + expression.NewLiteral(int64(2), sql.Int64), ), nil, nil, - ErrInvalidOperandColumns, + expression.ErrInvalidOperandColumns, }, { "right is an unsupported operand", - NewLiteral(1, sql.Int64), - NewLiteral(int64(2), sql.Int64), + expression.NewLiteral(1, sql.Int64), + expression.NewLiteral(int64(2), sql.Int64), nil, nil, - ErrUnsupportedInOperand, + expression.ErrUnsupportedInOperand, }, { "left is in right", - NewGetField(0, sql.Int64, "foo", false), - NewTuple( - NewGetField(0, sql.Int64, "foo", false), - NewLiteral(int64(2), sql.Int64), + expression.NewGetField(0, sql.Int64, "foo", false), + expression.NewTuple( + expression.NewGetField(0, sql.Int64, "foo", false), + expression.NewLiteral(int64(2), sql.Int64), ), sql.NewRow(int64(1)), false, @@ -336,10 +452,10 @@ func TestNotIn(t *testing.T) { }, { "left is not in right", - NewGetField(0, sql.Int64, "foo", false), - NewTuple( - NewGetField(1, sql.Int64, "bar", false), - NewLiteral(int64(2), sql.Int64), + expression.NewGetField(0, sql.Int64, "foo", false), + expression.NewTuple( + expression.NewGetField(1, sql.Int64, "bar", false), + expression.NewLiteral(int64(2), sql.Int64), ), sql.NewRow(int64(1), int64(3)), true, @@ -351,7 +467,8 @@ func TestNotIn(t *testing.T) { t.Run(tt.name, func(t *testing.T) { require := require.New(t) - result, err := NewNotIn(tt.left, tt.right).Eval(sql.NewEmptyContext(), tt.row) + result, err := expression.NewNotIn(tt.left, tt.right). + Eval(sql.NewEmptyContext(), tt.row) if tt.err != nil { require.Error(err) require.True(tt.err.Is(err)) @@ -362,3 +479,98 @@ func TestNotIn(t *testing.T) { }) } } + +func TestNotInSubquery(t *testing.T) { + ctx := sql.NewEmptyContext() + table := memory.NewTable("foo", sql.Schema{ + {Name: "t", Source: "foo", Type: sql.Text}, + }) + + require.NoError(t, table.Insert(ctx, sql.Row{"one"})) + require.NoError(t, table.Insert(ctx, sql.Row{"two"})) + require.NoError(t, table.Insert(ctx, sql.Row{"three"})) + + project := func(expr sql.Expression) sql.Node { + return plan.NewProject([]sql.Expression{ + expr, + }, plan.NewResolvedTable(table)) + } + + testCases := []struct { + name string + left sql.Expression + right sql.Node + row sql.Row + result interface{} + err *errors.Kind + }{ + { + "left is nil", + expression.NewLiteral(nil, sql.Null), + project( + expression.NewLiteral(int64(1), sql.Int64), + ), + nil, + nil, + nil, + }, + { + "left and right don't have the same cols", + expression.NewTuple( + expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int64(1), sql.Int64), + ), + project( + expression.NewLiteral(int64(2), sql.Int64), + ), + nil, + nil, + expression.ErrInvalidOperandColumns, + }, + { + "left is in right", + expression.NewGetField(0, sql.Text, "foo", false), + project( + expression.NewGetField(0, sql.Text, "foo", false), + ), + sql.NewRow("two"), + false, + nil, + }, + { + "left is not in right", + expression.NewGetField(0, sql.Text, "foo", false), + project( + expression.NewGetField(0, sql.Text, "foo", false), + ), + sql.NewRow("four"), + true, + nil, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + result, err := expression.NewNotIn( + tt.left, + expression.NewSubquery(tt.right), + ).Eval(sql.NewEmptyContext(), tt.row) + if tt.err != nil { + require.Error(err) + require.True(tt.err.Is(err)) + } else { + require.NoError(err) + require.Equal(tt.result, result) + } + }) + } +} + +func eval(t *testing.T, e sql.Expression, row sql.Row) interface{} { + t.Helper() + v, err := e.Eval(sql.NewEmptyContext(), row) + require.NoError(t, err) + return v +} diff --git a/sql/expression/convert.go b/sql/expression/convert.go index cafcf4e1a..bcfe778e0 100644 --- a/sql/expression/convert.go +++ b/sql/expression/convert.go @@ -7,10 +7,9 @@ import ( "strings" "time" - opentracing "github.com/opentracing/opentracing-go" "github.com/spf13/cast" + "github.com/src-d/go-mysql-server/sql" errors "gopkg.in/src-d/go-errors.v1" - "gopkg.in/src-d/go-mysql-server.v0/sql" ) // ErrConvertExpression is returned when a conversion is not possible. @@ -52,6 +51,16 @@ func NewConvert(expr sql.Expression, castToType string) *Convert { } } +// IsNullable implements the Expression interface. +func (c *Convert) IsNullable() bool { + switch c.castToType { + case ConvertToDate, ConvertToDatetime: + return true + default: + return c.Child.IsNullable() + } +} + // Type implements the Expression interface. func (c *Convert) Type() sql.Type { switch c.castToType { @@ -59,8 +68,10 @@ func (c *Convert) Type() sql.Type { return sql.Blob case ConvertToChar, ConvertToNChar: return sql.Text - case ConvertToDate, ConvertToDatetime: + case ConvertToDate: return sql.Date + case ConvertToDatetime: + return sql.Timestamp case ConvertToDecimal: return sql.Float64 case ConvertToJSON: @@ -79,21 +90,16 @@ func (c *Convert) String() string { return fmt.Sprintf("convert(%v, %v)", c.Child, c.castToType) } -// TransformUp implements the Expression interface. -func (c *Convert) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := c.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (c *Convert) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 1) } - - return f(NewConvert(child, c.castToType)) + return NewConvert(children[0], c.castToType), nil } // Eval implements the Expression interface. func (c *Convert) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - span, ctx := ctx.Span("expression.Convert", opentracing.Tag{Key: "type", Value: c.castToType}) - defer span.Finish() - val, err := c.Child.Eval(ctx, row) if err != nil { return nil, err @@ -147,7 +153,7 @@ func convertValue(val interface{}, castTo string) (interface{}, error) { } } - return d, nil + return sql.ValidateTime(d.(time.Time)), nil case ConvertToDecimal: d, err := cast.ToFloat64E(val) if err != nil { diff --git a/sql/expression/convert_test.go b/sql/expression/convert_test.go index 45f12d708..726c68f8e 100644 --- a/sql/expression/convert_test.go +++ b/sql/expression/convert_test.go @@ -4,8 +4,8 @@ import ( "testing" "time" + "github.com/src-d/go-mysql-server/sql" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/sql" ) func TestConvert(t *testing.T) { diff --git a/sql/expression/default.go b/sql/expression/default.go new file mode 100644 index 000000000..82c5cb9e7 --- /dev/null +++ b/sql/expression/default.go @@ -0,0 +1,62 @@ +package expression + +import ( + "github.com/src-d/go-mysql-server/sql" +) + +// DefaultColumn is an default expression of a column that is not yet resolved. +type DefaultColumn struct { + name string +} + +// NewDefaultColumn creates a new NewDefaultColumn expression. +func NewDefaultColumn(name string) *DefaultColumn { + return &DefaultColumn{name: name} +} + +// Children implements the sql.Expression interface. +// The function returns always nil +func (*DefaultColumn) Children() []sql.Expression { + return nil +} + +// Resolved implements the sql.Expression interface. +// The function returns always false +func (*DefaultColumn) Resolved() bool { + return false +} + +// IsNullable implements the sql.Expression interface. +// The function always panics! +func (*DefaultColumn) IsNullable() bool { + panic("default column is a placeholder node, but IsNullable was called") +} + +// Type implements the sql.Expression interface. +// The function always panics! +func (*DefaultColumn) Type() sql.Type { + panic("default column is a placeholder node, but Type was called") +} + +// Name implements the sql.Nameable interface. +func (c *DefaultColumn) Name() string { return c.name } + +// String implements the Stringer +// The function returns column's name (can be an empty string) +func (c *DefaultColumn) String() string { + return c.name +} + +// Eval implements the sql.Expression interface. +// The function always panics! +func (*DefaultColumn) Eval(ctx *sql.Context, r sql.Row) (interface{}, error) { + panic("default column is a placeholder node, but Eval was called") +} + +// WithChildren implements the Expression interface. +func (c *DefaultColumn) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 0) + } + return c, nil +} diff --git a/sql/expression/function/aggregation/avg.go b/sql/expression/function/aggregation/avg.go index a46814dc3..eae1c8c0e 100644 --- a/sql/expression/function/aggregation/avg.go +++ b/sql/expression/function/aggregation/avg.go @@ -2,10 +2,9 @@ package aggregation import ( "fmt" - "reflect" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" ) // Avg node to calculate the average from numeric column @@ -39,91 +38,81 @@ func (a *Avg) IsNullable() bool { // Eval implements AggregationExpression interface. (AggregationExpression[Expression]]) func (a *Avg) Eval(ctx *sql.Context, buffer sql.Row) (interface{}, error) { - span, ctx := ctx.Span("aggregation.Avg_Eval") - defer span.Finish() - - isNoNum := buffer[2].(bool) - if isNoNum { - return float64(0), nil + nulls := buffer[2].(bool) + if nulls { + return nil, nil } - noNullRows := buffer[1].(float64) - if noNullRows == 0 { - return nil, nil + sum := buffer[0].(float64) + rows := buffer[1].(int64) + + if rows == 0 { + return float64(0), nil } - avg := buffer[0] - span.LogKV("avg", avg) - return avg, nil + return sum / float64(rows), nil } -// TransformUp implements AggregationExpression interface. -func (a *Avg) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := a.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (a *Avg) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(a, len(children), 1) } - return f(NewAvg(child)) + return NewAvg(children[0]), nil } // NewBuffer implements AggregationExpression interface. (AggregationExpression) func (a *Avg) NewBuffer() sql.Row { const ( - currentAvg = float64(0) - rowsCount = float64(0) - noNum = false + sum = float64(0) + rows = int64(0) + nulls = false ) - return sql.NewRow(currentAvg, rowsCount, noNum) + return sql.NewRow(sum, rows, nulls) } // Update implements AggregationExpression interface. (AggregationExpression) func (a *Avg) Update(ctx *sql.Context, buffer, row sql.Row) error { + // if there are nulls already skip all the remainiing rows + if buffer[2].(bool) { + return nil + } + v, err := a.Child.Eval(ctx, row) if err != nil { return err } - if reflect.TypeOf(v) == nil { + if v == nil { + buffer[2] = true return nil } - var num float64 - switch n := row[0].(type) { - case int, int16, int32, int64: - num = float64(reflect.ValueOf(n).Int()) - case uint, uint8, uint16, uint32, uint64: - num = float64(reflect.ValueOf(n).Uint()) - case float32, float64: - num = float64(reflect.ValueOf(n).Float()) - default: - buffer[2] = true - return nil + v, err = sql.Float64.Convert(v) + if err != nil { + v = float64(0) } - prevAvg := buffer[0].(float64) - numRows := buffer[1].(float64) - nextAvg := (prevAvg*numRows + num) / (numRows + 1) - buffer[0] = nextAvg - buffer[1] = numRows + 1 + buffer[0] = buffer[0].(float64) + v.(float64) + buffer[1] = buffer[1].(int64) + 1 return nil - } // Merge implements AggregationExpression interface. (AggregationExpression) func (a *Avg) Merge(ctx *sql.Context, buffer, partial sql.Row) error { - bufferAvg := buffer[0].(float64) - bufferRows := buffer[1].(float64) - - partialAvg := partial[0].(float64) - partialRows := partial[1].(float64) + bsum := buffer[0].(float64) + brows := buffer[1].(int64) + bnulls := buffer[2].(bool) - totalRows := bufferRows + partialRows - nextAvg := ((bufferAvg * bufferRows) + (partialAvg * partialRows)) / totalRows + psum := partial[0].(float64) + prows := partial[1].(int64) + pnulls := buffer[2].(bool) - buffer[0] = nextAvg - buffer[1] = totalRows + buffer[0] = bsum + psum + buffer[1] = brows + prows + buffer[2] = bnulls || pnulls return nil } diff --git a/sql/expression/function/aggregation/avg_test.go b/sql/expression/function/aggregation/avg_test.go index b1fa2db3e..b63114185 100644 --- a/sql/expression/function/aggregation/avg_test.go +++ b/sql/expression/function/aggregation/avg_test.go @@ -3,9 +3,9 @@ package aggregation import ( "testing" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" ) func TestAvg_String(t *testing.T) { @@ -15,13 +15,24 @@ func TestAvg_String(t *testing.T) { require.Equal("AVG(col1)", avg.String()) } +func TestAvg_Float64(t *testing.T) { + require := require.New(t) + ctx := sql.NewEmptyContext() + + avg := NewAvg(expression.NewGetField(0, sql.Float64, "col1", true)) + buffer := avg.NewBuffer() + avg.Update(ctx, buffer, sql.NewRow(float64(23.2220000))) + + require.Equal(float64(23.222), eval(t, avg, buffer)) +} + func TestAvg_Eval_INT32(t *testing.T) { require := require.New(t) ctx := sql.NewEmptyContext() avgNode := NewAvg(expression.NewGetField(0, sql.Int32, "col1", true)) buffer := avgNode.NewBuffer() - require.Zero(avgNode.Eval(ctx, buffer)) + require.Equal(float64(0), eval(t, avgNode, buffer)) avgNode.Update(ctx, buffer, sql.NewRow(int32(1))) require.Equal(float64(1), eval(t, avgNode, buffer)) @@ -36,7 +47,7 @@ func TestAvg_Eval_UINT64(t *testing.T) { avgNode := NewAvg(expression.NewGetField(0, sql.Uint64, "col1", true)) buffer := avgNode.NewBuffer() - require.Zero(avgNode.Eval(ctx, buffer)) + require.Equal(float64(0), eval(t, avgNode, buffer)) err := avgNode.Update(ctx, buffer, sql.NewRow(uint64(1))) require.NoError(err) @@ -47,17 +58,21 @@ func TestAvg_Eval_UINT64(t *testing.T) { require.Equal(float64(1.5), eval(t, avgNode, buffer)) } -func TestAvg_Eval_NoNum(t *testing.T) { +func TestAvg_Eval_String(t *testing.T) { require := require.New(t) ctx := sql.NewEmptyContext() avgNode := NewAvg(expression.NewGetField(0, sql.Text, "col1", true)) buffer := avgNode.NewBuffer() - require.Zero(avgNode.Eval(ctx, buffer)) + require.Equal(float64(0), eval(t, avgNode, buffer)) err := avgNode.Update(ctx, buffer, sql.NewRow("foo")) require.NoError(err) require.Equal(float64(0), eval(t, avgNode, buffer)) + + err = avgNode.Update(ctx, buffer, sql.NewRow("2")) + require.NoError(err) + require.Equal(float64(1), eval(t, avgNode, buffer)) } func TestAvg_Merge(t *testing.T) { diff --git a/sql/expression/function/aggregation/common_test.go b/sql/expression/function/aggregation/common_test.go index 5e22bc3b8..759a81601 100644 --- a/sql/expression/function/aggregation/common_test.go +++ b/sql/expression/function/aggregation/common_test.go @@ -3,8 +3,8 @@ package aggregation import ( "testing" + "github.com/src-d/go-mysql-server/sql" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/sql" ) func eval(t *testing.T, e sql.Expression, row sql.Row) interface{} { @@ -15,3 +15,17 @@ func eval(t *testing.T, e sql.Expression, row sql.Row) interface{} { require.NoError(t, err) return v } + +func aggregate(t *testing.T, agg sql.Aggregation, rows ...sql.Row) interface{} { + t.Helper() + + ctx := sql.NewEmptyContext() + buf := agg.NewBuffer() + for _, row := range rows { + require.NoError(t, agg.Update(ctx, buf, row)) + } + + v, err := agg.Eval(ctx, buf) + require.NoError(t, err) + return v +} diff --git a/sql/expression/function/aggregation/count.go b/sql/expression/function/aggregation/count.go index 8e5964de4..0c6a8294b 100644 --- a/sql/expression/function/aggregation/count.go +++ b/sql/expression/function/aggregation/count.go @@ -3,8 +3,9 @@ package aggregation import ( "fmt" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" + "github.com/mitchellh/hashstructure" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" ) // Count node to count how many rows are in the result set. @@ -19,12 +20,12 @@ func NewCount(e sql.Expression) *Count { // NewBuffer creates a new buffer for the aggregation. func (c *Count) NewBuffer() sql.Row { - return sql.NewRow(int32(0)) + return sql.NewRow(int64(0)) } // Type returns the type of the result. func (c *Count) Type() sql.Type { - return sql.Int32 + return sql.Int64 } // IsNullable returns whether the return value can be null. @@ -45,13 +46,12 @@ func (c *Count) String() string { return fmt.Sprintf("COUNT(%s)", c.Child) } -// TransformUp implements the Expression interface. -func (c *Count) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := c.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (c *Count) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 1) } - return f(NewCount(child)) + return NewCount(children[0]), nil } // Update implements the Aggregation interface. @@ -71,7 +71,7 @@ func (c *Count) Update(ctx *sql.Context, buffer, row sql.Row) error { } if inc { - buffer[0] = buffer[0].(int32) + int32(1) + buffer[0] = buffer[0].(int64) + int64(1) } return nil @@ -79,16 +79,102 @@ func (c *Count) Update(ctx *sql.Context, buffer, row sql.Row) error { // Merge implements the Aggregation interface. func (c *Count) Merge(ctx *sql.Context, buffer, partial sql.Row) error { - buffer[0] = buffer[0].(int32) + partial[0].(int32) + buffer[0] = buffer[0].(int64) + partial[0].(int64) return nil } // Eval implements the Aggregation interface. func (c *Count) Eval(ctx *sql.Context, buffer sql.Row) (interface{}, error) { - span, ctx := ctx.Span("aggregation.Count_Eval") count := buffer[0] - span.LogKV("count", count) - span.Finish() - return count, nil } + +// CountDistinct node to count how many rows are in the result set. +type CountDistinct struct { + expression.UnaryExpression +} + +// NewCountDistinct creates a new CountDistinct node. +func NewCountDistinct(e sql.Expression) *CountDistinct { + return &CountDistinct{expression.UnaryExpression{Child: e}} +} + +// NewBuffer creates a new buffer for the aggregation. +func (c *CountDistinct) NewBuffer() sql.Row { + return sql.NewRow(make(map[uint64]struct{})) +} + +// Type returns the type of the result. +func (c *CountDistinct) Type() sql.Type { + return sql.Int64 +} + +// IsNullable returns whether the return value can be null. +func (c *CountDistinct) IsNullable() bool { + return false +} + +// Resolved implements the Expression interface. +func (c *CountDistinct) Resolved() bool { + if _, ok := c.Child.(*expression.Star); ok { + return true + } + + return c.Child.Resolved() +} + +func (c *CountDistinct) String() string { + return fmt.Sprintf("COUNT(DISTINCT %s)", c.Child) +} + +// WithChildren implements the Expression interface. +func (c *CountDistinct) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 1) + } + return NewCountDistinct(children[0]), nil +} + +// Update implements the Aggregation interface. +func (c *CountDistinct) Update(ctx *sql.Context, buffer, row sql.Row) error { + seen := buffer[0].(map[uint64]struct{}) + var value interface{} + if _, ok := c.Child.(*expression.Star); ok { + value = row + } else { + v, err := c.Child.Eval(ctx, row) + if v == nil { + return nil + } + + if err != nil { + return err + } + + value = v + } + + hash, err := hashstructure.Hash(value, nil) + if err != nil { + return fmt.Errorf("count distinct unable to hash value: %s", err) + } + + seen[hash] = struct{}{} + + return nil +} + +// Merge implements the Aggregation interface. +func (c *CountDistinct) Merge(ctx *sql.Context, buffer, partial sql.Row) error { + seen := buffer[0].(map[uint64]struct{}) + for k := range partial[0].(map[uint64]struct{}) { + seen[k] = struct{}{} + } + return nil +} + +// Eval implements the Aggregation interface. +func (c *CountDistinct) Eval(ctx *sql.Context, buffer sql.Row) (interface{}, error) { + seen := buffer[0].(map[uint64]struct{}) + return int64(len(seen)), nil +} diff --git a/sql/expression/function/aggregation/count_test.go b/sql/expression/function/aggregation/count_test.go index 1459b2fcc..ea27b0504 100644 --- a/sql/expression/function/aggregation/count_test.go +++ b/sql/expression/function/aggregation/count_test.go @@ -3,73 +3,123 @@ package aggregation import ( "testing" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" ) -func TestCount_String(t *testing.T) { - require := require.New(t) - - c := NewCount(expression.NewLiteral("foo", sql.Text)) - require.Equal(`COUNT("foo")`, c.String()) -} - -func TestCount_Eval_1(t *testing.T) { +func TestCountEval1(t *testing.T) { require := require.New(t) ctx := sql.NewEmptyContext() c := NewCount(expression.NewLiteral(1, sql.Int32)) b := c.NewBuffer() - require.Equal(int32(0), eval(t, c, b)) + require.Equal(int64(0), eval(t, c, b)) require.NoError(c.Update(ctx, b, nil)) require.NoError(c.Update(ctx, b, sql.NewRow("foo"))) require.NoError(c.Update(ctx, b, sql.NewRow(1))) require.NoError(c.Update(ctx, b, sql.NewRow(nil))) require.NoError(c.Update(ctx, b, sql.NewRow(1, 2, 3))) - require.Equal(int32(5), eval(t, c, b)) + require.Equal(int64(5), eval(t, c, b)) b2 := c.NewBuffer() require.NoError(c.Update(ctx, b2, nil)) require.NoError(c.Update(ctx, b2, sql.NewRow("foo"))) require.NoError(c.Merge(ctx, b, b2)) - require.Equal(int32(7), eval(t, c, b)) + require.Equal(int64(7), eval(t, c, b)) } -func TestCount_Eval_Star(t *testing.T) { +func TestCountEvalStar(t *testing.T) { require := require.New(t) ctx := sql.NewEmptyContext() c := NewCount(expression.NewStar()) b := c.NewBuffer() - require.Equal(int32(0), eval(t, c, b)) + require.Equal(int64(0), eval(t, c, b)) - c.Update(ctx, b, nil) - c.Update(ctx, b, sql.NewRow("foo")) - c.Update(ctx, b, sql.NewRow(1)) - c.Update(ctx, b, sql.NewRow(nil)) - c.Update(ctx, b, sql.NewRow(1, 2, 3)) - require.Equal(int32(5), eval(t, c, b)) + require.NoError(c.Update(ctx, b, nil)) + require.NoError(c.Update(ctx, b, sql.NewRow("foo"))) + require.NoError(c.Update(ctx, b, sql.NewRow(1))) + require.NoError(c.Update(ctx, b, sql.NewRow(nil))) + require.NoError(c.Update(ctx, b, sql.NewRow(1, 2, 3))) + require.Equal(int64(5), eval(t, c, b)) b2 := c.NewBuffer() - c.Update(ctx, b2, sql.NewRow()) - c.Update(ctx, b2, sql.NewRow("foo")) - c.Merge(ctx, b, b2) - require.Equal(int32(7), eval(t, c, b)) + require.NoError(c.Update(ctx, b2, sql.NewRow())) + require.NoError(c.Update(ctx, b2, sql.NewRow("foo"))) + require.NoError(c.Merge(ctx, b, b2)) + require.Equal(int64(7), eval(t, c, b)) } -func TestCount_Eval_String(t *testing.T) { +func TestCountEvalString(t *testing.T) { require := require.New(t) ctx := sql.NewEmptyContext() c := NewCount(expression.NewGetField(0, sql.Text, "", true)) b := c.NewBuffer() - require.Equal(int32(0), eval(t, c, b)) + require.Equal(int64(0), eval(t, c, b)) - c.Update(ctx, b, sql.NewRow("foo")) - require.Equal(int32(1), eval(t, c, b)) + require.NoError(c.Update(ctx, b, sql.NewRow("foo"))) + require.Equal(int64(1), eval(t, c, b)) + + require.NoError(c.Update(ctx, b, sql.NewRow(nil))) + require.Equal(int64(1), eval(t, c, b)) +} + +func TestCountDistinctEval1(t *testing.T) { + require := require.New(t) + ctx := sql.NewEmptyContext() + + c := NewCountDistinct(expression.NewLiteral(1, sql.Int32)) + b := c.NewBuffer() + require.Equal(int64(0), eval(t, c, b)) + + require.NoError(c.Update(ctx, b, nil)) + require.NoError(c.Update(ctx, b, sql.NewRow("foo"))) + require.NoError(c.Update(ctx, b, sql.NewRow(1))) + require.NoError(c.Update(ctx, b, sql.NewRow(nil))) + require.NoError(c.Update(ctx, b, sql.NewRow(1, 2, 3))) + require.Equal(int64(1), eval(t, c, b)) +} - c.Update(ctx, b, sql.NewRow(nil)) - require.Equal(int32(1), eval(t, c, b)) +func TestCountDistinctEvalStar(t *testing.T) { + require := require.New(t) + ctx := sql.NewEmptyContext() + + c := NewCountDistinct(expression.NewStar()) + b := c.NewBuffer() + require.Equal(int64(0), eval(t, c, b)) + + require.NoError(c.Update(ctx, b, nil)) + require.NoError(c.Update(ctx, b, sql.NewRow("foo"))) + require.NoError(c.Update(ctx, b, sql.NewRow(1))) + require.NoError(c.Update(ctx, b, sql.NewRow(nil))) + require.NoError(c.Update(ctx, b, sql.NewRow(1, 2, 3))) + require.Equal(int64(5), eval(t, c, b)) + + b2 := c.NewBuffer() + require.NoError(c.Update(ctx, b2, sql.NewRow(1))) + require.NoError(c.Update(ctx, b2, sql.NewRow("foo"))) + require.NoError(c.Update(ctx, b2, sql.NewRow(5))) + require.NoError(c.Merge(ctx, b, b2)) + + require.Equal(int64(6), eval(t, c, b)) +} + +func TestCountDistinctEvalString(t *testing.T) { + require := require.New(t) + ctx := sql.NewEmptyContext() + + c := NewCountDistinct(expression.NewGetField(0, sql.Text, "", true)) + b := c.NewBuffer() + require.Equal(int64(0), eval(t, c, b)) + + require.NoError(c.Update(ctx, b, sql.NewRow("foo"))) + require.Equal(int64(1), eval(t, c, b)) + + require.NoError(c.Update(ctx, b, sql.NewRow(nil))) + require.NoError(c.Update(ctx, b, sql.NewRow("foo"))) + require.NoError(c.Update(ctx, b, sql.NewRow("bar"))) + require.Equal(int64(2), eval(t, c, b)) } diff --git a/sql/expression/function/aggregation/first.go b/sql/expression/function/aggregation/first.go new file mode 100644 index 000000000..20ef8af6c --- /dev/null +++ b/sql/expression/function/aggregation/first.go @@ -0,0 +1,71 @@ +package aggregation + +import ( + "fmt" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" +) + +// First agregation returns the first of all values in the selected column. +// It implements the Aggregation interface. +type First struct { + expression.UnaryExpression +} + +// NewFirst returns a new First node. +func NewFirst(e sql.Expression) *First { + return &First{expression.UnaryExpression{Child: e}} +} + +// Type returns the resultant type of the aggregation. +func (f *First) Type() sql.Type { + return f.Child.Type() +} + +func (f *First) String() string { + return fmt.Sprintf("FIRST(%s)", f.Child) +} + +// WithChildren implements the sql.Expression interface. +func (f *First) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(f, len(children), 1) + } + return NewFirst(children[0]), nil +} + +// NewBuffer creates a new buffer to compute the result. +func (f *First) NewBuffer() sql.Row { + return sql.NewRow(nil) +} + +// Update implements the Aggregation interface. +func (f *First) Update(ctx *sql.Context, buffer, row sql.Row) error { + if buffer[0] != nil { + return nil + } + + v, err := f.Child.Eval(ctx, row) + if err != nil { + return err + } + + if v == nil { + return nil + } + + buffer[0] = v + + return nil +} + +// Merge implements the Aggregation interface. +func (f *First) Merge(ctx *sql.Context, buffer, partial sql.Row) error { + return nil +} + +// Eval implements the Aggregation interface. +func (f *First) Eval(ctx *sql.Context, buffer sql.Row) (interface{}, error) { + return buffer[0], nil +} diff --git a/sql/expression/function/aggregation/first_test.go b/sql/expression/function/aggregation/first_test.go new file mode 100644 index 000000000..0ec3bdba6 --- /dev/null +++ b/sql/expression/function/aggregation/first_test.go @@ -0,0 +1,29 @@ +package aggregation + +import ( + "testing" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" +) + +func TestFirst(t *testing.T) { + testCases := []struct { + name string + rows []sql.Row + expected interface{} + }{ + {"no rows", nil, nil}, + {"one row", []sql.Row{{"first"}}, "first"}, + {"three rows", []sql.Row{{"first"}, {"second"}, {"last"}}, "first"}, + } + + agg := NewFirst(expression.NewGetField(0, sql.Text, "", false)) + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + result := aggregate(t, agg, tt.rows...) + require.Equal(t, tt.expected, result) + }) + } +} diff --git a/sql/expression/function/aggregation/last.go b/sql/expression/function/aggregation/last.go new file mode 100644 index 000000000..55457a5e5 --- /dev/null +++ b/sql/expression/function/aggregation/last.go @@ -0,0 +1,68 @@ +package aggregation + +import ( + "fmt" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" +) + +// Last agregation returns the last of all values in the selected column. +// It implements the Aggregation interface. +type Last struct { + expression.UnaryExpression +} + +// NewLast returns a new Last node. +func NewLast(e sql.Expression) *Last { + return &Last{expression.UnaryExpression{Child: e}} +} + +// Type returns the resultant type of the aggregation. +func (l *Last) Type() sql.Type { + return l.Child.Type() +} + +func (l *Last) String() string { + return fmt.Sprintf("LAST(%s)", l.Child) +} + +// WithChildren implements the sql.Expression interface. +func (l *Last) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(l, len(children), 1) + } + return NewLast(children[0]), nil +} + +// NewBuffer creates a new buffer to compute the result. +func (l *Last) NewBuffer() sql.Row { + return sql.NewRow(nil) +} + +// Update implements the Aggregation interface. +func (l *Last) Update(ctx *sql.Context, buffer, row sql.Row) error { + v, err := l.Child.Eval(ctx, row) + if err != nil { + return err + } + + if v == nil { + return nil + } + + buffer[0] = v + + return nil +} + +// Merge implements the Aggregation interface. +func (l *Last) Merge(ctx *sql.Context, buffer, partial sql.Row) error { + buffer[0] = partial[0] + return nil +} + +// Eval implements the Aggregation interface. +func (l *Last) Eval(ctx *sql.Context, buffer sql.Row) (interface{}, error) { + return buffer[0], nil +} diff --git a/sql/expression/function/aggregation/last_test.go b/sql/expression/function/aggregation/last_test.go new file mode 100644 index 000000000..4271eebf4 --- /dev/null +++ b/sql/expression/function/aggregation/last_test.go @@ -0,0 +1,29 @@ +package aggregation + +import ( + "testing" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" +) + +func TestLast(t *testing.T) { + testCases := []struct { + name string + rows []sql.Row + expected interface{} + }{ + {"no rows", nil, nil}, + {"one row", []sql.Row{{"first"}}, "first"}, + {"three rows", []sql.Row{{"first"}, {"second"}, {"last"}}, "last"}, + } + + agg := NewLast(expression.NewGetField(0, sql.Text, "", false)) + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + result := aggregate(t, agg, tt.rows...) + require.Equal(t, tt.expected, result) + }) + } +} diff --git a/sql/expression/function/aggregation/max.go b/sql/expression/function/aggregation/max.go index 9f2cee986..e47211f2c 100644 --- a/sql/expression/function/aggregation/max.go +++ b/sql/expression/function/aggregation/max.go @@ -4,8 +4,8 @@ import ( "fmt" "reflect" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" ) // Max agregation returns the greatest value of the selected column. @@ -38,13 +38,12 @@ func (m *Max) IsNullable() bool { return false } -// TransformUp implements the Transformable interface. -func (m *Max) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := m.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (m *Max) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(m, len(children), 1) } - return f(NewMax(child)) + return NewMax(children[0]), nil } // NewBuffer creates a new buffer to compute the result. @@ -85,10 +84,6 @@ func (m *Max) Merge(ctx *sql.Context, buffer, partial sql.Row) error { // Eval implements the Aggregation interface. func (m *Max) Eval(ctx *sql.Context, buffer sql.Row) (interface{}, error) { - span, ctx := ctx.Span("aggregation.Max_Eval") max := buffer[0] - span.LogKV("max", max) - span.Finish() - return max, nil } diff --git a/sql/expression/function/aggregation/max_test.go b/sql/expression/function/aggregation/max_test.go index 8f467bc2d..f8da09942 100644 --- a/sql/expression/function/aggregation/max_test.go +++ b/sql/expression/function/aggregation/max_test.go @@ -4,9 +4,9 @@ import ( "testing" "time" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" ) func TestMax_String(t *testing.T) { diff --git a/sql/expression/function/aggregation/min.go b/sql/expression/function/aggregation/min.go index dc673a504..8e73e0812 100644 --- a/sql/expression/function/aggregation/min.go +++ b/sql/expression/function/aggregation/min.go @@ -4,8 +4,8 @@ import ( "fmt" "reflect" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" ) // Min aggregation returns the smallest value of the selected column. @@ -38,13 +38,12 @@ func (m *Min) IsNullable() bool { return true } -// TransformUp implements the Transformable interface. -func (m *Min) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := m.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (m *Min) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(m, len(children), 1) } - return f(NewMin(child)) + return NewMin(children[0]), nil } // NewBuffer creates a new buffer to compute the result. @@ -85,10 +84,6 @@ func (m *Min) Merge(ctx *sql.Context, buffer, partial sql.Row) error { // Eval implements the Aggregation interface func (m *Min) Eval(ctx *sql.Context, buffer sql.Row) (interface{}, error) { - span, ctx := ctx.Span("aggregation.Min_Eval") min := buffer[0] - span.LogKV("min", min) - span.Finish() - return min, nil } diff --git a/sql/expression/function/aggregation/min_test.go b/sql/expression/function/aggregation/min_test.go index c5090aa34..6039219d6 100644 --- a/sql/expression/function/aggregation/min_test.go +++ b/sql/expression/function/aggregation/min_test.go @@ -4,9 +4,9 @@ import ( "testing" "time" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" ) func TestMin_Name(t *testing.T) { diff --git a/sql/expression/function/aggregation/sum.go b/sql/expression/function/aggregation/sum.go new file mode 100644 index 000000000..09df362be --- /dev/null +++ b/sql/expression/function/aggregation/sum.go @@ -0,0 +1,78 @@ +package aggregation + +import ( + "fmt" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" +) + +// Sum agregation returns the sum of all values in the selected column. +// It implements the Aggregation interface. +type Sum struct { + expression.UnaryExpression +} + +// NewSum returns a new Sum node. +func NewSum(e sql.Expression) *Sum { + return &Sum{expression.UnaryExpression{Child: e}} +} + +// Type returns the resultant type of the aggregation. +func (m *Sum) Type() sql.Type { + return sql.Float64 +} + +func (m *Sum) String() string { + return fmt.Sprintf("SUM(%s)", m.Child) +} + +// WithChildren implements the Expression interface. +func (m *Sum) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(m, len(children), 1) + } + return NewSum(children[0]), nil +} + +// NewBuffer creates a new buffer to compute the result. +func (m *Sum) NewBuffer() sql.Row { + return sql.NewRow(nil) +} + +// Update implements the Aggregation interface. +func (m *Sum) Update(ctx *sql.Context, buffer, row sql.Row) error { + v, err := m.Child.Eval(ctx, row) + if err != nil { + return err + } + + if v == nil { + return nil + } + + val, err := sql.Float64.Convert(v) + if err != nil { + val = float64(0) + } + + if buffer[0] == nil { + buffer[0] = float64(0) + } + + buffer[0] = buffer[0].(float64) + val.(float64) + + return nil +} + +// Merge implements the Aggregation interface. +func (m *Sum) Merge(ctx *sql.Context, buffer, partial sql.Row) error { + return m.Update(ctx, buffer, partial) +} + +// Eval implements the Aggregation interface. +func (m *Sum) Eval(ctx *sql.Context, buffer sql.Row) (interface{}, error) { + sum := buffer[0] + + return sum, nil +} diff --git a/sql/expression/function/aggregation/sum_test.go b/sql/expression/function/aggregation/sum_test.go new file mode 100644 index 000000000..37b29324e --- /dev/null +++ b/sql/expression/function/aggregation/sum_test.go @@ -0,0 +1,75 @@ +package aggregation + +import ( + "testing" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" +) + +func TestSum(t *testing.T) { + sum := NewSum(expression.NewGetField(0, nil, "", false)) + + testCases := []struct { + name string + rows []sql.Row + expected interface{} + }{ + { + "string int values", + []sql.Row{{"1"}, {"2"}, {"3"}, {"4"}}, + float64(10), + }, + { + "string float values", + []sql.Row{{"1.5"}, {"2"}, {"3"}, {"4"}}, + float64(10.5), + }, + { + "string non-int values", + []sql.Row{{"a"}, {"b"}, {"c"}, {"d"}}, + float64(0), + }, + { + "float values", + []sql.Row{{1.}, {2.5}, {3.}, {4.}}, + float64(10.5), + }, + { + "no rows", + []sql.Row{}, + nil, + }, + { + "nil values", + []sql.Row{{nil}, {nil}}, + nil, + }, + { + "int64 values", + []sql.Row{{int64(1)}, {int64(3)}}, + float64(4), + }, + { + "int32 values", + []sql.Row{{int32(1)}, {int32(3)}}, + float64(4), + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + buf := sum.NewBuffer() + for _, row := range tt.rows { + require.NoError(sum.Update(sql.NewEmptyContext(), buf, row)) + } + + result, err := sum.Eval(sql.NewEmptyContext(), buf) + require.NoError(err) + require.Equal(tt.expected, result) + }) + } +} diff --git a/sql/expression/function/arraylength.go b/sql/expression/function/arraylength.go index 457b33865..61a902cd9 100644 --- a/sql/expression/function/arraylength.go +++ b/sql/expression/function/arraylength.go @@ -2,10 +2,9 @@ package function import ( "fmt" - "reflect" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" ) // ArrayLength returns the length of an array. @@ -27,22 +26,18 @@ func (f *ArrayLength) String() string { return fmt.Sprintf("array_length(%s)", f.Child) } -// TransformUp implements the Expression interface. -func (f *ArrayLength) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { - child, err := f.Child.TransformUp(fn) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (f *ArrayLength) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(f, len(children), 1) } - return fn(NewArrayLength(child)) + return NewArrayLength(children[0]), nil } // Eval implements the Expression interface. func (f *ArrayLength) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - span, ctx := ctx.Span("function.ArrayLength") - defer span.Finish() - - if !sql.IsArray(f.Child.Type()) { - return nil, sql.ErrInvalidType.New(f.Child.Type().Type().String()) + if t := f.Child.Type(); !sql.IsArray(t) && t != sql.JSON { + return nil, nil } child, err := f.Child.Eval(ctx, row) @@ -56,7 +51,7 @@ func (f *ArrayLength) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { array, ok := child.([]interface{}) if !ok { - return nil, sql.ErrInvalidType.New(reflect.TypeOf(child)) + return nil, nil } return int32(len(array)), nil diff --git a/sql/expression/function/arraylength_test.go b/sql/expression/function/arraylength_test.go index 034b6afa7..7e9437144 100644 --- a/sql/expression/function/arraylength_test.go +++ b/sql/expression/function/arraylength_test.go @@ -3,10 +3,10 @@ package function import ( "testing" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" "github.com/stretchr/testify/require" errors "gopkg.in/src-d/go-errors.v1" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" ) func TestArrayLength(t *testing.T) { @@ -19,7 +19,7 @@ func TestArrayLength(t *testing.T) { err *errors.Kind }{ {"array is nil", sql.NewRow(nil), nil, nil}, - {"array is not of right type", sql.NewRow(5), nil, sql.ErrInvalidType}, + {"array is not of right type", sql.NewRow(5), nil, nil}, {"array is ok", sql.NewRow([]interface{}{1, 2, 3}), int32(3), nil}, } @@ -40,7 +40,7 @@ func TestArrayLength(t *testing.T) { f = NewArrayLength(expression.NewGetField(0, sql.Tuple(sql.Int64, sql.Int64), "", false)) require := require.New(t) - _, err := f.Eval(sql.NewEmptyContext(), []interface{}{int64(1), int64(2)}) - require.Error(err) - require.True(sql.ErrInvalidType.Is(err)) + v, err := f.Eval(sql.NewEmptyContext(), []interface{}{int64(1), int64(2)}) + require.NoError(err) + require.Nil(v) } diff --git a/sql/expression/function/ceil_round_floor.go b/sql/expression/function/ceil_round_floor.go new file mode 100644 index 000000000..c056f82da --- /dev/null +++ b/sql/expression/function/ceil_round_floor.go @@ -0,0 +1,297 @@ +package function + +import ( + "fmt" + "math" + "reflect" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" +) + +// Ceil returns the smallest integer value not less than X. +type Ceil struct { + expression.UnaryExpression +} + +// NewCeil creates a new Ceil expression. +func NewCeil(num sql.Expression) sql.Expression { + return &Ceil{expression.UnaryExpression{Child: num}} +} + +// Type implements the Expression interface. +func (c *Ceil) Type() sql.Type { + childType := c.Child.Type() + if sql.IsNumber(childType) { + return childType + } + return sql.Int32 +} + +func (c *Ceil) String() string { + return fmt.Sprintf("CEIL(%s)", c.Child) +} + +// WithChildren implements the Expression interface. +func (c *Ceil) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 1) + } + return NewCeil(children[0]), nil +} + +// Eval implements the Expression interface. +func (c *Ceil) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + child, err := c.Child.Eval(ctx, row) + + if err != nil { + return nil, err + } + + if child == nil { + return nil, nil + } + + if !sql.IsNumber(c.Child.Type()) { + child, err = sql.Float64.Convert(child) + if err != nil { + return int32(0), nil + } + + return int32(math.Ceil(child.(float64))), nil + } + + if !sql.IsDecimal(c.Child.Type()) { + return child, err + } + + switch num := child.(type) { + case float64: + return math.Ceil(num), nil + case float32: + return float32(math.Ceil(float64(num))), nil + default: + return nil, sql.ErrInvalidType.New(reflect.TypeOf(num)) + } +} + +// Floor returns the biggest integer value not less than X. +type Floor struct { + expression.UnaryExpression +} + +// NewFloor returns a new Floor expression. +func NewFloor(num sql.Expression) sql.Expression { + return &Floor{expression.UnaryExpression{Child: num}} +} + +// Type implements the Expression interface. +func (f *Floor) Type() sql.Type { + childType := f.Child.Type() + if sql.IsNumber(childType) { + return childType + } + return sql.Int32 +} + +func (f *Floor) String() string { + return fmt.Sprintf("FLOOR(%s)", f.Child) +} + +// WithChildren implements the Expression interface. +func (f *Floor) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(f, len(children), 1) + } + return NewFloor(children[0]), nil +} + +// Eval implements the Expression interface. +func (f *Floor) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + child, err := f.Child.Eval(ctx, row) + + if err != nil { + return nil, err + } + + if child == nil { + return nil, nil + } + + if !sql.IsNumber(f.Child.Type()) { + child, err = sql.Float64.Convert(child) + if err != nil { + return int32(0), nil + } + + return int32(math.Floor(child.(float64))), nil + } + + if !sql.IsDecimal(f.Child.Type()) { + return child, err + } + + switch num := child.(type) { + case float64: + return math.Floor(num), nil + case float32: + return float32(math.Floor(float64(num))), nil + default: + return nil, sql.ErrInvalidType.New(reflect.TypeOf(num)) + } +} + +// Round returns the number (x) with (d) requested decimal places. +// If d is negative, the number is returned with the (abs(d)) least significant +// digits of it's integer part set to 0. If d is not specified or nil/null +// it defaults to 0. +type Round struct { + expression.BinaryExpression +} + +// NewRound returns a new Round expression. +func NewRound(args ...sql.Expression) (sql.Expression, error) { + argLen := len(args) + if argLen == 0 || argLen > 2 { + return nil, sql.ErrInvalidArgumentNumber.New("ROUND", "1 or 2", argLen) + } + + var right sql.Expression + if len(args) == 2 { + right = args[1] + } + + return &Round{expression.BinaryExpression{Left: args[0], Right: right}}, nil +} + +// Children implements the Expression interface. +func (r *Round) Children() []sql.Expression { + if r.Right == nil { + return []sql.Expression{r.Left} + } + + return r.BinaryExpression.Children() +} + +// Eval implements the Expression interface. +func (r *Round) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + xVal, err := r.Left.Eval(ctx, row) + if err != nil { + return nil, err + } + + if xVal == nil { + return nil, nil + } + + dVal := float64(0) + + if r.Right != nil { + var dTemp interface{} + dTemp, err = r.Right.Eval(ctx, row) + if err != nil { + return nil, err + } + + if dTemp != nil { + switch dNum := dTemp.(type) { + case float64: + dVal = float64(int64(dNum)) + case float32: + dVal = float64(int64(dNum)) + case int64: + dVal = float64(dNum) + case int32: + dVal = float64(dNum) + case int16: + dVal = float64(dNum) + case int8: + dVal = float64(dNum) + case uint64: + dVal = float64(dNum) + case uint32: + dVal = float64(dNum) + case uint16: + dVal = float64(dNum) + case uint8: + dVal = float64(dNum) + case int: + dVal = float64(dNum) + default: + dTemp, err = sql.Float64.Convert(dTemp) + if err == nil { + dVal = dTemp.(float64) + } + } + } + } + + if !sql.IsNumber(r.Left.Type()) { + xVal, err = sql.Float64.Convert(xVal) + if err != nil { + return int32(0), nil + } + + xNum := xVal.(float64) + return int32(math.Round(xNum*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil + } + + switch xNum := xVal.(type) { + case float64: + return math.Round(xNum*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal), nil + case float32: + return float32(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil + case int64: + return int64(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil + case int32: + return int32(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil + case int16: + return int16(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil + case int8: + return int8(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil + case uint64: + return uint64(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil + case uint32: + return uint32(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil + case uint16: + return uint16(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil + case uint8: + return uint8(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil + case int: + return int(math.Round(float64(xNum)*math.Pow(10.0, dVal)) / math.Pow(10.0, dVal)), nil + default: + return nil, sql.ErrInvalidType.New(r.Left.Type().Type().String()) + } +} + +// IsNullable implements the Expression interface. +func (r *Round) IsNullable() bool { + return r.Left.IsNullable() +} + +func (r *Round) String() string { + if r.Right == nil { + return fmt.Sprintf("ROUND(%s, 0)", r.Left.String()) + } + + return fmt.Sprintf("ROUND(%s, %s)", r.Left.String(), r.Right.String()) +} + +// Resolved implements the Expression interface. +func (r *Round) Resolved() bool { + return r.Left.Resolved() && (r.Right == nil || r.Right.Resolved()) +} + +// Type implements the Expression interface. +func (r *Round) Type() sql.Type { + leftChildType := r.Left.Type() + if sql.IsNumber(leftChildType) { + return leftChildType + } + return sql.Int32 +} + +// WithChildren implements the Expression interface. +func (r *Round) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewRound(children...) +} diff --git a/sql/expression/function/ceil_round_floor_test.go b/sql/expression/function/ceil_round_floor_test.go new file mode 100644 index 000000000..4af2456ef --- /dev/null +++ b/sql/expression/function/ceil_round_floor_test.go @@ -0,0 +1,268 @@ +package function + +import ( + "testing" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" + "gopkg.in/src-d/go-errors.v1" +) + +func TestCeil(t *testing.T) { + testCases := []struct { + name string + rowType sql.Type + row sql.Row + expected interface{} + err *errors.Kind + }{ + {"float64 is nil", sql.Float64, sql.NewRow(nil), nil, nil}, + {"float64 is ok", sql.Float64, sql.NewRow(5.8), float64(6), nil}, + {"float32 is nil", sql.Float32, sql.NewRow(nil), nil, nil}, + {"float32 is ok", sql.Float32, sql.NewRow(float32(5.8)), float32(6), nil}, + {"int32 is nil", sql.Int32, sql.NewRow(nil), nil, nil}, + {"int32 is ok", sql.Int32, sql.NewRow(int32(6)), int32(6), nil}, + {"int64 is nil", sql.Int64, sql.NewRow(nil), nil, nil}, + {"int64 is ok", sql.Int64, sql.NewRow(int64(6)), int64(6), nil}, + {"blob is nil", sql.Blob, sql.NewRow(nil), nil, nil}, + {"blob is ok", sql.Blob, sql.NewRow([]byte{1, 2, 3}), int32(0), nil}, + {"string int is ok", sql.Text, sql.NewRow("1"), int32(1), nil}, + {"string float is ok", sql.Text, sql.NewRow("1.2"), int32(2), nil}, + } + + for _, tt := range testCases { + f := NewCeil(expression.NewGetField(0, tt.rowType, "", false)) + + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + exprs := f.Children() + require.True(len(exprs) > 0 && len(exprs) < 3) + require.NotNil(exprs[0]) + + result, err := f.Eval(sql.NewEmptyContext(), tt.row) + if tt.err != nil { + require.Error(err) + require.True(tt.err.Is(err)) + } else { + require.NoError(err) + require.Equal(tt.expected, result) + } + + switch { + case sql.IsDecimal(tt.rowType): + require.True(sql.IsDecimal(f.Type())) + require.False(f.IsNullable()) + case sql.IsInteger(tt.rowType): + require.True(sql.IsInteger(f.Type())) + require.False(f.IsNullable()) + default: + require.True(sql.IsInteger(f.Type())) + require.False(f.IsNullable()) + } + }) + } +} + +func TestFloor(t *testing.T) { + testCases := []struct { + name string + rowType sql.Type + row sql.Row + expected interface{} + err *errors.Kind + }{ + {"float64 is nil", sql.Float64, sql.NewRow(nil), nil, nil}, + {"float64 is ok", sql.Float64, sql.NewRow(5.8), float64(5), nil}, + {"float32 is nil", sql.Float32, sql.NewRow(nil), nil, nil}, + {"float32 is ok", sql.Float32, sql.NewRow(float32(5.8)), float32(5), nil}, + {"int32 is nil", sql.Int32, sql.NewRow(nil), nil, nil}, + {"int32 is ok", sql.Int32, sql.NewRow(int32(6)), int32(6), nil}, + {"int64 is nil", sql.Int64, sql.NewRow(nil), nil, nil}, + {"int64 is ok", sql.Int64, sql.NewRow(int64(6)), int64(6), nil}, + {"blob is nil", sql.Blob, sql.NewRow(nil), nil, nil}, + {"blob is ok", sql.Blob, sql.NewRow([]byte{1, 2, 3}), int32(0), nil}, + {"string int is ok", sql.Text, sql.NewRow("1"), int32(1), nil}, + {"string float is ok", sql.Text, sql.NewRow("1.2"), int32(1), nil}, + } + + for _, tt := range testCases { + f := NewFloor(expression.NewGetField(0, tt.rowType, "", false)) + + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + exprs := f.Children() + require.True(len(exprs) > 0 && len(exprs) < 3) + require.NotNil(exprs[0]) + + result, err := f.Eval(sql.NewEmptyContext(), tt.row) + if tt.err != nil { + require.Error(err) + require.True(tt.err.Is(err)) + } else { + require.NoError(err) + require.Equal(tt.expected, result) + } + + switch { + case sql.IsDecimal(tt.rowType): + require.True(sql.IsDecimal(f.Type())) + require.False(f.IsNullable()) + case sql.IsInteger(tt.rowType): + require.True(sql.IsInteger(f.Type())) + require.False(f.IsNullable()) + default: + require.True(sql.IsInteger(f.Type())) + require.False(f.IsNullable()) + } + }) + } +} + +func TestRound(t *testing.T) { + testCases := []struct { + name string + xType sql.Type + dType sql.Type + row sql.Row + expected interface{} + err *errors.Kind + }{ + {"float64 is nil", sql.Float64, sql.Int32, sql.NewRow(nil, nil), nil, nil}, + {"float64 without d", sql.Float64, sql.Int32, sql.NewRow(5.8, nil), float64(6), nil}, + {"float64 with d", sql.Float64, sql.Int32, sql.NewRow(5.855, 2), float64(5.86), nil}, + {"float64 with negative d", sql.Float64, sql.Int32, sql.NewRow(52.855, -1), float64(50), nil}, + {"float64 with float d", sql.Float64, sql.Float64, sql.NewRow(5.855, float64(2.123)), float64(5.86), nil}, + {"float64 with float negative d", sql.Float64, sql.Float64, sql.NewRow(52.855, float64(-1)), float64(50), nil}, + {"float64 with blob d", sql.Float64, sql.Blob, sql.NewRow(5.855, []byte{1, 2, 3}), float64(6), nil}, + {"float32 is nil", sql.Float32, sql.Int32, sql.NewRow(nil, nil), nil, nil}, + {"float32 without d", sql.Float32, sql.Int32, sql.NewRow(float32(5.8), nil), float32(6), nil}, + {"float32 with d", sql.Float32, sql.Int32, sql.NewRow(float32(5.855), 2), float32(5.86), nil}, + {"float32 with negative d", sql.Float32, sql.Int32, sql.NewRow(float32(52.855), -1), float32(50), nil}, + {"float32 with float d", sql.Float32, sql.Float64, sql.NewRow(float32(5.855), float32(2.123)), float32(5.86), nil}, + {"float32 with float negative d", sql.Float32, sql.Float64, sql.NewRow(float32(52.855), float32(-1)), float32(50), nil}, + {"float32 with blob d", sql.Float32, sql.Blob, sql.NewRow(float32(5.855), []byte{1, 2, 3}), float32(6), nil}, + {"int64 is nil", sql.Int64, sql.Int32, sql.NewRow(nil, nil), nil, nil}, + {"int64 without d", sql.Int64, sql.Int32, sql.NewRow(int64(5), nil), int64(5), nil}, + {"int64 with d", sql.Int64, sql.Int32, sql.NewRow(int64(5), 2), int64(5), nil}, + {"int64 with negative d", sql.Int64, sql.Int32, sql.NewRow(int64(52), -1), int64(50), nil}, + {"int64 with float d", sql.Int64, sql.Float64, sql.NewRow(int64(5), float32(2.123)), int64(5), nil}, + {"int64 with float negative d", sql.Int64, sql.Float64, sql.NewRow(int64(52), float32(-1)), int64(50), nil}, + {"int32 with blob d", sql.Int32, sql.Blob, sql.NewRow(int32(5), []byte{1, 2, 3}), int32(5), nil}, + {"int32 is nil", sql.Int32, sql.Int32, sql.NewRow(nil, nil), nil, nil}, + {"int32 without d", sql.Int32, sql.Int32, sql.NewRow(int32(5), nil), int32(5), nil}, + {"int32 with d", sql.Int32, sql.Int32, sql.NewRow(int32(5), 2), int32(5), nil}, + {"int32 with negative d", sql.Int32, sql.Int32, sql.NewRow(int32(52), -1), int32(50), nil}, + {"int32 with float d", sql.Int32, sql.Float64, sql.NewRow(int32(5), float32(2.123)), int32(5), nil}, + {"int32 with float negative d", sql.Int32, sql.Float64, sql.NewRow(int32(52), float32(-1)), int32(50), nil}, + {"int32 with blob d", sql.Int32, sql.Blob, sql.NewRow(int32(5), []byte{1, 2, 3}), int32(5), nil}, + {"int16 is nil", sql.Int16, sql.Int16, sql.NewRow(nil, nil), nil, nil}, + {"int16 without d", sql.Int16, sql.Int16, sql.NewRow(int16(5), nil), int16(5), nil}, + {"int16 with d", sql.Int16, sql.Int16, sql.NewRow(int16(5), 2), int16(5), nil}, + {"int16 with negative d", sql.Int16, sql.Int16, sql.NewRow(int16(52), -1), int16(50), nil}, + {"int16 with float d", sql.Int16, sql.Float64, sql.NewRow(int16(5), float32(2.123)), int16(5), nil}, + {"int16 with float negative d", sql.Int16, sql.Float64, sql.NewRow(int16(52), float32(-1)), int16(50), nil}, + {"int16 with blob d", sql.Int16, sql.Blob, sql.NewRow(int16(5), []byte{1, 2, 3}), int16(5), nil}, + {"int8 is nil", sql.Int8, sql.Int8, sql.NewRow(nil, nil), nil, nil}, + {"int8 without d", sql.Int8, sql.Int8, sql.NewRow(int8(5), nil), int8(5), nil}, + {"int8 with d", sql.Int8, sql.Int8, sql.NewRow(int8(5), 2), int8(5), nil}, + {"int8 with negative d", sql.Int8, sql.Int8, sql.NewRow(int8(52), -1), int8(50), nil}, + {"int8 with float d", sql.Int8, sql.Float64, sql.NewRow(int8(5), float32(2.123)), int8(5), nil}, + {"int8 with float negative d", sql.Int8, sql.Float64, sql.NewRow(int8(52), float32(-1)), int8(50), nil}, + {"int8 with blob d", sql.Int8, sql.Blob, sql.NewRow(int8(5), []byte{1, 2, 3}), int8(5), nil}, + {"uint64 is nil", sql.Uint64, sql.Int32, sql.NewRow(nil, nil), nil, nil}, + {"uint64 without d", sql.Uint64, sql.Int32, sql.NewRow(uint64(5), nil), uint64(5), nil}, + {"uint64 with d", sql.Uint64, sql.Int32, sql.NewRow(uint64(5), 2), uint64(5), nil}, + {"uint64 with negative d", sql.Uint64, sql.Int32, sql.NewRow(uint64(52), -1), uint64(50), nil}, + {"uint64 with float d", sql.Uint64, sql.Float64, sql.NewRow(uint64(5), float32(2.123)), uint64(5), nil}, + {"uint64 with float negative d", sql.Uint64, sql.Float64, sql.NewRow(uint64(52), float32(-1)), uint64(50), nil}, + {"uint32 with blob d", sql.Uint32, sql.Blob, sql.NewRow(uint32(5), []byte{1, 2, 3}), uint32(5), nil}, + {"uint32 is nil", sql.Uint32, sql.Int32, sql.NewRow(nil, nil), nil, nil}, + {"uint32 without d", sql.Uint32, sql.Int32, sql.NewRow(uint32(5), nil), uint32(5), nil}, + {"uint32 with d", sql.Uint32, sql.Int32, sql.NewRow(uint32(5), 2), uint32(5), nil}, + {"uint32 with negative d", sql.Uint32, sql.Int32, sql.NewRow(uint32(52), -1), uint32(50), nil}, + {"uint32 with float d", sql.Uint32, sql.Float64, sql.NewRow(uint32(5), float32(2.123)), uint32(5), nil}, + {"uint32 with float negative d", sql.Uint32, sql.Float64, sql.NewRow(uint32(52), float32(-1)), uint32(50), nil}, + {"uint32 with blob d", sql.Uint32, sql.Blob, sql.NewRow(uint32(5), []byte{1, 2, 3}), uint32(5), nil}, + {"uint16 with blob d", sql.Uint16, sql.Blob, sql.NewRow(uint16(5), []byte{1, 2, 3}), uint16(5), nil}, + {"uint16 is nil", sql.Uint16, sql.Int16, sql.NewRow(nil, nil), nil, nil}, + {"uint16 without d", sql.Uint16, sql.Int16, sql.NewRow(uint16(5), nil), uint16(5), nil}, + {"uint16 with d", sql.Uint16, sql.Int16, sql.NewRow(uint16(5), 2), uint16(5), nil}, + {"uint16 with negative d", sql.Uint16, sql.Int16, sql.NewRow(uint16(52), -1), uint16(50), nil}, + {"uint16 with float d", sql.Uint16, sql.Float64, sql.NewRow(uint16(5), float32(2.123)), uint16(5), nil}, + {"uint16 with float negative d", sql.Uint16, sql.Float64, sql.NewRow(uint16(52), float32(-1)), uint16(50), nil}, + {"uint16 with blob d", sql.Uint16, sql.Blob, sql.NewRow(uint16(5), []byte{1, 2, 3}), uint16(5), nil}, + {"uint8 with blob d", sql.Uint8, sql.Blob, sql.NewRow(uint8(5), []byte{1, 2, 3}), uint8(5), nil}, + {"uint8 is nil", sql.Uint8, sql.Int8, sql.NewRow(nil, nil), nil, nil}, + {"uint8 without d", sql.Uint8, sql.Int8, sql.NewRow(uint8(5), nil), uint8(5), nil}, + {"uint8 with d", sql.Uint8, sql.Int8, sql.NewRow(uint8(5), 2), uint8(5), nil}, + {"uint8 with negative d", sql.Uint8, sql.Int8, sql.NewRow(uint8(52), -1), uint8(50), nil}, + {"uint8 with float d", sql.Uint8, sql.Float64, sql.NewRow(uint8(5), float32(2.123)), uint8(5), nil}, + {"uint8 with float negative d", sql.Uint8, sql.Float64, sql.NewRow(uint8(52), float32(-1)), uint8(50), nil}, + {"uint8 with blob d", sql.Uint8, sql.Blob, sql.NewRow(uint8(5), []byte{1, 2, 3}), uint8(5), nil}, + {"blob is nil", sql.Blob, sql.Int32, sql.NewRow(nil, nil), nil, nil}, + {"blob is ok", sql.Blob, sql.Int32, sql.NewRow([]byte{1, 2, 3}, nil), int32(0), nil}, + {"text int without d", sql.Text, sql.Int32, sql.NewRow("5", nil), int32(5), nil}, + {"text int with d", sql.Text, sql.Int32, sql.NewRow("5", 2), int32(5), nil}, + {"text int with negative d", sql.Text, sql.Int32, sql.NewRow("52", -1), int32(50), nil}, + {"text int with float d", sql.Text, sql.Float64, sql.NewRow("5", float32(2.123)), int32(5), nil}, + {"text int with float negative d", sql.Text, sql.Float64, sql.NewRow("52", float32(-1)), int32(50), nil}, + {"text float without d", sql.Text, sql.Int32, sql.NewRow("5.8", nil), int32(6), nil}, + {"text float with d", sql.Text, sql.Int32, sql.NewRow("5.855", 2), int32(5), nil}, + {"text float with negative d", sql.Text, sql.Int32, sql.NewRow("52.855", -1), int32(50), nil}, + {"text float with float d", sql.Text, sql.Float64, sql.NewRow("5.855", float64(2.123)), int32(5), nil}, + {"text float with float negative d", sql.Text, sql.Float64, sql.NewRow("52.855", float64(-1)), int32(50), nil}, + {"text float with blob d", sql.Text, sql.Blob, sql.NewRow("5.855", []byte{1, 2, 3}), int32(6), nil}, + } + + for _, tt := range testCases { + var args = make([]sql.Expression, 2) + args[0] = expression.NewGetField(0, tt.xType, "", false) + args[1] = expression.NewGetField(1, tt.dType, "", false) + f, err := NewRound(args...) + + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + require.Nil(err) + + result, err := f.Eval(sql.NewEmptyContext(), tt.row) + if tt.err != nil { + require.Error(err) + require.True(tt.err.Is(err)) + } else { + require.NoError(err) + require.Equal(tt.expected, result) + } + + switch { + case sql.IsDecimal(tt.xType): + require.True(sql.IsDecimal(f.Type())) + require.False(f.IsNullable()) + case sql.IsInteger(tt.xType): + require.True(sql.IsInteger(f.Type())) + require.False(f.IsNullable()) + default: + require.True(sql.IsInteger(f.Type())) + require.False(f.IsNullable()) + } + }) + } + + // Test on invalid type return 0 + var args = make([]sql.Expression, 2) + args[0] = expression.NewGetField(0, sql.Blob, "", false) + args[1] = expression.NewGetField(1, sql.Int32, "", false) + + f, err := NewRound(args...) + req := require.New(t) + req.Nil(err) + + exprs := f.Children() + req.True(len(exprs) > 0 && len(exprs) < 3) + req.NotNil(exprs[0]) + + result, err := f.Eval(sql.NewEmptyContext(), sql.NewRow([]byte{1, 2, 3}, 2)) + req.NoError(err) + req.Equal(int32(0), result) +} diff --git a/sql/expression/function/coalesce.go b/sql/expression/function/coalesce.go new file mode 100644 index 000000000..07f7f64e6 --- /dev/null +++ b/sql/expression/function/coalesce.go @@ -0,0 +1,106 @@ +package function + +import ( + "fmt" + "strings" + + "github.com/src-d/go-mysql-server/sql" +) + +// Coalesce returns the first non-NULL value in the list, or NULL if there are no non-NULL values. +type Coalesce struct { + args []sql.Expression +} + +// NewCoalesce creates a new Coalesce sql.Expression. +func NewCoalesce(args ...sql.Expression) (sql.Expression, error) { + if len(args) == 0 { + return nil, sql.ErrInvalidArgumentNumber.New("COALESCE", "1 or more", 0) + } + + return &Coalesce{args}, nil +} + +// Type implements the sql.Expression interface. +// The return type of Type() is the aggregated type of the argument types. +func (c *Coalesce) Type() sql.Type { + for _, arg := range c.args { + if arg == nil { + continue + } + t := arg.Type() + if t == nil { + continue + } + return t + } + + return nil +} + +// IsNullable implements the sql.Expression interface. +// Returns true if all arguments are nil +// or of the first non-nil argument is nullable, otherwise false. +func (c *Coalesce) IsNullable() bool { + for _, arg := range c.args { + if arg == nil { + continue + } + return arg.IsNullable() + } + return true +} + +func (c *Coalesce) String() string { + var args = make([]string, len(c.args)) + for i, arg := range c.args { + args[i] = arg.String() + } + return fmt.Sprintf("coalesce(%s)", strings.Join(args, ", ")) +} + +// WithChildren implements the Expression interface. +func (*Coalesce) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewCoalesce(children...) +} + +// Resolved implements the sql.Expression interface. +// The function checks if first non-nil argument is resolved. +func (c *Coalesce) Resolved() bool { + for _, arg := range c.args { + if arg == nil { + continue + } + if !arg.Resolved() { + return false + } + } + return true +} + +// Children implements the sql.Expression interface. +func (c *Coalesce) Children() []sql.Expression { return c.args } + +// Eval implements the sql.Expression interface. +// The function evaluates the first non-nil argument. If the value is nil, +// then we keep going, otherwise we return the first non-nil value. +func (c *Coalesce) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + for _, arg := range c.args { + if arg == nil { + continue + } + + val, err := arg.Eval(ctx, row) + if err != nil { + return nil, err + } + + if val == nil { + continue + } + + return val, nil + } + + return nil, nil +} diff --git a/sql/expression/function/coalesce_test.go b/sql/expression/function/coalesce_test.go new file mode 100644 index 000000000..a1fd6a89d --- /dev/null +++ b/sql/expression/function/coalesce_test.go @@ -0,0 +1,64 @@ +package function + +import ( + "testing" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" +) + +func TestEmptyCoalesce(t *testing.T) { + _, err := NewCoalesce() + require.True(t, sql.ErrInvalidArgumentNumber.Is(err)) +} + +func TestCoalesce(t *testing.T) { + testCases := []struct { + name string + input []sql.Expression + expected interface{} + typ sql.Type + nullable bool + }{ + {"coalesce(1, 2, 3)", []sql.Expression{expression.NewLiteral(1, sql.Int32), expression.NewLiteral(2, sql.Int32), expression.NewLiteral(3, sql.Int32)}, 1, sql.Int32, false}, + {"coalesce(NULL, NULL, 3)", []sql.Expression{nil, nil, expression.NewLiteral(3, sql.Int32)}, 3, sql.Int32, false}, + {"coalesce(NULL, NULL, '3')", []sql.Expression{nil, nil, expression.NewLiteral("3", sql.Text)}, "3", sql.Text, false}, + {"coalesce(NULL, '2', 3)", []sql.Expression{nil, expression.NewLiteral("2", sql.Text), expression.NewLiteral(3, sql.Int32)}, "2", sql.Text, false}, + {"coalesce(NULL, NULL, NULL)", []sql.Expression{nil, nil, nil}, nil, nil, true}, + } + + for _, tt := range testCases { + c, err := NewCoalesce(tt.input...) + require.NoError(t, err) + + require.Equal(t, tt.typ, c.Type()) + require.Equal(t, tt.nullable, c.IsNullable()) + v, err := c.Eval(sql.NewEmptyContext(), nil) + require.NoError(t, err) + require.Equal(t, tt.expected, v) + } +} + +func TestComposeCoalasce(t *testing.T) { + c1, err := NewCoalesce(nil) + require.NoError(t, err) + require.Equal(t, nil, c1.Type()) + v, err := c1.Eval(sql.NewEmptyContext(), nil) + require.NoError(t, err) + require.Equal(t, nil, v) + + c2, err := NewCoalesce(nil, expression.NewLiteral(1, sql.Int32)) + require.NoError(t, err) + require.Equal(t, sql.Int32, c2.Type()) + v, err = c2.Eval(sql.NewEmptyContext(), nil) + require.NoError(t, err) + require.Equal(t, 1, v) + + c, err := NewCoalesce(nil, c1, c2) + require.NoError(t, err) + require.Equal(t, sql.Int32, c.Type()) + v, err = c.Eval(sql.NewEmptyContext(), nil) + require.NoError(t, err) + require.Equal(t, 1, v) +} diff --git a/sql/expression/function/common_test.go b/sql/expression/function/common_test.go index 2d86f8c86..3fbc45989 100644 --- a/sql/expression/function/common_test.go +++ b/sql/expression/function/common_test.go @@ -3,8 +3,8 @@ package function import ( "testing" + "github.com/src-d/go-mysql-server/sql" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/sql" ) func eval(t *testing.T, e sql.Expression, row sql.Row) interface{} { diff --git a/sql/expression/function/concat.go b/sql/expression/function/concat.go index 75bb23a0d..56e7bcbab 100644 --- a/sql/expression/function/concat.go +++ b/sql/expression/function/concat.go @@ -4,8 +4,8 @@ import ( "fmt" "strings" + "github.com/src-d/go-mysql-server/sql" errors "gopkg.in/src-d/go-errors.v1" - "gopkg.in/src-d/go-mysql-server.v0/sql" ) // Concat joins several strings together. @@ -20,10 +20,16 @@ var ErrConcatArrayWithOthers = errors.NewKind("can't concat a string array with // NewConcat creates a new Concat UDF. func NewConcat(args ...sql.Expression) (sql.Expression, error) { if len(args) == 0 { - return nil, sql.ErrInvalidArgumentNumber.New("1 or more", 0) + return nil, sql.ErrInvalidArgumentNumber.New("CONCAT", "1 or more", 0) } for _, arg := range args { + // Don't perform this check until it's resolved. Otherwise we + // can't get the type for sure. + if !arg.Resolved() { + continue + } + if len(args) > 1 && sql.IsArray(arg.Type()) { return nil, ErrConcatArrayWithOthers.New() } @@ -57,17 +63,9 @@ func (f *Concat) String() string { return fmt.Sprintf("concat(%s)", strings.Join(args, ", ")) } -// TransformUp implements the Expression interface. -func (f *Concat) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { - var args = make([]sql.Expression, len(f.args)) - for i, arg := range f.args { - arg, err := arg.TransformUp(fn) - if err != nil { - return nil, err - } - args[i] = arg - } - return fn(&Concat{args}) +// WithChildren implements the Expression interface. +func (*Concat) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewConcat(children...) } // Resolved implements the Expression interface. diff --git a/sql/expression/function/concat_test.go b/sql/expression/function/concat_test.go index dccadab0b..450c0c442 100644 --- a/sql/expression/function/concat_test.go +++ b/sql/expression/function/concat_test.go @@ -3,9 +3,9 @@ package function import ( "testing" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" ) func TestConcat(t *testing.T) { diff --git a/sql/expression/function/concat_ws.go b/sql/expression/function/concat_ws.go new file mode 100644 index 000000000..c1e2dacc1 --- /dev/null +++ b/sql/expression/function/concat_ws.go @@ -0,0 +1,120 @@ +package function + +import ( + "fmt" + "strings" + + "github.com/src-d/go-mysql-server/sql" +) + +// ConcatWithSeparator joins several strings together. The first argument is +// the separator for the rest of the arguments. The separator is added between +// the strings to be concatenated. The separator can be a string, as can the +// rest of the arguments. If the separator is NULL, the result is NULL. +type ConcatWithSeparator struct { + args []sql.Expression +} + +// NewConcatWithSeparator creates a new NewConcatWithSeparator UDF. +func NewConcatWithSeparator(args ...sql.Expression) (sql.Expression, error) { + if len(args) == 0 { + return nil, sql.ErrInvalidArgumentNumber.New("CONCAT_WS", "1 or more", 0) + } + + for _, arg := range args { + // Don't perform this check until it's resolved. Otherwise we + // can't get the type for sure. + if !arg.Resolved() { + continue + } + + if len(args) > 1 && sql.IsArray(arg.Type()) { + return nil, ErrConcatArrayWithOthers.New() + } + + if sql.IsTuple(arg.Type()) { + return nil, sql.ErrInvalidType.New("tuple") + } + } + + return &ConcatWithSeparator{args}, nil +} + +// Type implements the Expression interface. +func (f *ConcatWithSeparator) Type() sql.Type { return sql.Text } + +// IsNullable implements the Expression interface. +func (f *ConcatWithSeparator) IsNullable() bool { + for _, arg := range f.args { + if arg.IsNullable() { + return true + } + } + return false +} + +func (f *ConcatWithSeparator) String() string { + var args = make([]string, len(f.args)) + for i, arg := range f.args { + args[i] = arg.String() + } + return fmt.Sprintf("concat_ws(%s)", strings.Join(args, ", ")) +} + +// WithChildren implements the Expression interface. +func (*ConcatWithSeparator) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewConcatWithSeparator(children...) +} + +// Resolved implements the Expression interface. +func (f *ConcatWithSeparator) Resolved() bool { + for _, arg := range f.args { + if !arg.Resolved() { + return false + } + } + return true +} + +// Children implements the Expression interface. +func (f *ConcatWithSeparator) Children() []sql.Expression { return f.args } + +// Eval implements the Expression interface. +func (f *ConcatWithSeparator) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + var parts []string + + for i, arg := range f.args { + val, err := arg.Eval(ctx, row) + if err != nil { + return nil, err + } + + if val == nil && i == 0 { + return nil, nil + } + + if val == nil { + continue + } + + if sql.IsArray(arg.Type()) { + val, err = sql.Array(sql.Text).Convert(val) + if err != nil { + return nil, err + } + + for _, v := range val.([]interface{}) { + parts = append(parts, v.(string)) + } + } else { + val, err = sql.Text.Convert(val) + if err != nil { + return nil, err + } + + parts = append(parts, val.(string)) + } + } + + return strings.Join(parts[1:], parts[0]), nil +} diff --git a/sql/expression/function/concat_ws_test.go b/sql/expression/function/concat_ws_test.go new file mode 100644 index 000000000..7ec7bdc0c --- /dev/null +++ b/sql/expression/function/concat_ws_test.go @@ -0,0 +1,106 @@ +package function + +import ( + "testing" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" +) + +func TestConcatWithSeparator(t *testing.T) { + t.Run("multiple arguments", func(t *testing.T) { + require := require.New(t) + f, err := NewConcatWithSeparator( + expression.NewLiteral(",", sql.Text), + expression.NewLiteral("foo", sql.Text), + expression.NewLiteral(5, sql.Text), + expression.NewLiteral(true, sql.Boolean), + ) + require.NoError(err) + + v, err := f.Eval(sql.NewEmptyContext(), nil) + require.NoError(err) + require.Equal("foo,5,true", v) + }) + + t.Run("some argument is empty", func(t *testing.T) { + require := require.New(t) + f, err := NewConcatWithSeparator( + expression.NewLiteral(",", sql.Text), + expression.NewLiteral("foo", sql.Text), + expression.NewLiteral("", sql.Text), + expression.NewLiteral(true, sql.Boolean), + ) + require.NoError(err) + + v, err := f.Eval(sql.NewEmptyContext(), nil) + require.NoError(err) + require.Equal("foo,,true", v) + }) + + t.Run("some argument is nil", func(t *testing.T) { + require := require.New(t) + f, err := NewConcatWithSeparator( + expression.NewLiteral(",", sql.Text), + expression.NewLiteral("foo", sql.Text), + expression.NewLiteral(nil, sql.Text), + expression.NewLiteral(true, sql.Boolean), + ) + require.NoError(err) + + v, err := f.Eval(sql.NewEmptyContext(), nil) + require.NoError(err) + require.Equal("foo,true", v) + }) + + t.Run("separator is nil", func(t *testing.T) { + require := require.New(t) + f, err := NewConcatWithSeparator( + expression.NewLiteral(nil, sql.Text), + expression.NewLiteral("foo", sql.Text), + expression.NewLiteral(5, sql.Text), + expression.NewLiteral(true, sql.Boolean), + ) + require.NoError(err) + + v, err := f.Eval(sql.NewEmptyContext(), nil) + require.NoError(err) + require.Equal(nil, v) + }) + + t.Run("concat_ws array", func(t *testing.T) { + require := require.New(t) + f, err := NewConcatWithSeparator( + expression.NewLiteral([]interface{}{",", 5, "bar", true}, sql.Array(sql.Text)), + ) + require.NoError(err) + + v, err := f.Eval(sql.NewEmptyContext(), nil) + require.NoError(err) + require.Equal("5,bar,true", v) + }) +} + +func TestNewConcatWithSeparator(t *testing.T) { + require := require.New(t) + + _, err := NewConcatWithSeparator(expression.NewLiteral(nil, sql.Array(sql.Text))) + require.NoError(err) + + _, err = NewConcatWithSeparator(expression.NewLiteral(nil, sql.Array(sql.Text)), expression.NewLiteral(nil, sql.Int64)) + require.Error(err) + require.True(ErrConcatArrayWithOthers.Is(err)) + + _, err = NewConcatWithSeparator(expression.NewLiteral(nil, sql.Tuple(sql.Text, sql.Text))) + require.Error(err) + require.True(sql.ErrInvalidType.Is(err)) + + _, err = NewConcatWithSeparator( + expression.NewLiteral(nil, sql.Text), + expression.NewLiteral(nil, sql.Boolean), + expression.NewLiteral(nil, sql.Int64), + expression.NewLiteral(nil, sql.Text), + ) + require.NoError(err) +} diff --git a/sql/expression/function/connection_id.go b/sql/expression/function/connection_id.go new file mode 100644 index 000000000..7ee96ef58 --- /dev/null +++ b/sql/expression/function/connection_id.go @@ -0,0 +1,39 @@ +package function + +import "github.com/src-d/go-mysql-server/sql" + +// ConnectionID returns the current connection id. +type ConnectionID struct{} + +// NewConnectionID creates a new ConnectionID UDF node. +func NewConnectionID() sql.Expression { + return ConnectionID{} +} + +// Children implements the sql.Expression interface. +func (ConnectionID) Children() []sql.Expression { return nil } + +// Type implements the sql.Expression interface. +func (ConnectionID) Type() sql.Type { return sql.Uint32 } + +// Resolved implements the sql.Expression interface. +func (ConnectionID) Resolved() bool { return true } + +// WithChildren implements the Expression interface. +func (c ConnectionID) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 0) + } + return c, nil +} + +// IsNullable implements the sql.Expression interface. +func (ConnectionID) IsNullable() bool { return false } + +// String implements the fmt.Stringer interface. +func (ConnectionID) String() string { return "connection_id()" } + +// Eval implements the sql.Expression interface. +func (ConnectionID) Eval(ctx *sql.Context, _ sql.Row) (interface{}, error) { + return ctx.ID(), nil +} diff --git a/sql/expression/function/connection_id_test.go b/sql/expression/function/connection_id_test.go new file mode 100644 index 000000000..5a9699899 --- /dev/null +++ b/sql/expression/function/connection_id_test.go @@ -0,0 +1,21 @@ +package function + +import ( + "context" + "testing" + + "github.com/src-d/go-mysql-server/sql" + "github.com/stretchr/testify/require" +) + +func TestConnectionID(t *testing.T) { + require := require.New(t) + + session := sql.NewSession("", "", "", 2) + ctx := sql.NewContext(context.Background(), sql.WithSession(session)) + + f := NewConnectionID() + result, err := f.Eval(ctx, nil) + require.NoError(err) + require.Equal(uint32(2), result) +} diff --git a/sql/expression/function/database.go b/sql/expression/function/database.go new file mode 100644 index 000000000..1246e488c --- /dev/null +++ b/sql/expression/function/database.go @@ -0,0 +1,51 @@ +package function + +import ( + "github.com/src-d/go-mysql-server/sql" +) + +// Database stands for DATABASE() function +type Database struct { + catalog *sql.Catalog +} + +// NewDatabase returns a new Database function +func NewDatabase(c *sql.Catalog) func() sql.Expression { + return func() sql.Expression { + return &Database{c} + } +} + +// Type implements the sql.Expression (sql.Text) +func (db *Database) Type() sql.Type { return sql.Text } + +// IsNullable implements the sql.Expression interface. +// The function returns always true +func (db *Database) IsNullable() bool { + return true +} + +func (*Database) String() string { + return "DATABASE()" +} + +// WithChildren implements the Expression interface. +func (d *Database) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 0) + } + return NewDatabase(d.catalog)(), nil +} + +// Resolved implements the sql.Expression interface. +func (db *Database) Resolved() bool { + return true +} + +// Children implements the sql.Expression interface. +func (db *Database) Children() []sql.Expression { return nil } + +// Eval implements the sql.Expression interface. +func (db *Database) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + return db.catalog.CurrentDatabase(), nil +} diff --git a/sql/expression/function/date.go b/sql/expression/function/date.go new file mode 100644 index 000000000..919775a42 --- /dev/null +++ b/sql/expression/function/date.go @@ -0,0 +1,159 @@ +package function + +import ( + "fmt" + "time" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" +) + +// DateAdd adds an interval to a date. +type DateAdd struct { + Date sql.Expression + Interval *expression.Interval +} + +// NewDateAdd creates a new date add function. +func NewDateAdd(args ...sql.Expression) (sql.Expression, error) { + if len(args) != 2 { + return nil, sql.ErrInvalidArgumentNumber.New("DATE_ADD", 2, len(args)) + } + + i, ok := args[1].(*expression.Interval) + if !ok { + return nil, fmt.Errorf("DATE_ADD expects an interval as second parameter") + } + + return &DateAdd{args[0], i}, nil +} + +// Children implements the sql.Expression interface. +func (d *DateAdd) Children() []sql.Expression { + return []sql.Expression{d.Date, d.Interval} +} + +// Resolved implements the sql.Expression interface. +func (d *DateAdd) Resolved() bool { + return d.Date.Resolved() && d.Interval.Resolved() +} + +// IsNullable implements the sql.Expression interface. +func (d *DateAdd) IsNullable() bool { + return true +} + +// Type implements the sql.Expression interface. +func (d *DateAdd) Type() sql.Type { return sql.Date } + +// WithChildren implements the Expression interface. +func (d *DateAdd) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewDateAdd(children...) +} + +// Eval implements the sql.Expression interface. +func (d *DateAdd) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + date, err := d.Date.Eval(ctx, row) + if err != nil { + return nil, err + } + + if date == nil { + return nil, nil + } + + date, err = sql.Timestamp.Convert(date) + if err != nil { + return nil, err + } + + delta, err := d.Interval.EvalDelta(ctx, row) + if err != nil { + return nil, err + } + + if delta == nil { + return nil, nil + } + + return sql.ValidateTime(delta.Add(date.(time.Time))), nil +} + +func (d *DateAdd) String() string { + return fmt.Sprintf("DATE_ADD(%s, %s)", d.Date, d.Interval) +} + +// DateSub subtracts an interval from a date. +type DateSub struct { + Date sql.Expression + Interval *expression.Interval +} + +// NewDateSub creates a new date add function. +func NewDateSub(args ...sql.Expression) (sql.Expression, error) { + if len(args) != 2 { + return nil, sql.ErrInvalidArgumentNumber.New("DATE_SUB", 2, len(args)) + } + + i, ok := args[1].(*expression.Interval) + if !ok { + return nil, fmt.Errorf("DATE_SUB expects an interval as second parameter") + } + + return &DateSub{args[0], i}, nil +} + +// Children implements the sql.Expression interface. +func (d *DateSub) Children() []sql.Expression { + return []sql.Expression{d.Date, d.Interval} +} + +// Resolved implements the sql.Expression interface. +func (d *DateSub) Resolved() bool { + return d.Date.Resolved() && d.Interval.Resolved() +} + +// IsNullable implements the sql.Expression interface. +func (d *DateSub) IsNullable() bool { + return true +} + +// Type implements the sql.Expression interface. +func (d *DateSub) Type() sql.Type { return sql.Date } + +// WithChildren implements the Expression interface. +func (d *DateSub) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewDateSub(children...) +} + +// Eval implements the sql.Expression interface. +func (d *DateSub) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + date, err := d.Date.Eval(ctx, row) + if err != nil { + return nil, err + } + + if date == nil { + return nil, nil + } + + date, err = sql.Timestamp.Convert(date) + if err != nil { + return nil, err + } + + delta, err := d.Interval.EvalDelta(ctx, row) + if err != nil { + return nil, err + } + + if delta == nil { + return nil, nil + } + + return sql.ValidateTime(delta.Sub(date.(time.Time))), nil +} + +func (d *DateSub) String() string { + return fmt.Sprintf("DATE_SUB(%s, %s)", d.Date, d.Interval) +} diff --git a/sql/expression/function/date_test.go b/sql/expression/function/date_test.go new file mode 100644 index 000000000..18f436ed9 --- /dev/null +++ b/sql/expression/function/date_test.go @@ -0,0 +1,87 @@ +package function + +import ( + "testing" + "time" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" +) + +func TestDateAdd(t *testing.T) { + require := require.New(t) + + _, err := NewDateAdd() + require.Error(err) + + _, err = NewDateAdd(expression.NewLiteral("2018-05-02", sql.Text)) + require.Error(err) + + _, err = NewDateAdd( + expression.NewLiteral("2018-05-02", sql.Text), + expression.NewLiteral(int64(1), sql.Int64), + ) + require.Error(err) + + f, err := NewDateAdd( + expression.NewGetField(0, sql.Text, "foo", false), + expression.NewInterval( + expression.NewLiteral(int64(1), sql.Int64), + "DAY", + ), + ) + require.NoError(err) + + ctx := sql.NewEmptyContext() + expected := time.Date(2018, time.May, 3, 0, 0, 0, 0, time.UTC) + + result, err := f.Eval(ctx, sql.Row{"2018-05-02"}) + require.NoError(err) + require.Equal(expected, result) + + result, err = f.Eval(ctx, sql.Row{nil}) + require.NoError(err) + require.Nil(result) + + _, err = f.Eval(ctx, sql.Row{"asdasdasd"}) + require.Error(err) +} +func TestDateSub(t *testing.T) { + require := require.New(t) + + _, err := NewDateSub() + require.Error(err) + + _, err = NewDateSub(expression.NewLiteral("2018-05-02", sql.Text)) + require.Error(err) + + _, err = NewDateSub( + expression.NewLiteral("2018-05-02", sql.Text), + expression.NewLiteral(int64(1), sql.Int64), + ) + require.Error(err) + + f, err := NewDateSub( + expression.NewGetField(0, sql.Text, "foo", false), + expression.NewInterval( + expression.NewLiteral(int64(1), sql.Int64), + "DAY", + ), + ) + require.NoError(err) + + ctx := sql.NewEmptyContext() + expected := time.Date(2018, time.May, 1, 0, 0, 0, 0, time.UTC) + + result, err := f.Eval(ctx, sql.Row{"2018-05-02"}) + require.NoError(err) + require.Equal(expected, result) + + result, err = f.Eval(ctx, sql.Row{nil}) + require.NoError(err) + require.Nil(result) + + _, err = f.Eval(ctx, sql.Row{"asdasdasd"}) + require.Error(err) +} diff --git a/sql/expression/function/explode.go b/sql/expression/function/explode.go new file mode 100644 index 000000000..51cd2b66b --- /dev/null +++ b/sql/expression/function/explode.go @@ -0,0 +1,91 @@ +package function + +import ( + "fmt" + + "github.com/src-d/go-mysql-server/sql" +) + +// Explode is a function that generates a row for each value of its child. +// It is a placeholder expression node. +type Explode struct { + Child sql.Expression +} + +// NewExplode creates a new Explode function. +func NewExplode(child sql.Expression) sql.Expression { + return &Explode{child} +} + +// Resolved implements the sql.Expression interface. +func (e *Explode) Resolved() bool { return e.Child.Resolved() } + +// Children implements the sql.Expression interface. +func (e *Explode) Children() []sql.Expression { return []sql.Expression{e.Child} } + +// IsNullable implements the sql.Expression interface. +func (e *Explode) IsNullable() bool { return e.Child.IsNullable() } + +// Type implements the sql.Expression interface. +func (e *Explode) Type() sql.Type { + return sql.UnderlyingType(e.Child.Type()) +} + +// Eval implements the sql.Expression interface. +func (e *Explode) Eval(*sql.Context, sql.Row) (interface{}, error) { + panic("eval method of Explode is only a placeholder") +} + +func (e *Explode) String() string { + return fmt.Sprintf("EXPLODE(%s)", e.Child) +} + +// WithChildren implements the Expression interface. +func (e *Explode) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(e, len(children), 1) + } + return NewExplode(children[0]), nil +} + +// Generate is a function that generates a row for each value of its child. +// This is the non-placeholder counterpart of Explode. +type Generate struct { + Child sql.Expression +} + +// NewGenerate creates a new Generate function. +func NewGenerate(child sql.Expression) sql.Expression { + return &Generate{child} +} + +// Resolved implements the sql.Expression interface. +func (e *Generate) Resolved() bool { return e.Child.Resolved() } + +// Children implements the sql.Expression interface. +func (e *Generate) Children() []sql.Expression { return []sql.Expression{e.Child} } + +// IsNullable implements the sql.Expression interface. +func (e *Generate) IsNullable() bool { return e.Child.IsNullable() } + +// Type implements the sql.Expression interface. +func (e *Generate) Type() sql.Type { + return e.Child.Type() +} + +// Eval implements the sql.Expression interface. +func (e *Generate) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + return e.Child.Eval(ctx, row) +} + +func (e *Generate) String() string { + return fmt.Sprintf("EXPLODE(%s)", e.Child) +} + +// WithChildren implements the Expression interface. +func (e *Generate) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(e, len(children), 1) + } + return NewGenerate(children[0]), nil +} diff --git a/sql/expression/function/greatest_least.go b/sql/expression/function/greatest_least.go new file mode 100644 index 000000000..9e822d8ed --- /dev/null +++ b/sql/expression/function/greatest_least.go @@ -0,0 +1,308 @@ +package function + +import ( + "fmt" + "strconv" + "strings" + + "github.com/src-d/go-mysql-server/sql" + "gopkg.in/src-d/go-errors.v1" +) + +var ErrUintOverflow = errors.NewKind( + "Unsigned integer too big to fit on signed integer") + +// compEval is used to implement Greatest/Least Eval() using a comparison function +func compEval( + returnType sql.Type, + args []sql.Expression, + ctx *sql.Context, + row sql.Row, + cmp compareFn, +) (interface{}, error) { + + if returnType == sql.Null { + return nil, nil + } + + var selectedNum float64 + var selectedString string + + for i, arg := range args { + val, err := arg.Eval(ctx, row) + if err != nil { + return nil, err + } + + switch t := val.(type) { + case int, int8, int16, int32, int64, uint, + uint8, uint16, uint32, uint64: + switch x := t.(type) { + case int: + t = int64(x) + case int8: + t = int64(x) + case int16: + t = int64(x) + case int32: + t = int64(x) + case uint: + i := int64(x) + if i < 0 { + return nil, ErrUintOverflow.New() + } + t = i + case uint64: + i := int64(x) + if i < 0 { + return nil, ErrUintOverflow.New() + } + t = i + case uint8: + t = int64(x) + case uint16: + t = int64(x) + case uint32: + t = int64(x) + } + ival := t.(int64) + if i == 0 || cmp(ival, int64(selectedNum)) { + selectedNum = float64(ival) + } + case float32, float64: + if x, ok := t.(float32); ok { + t = float64(x) + } + + fval := t.(float64) + if i == 0 || cmp(fval, float64(selectedNum)) { + selectedNum = fval + } + + case string: + if returnType == sql.Text && (i == 0 || cmp(t, selectedString)) { + selectedString = t + } + + fval, err := strconv.ParseFloat(t, 64) + if err != nil { + // MySQL just ignores non numerically convertible string arguments + // when mixed with numeric ones + continue + } + + if i == 0 || cmp(fval, selectedNum) { + selectedNum = fval + } + default: + return nil, ErrUnsupportedType.New(t) + } + + } + + switch returnType { + case sql.Int64: + return int64(selectedNum), nil + case sql.Text: + return selectedString, nil + } + + // sql.Float64 + return float64(selectedNum), nil +} + +// compRetType is used to determine the type from args based on the rules described for +// Greatest/Least +func compRetType(args ...sql.Expression) (sql.Type, error) { + if len(args) == 0 { + return nil, sql.ErrInvalidArgumentNumber.New("LEAST", "1 or more", 0) + } + + allString := true + allInt := true + + for _, arg := range args { + argType := arg.Type() + if sql.IsTuple(argType) { + return nil, sql.ErrInvalidType.New("tuple") + } else if sql.IsNumber(argType) { + allString = false + if sql.IsDecimal(argType) { + allString = false + allInt = false + } + } else if sql.IsText(argType) { + allInt = false + } else if argType == sql.Null { + // When a Null is present the return will always de Null + return sql.Null, nil + } else { + return nil, ErrUnsupportedType.New(argType) + } + } + + if allString { + return sql.Text, nil + } else if allInt { + return sql.Int64, nil + } + + return sql.Float64, nil +} + +// Greatest returns the argument with the greatest numerical or string value. It allows for +// numeric (ints anf floats) and string arguments and will return the used type +// when all arguments are of the same type or floats if there are numerically +// convertible strings or integers mixed with floats. When ints or floats +// are mixed with non numerically convertible strings, those are ignored. +type Greatest struct { + Args []sql.Expression + returnType sql.Type +} + +// ErrUnsupportedType is returned when an argument to Greatest or Latest is not numeric or string +var ErrUnsupportedType = errors.NewKind("unsupported type for greatest/latest argument: %T") + +// NewGreatest creates a new Greatest UDF +func NewGreatest(args ...sql.Expression) (sql.Expression, error) { + retType, err := compRetType(args...) + if err != nil { + return nil, err + } + + return &Greatest{args, retType}, nil +} + +// Type implements the Expression interface. +func (f *Greatest) Type() sql.Type { return f.returnType } + +// IsNullable implements the Expression interface. +func (f *Greatest) IsNullable() bool { + for _, arg := range f.Args { + if arg.IsNullable() { + return true + } + } + return false +} + +func (f *Greatest) String() string { + var args = make([]string, len(f.Args)) + for i, arg := range f.Args { + args[i] = arg.String() + } + return fmt.Sprintf("greatest(%s)", strings.Join(args, ", ")) +} + +// WithChildren implements the Expression interface. +func (f *Greatest) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewGreatest(children...) +} + +// Resolved implements the Expression interface. +func (f *Greatest) Resolved() bool { + for _, arg := range f.Args { + if !arg.Resolved() { + return false + } + } + return true +} + +// Children implements the Expression interface. +func (f *Greatest) Children() []sql.Expression { return f.Args } + +type compareFn func(interface{}, interface{}) bool + +func greaterThan(a, b interface{}) bool { + switch i := a.(type) { + case int64: + return i > b.(int64) + case float64: + return i > b.(float64) + case string: + return i > b.(string) + } + panic("Implementation error on greaterThan") +} + +func lessThan(a, b interface{}) bool { + switch i := a.(type) { + case int64: + return i < b.(int64) + case float64: + return i < b.(float64) + case string: + return i < b.(string) + } + panic("Implementation error on lessThan") +} + +// Eval implements the Expression interface. +func (f *Greatest) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + return compEval(f.returnType, f.Args, ctx, row, greaterThan) +} + +// Least returns the argument with the least numerical or string value. It allows for +// numeric (ints anf floats) and string arguments and will return the used type +// when all arguments are of the same type or floats if there are numerically +// convertible strings or integers mixed with floats. When ints or floats +// are mixed with non numerically convertible strings, those are ignored. +type Least struct { + Args []sql.Expression + returnType sql.Type +} + +// NewLeast creates a new Least UDF +func NewLeast(args ...sql.Expression) (sql.Expression, error) { + retType, err := compRetType(args...) + if err != nil { + return nil, err + } + + return &Least{args, retType}, nil +} + +// Type implements the Expression interface. +func (f *Least) Type() sql.Type { return f.returnType } + +// IsNullable implements the Expression interface. +func (f *Least) IsNullable() bool { + for _, arg := range f.Args { + if arg.IsNullable() { + return true + } + } + return false +} + +func (f *Least) String() string { + var args = make([]string, len(f.Args)) + for i, arg := range f.Args { + args[i] = arg.String() + } + return fmt.Sprintf("least(%s)", strings.Join(args, ", ")) +} + +// WithChildren implements the Expression interface. +func (f *Least) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewLeast(children...) +} + +// Resolved implements the Expression interface. +func (f *Least) Resolved() bool { + for _, arg := range f.Args { + if !arg.Resolved() { + return false + } + } + return true +} + +// Children implements the Expression interface. +func (f *Least) Children() []sql.Expression { return f.Args } + +// Eval implements the Expression interface. +func (f *Least) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + return compEval(f.returnType, f.Args, ctx, row, lessThan) +} diff --git a/sql/expression/function/greatest_least_test.go b/sql/expression/function/greatest_least_test.go new file mode 100644 index 000000000..b75ea0bba --- /dev/null +++ b/sql/expression/function/greatest_least_test.go @@ -0,0 +1,214 @@ +package function + +import ( + "testing" + "unsafe" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" +) + +func TestGreatest(t *testing.T) { + testCases := []struct { + name string + args []sql.Expression + expected interface{} + }{ + { + "null", + []sql.Expression{ + expression.NewLiteral(nil, sql.Null), + expression.NewLiteral(5, sql.Int64), + expression.NewLiteral(1, sql.Int64), + }, + nil, + }, + { + "negative and all ints", + []sql.Expression{ + expression.NewLiteral(int64(-1), sql.Int64), + expression.NewLiteral(int64(5), sql.Int64), + expression.NewLiteral(int64(1), sql.Int64), + }, + int64(5), + }, + { + "string mixed", + []sql.Expression{ + expression.NewLiteral(string("9"), sql.Text), + expression.NewLiteral(int64(5), sql.Int64), + expression.NewLiteral(int64(1), sql.Int64), + }, + float64(9), + }, + { + "unconvertible string mixed ignored", + []sql.Expression{ + expression.NewLiteral(string("10.5"), sql.Text), + expression.NewLiteral(string("foobar"), sql.Int64), + expression.NewLiteral(int64(5), sql.Int64), + expression.NewLiteral(int64(1), sql.Int64), + }, + float64(10.5), + }, + { + "float mixed", + []sql.Expression{ + expression.NewLiteral(float64(10.0), sql.Float64), + expression.NewLiteral(int(5), sql.Int64), + expression.NewLiteral(int(1), sql.Int64), + }, + float64(10.0), + }, + { + "all strings", + []sql.Expression{ + expression.NewLiteral("aaa", sql.Text), + expression.NewLiteral("bbb", sql.Text), + expression.NewLiteral("9999", sql.Text), + expression.NewLiteral("", sql.Text), + }, + "bbb", + }, + { + "all strings and empty", + []sql.Expression{ + expression.NewLiteral("aaa", sql.Text), + expression.NewLiteral("bbb", sql.Text), + expression.NewLiteral("9999", sql.Text), + expression.NewLiteral("", sql.Text), + }, + "bbb", + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + f, err := NewGreatest(tt.args...) + require.NoError(err) + + output, err := f.Eval(sql.NewEmptyContext(), nil) + require.NoError(err) + require.Equal(tt.expected, output) + }) + } +} + +func TestGreatestUnsignedOverflow(t *testing.T) { + require := require.New(t) + + var x int + var gr sql.Expression + var err error + + switch unsafe.Sizeof(x) { + case 4: + gr, err = NewGreatest( + expression.NewLiteral(int32(1), sql.Int32), + expression.NewLiteral(uint32(4294967295), sql.Uint32), + ) + require.NoError(err) + case 8: + gr, err = NewGreatest( + expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(uint64(18446744073709551615), sql.Uint64), + ) + require.NoError(err) + default: + // non 32/64 bits?? + return + } + + _, err = gr.Eval(sql.NewEmptyContext(), nil) + require.EqualError(err, "Unsigned integer too big to fit on signed integer") +} + +func TestLeast(t *testing.T) { + testCases := []struct { + name string + args []sql.Expression + expected interface{} + }{ + { + "null", + []sql.Expression{ + expression.NewLiteral(nil, sql.Null), + expression.NewLiteral(5, sql.Int64), + expression.NewLiteral(1, sql.Int64), + }, + nil, + }, + { + "negative and all ints", + []sql.Expression{ + expression.NewLiteral(int64(-1), sql.Int64), + expression.NewLiteral(int64(5), sql.Int64), + expression.NewLiteral(int64(1), sql.Int64), + }, + int64(-1), + }, + { + "string mixed", + []sql.Expression{ + expression.NewLiteral(string("10"), sql.Text), + expression.NewLiteral(int64(5), sql.Int64), + expression.NewLiteral(int64(1), sql.Int64), + }, + float64(1), + }, + { + "unconvertible string mixed ignored", + []sql.Expression{ + expression.NewLiteral(string("10.5"), sql.Text), + expression.NewLiteral(string("foobar"), sql.Int64), + expression.NewLiteral(int64(5), sql.Int64), + expression.NewLiteral(int64(1), sql.Int64), + }, + float64(1), + }, + { + "float mixed", + []sql.Expression{ + expression.NewLiteral(float64(10.0), sql.Float64), + expression.NewLiteral(int(5), sql.Int64), + expression.NewLiteral(int(1), sql.Int64), + }, + float64(1.0), + }, + { + "all strings", + []sql.Expression{ + expression.NewLiteral("aaa", sql.Text), + expression.NewLiteral("bbb", sql.Text), + expression.NewLiteral("9999", sql.Text), + }, + "9999", + }, + { + "all strings and empty", + []sql.Expression{ + expression.NewLiteral("aaa", sql.Text), + expression.NewLiteral("bbb", sql.Text), + expression.NewLiteral("9999", sql.Text), + expression.NewLiteral("", sql.Text), + }, + "", + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + f, err := NewLeast(tt.args...) + require.NoError(err) + + output, err := f.Eval(sql.NewEmptyContext(), nil) + require.NoError(err) + require.Equal(tt.expected, output) + }) + } +} diff --git a/sql/expression/function/ifnull.go b/sql/expression/function/ifnull.go new file mode 100644 index 000000000..566d62ea6 --- /dev/null +++ b/sql/expression/function/ifnull.go @@ -0,0 +1,74 @@ +package function + +import ( + "fmt" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" +) + +// IfNull function returns the specified value IF the expression is NULL, otherwise return the expression. +type IfNull struct { + expression.BinaryExpression +} + +// NewIfNull returns a new IFNULL UDF +func NewIfNull(ex, value sql.Expression) sql.Expression { + return &IfNull{ + expression.BinaryExpression{ + Left: ex, + Right: value, + }, + } +} + +// Eval implements the Expression interface. +func (f *IfNull) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + left, err := f.Left.Eval(ctx, row) + if err != nil { + return nil, err + } + if left != nil { + return left, nil + } + + right, err := f.Right.Eval(ctx, row) + if err != nil { + return nil, err + } + return right, nil +} + +// Type implements the Expression interface. +func (f *IfNull) Type() sql.Type { + if sql.IsNull(f.Left) { + if sql.IsNull(f.Right) { + return sql.Null + } + return f.Right.Type() + } + return f.Left.Type() +} + +// IsNullable implements the Expression interface. +func (f *IfNull) IsNullable() bool { + if sql.IsNull(f.Left) { + if sql.IsNull(f.Right) { + return true + } + return f.Right.IsNullable() + } + return f.Left.IsNullable() +} + +func (f *IfNull) String() string { + return fmt.Sprintf("ifnull(%s, %s)", f.Left, f.Right) +} + +// WithChildren implements the Expression interface. +func (f *IfNull) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(f, len(children), 2) + } + return NewIfNull(children[0], children[1]), nil +} diff --git a/sql/expression/function/ifnull_test.go b/sql/expression/function/ifnull_test.go new file mode 100644 index 000000000..1751fe215 --- /dev/null +++ b/sql/expression/function/ifnull_test.go @@ -0,0 +1,36 @@ +package function + +import ( + "testing" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" +) + +func TestIfNull(t *testing.T) { + testCases := []struct { + expression interface{} + value interface{} + expected interface{} + }{ + {"foo", "bar", "foo"}, + {"foo", "foo", "foo"}, + {nil, "foo", "foo"}, + {"foo", nil, "foo"}, + {nil, nil, nil}, + {"", nil, ""}, + } + + f := NewIfNull( + expression.NewGetField(0, sql.Text, "expression", true), + expression.NewGetField(1, sql.Text, "value", true), + ) + require.Equal(t, sql.Text, f.Type()) + + for _, tc := range testCases { + v, err := f.Eval(sql.NewEmptyContext(), sql.NewRow(tc.expression, tc.value)) + require.NoError(t, err) + require.Equal(t, tc.expected, v) + } +} diff --git a/sql/expression/function/isbinary.go b/sql/expression/function/isbinary.go index 499b2618d..e3edf74d6 100644 --- a/sql/expression/function/isbinary.go +++ b/sql/expression/function/isbinary.go @@ -3,13 +3,9 @@ package function import ( "bytes" "fmt" - "time" - "github.com/opentracing/opentracing-go" - "github.com/opentracing/opentracing-go/log" - - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" ) // IsBinary is a function that returns whether a blob is binary or not. @@ -27,21 +23,6 @@ func (ib *IsBinary) Eval( ctx *sql.Context, row sql.Row, ) (interface{}, error) { - var blobSize int - span, ctx := ctx.Span("function.IsBinary") - defer func() { - span.FinishWithOptions(opentracing.FinishOptions{ - LogRecords: []opentracing.LogRecord{ - { - Timestamp: time.Now(), - Fields: []log.Field{ - log.Int("blobsize", blobSize), - }, - }, - }, - }) - }() - v, err := ib.Child.Eval(ctx, row) if err != nil { return nil, err @@ -57,7 +38,6 @@ func (ib *IsBinary) Eval( } blobBytes := blob.([]byte) - blobSize = len(blobBytes) return isBinary(blobBytes), nil } @@ -65,13 +45,12 @@ func (ib *IsBinary) String() string { return fmt.Sprintf("IS_BINARY(%s)", ib.Child) } -// TransformUp implements the Expression interface. -func (ib *IsBinary) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := ib.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (ib *IsBinary) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(ib, len(children), 1) } - return f(NewIsBinary(child)) + return NewIsBinary(children[0]), nil } // Type implements the Expression interface. diff --git a/sql/expression/function/isbinary_test.go b/sql/expression/function/isbinary_test.go index d8cc4ae2a..375ac03e5 100644 --- a/sql/expression/function/isbinary_test.go +++ b/sql/expression/function/isbinary_test.go @@ -3,9 +3,9 @@ package function import ( "testing" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" ) func TestIsBinary(t *testing.T) { diff --git a/sql/expression/function/json_extract.go b/sql/expression/function/json_extract.go new file mode 100644 index 000000000..c3f8f65eb --- /dev/null +++ b/sql/expression/function/json_extract.go @@ -0,0 +1,123 @@ +package function + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/oliveagle/jsonpath" + "github.com/src-d/go-mysql-server/sql" +) + +// JSONExtract extracts data from a json document using json paths. +type JSONExtract struct { + JSON sql.Expression + Paths []sql.Expression +} + +// NewJSONExtract creates a new JSONExtract UDF. +func NewJSONExtract(args ...sql.Expression) (sql.Expression, error) { + if len(args) < 2 { + return nil, sql.ErrInvalidArgumentNumber.New("JSON_EXTRACT", 2, len(args)) + } + + return &JSONExtract{args[0], args[1:]}, nil +} + +// Resolved implements the sql.Expression interface. +func (j *JSONExtract) Resolved() bool { + for _, p := range j.Paths { + if !p.Resolved() { + return false + } + } + return j.JSON.Resolved() +} + +// Type implements the sql.Expression interface. +func (j *JSONExtract) Type() sql.Type { return sql.JSON } + +// Eval implements the sql.Expression interface. +func (j *JSONExtract) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + span, ctx := ctx.Span("function.JSONExtract") + defer span.Finish() + + js, err := j.JSON.Eval(ctx, row) + if err != nil { + return nil, err + } + + doc, err := unmarshalVal(js) + if err != nil { + return nil, err + } + + var result = make([]interface{}, len(j.Paths)) + for i, p := range j.Paths { + path, err := p.Eval(ctx, row) + if err != nil { + return nil, err + } + + path, err = sql.Text.Convert(path) + if err != nil { + return nil, err + } + + c, err := jsonpath.Compile(path.(string)) + if err != nil { + return nil, err + } + + result[i], _ = c.Lookup(doc) // err ignored + } + + if len(result) == 1 { + return result[0], nil + } + + return result, nil +} + +func unmarshalVal(v interface{}) (interface{}, error) { + v, err := sql.JSON.Convert(v) + if err != nil { + return nil, err + } + + var doc interface{} + if err := json.Unmarshal(v.([]byte), &doc); err != nil { + return nil, err + } + + return doc, nil +} + +// IsNullable implements the sql.Expression interface. +func (j *JSONExtract) IsNullable() bool { + for _, p := range j.Paths { + if p.IsNullable() { + return true + } + } + return j.JSON.IsNullable() +} + +// Children implements the sql.Expression interface. +func (j *JSONExtract) Children() []sql.Expression { + return append([]sql.Expression{j.JSON}, j.Paths...) +} + +// WithChildren implements the Expression interface. +func (j *JSONExtract) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewJSONExtract(children...) +} + +func (j *JSONExtract) String() string { + children := j.Children() + var parts = make([]string, len(children)) + for i, c := range children { + parts[i] = c.String() + } + return fmt.Sprintf("JSON_EXTRACT(%s)", strings.Join(parts, ", ")) +} diff --git a/sql/expression/function/json_extract_test.go b/sql/expression/function/json_extract_test.go new file mode 100644 index 000000000..06f94cbfc --- /dev/null +++ b/sql/expression/function/json_extract_test.go @@ -0,0 +1,77 @@ +package function + +import ( + "errors" + "testing" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" +) + +func TestJSONExtract(t *testing.T) { + f2, err := NewJSONExtract( + expression.NewGetField(0, sql.Text, "arg1", false), + expression.NewGetField(1, sql.Text, "arg2", false), + ) + require.NoError(t, err) + + f3, err := NewJSONExtract( + expression.NewGetField(0, sql.Text, "arg1", false), + expression.NewGetField(1, sql.Text, "arg2", false), + expression.NewGetField(2, sql.Text, "arg3", false), + ) + require.NoError(t, err) + + f4, err := NewJSONExtract( + expression.NewGetField(0, sql.Text, "arg1", false), + expression.NewGetField(1, sql.Text, "arg2", false), + expression.NewGetField(2, sql.Text, "arg3", false), + expression.NewGetField(3, sql.Text, "arg4", false), + ) + require.NoError(t, err) + + json := map[string]interface{}{ + "a": []interface{}{1, 2, 3, 4}, + "b": map[string]interface{}{ + "c": "foo", + "d": true, + }, + "e": []interface{}{ + []interface{}{1, 2}, + []interface{}{3, 4}, + }, + } + + testCases := []struct { + f sql.Expression + row sql.Row + expected interface{} + err error + }{ + {f2, sql.Row{json, "FOO"}, nil, errors.New("should start with '$'")}, + {f2, sql.Row{nil, "$.b.c"}, nil, nil}, + {f2, sql.Row{json, "$.foo"}, nil, nil}, + {f2, sql.Row{json, "$.b.c"}, "foo", nil}, + {f3, sql.Row{json, "$.b.c", "$.b.d"}, []interface{}{"foo", true}, nil}, + {f4, sql.Row{json, "$.b.c", "$.b.d", "$.e[0][*]"}, []interface{}{ + "foo", + true, + []interface{}{1., 2.}, + }, nil}, + } + + for _, tt := range testCases { + t.Run(tt.f.String(), func(t *testing.T) { + require := require.New(t) + result, err := tt.f.Eval(sql.NewEmptyContext(), tt.row) + if tt.err == nil { + require.NoError(err) + } else { + require.Equal(err.Error(), tt.err.Error()) + } + + require.Equal(tt.expected, result) + }) + } +} diff --git a/sql/expression/function/json_unquote.go b/sql/expression/function/json_unquote.go new file mode 100644 index 000000000..8a0c42de3 --- /dev/null +++ b/sql/expression/function/json_unquote.go @@ -0,0 +1,138 @@ +package function + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + "fmt" + "reflect" + "unicode/utf8" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" +) + +// JSONUnquote unquotes JSON value and returns the result as a utf8mb4 string. +// Returns NULL if the argument is NULL. +// An error occurs if the value starts and ends with double quotes but is not a valid JSON string literal. +type JSONUnquote struct { + expression.UnaryExpression +} + +// NewJSONUnquote creates a new JSONUnquote UDF. +func NewJSONUnquote(json sql.Expression) sql.Expression { + return &JSONUnquote{expression.UnaryExpression{Child: json}} +} + +func (js *JSONUnquote) String() string { + return fmt.Sprintf("JSON_UNQUOTE(%s)", js.Child) +} + +// Type implements the Expression interface. +func (*JSONUnquote) Type() sql.Type { + return sql.Text +} + +// WithChildren implements the Expression interface. +func (js *JSONUnquote) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(js, len(children), 1) + } + return NewJSONUnquote(children[0]), nil +} + +// Eval implements the Expression interface. +func (js *JSONUnquote) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + json, err := js.Child.Eval(ctx, row) + if json == nil || err != nil { + return json, err + } + + ex, err := sql.Text.Convert(json) + if err != nil { + return nil, err + } + str, ok := ex.(string) + if !ok { + return nil, sql.ErrInvalidType.New(reflect.TypeOf(ex).String()) + } + + return unquote(str) +} + +// The implementation is taken from TiDB +// https://github.com/pingcap/tidb/blob/a594287e9f402037b06930026906547000006bb6/types/json/binary_functions.go#L89 +func unquote(s string) (string, error) { + ret := new(bytes.Buffer) + for i := 0; i < len(s); i++ { + if s[i] == '\\' { + i++ + if i == len(s) { + return "", fmt.Errorf("Missing a closing quotation mark in string") + } + switch s[i] { + case '"': + ret.WriteByte('"') + case 'b': + ret.WriteByte('\b') + case 'f': + ret.WriteByte('\f') + case 'n': + ret.WriteByte('\n') + case 'r': + ret.WriteByte('\r') + case 't': + ret.WriteByte('\t') + case '\\': + ret.WriteByte('\\') + case 'u': + if i+4 > len(s) { + return "", fmt.Errorf("Invalid unicode: %s", s[i+1:]) + } + char, size, err := decodeEscapedUnicode([]byte(s[i+1 : i+5])) + if err != nil { + return "", err + } + ret.Write(char[0:size]) + i += 4 + default: + // For all other escape sequences, backslash is ignored. + ret.WriteByte(s[i]) + } + } else { + ret.WriteByte(s[i]) + } + } + + str := ret.String() + strlen := len(str) + // Remove prefix and suffix '"'. + if strlen > 1 { + head, tail := str[0], str[strlen-1] + if head == '"' && tail == '"' { + return str[1 : strlen-1], nil + } + } + return str, nil +} + +// decodeEscapedUnicode decodes unicode into utf8 bytes specified in RFC 3629. +// According RFC 3629, the max length of utf8 characters is 4 bytes. +// And MySQL use 4 bytes to represent the unicode which must be in [0, 65536). +// The implementation is taken from TiDB: +// https://github.com/pingcap/tidb/blob/a594287e9f402037b06930026906547000006bb6/types/json/binary_functions.go#L136 +func decodeEscapedUnicode(s []byte) (char [4]byte, size int, err error) { + size, err = hex.Decode(char[0:2], s) + if err != nil || size != 2 { + // The unicode must can be represented in 2 bytes. + return char, 0, err + } + var unicode uint16 + err = binary.Read(bytes.NewReader(char[0:2]), binary.BigEndian, &unicode) + if err != nil { + return char, 0, err + } + size = utf8.RuneLen(rune(unicode)) + utf8.EncodeRune(char[0:size], rune(unicode)) + return +} diff --git a/sql/expression/function/json_unquote_test.go b/sql/expression/function/json_unquote_test.go new file mode 100644 index 000000000..d5d054f10 --- /dev/null +++ b/sql/expression/function/json_unquote_test.go @@ -0,0 +1,37 @@ +package function + +import ( + "testing" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" +) + +func TestJSONUnquote(t *testing.T) { + require := require.New(t) + js := NewJSONUnquote(expression.NewGetField(0, sql.Text, "json", false)) + + testCases := []struct { + row sql.Row + expected interface{} + err bool + }{ + {sql.Row{nil}, nil, false}, + {sql.Row{"\"abc\""}, `abc`, false}, + {sql.Row{"[1, 2, 3]"}, `[1, 2, 3]`, false}, + {sql.Row{"\"\t\u0032\""}, "\t2", false}, + {sql.Row{"\\"}, nil, true}, + } + + for _, tt := range testCases { + result, err := js.Eval(sql.NewEmptyContext(), tt.row) + + if !tt.err { + require.NoError(err) + require.Equal(tt.expected, result) + } else { + require.NotNil(err) + } + } +} diff --git a/sql/expression/function/length.go b/sql/expression/function/length.go new file mode 100644 index 000000000..49d46aaf8 --- /dev/null +++ b/sql/expression/function/length.go @@ -0,0 +1,91 @@ +package function + +import ( + "fmt" + "unicode/utf8" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" +) + +// Length returns the length of a string or binary content, either in bytes +// or characters. +type Length struct { + expression.UnaryExpression + CountType CountType +} + +// CountType is the kind of length count. +type CountType bool + +const ( + // NumBytes counts the number of bytes in a string or binary content. + NumBytes = CountType(false) + // NumChars counts the number of characters in a string or binary content. + NumChars = CountType(true) +) + +// NewLength returns a new LENGTH function. +func NewLength(e sql.Expression) sql.Expression { + return &Length{expression.UnaryExpression{Child: e}, NumBytes} +} + +// NewCharLength returns a new CHAR_LENGTH function. +func NewCharLength(e sql.Expression) sql.Expression { + return &Length{expression.UnaryExpression{Child: e}, NumChars} +} + +// WithChildren implements the Expression interface. +func (l *Length) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(l, len(children), 1) + } + + return &Length{expression.UnaryExpression{Child: children[0]}, l.CountType}, nil +} + +// Type implements the sql.Expression interface. +func (l *Length) Type() sql.Type { return sql.Int32 } + +func (l *Length) String() string { + if l.CountType == NumBytes { + return fmt.Sprintf("LENGTH(%s)", l.Child) + } + return fmt.Sprintf("CHAR_LENGTH(%s)", l.Child) +} + +// Eval implements the sql.Expression interface. +func (l *Length) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + val, err := l.Child.Eval(ctx, row) + if err != nil { + return nil, err + } + + if val == nil { + return nil, nil + } + + var content string + switch l.Child.Type() { + case sql.Blob: + val, err = sql.Blob.Convert(val) + if err != nil { + return nil, err + } + + content = string(val.([]byte)) + default: + val, err = sql.Text.Convert(val) + if err != nil { + return nil, err + } + + content = string(val.(string)) + } + + if l.CountType == NumBytes { + return int32(len(content)), nil + } + + return int32(utf8.RuneCountInString(content)), nil +} diff --git a/sql/expression/function/length_test.go b/sql/expression/function/length_test.go new file mode 100644 index 000000000..59c6f95d9 --- /dev/null +++ b/sql/expression/function/length_test.go @@ -0,0 +1,104 @@ +package function + +import ( + "testing" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" +) + +func TestLength(t *testing.T) { + testCases := []struct { + name string + input interface{} + inputType sql.Type + fn func(sql.Expression) sql.Expression + expected interface{} + }{ + { + "length string", + "fóo", + sql.Text, + NewLength, + int32(4), + }, + { + "length binary", + []byte("fóo"), + sql.Blob, + NewLength, + int32(4), + }, + { + "length empty", + "", + sql.Blob, + NewLength, + int32(0), + }, + { + "length empty binary", + []byte{}, + sql.Blob, + NewLength, + int32(0), + }, + { + "length nil", + nil, + sql.Blob, + NewLength, + nil, + }, + { + "char_length string", + "fóo", + sql.Text, + NewCharLength, + int32(3), + }, + { + "char_length binary", + []byte("fóo"), + sql.Blob, + NewCharLength, + int32(3), + }, + { + "char_length empty", + "", + sql.Blob, + NewCharLength, + int32(0), + }, + { + "char_length empty binary", + []byte{}, + sql.Blob, + NewCharLength, + int32(0), + }, + { + "char_length nil", + nil, + sql.Blob, + NewCharLength, + nil, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + result, err := tt.fn(expression.NewGetField(0, tt.inputType, "foo", false)).Eval( + sql.NewEmptyContext(), + sql.Row{tt.input}, + ) + + require.NoError(err) + require.Equal(tt.expected, result) + }) + } +} diff --git a/sql/expression/function/logarithm.go b/sql/expression/function/logarithm.go new file mode 100644 index 000000000..eb6355a83 --- /dev/null +++ b/sql/expression/function/logarithm.go @@ -0,0 +1,185 @@ +package function + +import ( + "fmt" + "math" + "reflect" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "gopkg.in/src-d/go-errors.v1" +) + +// ErrInvalidArgumentForLogarithm is returned when an invalid argument value is passed to a +// logarithm function +var ErrInvalidArgumentForLogarithm = errors.NewKind("invalid argument value for logarithm: %v") + +// NewLogBaseFunc returns LogBase creator function with a specific base. +func NewLogBaseFunc(base float64) func(e sql.Expression) sql.Expression { + return func(e sql.Expression) sql.Expression { + return NewLogBase(base, e) + } +} + +// LogBase is a function that returns the logarithm of a value with a specific base. +type LogBase struct { + expression.UnaryExpression + base float64 +} + +// NewLogBase creates a new LogBase expression. +func NewLogBase(base float64, e sql.Expression) sql.Expression { + return &LogBase{UnaryExpression: expression.UnaryExpression{Child: e}, base: base} +} + +func (l *LogBase) String() string { + switch l.base { + case float64(math.E): + return fmt.Sprintf("ln(%s)", l.Child) + case float64(10): + return fmt.Sprintf("log10(%s)", l.Child) + case float64(2): + return fmt.Sprintf("log2(%s)", l.Child) + default: + return fmt.Sprintf("log(%v, %s)", l.base, l.Child) + } +} + +// WithChildren implements the Expression interface. +func (l *LogBase) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(l, len(children), 1) + } + return NewLogBase(l.base, children[0]), nil +} + +// Type returns the resultant type of the function. +func (l *LogBase) Type() sql.Type { + return sql.Float64 +} + +// IsNullable implements the sql.Expression interface. +func (l *LogBase) IsNullable() bool { + return l.base == float64(1) || l.base <= float64(0) || l.Child.IsNullable() +} + +// Eval implements the Expression interface. +func (l *LogBase) Eval( + ctx *sql.Context, + row sql.Row, +) (interface{}, error) { + v, err := l.Child.Eval(ctx, row) + if err != nil { + return nil, err + } + + if v == nil { + return nil, nil + } + + val, err := sql.Float64.Convert(v) + if err != nil { + return nil, sql.ErrInvalidType.New(reflect.TypeOf(v)) + } + return computeLog(val.(float64), l.base) +} + +// Log is a function that returns the natural logarithm of a value. +type Log struct { + expression.BinaryExpression +} + +// NewLog creates a new Log expression. +func NewLog(args ...sql.Expression) (sql.Expression, error) { + argLen := len(args) + if argLen == 0 || argLen > 2 { + return nil, sql.ErrInvalidArgumentNumber.New("LOG", "1 or 2", argLen) + } + + if argLen == 1 { + return &Log{expression.BinaryExpression{Left: expression.NewLiteral(math.E, sql.Float64), Right: args[0]}}, nil + } else { + return &Log{expression.BinaryExpression{Left: args[0], Right: args[1]}}, nil + } +} + +func (l *Log) String() string { + return fmt.Sprintf("log(%s, %s)", l.Left, l.Right) +} + +// WithChildren implements the Expression interface. +func (l *Log) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewLog(children...) +} + +// Children implements the Expression interface. +func (l *Log) Children() []sql.Expression { + return []sql.Expression{l.Left, l.Right} +} + +// Type returns the resultant type of the function. +func (l *Log) Type() sql.Type { + return sql.Float64 +} + +// IsNullable implements the Expression interface. +func (l *Log) IsNullable() bool { + return l.Left.IsNullable() || l.Right.IsNullable() +} + +// Eval implements the Expression interface. +func (l *Log) Eval( + ctx *sql.Context, + row sql.Row, +) (interface{}, error) { + left, err := l.Left.Eval(ctx, row) + if err != nil { + return nil, err + } + + if left == nil { + return nil, nil + } + + lhs, err := sql.Float64.Convert(left) + if err != nil { + return nil, sql.ErrInvalidType.New(reflect.TypeOf(left)) + } + + right, err := l.Right.Eval(ctx, row) + if err != nil { + return nil, err + } + + if right == nil { + return nil, nil + } + + rhs, err := sql.Float64.Convert(right) + if err != nil { + return nil, sql.ErrInvalidType.New(reflect.TypeOf(right)) + } + + // rhs becomes value, lhs becomes base + return computeLog(rhs.(float64), lhs.(float64)) +} + +func computeLog(v float64, base float64) (float64, error) { + if v <= 0 { + return float64(0), ErrInvalidArgumentForLogarithm.New(v) + } + if base == float64(1) || base <= float64(0) { + return float64(0), ErrInvalidArgumentForLogarithm.New(base) + } + switch base { + case float64(2): + return math.Log2(v), nil + case float64(10): + return math.Log10(v), nil + case math.E: + return math.Log(v), nil + default: + // LOG(BASE,V) is equivalent to LOG(V) / LOG(BASE). + return float64(math.Log(v) / math.Log(base)), nil + } +} diff --git a/sql/expression/function/logarithm_test.go b/sql/expression/function/logarithm_test.go new file mode 100644 index 000000000..7c5c7d59d --- /dev/null +++ b/sql/expression/function/logarithm_test.go @@ -0,0 +1,209 @@ +package function + +import ( + "fmt" + "math" + "testing" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" + "gopkg.in/src-d/go-errors.v1" +) + +var epsilon = math.Nextafter(1, 2) - 1 + +func TestLn(t *testing.T) { + var testCases = []struct { + name string + rowType sql.Type + row sql.Row + expected interface{} + err *errors.Kind + }{ + {"Input value is zero", sql.Float64, sql.NewRow(0), nil, ErrInvalidArgumentForLogarithm}, + {"Input value is negative", sql.Float64, sql.NewRow(-1), nil, ErrInvalidArgumentForLogarithm}, + {"Input value is valid string", sql.Float64, sql.NewRow("2"), float64(0.6931471805599453), nil}, + {"Input value is invalid string", sql.Float64, sql.NewRow("aaa"), nil, sql.ErrInvalidType}, + {"Input value is valid float64", sql.Float64, sql.NewRow(3), float64(1.0986122886681096), nil}, + {"Input value is valid float32", sql.Float32, sql.NewRow(float32(6)), float64(1.791759469228055), nil}, + {"Input value is valid int64", sql.Int64, sql.NewRow(int64(8)), float64(2.0794415416798357), nil}, + {"Input value is valid int32", sql.Int32, sql.NewRow(int32(10)), float64(2.302585092994046), nil}, + } + + for _, tt := range testCases { + f := NewLogBaseFunc(math.E)(expression.NewGetField(0, tt.rowType, "", false)) + t.Run(tt.name, func(t *testing.T) { + t.Helper() + require := require.New(t) + result, err := f.Eval(sql.NewEmptyContext(), tt.row) + if tt.err != nil { + require.Error(err) + require.True(tt.err.Is(err)) + } else { + require.NoError(err) + require.InEpsilonf(tt.expected, result, epsilon, fmt.Sprintf("Actual is: %v", result)) + } + }) + } + + // Test Nil + f := NewLogBaseFunc(math.E)(expression.NewGetField(0, sql.Float64, "", true)) + require := require.New(t) + result, err := f.Eval(sql.NewEmptyContext(), sql.NewRow(nil)) + require.NoError(err) + require.Nil(result) + require.True(f.IsNullable()) +} + +func TestLog2(t *testing.T) { + var testCases = []struct { + name string + rowType sql.Type + row sql.Row + expected interface{} + err *errors.Kind + }{ + {"Input value is zero", sql.Float64, sql.NewRow(0), nil, ErrInvalidArgumentForLogarithm}, + {"Input value is negative", sql.Float64, sql.NewRow(-1), nil, ErrInvalidArgumentForLogarithm}, + {"Input value is valid string", sql.Float64, sql.NewRow("2"), float64(1), nil}, + {"Input value is invalid string", sql.Float64, sql.NewRow("aaa"), nil, sql.ErrInvalidType}, + {"Input value is valid float64", sql.Float64, sql.NewRow(3), float64(1.5849625007211563), nil}, + {"Input value is valid float32", sql.Float32, sql.NewRow(float32(6)), float64(2.584962500721156), nil}, + {"Input value is valid int64", sql.Int64, sql.NewRow(int64(8)), float64(3), nil}, + {"Input value is valid int32", sql.Int32, sql.NewRow(int32(10)), float64(3.321928094887362), nil}, + } + + for _, tt := range testCases { + f := NewLogBaseFunc(float64(2))(expression.NewGetField(0, tt.rowType, "", false)) + t.Run(tt.name, func(t *testing.T) { + t.Helper() + require := require.New(t) + result, err := f.Eval(sql.NewEmptyContext(), tt.row) + if tt.err != nil { + require.Error(err) + require.True(tt.err.Is(err)) + } else { + require.NoError(err) + require.InEpsilonf(tt.expected, result, epsilon, fmt.Sprintf("Actual is: %v", result)) + } + }) + } + + // Test Nil + f := NewLogBaseFunc(float64(2))(expression.NewGetField(0, sql.Float64, "", true)) + require := require.New(t) + result, err := f.Eval(sql.NewEmptyContext(), sql.NewRow(nil)) + require.NoError(err) + require.Nil(result) + require.True(f.IsNullable()) +} + +func TestLog10(t *testing.T) { + var testCases = []struct { + name string + rowType sql.Type + row sql.Row + expected interface{} + err *errors.Kind + }{ + {"Input value is zero", sql.Float64, sql.NewRow(0), nil, ErrInvalidArgumentForLogarithm}, + {"Input value is negative", sql.Float64, sql.NewRow(-1), nil, ErrInvalidArgumentForLogarithm}, + {"Input value is valid string", sql.Float64, sql.NewRow("2"), float64(0.3010299956639812), nil}, + {"Input value is invalid string", sql.Float64, sql.NewRow("aaa"), nil, sql.ErrInvalidType}, + {"Input value is valid float64", sql.Float64, sql.NewRow(3), float64(0.4771212547196624), nil}, + {"Input value is valid float32", sql.Float32, sql.NewRow(float32(6)), float64(0.7781512503836436), nil}, + {"Input value is valid int64", sql.Int64, sql.NewRow(int64(8)), float64(0.9030899869919435), nil}, + {"Input value is valid int32", sql.Int32, sql.NewRow(int32(10)), float64(1), nil}, + } + + for _, tt := range testCases { + f := NewLogBaseFunc(float64(10))(expression.NewGetField(0, tt.rowType, "", false)) + t.Run(tt.name, func(t *testing.T) { + t.Helper() + require := require.New(t) + result, err := f.Eval(sql.NewEmptyContext(), tt.row) + if tt.err != nil { + require.Error(err) + require.True(tt.err.Is(err)) + } else { + require.NoError(err) + require.InEpsilonf(tt.expected, result, epsilon, fmt.Sprintf("Actual is: %v", result)) + } + }) + } + + // Test Nil + f := NewLogBaseFunc(float64(10))(expression.NewGetField(0, sql.Float64, "", true)) + require := require.New(t) + result, err := f.Eval(sql.NewEmptyContext(), sql.NewRow(nil)) + require.NoError(err) + require.Nil(result) + require.True(f.IsNullable()) +} + +func TestLogInvalidArguments(t *testing.T) { + _, err := NewLog() + require.True(t, sql.ErrInvalidArgumentNumber.Is(err)) + + _, err = NewLog( + expression.NewLiteral(1, sql.Float64), + expression.NewLiteral(1, sql.Float64), + expression.NewLiteral(1, sql.Float64), + ) + require.True(t, sql.ErrInvalidArgumentNumber.Is(err)) +} + +func TestLog(t *testing.T) { + var testCases = []struct { + name string + input []sql.Expression + expected interface{} + err *errors.Kind + }{ + {"Input base is 1", []sql.Expression{expression.NewLiteral(float64(1), sql.Float64), expression.NewLiteral(float64(10), sql.Float64)}, nil, ErrInvalidArgumentForLogarithm}, + {"Input base is zero", []sql.Expression{expression.NewLiteral(float64(0), sql.Float64), expression.NewLiteral(float64(10), sql.Float64)}, nil, ErrInvalidArgumentForLogarithm}, + {"Input base is negative", []sql.Expression{expression.NewLiteral(float64(-5), sql.Float64), expression.NewLiteral(float64(10), sql.Float64)}, nil, ErrInvalidArgumentForLogarithm}, + {"Input base is valid string", []sql.Expression{expression.NewLiteral("4", sql.Text), expression.NewLiteral(float64(10), sql.Float64)}, float64(1.6609640474436813), nil}, + {"Input base is invalid string", []sql.Expression{expression.NewLiteral("bbb", sql.Text), expression.NewLiteral(float64(10), sql.Float64)}, nil, sql.ErrInvalidType}, + + {"Input value is zero", []sql.Expression{expression.NewLiteral(float64(0), sql.Float64)}, nil, ErrInvalidArgumentForLogarithm}, + {"Input value is negative", []sql.Expression{expression.NewLiteral(float64(-9), sql.Float64)}, nil, ErrInvalidArgumentForLogarithm}, + {"Input value is valid string", []sql.Expression{expression.NewLiteral("7", sql.Text)}, float64(1.9459101490553132), nil}, + {"Input value is invalid string", []sql.Expression{expression.NewLiteral("766j", sql.Text)}, nil, sql.ErrInvalidType}, + + {"Input base is valid float64", []sql.Expression{expression.NewLiteral(float64(5), sql.Float64), expression.NewLiteral(float64(99), sql.Float64)}, float64(2.855108491376949), nil}, + {"Input base is valid float32", []sql.Expression{expression.NewLiteral(float32(6), sql.Float32), expression.NewLiteral(float64(80), sql.Float64)}, float64(2.4456556306420936), nil}, + {"Input base is valid int64", []sql.Expression{expression.NewLiteral(int64(8), sql.Int64), expression.NewLiteral(float64(64), sql.Float64)}, float64(2), nil}, + {"Input base is valid int32", []sql.Expression{expression.NewLiteral(int32(10), sql.Int32), expression.NewLiteral(float64(100), sql.Float64)}, float64(2), nil}, + + {"Input value is valid float64", []sql.Expression{expression.NewLiteral(float64(5), sql.Float64), expression.NewLiteral(float64(66), sql.Float64)}, float64(2.6031788549643564), nil}, + {"Input value is valid float32", []sql.Expression{expression.NewLiteral(float32(3), sql.Float32), expression.NewLiteral(float64(50), sql.Float64)}, float64(3.560876795007312), nil}, + {"Input value is valid int64", []sql.Expression{expression.NewLiteral(int64(5), sql.Int64), expression.NewLiteral(float64(77), sql.Float64)}, float64(2.698958057527146), nil}, + {"Input value is valid int32", []sql.Expression{expression.NewLiteral(int32(4), sql.Int32), expression.NewLiteral(float64(40), sql.Float64)}, float64(2.6609640474436813), nil}, + } + + for _, tt := range testCases { + f, _ := NewLog(tt.input...) + t.Run(tt.name, func(t *testing.T) { + t.Helper() + require := require.New(t) + result, err := f.Eval(sql.NewEmptyContext(), nil) + if tt.err != nil { + require.Error(err) + require.True(tt.err.Is(err)) + } else { + require.NoError(err) + require.InEpsilonf(tt.expected, result, epsilon, fmt.Sprintf("Actual is: %v", result)) + } + }) + } + + // Test Nil + f, _ := NewLog(expression.NewLiteral(nil, sql.Float64)) + require := require.New(t) + result, err := f.Eval(sql.NewEmptyContext(), nil) + require.NoError(err) + require.Nil(result) + require.True(f.IsNullable()) +} diff --git a/sql/expression/function/lower_upper.go b/sql/expression/function/lower_upper.go new file mode 100644 index 000000000..f0d6cdb9d --- /dev/null +++ b/sql/expression/function/lower_upper.go @@ -0,0 +1,107 @@ +package function + +import ( + "fmt" + "strings" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" +) + +// Lower is a function that returns the lowercase of the text provided. +type Lower struct { + expression.UnaryExpression +} + +// NewLower creates a new Lower expression. +func NewLower(e sql.Expression) sql.Expression { + return &Lower{expression.UnaryExpression{Child: e}} +} + +// Eval implements the Expression interface. +func (l *Lower) Eval( + ctx *sql.Context, + row sql.Row, +) (interface{}, error) { + v, err := l.Child.Eval(ctx, row) + if err != nil { + return nil, err + } + + if v == nil { + return nil, nil + } + + v, err = sql.Text.Convert(v) + if err != nil { + return nil, err + } + + return strings.ToLower(v.(string)), nil +} + +func (l *Lower) String() string { + return fmt.Sprintf("LOWER(%s)", l.Child) +} + +// WithChildren implements the Expression interface. +func (l *Lower) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(l, len(children), 1) + } + return NewLower(children[0]), nil +} + +// Type implements the Expression interface. +func (l *Lower) Type() sql.Type { + return l.Child.Type() +} + +// Upper is a function that returns the UPPERCASE of the text provided. +type Upper struct { + expression.UnaryExpression +} + +// NewUpper creates a new Lower expression. +func NewUpper(e sql.Expression) sql.Expression { + return &Upper{expression.UnaryExpression{Child: e}} +} + +// Eval implements the Expression interface. +func (u *Upper) Eval( + ctx *sql.Context, + row sql.Row, +) (interface{}, error) { + v, err := u.Child.Eval(ctx, row) + if err != nil { + return nil, err + } + + if v == nil { + return nil, nil + } + + v, err = sql.Text.Convert(v) + if err != nil { + return nil, err + } + + return strings.ToUpper(v.(string)), nil +} + +func (u *Upper) String() string { + return fmt.Sprintf("UPPER(%s)", u.Child) +} + +// WithChildren implements the Expression interface. +func (u *Upper) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(u, len(children), 1) + } + return NewUpper(children[0]), nil +} + +// Type implements the Expression interface. +func (u *Upper) Type() sql.Type { + return u.Child.Type() +} diff --git a/sql/expression/function/lower_upper_test.go b/sql/expression/function/lower_upper_test.go new file mode 100644 index 000000000..e06dca8a9 --- /dev/null +++ b/sql/expression/function/lower_upper_test.go @@ -0,0 +1,61 @@ +package function + +import ( + "testing" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" +) + +func TestLower(t *testing.T) { + testCases := []struct { + name string + rowType sql.Type + row sql.Row + expected interface{} + }{ + {"text nil", sql.Text, sql.NewRow(nil), nil}, + {"text ok", sql.Text, sql.NewRow("LoWeR"), "lower"}, + {"binary ok", sql.Blob, sql.NewRow([]byte("LoWeR")), "lower"}, + {"other type", sql.Int32, sql.NewRow(int32(1)), "1"}, + } + + for _, tt := range testCases { + f := NewLower(expression.NewGetField(0, tt.rowType, "", true)) + + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.expected, eval(t, f, tt.row)) + }) + + req := require.New(t) + req.True(f.IsNullable()) + req.Equal(tt.rowType, f.Type()) + } +} + +func TestUpper(t *testing.T) { + testCases := []struct { + name string + rowType sql.Type + row sql.Row + expected interface{} + }{ + {"text nil", sql.Text, sql.NewRow(nil), nil}, + {"text ok", sql.Text, sql.NewRow("UpPeR"), "UPPER"}, + {"binary ok", sql.Blob, sql.NewRow([]byte("UpPeR")), "UPPER"}, + {"other type", sql.Int32, sql.NewRow(int32(1)), "1"}, + } + + for _, tt := range testCases { + f := NewUpper(expression.NewGetField(0, tt.rowType, "", true)) + + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.expected, eval(t, f, tt.row)) + }) + + req := require.New(t) + req.True(f.IsNullable()) + req.Equal(tt.rowType, f.Type()) + } +} diff --git a/sql/expression/function/nullif.go b/sql/expression/function/nullif.go new file mode 100644 index 000000000..49b5a5d9d --- /dev/null +++ b/sql/expression/function/nullif.go @@ -0,0 +1,66 @@ +package function + +import ( + "fmt" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" +) + +// NullIf function compares two expressions and returns NULL if they are equal. Otherwise, the first expression is returned. +type NullIf struct { + expression.BinaryExpression +} + +// NewNullIf returns a new NULLIF UDF +func NewNullIf(ex1, ex2 sql.Expression) sql.Expression { + return &NullIf{ + expression.BinaryExpression{ + Left: ex1, + Right: ex2, + }, + } +} + +// Eval implements the Expression interface. +func (f *NullIf) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + if sql.IsNull(f.Left) && sql.IsNull(f.Right) { + return sql.Null, nil + } + + val, err := expression.NewEquals(f.Left, f.Right).Eval(ctx, row) + if err != nil { + return nil, err + } + if b, ok := val.(bool); ok && b { + return sql.Null, nil + } + + return f.Left.Eval(ctx, row) +} + +// Type implements the Expression interface. +func (f *NullIf) Type() sql.Type { + if sql.IsNull(f.Left) { + return sql.Null + } + + return f.Left.Type() +} + +// IsNullable implements the Expression interface. +func (f *NullIf) IsNullable() bool { + return true +} + +func (f *NullIf) String() string { + return fmt.Sprintf("nullif(%s, %s)", f.Left, f.Right) +} + +// WithChildren implements the Expression interface. +func (f *NullIf) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(f, len(children), 2) + } + return NewNullIf(children[0], children[1]), nil +} diff --git a/sql/expression/function/nullif_test.go b/sql/expression/function/nullif_test.go new file mode 100644 index 000000000..c62284cbe --- /dev/null +++ b/sql/expression/function/nullif_test.go @@ -0,0 +1,43 @@ +package function + +import ( + "testing" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" +) + +func TestNullIf(t *testing.T) { + testCases := []struct { + ex1 interface{} + ex2 interface{} + expected interface{} + }{ + {"foo", "bar", "foo"}, + {"foo", "foo", sql.Null}, + {nil, "foo", nil}, + {"foo", nil, "foo"}, + {nil, nil, nil}, + {"", nil, ""}, + } + + f := NewNullIf( + expression.NewGetField(0, sql.Text, "ex1", true), + expression.NewGetField(1, sql.Text, "ex2", true), + ) + require.Equal(t, sql.Text, f.Type()) + + var3 := sql.VarChar(3) + f = NewNullIf( + expression.NewGetField(0, var3, "ex1", true), + expression.NewGetField(1, var3, "ex2", true), + ) + require.Equal(t, var3, f.Type()) + + for _, tc := range testCases { + v, err := f.Eval(sql.NewEmptyContext(), sql.NewRow(tc.ex1, tc.ex2)) + require.NoError(t, err) + require.Equal(t, tc.expected, v) + } +} diff --git a/sql/expression/function/regexp_matches.go b/sql/expression/function/regexp_matches.go new file mode 100644 index 000000000..417e91f5e --- /dev/null +++ b/sql/expression/function/regexp_matches.go @@ -0,0 +1,204 @@ +package function + +import ( + "fmt" + "regexp" + "strings" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + errors "gopkg.in/src-d/go-errors.v1" +) + +// RegexpMatches returns the matches of a regular expression. +type RegexpMatches struct { + Text sql.Expression + Pattern sql.Expression + Flags sql.Expression + + cacheable bool + re *regexp.Regexp +} + +// NewRegexpMatches creates a new RegexpMatches expression. +func NewRegexpMatches(args ...sql.Expression) (sql.Expression, error) { + var r RegexpMatches + switch len(args) { + case 3: + r.Flags = args[2] + fallthrough + case 2: + r.Text = args[0] + r.Pattern = args[1] + default: + return nil, sql.ErrInvalidArgumentNumber.New("regexp_matches", "2 or 3", len(args)) + } + + if canBeCached(r.Pattern) && (r.Flags == nil || canBeCached(r.Flags)) { + r.cacheable = true + } + + return &r, nil +} + +// Type implements the sql.Expression interface. +func (r *RegexpMatches) Type() sql.Type { return sql.Array(sql.Text) } + +// IsNullable implements the sql.Expression interface. +func (r *RegexpMatches) IsNullable() bool { return true } + +// Children implements the sql.Expression interface. +func (r *RegexpMatches) Children() []sql.Expression { + var result = []sql.Expression{r.Text, r.Pattern} + if r.Flags != nil { + result = append(result, r.Flags) + } + return result +} + +// Resolved implements the sql.Expression interface. +func (r *RegexpMatches) Resolved() bool { + return r.Text.Resolved() && r.Pattern.Resolved() && (r.Flags == nil || r.Flags.Resolved()) +} + +// WithChildren implements the sql.Expression interface. +func (r *RegexpMatches) WithChildren(children ...sql.Expression) (sql.Expression, error) { + required := 2 + if r.Flags != nil { + required = 3 + } + + if len(children) != required { + return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), required) + } + + return NewRegexpMatches(children...) +} + +func (r *RegexpMatches) String() string { + var args []string + for _, e := range r.Children() { + args = append(args, e.String()) + } + return fmt.Sprintf("regexp_matches(%s)", strings.Join(args, ", ")) +} + +// Eval implements the sql.Expression interface. +func (r *RegexpMatches) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + span, ctx := ctx.Span("function.RegexpMatches") + defer span.Finish() + + var re *regexp.Regexp + var err error + if r.cacheable { + if r.re == nil { + r.re, err = r.compileRegex(ctx, nil) + if err != nil { + return nil, err + } + + if r.re == nil { + return nil, nil + } + } + re = r.re + } else { + re, err = r.compileRegex(ctx, row) + if err != nil { + return nil, err + } + + if re == nil { + return nil, nil + } + } + + text, err := r.Text.Eval(ctx, row) + if err != nil { + return nil, err + } + + if text == nil { + return nil, nil + } + + text, err = sql.Text.Convert(text) + if err != nil { + return nil, err + } + + matches := re.FindAllStringSubmatch(text.(string), -1) + if len(matches) == 0 { + return nil, nil + } + + var result []interface{} + for _, m := range matches { + for _, sm := range m { + result = append(result, sm) + } + } + + return result, nil +} + +func (r *RegexpMatches) compileRegex(ctx *sql.Context, row sql.Row) (*regexp.Regexp, error) { + pattern, err := r.Pattern.Eval(ctx, row) + if err != nil { + return nil, err + } + + if pattern == nil { + return nil, nil + } + + pattern, err = sql.Text.Convert(pattern) + if err != nil { + return nil, err + } + + var flags string + if r.Flags != nil { + f, err := r.Flags.Eval(ctx, row) + if err != nil { + return nil, err + } + + if f == nil { + return nil, nil + } + + f, err = sql.Text.Convert(f) + if err != nil { + return nil, err + } + + flags = f.(string) + for _, f := range flags { + if !validRegexpFlags[f] { + return nil, errInvalidRegexpFlag.New(f) + } + } + + flags = fmt.Sprintf("(?%s)", flags) + } + + return regexp.Compile(flags + pattern.(string)) +} + +var errInvalidRegexpFlag = errors.NewKind("invalid regexp flag: %v") + +var validRegexpFlags = map[rune]bool{ + 'i': true, +} + +func canBeCached(e sql.Expression) bool { + var hasCols bool + expression.Inspect(e, func(e sql.Expression) bool { + if _, ok := e.(*expression.GetField); ok { + hasCols = true + } + return true + }) + return !hasCols +} diff --git a/sql/expression/function/regexp_matches_test.go b/sql/expression/function/regexp_matches_test.go new file mode 100644 index 000000000..4a7fc35c5 --- /dev/null +++ b/sql/expression/function/regexp_matches_test.go @@ -0,0 +1,146 @@ +package function + +import ( + "testing" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" + + errors "gopkg.in/src-d/go-errors.v1" +) + +func TestRegexpMatches(t *testing.T) { + testCases := []struct { + pattern interface{} + text interface{} + flags interface{} + expected interface{} + err *errors.Kind + }{ + { + `^foobar(.*)bye$`, + "foobarhellobye", + "", + []interface{}{"foobarhellobye", "hello"}, + nil, + }, + { + "bop", + "bopbeepbop", + "", + []interface{}{"bop", "bop"}, + nil, + }, + { + "bop", + "bopbeepBop", + "i", + []interface{}{"bop", "Bop"}, + nil, + }, + { + "bop", + "helloworld", + "", + nil, + nil, + }, + { + "foo", + "", + "", + nil, + nil, + }, + { + "", + "", + "", + []interface{}{""}, + nil, + }, + { + "bop", + nil, + "", + nil, + nil, + }, + { + "bop", + "beep", + nil, + nil, + nil, + }, + { + nil, + "bop", + "", + nil, + nil, + }, + { + "bop", + "bopbeepBop", + "ix", + nil, + errInvalidRegexpFlag, + }, + } + + t.Run("cacheable", func(t *testing.T) { + for _, tt := range testCases { + var flags sql.Expression + if tt.flags != "" { + flags = expression.NewLiteral(tt.flags, sql.Text) + } + f, err := NewRegexpMatches( + expression.NewLiteral(tt.text, sql.Text), + expression.NewLiteral(tt.pattern, sql.Text), + flags, + ) + require.NoError(t, err) + + t.Run(f.String(), func(t *testing.T) { + require := require.New(t) + result, err := f.Eval(sql.NewEmptyContext(), nil) + if tt.err == nil { + require.NoError(err) + require.Equal(tt.expected, result) + } else { + require.Error(err) + require.True(tt.err.Is(err)) + } + }) + } + }) + + t.Run("not cacheable", func(t *testing.T) { + for _, tt := range testCases { + var flags sql.Expression + if tt.flags != "" { + flags = expression.NewGetField(2, sql.Text, "x", false) + } + f, err := NewRegexpMatches( + expression.NewGetField(0, sql.Text, "x", false), + expression.NewGetField(1, sql.Text, "x", false), + flags, + ) + require.NoError(t, err) + + t.Run(f.String(), func(t *testing.T) { + require := require.New(t) + result, err := f.Eval(sql.NewEmptyContext(), sql.Row{tt.text, tt.pattern, tt.flags}) + if tt.err == nil { + require.NoError(err) + require.Equal(tt.expected, result) + } else { + require.Error(err) + require.True(tt.err.Is(err)) + } + }) + } + }) +} diff --git a/sql/expression/function/registry.go b/sql/expression/function/registry.go index dc0ebdd26..7d2586bba 100644 --- a/sql/expression/function/registry.go +++ b/sql/expression/function/registry.go @@ -1,32 +1,102 @@ package function import ( - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression/function/aggregation" + "math" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression/function/aggregation" ) // Defaults is the function map with all the default functions. -var Defaults = sql.Functions{ - "count": sql.Function1(func(e sql.Expression) sql.Expression { - return aggregation.NewCount(e) - }), - "min": sql.Function1(func(e sql.Expression) sql.Expression { - return aggregation.NewMin(e) - }), - "max": sql.Function1(func(e sql.Expression) sql.Expression { - return aggregation.NewMax(e) - }), - "avg": sql.Function1(func(e sql.Expression) sql.Expression { - return aggregation.NewAvg(e) - }), - "is_binary": sql.Function1(NewIsBinary), - "substring": sql.FunctionN(NewSubstring), - "year": sql.Function1(NewYear), - "month": sql.Function1(NewMonth), - "day": sql.Function1(NewDay), - "hour": sql.Function1(NewHour), - "minute": sql.Function1(NewMinute), - "second": sql.Function1(NewSecond), - "dayofyear": sql.Function1(NewDayOfYear), - "array_length": sql.Function1(NewArrayLength), +var Defaults = []sql.Function{ + sql.Function1{ + Name: "count", + Fn: func(e sql.Expression) sql.Expression { return aggregation.NewCount(e) }, + }, + sql.Function1{ + Name: "min", + Fn: func(e sql.Expression) sql.Expression { return aggregation.NewMin(e) }, + }, + sql.Function1{ + Name: "max", + Fn: func(e sql.Expression) sql.Expression { return aggregation.NewMax(e) }, + }, + sql.Function1{ + Name: "avg", + Fn: func(e sql.Expression) sql.Expression { return aggregation.NewAvg(e) }, + }, + sql.Function1{ + Name: "sum", + Fn: func(e sql.Expression) sql.Expression { return aggregation.NewSum(e) }, + }, + sql.Function1{ + Name: "first", + Fn: func(e sql.Expression) sql.Expression { return aggregation.NewFirst(e) }, + }, + sql.Function1{ + Name: "last", + Fn: func(e sql.Expression) sql.Expression { return aggregation.NewLast(e) }, + }, + sql.Function1{Name: "is_binary", Fn: NewIsBinary}, + sql.FunctionN{Name: "substring", Fn: NewSubstring}, + sql.Function3{Name: "substring_index", Fn: NewSubstringIndex}, + sql.FunctionN{Name: "mid", Fn: NewSubstring}, + sql.FunctionN{Name: "substr", Fn: NewSubstring}, + sql.Function1{Name: "date", Fn: NewDate}, + sql.Function1{Name: "year", Fn: NewYear}, + sql.Function1{Name: "month", Fn: NewMonth}, + sql.Function1{Name: "day", Fn: NewDay}, + sql.Function1{Name: "weekday", Fn: NewWeekday}, + sql.Function1{Name: "hour", Fn: NewHour}, + sql.Function1{Name: "minute", Fn: NewMinute}, + sql.Function1{Name: "second", Fn: NewSecond}, + sql.Function1{Name: "dayofweek", Fn: NewDayOfWeek}, + sql.Function1{Name: "dayofmonth", Fn: NewDay}, + sql.Function1{Name: "dayofyear", Fn: NewDayOfYear}, + sql.FunctionN{Name: "yearweek", Fn: NewYearWeek}, + sql.Function1{Name: "array_length", Fn: NewArrayLength}, + sql.Function2{Name: "split", Fn: NewSplit}, + sql.FunctionN{Name: "concat", Fn: NewConcat}, + sql.FunctionN{Name: "concat_ws", Fn: NewConcatWithSeparator}, + sql.FunctionN{Name: "coalesce", Fn: NewCoalesce}, + sql.Function1{Name: "lower", Fn: NewLower}, + sql.Function1{Name: "upper", Fn: NewUpper}, + sql.Function1{Name: "ceiling", Fn: NewCeil}, + sql.Function1{Name: "ceil", Fn: NewCeil}, + sql.Function1{Name: "floor", Fn: NewFloor}, + sql.FunctionN{Name: "round", Fn: NewRound}, + sql.Function0{Name: "connection_id", Fn: NewConnectionID}, + sql.Function1{Name: "soundex", Fn: NewSoundex}, + sql.FunctionN{Name: "json_extract", Fn: NewJSONExtract}, + sql.Function1{Name: "json_unquote", Fn: NewJSONUnquote}, + sql.Function1{Name: "ln", Fn: NewLogBaseFunc(float64(math.E))}, + sql.Function1{Name: "log2", Fn: NewLogBaseFunc(float64(2))}, + sql.Function1{Name: "log10", Fn: NewLogBaseFunc(float64(10))}, + sql.FunctionN{Name: "log", Fn: NewLog}, + sql.FunctionN{Name: "rpad", Fn: NewPadFunc(rPadType)}, + sql.FunctionN{Name: "lpad", Fn: NewPadFunc(lPadType)}, + sql.Function1{Name: "sqrt", Fn: NewSqrt}, + sql.Function2{Name: "pow", Fn: NewPower}, + sql.Function2{Name: "power", Fn: NewPower}, + sql.Function1{Name: "ltrim", Fn: NewTrimFunc(lTrimType)}, + sql.Function1{Name: "rtrim", Fn: NewTrimFunc(rTrimType)}, + sql.Function1{Name: "trim", Fn: NewTrimFunc(bTrimType)}, + sql.Function1{Name: "reverse", Fn: NewReverse}, + sql.Function2{Name: "repeat", Fn: NewRepeat}, + sql.Function3{Name: "replace", Fn: NewReplace}, + sql.Function2{Name: "ifnull", Fn: NewIfNull}, + sql.Function2{Name: "nullif", Fn: NewNullIf}, + sql.Function0{Name: "now", Fn: NewNow}, + sql.Function1{Name: "sleep", Fn: NewSleep}, + sql.Function1{Name: "to_base64", Fn: NewToBase64}, + sql.Function1{Name: "from_base64", Fn: NewFromBase64}, + sql.FunctionN{Name: "date_add", Fn: NewDateAdd}, + sql.FunctionN{Name: "date_sub", Fn: NewDateSub}, + sql.FunctionN{Name: "greatest", Fn: NewGreatest}, + sql.FunctionN{Name: "least", Fn: NewLeast}, + sql.Function1{Name: "length", Fn: NewLength}, + sql.Function1{Name: "char_length", Fn: NewCharLength}, + sql.Function1{Name: "character_length", Fn: NewCharLength}, + sql.Function1{Name: "explode", Fn: NewExplode}, + sql.FunctionN{Name: "regexp_matches", Fn: NewRegexpMatches}, } diff --git a/sql/expression/function/reverse_repeat_replace.go b/sql/expression/function/reverse_repeat_replace.go new file mode 100644 index 000000000..cef9e9691 --- /dev/null +++ b/sql/expression/function/reverse_repeat_replace.go @@ -0,0 +1,209 @@ +package function + +import ( + "fmt" + "strings" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "gopkg.in/src-d/go-errors.v1" +) + +// Reverse is a function that returns the reverse of the text provided. +type Reverse struct { + expression.UnaryExpression +} + +// NewReverse creates a new Reverse expression. +func NewReverse(e sql.Expression) sql.Expression { + return &Reverse{expression.UnaryExpression{Child: e}} +} + +// Eval implements the Expression interface. +func (r *Reverse) Eval( + ctx *sql.Context, + row sql.Row, +) (interface{}, error) { + v, err := r.Child.Eval(ctx, row) + if v == nil || err != nil { + return nil, err + } + + v, err = sql.Text.Convert(v) + if err != nil { + return nil, err + } + + return reverseString(v.(string)), nil +} + +func reverseString(s string) string { + r := []rune(s) + for i, j := 0, len(r)-1; i < j; i, j = i+1, j-1 { + r[i], r[j] = r[j], r[i] + } + return string(r) +} + +func (r *Reverse) String() string { + return fmt.Sprintf("reverse(%s)", r.Child) +} + +// WithChildren implements the Expression interface. +func (r *Reverse) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), 1) + } + return NewReverse(children[0]), nil +} + +// Type implements the Expression interface. +func (r *Reverse) Type() sql.Type { + return r.Child.Type() +} + +var ErrNegativeRepeatCount = errors.NewKind("negative Repeat count: %v") + +// Repeat is a function that returns the string repeated n times. +type Repeat struct { + expression.BinaryExpression +} + +// NewRepeat creates a new Repeat expression. +func NewRepeat(str sql.Expression, count sql.Expression) sql.Expression { + return &Repeat{expression.BinaryExpression{Left: str, Right: count}} +} + +func (r *Repeat) String() string { + return fmt.Sprintf("repeat(%s, %s)", r.Left, r.Right) +} + +// Type implements the Expression interface. +func (r *Repeat) Type() sql.Type { + return sql.Text +} + +// WithChildren implements the Expression interface. +func (r *Repeat) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), 2) + } + return NewRepeat(children[0], children[1]), nil +} + +// Eval implements the Expression interface. +func (r *Repeat) Eval( + ctx *sql.Context, + row sql.Row, +) (interface{}, error) { + str, err := r.Left.Eval(ctx, row) + if str == nil || err != nil { + return nil, err + } + + str, err = sql.Text.Convert(str) + if err != nil { + return nil, err + } + + count, err := r.Right.Eval(ctx, row) + if count == nil || err != nil { + return nil, err + } + + count, err = sql.Int32.Convert(count) + if err != nil { + return nil, err + } + if count.(int32) < 0 { + return nil, ErrNegativeRepeatCount.New(count) + } + return strings.Repeat(str.(string), int(count.(int32))), nil +} + +// Replace is a function that returns a string with all occurrences of fromStr replaced by the +// string toStr +type Replace struct { + str sql.Expression + fromStr sql.Expression + toStr sql.Expression +} + +// NewReplace creates a new Replace expression. +func NewReplace(str sql.Expression, fromStr sql.Expression, toStr sql.Expression) sql.Expression { + return &Replace{str, fromStr, toStr} +} + +// Children implements the Expression interface. +func (r *Replace) Children() []sql.Expression { + return []sql.Expression{r.str, r.fromStr, r.toStr} +} + +// Resolved implements the Expression interface. +func (r *Replace) Resolved() bool { + return r.str.Resolved() && r.fromStr.Resolved() && r.toStr.Resolved() +} + +// IsNullable implements the Expression interface. +func (r *Replace) IsNullable() bool { + return r.str.IsNullable() || r.fromStr.IsNullable() || r.toStr.IsNullable() +} + +func (r *Replace) String() string { + return fmt.Sprintf("replace(%s, %s, %s)", r.str, r.fromStr, r.toStr) +} + +// Type implements the Expression interface. +func (r *Replace) Type() sql.Type { + return sql.Text +} + +// WithChildren implements the Expression interface. +func (r *Replace) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 3 { + return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), 3) + } + return NewReplace(children[0], children[1], children[2]), nil +} + +// Eval implements the Expression interface. +func (r *Replace) Eval( + ctx *sql.Context, + row sql.Row, +) (interface{}, error) { + str, err := r.str.Eval(ctx, row) + if str == nil || err != nil { + return nil, err + } + + str, err = sql.Text.Convert(str) + if err != nil { + return nil, err + } + + fromStr, err := r.fromStr.Eval(ctx, row) + if fromStr == nil || err != nil { + return nil, err + } + + fromStr, err = sql.Text.Convert(fromStr) + if err != nil { + return nil, err + } + + toStr, err := r.toStr.Eval(ctx, row) + if toStr == nil || err != nil { + return nil, err + } + + toStr, err = sql.Text.Convert(toStr) + if err != nil { + return nil, err + } + + if fromStr.(string) == "" { + return str, nil + } + + return strings.Replace(str.(string), fromStr.(string), toStr.(string), -1), nil +} diff --git a/sql/expression/function/reverse_repeat_replace_test.go b/sql/expression/function/reverse_repeat_replace_test.go new file mode 100644 index 000000000..f04b16bc6 --- /dev/null +++ b/sql/expression/function/reverse_repeat_replace_test.go @@ -0,0 +1,110 @@ +package function + +import ( + "testing" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" +) + +func TestReverse(t *testing.T) { + f := NewReverse(expression.NewGetField(0, sql.Text, "", false)) + testCases := []struct { + name string + row sql.Row + expected interface{} + err bool + }{ + {"null input", sql.NewRow(nil), nil, false}, + {"empty string", sql.NewRow(""), "", false}, + {"handles numbers as strings", sql.NewRow(123), "321", false}, + {"valid string", sql.NewRow("foobar"), "raboof", false}, + } + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + t.Helper() + require := require.New(t) + ctx := sql.NewEmptyContext() + + v, err := f.Eval(ctx, tt.row) + if tt.err { + require.Error(err) + } else { + require.NoError(err) + require.Equal(tt.expected, v) + } + }) + } +} + +func TestRepeat(t *testing.T) { + f := NewRepeat( + expression.NewGetField(0, sql.Text, "", false), + expression.NewGetField(1, sql.Int32, "", false), + ) + + testCases := []struct { + name string + row sql.Row + expected interface{} + err bool + }{ + {"null input", sql.NewRow(nil), nil, false}, + {"empty string", sql.NewRow("", 2), "", false}, + {"count is zero", sql.NewRow("foo", 0), "", false}, + {"count is negative", sql.NewRow("foo", -2), "foo", true}, + {"valid string", sql.NewRow("foobar", 2), "foobarfoobar", false}, + } + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + t.Helper() + require := require.New(t) + ctx := sql.NewEmptyContext() + + v, err := f.Eval(ctx, tt.row) + if tt.err { + require.Error(err) + } else { + require.NoError(err) + require.Equal(tt.expected, v) + } + }) + } +} + +func TestReplace(t *testing.T) { + f := NewReplace( + expression.NewGetField(0, sql.Text, "", false), + expression.NewGetField(1, sql.Text, "", false), + expression.NewGetField(2, sql.Text, "", false), + ) + + testCases := []struct { + name string + row sql.Row + expected interface{} + err bool + }{ + {"null inputs", sql.NewRow(nil), nil, false}, + {"empty str", sql.NewRow("", "foo", "bar"), "", false}, + {"empty fromStr", sql.NewRow("foobarfoobar", "", "car"), "foobarfoobar", false}, + {"empty toStr", sql.NewRow("foobarfoobar", "bar", ""), "foofoo", false}, + {"valid strings", sql.NewRow("foobarfoobar", "bar", "car"), "foocarfoocar", false}, + } + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + t.Helper() + require := require.New(t) + ctx := sql.NewEmptyContext() + + v, err := f.Eval(ctx, tt.row) + if tt.err { + require.Error(err) + } else { + require.NoError(err) + require.Equal(tt.expected, v) + } + }) + } +} diff --git a/sql/expression/function/rpad_lpad.go b/sql/expression/function/rpad_lpad.go new file mode 100644 index 000000000..12b33695b --- /dev/null +++ b/sql/expression/function/rpad_lpad.go @@ -0,0 +1,158 @@ +package function + +import ( + "fmt" + "reflect" + "strings" + + "github.com/src-d/go-mysql-server/sql" + "gopkg.in/src-d/go-errors.v1" +) + +var ErrDivisionByZero = errors.NewKind("division by zero") + +type padType rune + +const ( + lPadType padType = 'l' + rPadType padType = 'r' +) + +// NewPadFunc returns a Pad creator function with a specific padType. +func NewPadFunc(pType padType) func(e ...sql.Expression) (sql.Expression, error) { + return func(e ...sql.Expression) (sql.Expression, error) { + return NewPad(pType, e...) + } +} + +// NewPad creates a new Pad expression. +func NewPad(pType padType, args ...sql.Expression) (sql.Expression, error) { + argLen := len(args) + if argLen != 3 { + return nil, sql.ErrInvalidArgumentNumber.New(string(pType)+"pad", "3", argLen) + } + + return &Pad{args[0], args[1], args[2], pType}, nil +} + +// Pad is a function that pads a string with another string. +type Pad struct { + str sql.Expression + length sql.Expression + padStr sql.Expression + padType padType +} + +// Children implements the Expression interface. +func (p *Pad) Children() []sql.Expression { + return []sql.Expression{p.str, p.length, p.padStr} +} + +// Resolved implements the Expression interface. +func (p *Pad) Resolved() bool { + return p.str.Resolved() && p.length.Resolved() && (p.padStr.Resolved()) +} + +// IsNullable implements the Expression interface. +func (p *Pad) IsNullable() bool { + return p.str.IsNullable() || p.length.IsNullable() || p.padStr.IsNullable() +} + +// Type implements the Expression interface. +func (p *Pad) Type() sql.Type { return sql.Text } + +func (p *Pad) String() string { + if p.padType == lPadType { + return fmt.Sprintf("lpad(%s, %s, %s)", p.str, p.length, p.padStr) + } + return fmt.Sprintf("rpad(%s, %s, %s)", p.str, p.length, p.padStr) +} + +// WithChildren implements the Expression interface. +func (p *Pad) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewPad(p.padType, children...) +} + +// Eval implements the Expression interface. +func (p *Pad) Eval( + ctx *sql.Context, + row sql.Row, +) (interface{}, error) { + str, err := p.str.Eval(ctx, row) + if err != nil { + return nil, err + } + + if str == nil { + return nil, nil + } + + str, err = sql.Text.Convert(str) + if err != nil { + return nil, sql.ErrInvalidType.New(reflect.TypeOf(str)) + } + + length, err := p.length.Eval(ctx, row) + if err != nil { + return nil, err + } + + if length == nil { + return nil, nil + } + + length, err = sql.Int64.Convert(length) + if err != nil { + return nil, err + } + + padStr, err := p.padStr.Eval(ctx, row) + if err != nil { + return nil, err + } + + if padStr == nil { + return nil, nil + } + + padStr, err = sql.Text.Convert(padStr) + if err != nil { + return nil, err + } + + return padString(str.(string), length.(int64), padStr.(string), p.padType) +} + +func padString(str string, length int64, padStr string, padType padType) (string, error) { + if length <= 0 { + return "", nil + } + if int64(len(str)) >= length { + return str[:length], nil + } + if len(padStr) == 0 { + return "", nil + } + + padLen := int(length - int64(len(str))) + quo, rem, err := divmod(int64(padLen), int64(len(padStr))) + if err != nil { + return "", err + } + + if padType == lPadType { + result := strings.Repeat(padStr, int(quo)) + padStr[:rem] + str + return result[:length], nil + } + result := str + strings.Repeat(padStr, int(quo)) + padStr[:rem] + return result[(int64(len(result)) - length):], nil +} + +func divmod(a, b int64) (quotient, remainder int64, err error) { + if b == 0 { + return 0, 0, ErrDivisionByZero.New() + } + quotient = a / b + remainder = a % b + return +} diff --git a/sql/expression/function/rpad_lpad_test.go b/sql/expression/function/rpad_lpad_test.go new file mode 100644 index 000000000..81e1dbcae --- /dev/null +++ b/sql/expression/function/rpad_lpad_test.go @@ -0,0 +1,109 @@ +package function + +import ( + "testing" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" +) + +func TestLPad(t *testing.T) { + f, err := NewPad( + lPadType, + expression.NewGetField(0, sql.Text, "str", false), + expression.NewGetField(1, sql.Int64, "len", false), + expression.NewGetField(2, sql.Text, "padStr", false), + ) + require.NoError(t, err) + testCases := []struct { + name string + row sql.Row + expected interface{} + err bool + }{ + {"null string", sql.NewRow(nil, 1, "bar"), nil, false}, + {"null len", sql.NewRow("foo", nil, "bar"), nil, false}, + {"null padStr", sql.NewRow("foo", 1, nil), nil, false}, + + {"negative length", sql.NewRow("foo", -1, "bar"), "", false}, + {"length 0", sql.NewRow("foo", 0, "bar"), "", false}, + {"invalid length", sql.NewRow("foo", "a", "bar"), "", true}, + + {"empty padStr and len < len(str)", sql.NewRow("foo", 1, ""), "f", false}, + {"empty padStr and len > len(str)", sql.NewRow("foo", 4, ""), "", false}, + {"empty padStr and len == len(str)", sql.NewRow("foo", 3, ""), "foo", false}, + + {"non empty padStr and len < len(str)", sql.NewRow("foo", 1, "abcd"), "f", false}, + {"non empty padStr and len == len(str)", sql.NewRow("foo", 3, "abcd"), "foo", false}, + + {"padStr repeats exactly once", sql.NewRow("foo", 6, "abc"), "abcfoo", false}, + {"padStr does not repeat once", sql.NewRow("foo", 5, "abc"), "abfoo", false}, + {"padStr repeats many times", sql.NewRow("foo", 10, "abc"), "abcabcafoo", false}, + } + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + t.Helper() + require := require.New(t) + ctx := sql.NewEmptyContext() + + v, err := f.Eval(ctx, tt.row) + if tt.err { + require.Error(err) + } else { + require.NoError(err) + require.Equal(tt.expected, v) + } + }) + } +} + +func TestRPad(t *testing.T) { + f, err := NewPad( + rPadType, + expression.NewGetField(0, sql.Text, "str", false), + expression.NewGetField(1, sql.Int64, "len", false), + expression.NewGetField(2, sql.Text, "padStr", false), + ) + require.NoError(t, err) + testCases := []struct { + name string + row sql.Row + expected interface{} + err bool + }{ + {"null string", sql.NewRow(nil, 1, "bar"), nil, false}, + {"null len", sql.NewRow("foo", nil, "bar"), nil, false}, + {"null padStr", sql.NewRow("foo", 1, nil), nil, false}, + + {"negative length", sql.NewRow("foo", -1, "bar"), "", false}, + {"length 0", sql.NewRow("foo", 0, "bar"), "", false}, + {"invalid length", sql.NewRow("foo", "a", "bar"), "", true}, + + {"empty padStr and len < len(str)", sql.NewRow("foo", 1, ""), "f", false}, + {"empty padStr and len > len(str)", sql.NewRow("foo", 4, ""), "", false}, + {"empty padStr and len == len(str)", sql.NewRow("foo", 3, ""), "foo", false}, + + {"non empty padStr and len < len(str)", sql.NewRow("foo", 1, "abcd"), "f", false}, + {"non empty padStr and len == len(str)", sql.NewRow("foo", 3, "abcd"), "foo", false}, + + {"padStr repeats exactly once", sql.NewRow("foo", 6, "abc"), "fooabc", false}, + {"padStr does not repeat once", sql.NewRow("foo", 5, "abc"), "fooab", false}, + {"padStr repeats many times", sql.NewRow("foo", 10, "abc"), "fooabcabca", false}, + } + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + t.Helper() + require := require.New(t) + ctx := sql.NewEmptyContext() + + v, err := f.Eval(ctx, tt.row) + if tt.err { + require.Error(err) + } else { + require.NoError(err) + require.Equal(tt.expected, v) + } + }) + } +} diff --git a/sql/expression/function/sleep.go b/sql/expression/function/sleep.go new file mode 100644 index 000000000..230a9596f --- /dev/null +++ b/sql/expression/function/sleep.go @@ -0,0 +1,73 @@ +package function + +import ( + "context" + "fmt" + "time" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" +) + +// Sleep is a function that just waits for the specified number of seconds +// and returns 0. +// It can be useful to test timeouts or long queries. +type Sleep struct { + expression.UnaryExpression +} + +// NewSleep creates a new Sleep expression. +func NewSleep(e sql.Expression) sql.Expression { + return &Sleep{expression.UnaryExpression{Child: e}} +} + +// Eval implements the Expression interface. +func (s *Sleep) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + child, err := s.Child.Eval(ctx, row) + + if err != nil { + return nil, err + } + + if child == nil { + return nil, nil + } + + child, err = sql.Float64.Convert(child) + if err != nil { + return nil, err + } + + t := time.NewTimer(time.Duration(child.(float64)*1000) * time.Millisecond) + defer t.Stop() + + select { + case <-ctx.Done(): + return 0, context.Canceled + case <-t.C: + return 0, nil + } +} + +// String implements the Stringer interface. +func (s *Sleep) String() string { + return fmt.Sprintf("SLEEP(%s)", s.Child) +} + +// IsNullable implements the Expression interface. +func (s *Sleep) IsNullable() bool { + return false +} + +// WithChildren implements the Expression interface. +func (s *Sleep) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 1) + } + return NewSleep(children[0]), nil +} + +// Type implements the Expression interface. +func (s *Sleep) Type() sql.Type { + return sql.Int32 +} diff --git a/sql/expression/function/sleep_test.go b/sql/expression/function/sleep_test.go new file mode 100644 index 000000000..65ce02478 --- /dev/null +++ b/sql/expression/function/sleep_test.go @@ -0,0 +1,50 @@ +package function + +import ( + "testing" + "time" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" +) + +func TestSleep(t *testing.T) { + f := NewSleep( + expression.NewGetField(0, sql.Text, "n", false), + ) + testCases := []struct { + name string + row sql.Row + expected interface{} + waitTime float64 + err bool + }{ + {"null input", sql.NewRow(nil), nil, 0, false}, + {"string input", sql.NewRow("foo"), nil, 0, true}, + {"int input", sql.NewRow(3), int(0), 3.0, false}, + {"number is zero", sql.NewRow(0), int(0), 0, false}, + {"negative number", sql.NewRow(-4), int(0), 0, false}, + {"positive number", sql.NewRow(4.48), int(0), 4.48, false}, + } + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + t.Helper() + require := require.New(t) + ctx := sql.NewEmptyContext() + + t1 := time.Now() + v, err := f.Eval(ctx, tt.row) + t2 := time.Now() + if tt.err { + require.Error(err) + } else { + require.NoError(err) + require.Equal(tt.expected, v) + + waited := t2.Sub(t1).Seconds() + require.InDelta(waited, tt.waitTime, 0.1) + } + }) + } +} diff --git a/sql/expression/function/soundex.go b/sql/expression/function/soundex.go new file mode 100644 index 000000000..37774228e --- /dev/null +++ b/sql/expression/function/soundex.go @@ -0,0 +1,101 @@ +package function + +import ( + "fmt" + "strings" + "unicode" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" +) + +// Soundex is a function that returns the soundex of a string. Two strings that +// sound almost the same should have identical soundex strings. A standard +// soundex string is four characters long, but the SOUNDEX() function returns +// an arbitrarily long string. +type Soundex struct { + expression.UnaryExpression +} + +// NewSoundex creates a new Soundex expression. +func NewSoundex(e sql.Expression) sql.Expression { + return &Soundex{expression.UnaryExpression{Child: e}} +} + +// Eval implements the Expression interface. +func (s *Soundex) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + v, err := s.Child.Eval(ctx, row) + if err != nil { + return nil, err + } + + if v == nil { + return nil, nil + } + + v, err = sql.Text.Convert(v) + if err != nil { + return nil, err + } + + var b strings.Builder + var last rune + for _, c := range strings.ToUpper(v.(string)) { + if last == 0 && !unicode.IsLetter(c) { + continue + } + code := s.code(c) + if last == 0 { + b.WriteRune(c) + last = code + continue + } + if code == '0' || code == last { + continue + } + b.WriteRune(code) + last = code + } + if b.Len() == 0 { + return "0000", nil + } + for i := len([]rune(b.String())); i < 4; i++ { + b.WriteRune('0') + } + return b.String(), nil +} + +func (s *Soundex) code(c rune) rune { + switch c { + case 'B', 'F', 'P', 'V': + return '1' + case 'C', 'G', 'J', 'K', 'Q', 'S', 'X', 'Z': + return '2' + case 'D', 'T': + return '3' + case 'L': + return '4' + case 'M', 'N': + return '5' + case 'R': + return '6' + } + return '0' +} + +func (s *Soundex) String() string { + return fmt.Sprintf("SOUNDEX(%s)", s.Child) +} + +// WithChildren implements the Expression interface. +func (s *Soundex) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 1) + } + return NewSoundex(children[0]), nil +} + +// Type implements the Expression interface. +func (s *Soundex) Type() sql.Type { + return sql.Text +} diff --git a/sql/expression/function/soundex_test.go b/sql/expression/function/soundex_test.go new file mode 100644 index 000000000..eb5739e56 --- /dev/null +++ b/sql/expression/function/soundex_test.go @@ -0,0 +1,50 @@ +package function + +import ( + "testing" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" +) + +func TestSoundex(t *testing.T) { + testCases := []struct { + name string + rowType sql.Type + row sql.Row + expected interface{} + }{ + {"text nil", sql.Text, sql.NewRow(nil), nil}, + {"text empty", sql.Text, sql.NewRow(""), "0000"}, + {"text ignored character", sql.Text, sql.NewRow("-"), "0000"}, + {"text runes", sql.Text, sql.NewRow("日本語"), "日000"}, + {"text Hello ok", sql.Text, sql.NewRow("Hello"), "H400"}, + {"text Quadratically ok", sql.Text, sql.NewRow("Quadratically"), "Q36324"}, + {"text Lee ok", sql.Text, sql.NewRow("Lee"), "L000"}, + {"text McKnockitter ok", sql.Text, sql.NewRow("McKnockitter"), "M25236"}, + {"text Honeyman ok", sql.Text, sql.NewRow("Honeyman"), "H500"}, + {"text Munn ok", sql.Text, sql.NewRow("Munn"), "M000"}, + {"text Poppett ok", sql.Text, sql.NewRow("Poppett"), "P300"}, + {"text Peachman ok", sql.Text, sql.NewRow("Peachman"), "P250"}, + {"text Cochrane ok", sql.Text, sql.NewRow("Cochrane"), "C650"}, + {"text Chesley ok", sql.Text, sql.NewRow("Chesley"), "C400"}, + {"text Tachenion ok", sql.Text, sql.NewRow("Tachenion"), "T250"}, + {"text Wilcox ok", sql.Text, sql.NewRow("Wilcox"), "W420"}, + {"binary ok", sql.Text, sql.NewRow([]byte("Harvey")), "H610"}, + {"string one", sql.Text, sql.NewRow("1"), "0000"}, + {"other type", sql.Text, sql.NewRow(int32(1)), "0000"}, + } + + for _, tt := range testCases { + f := NewSoundex(expression.NewGetField(0, tt.rowType, "", true)) + + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.expected, eval(t, f, tt.row)) + }) + + req := require.New(t) + req.True(f.IsNullable()) + req.Equal(tt.rowType, f.Type()) + } +} diff --git a/sql/expression/function/split.go b/sql/expression/function/split.go index 0c9955b7b..20e2a49f9 100644 --- a/sql/expression/function/split.go +++ b/sql/expression/function/split.go @@ -4,8 +4,8 @@ import ( "fmt" "regexp" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" ) // Split receives a string and returns the parts of it splitted by a @@ -76,17 +76,10 @@ func (f *Split) String() string { return fmt.Sprintf("split(%s, %s)", f.Left, f.Right) } -// TransformUp implements the Expression interface. -func (f *Split) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { - left, err := f.Left.TransformUp(fn) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (f *Split) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(f, len(children), 2) } - - right, err := f.Right.TransformUp(fn) - if err != nil { - return nil, err - } - - return fn(NewSplit(left, right)) + return NewSplit(children[0], children[1]), nil } diff --git a/sql/expression/function/split_test.go b/sql/expression/function/split_test.go index 0de25db93..478d3b749 100644 --- a/sql/expression/function/split_test.go +++ b/sql/expression/function/split_test.go @@ -3,9 +3,9 @@ package function import ( "testing" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" ) func TestSplit(t *testing.T) { diff --git a/sql/expression/function/sqrt_power.go b/sql/expression/function/sqrt_power.go new file mode 100644 index 000000000..020c8b7bf --- /dev/null +++ b/sql/expression/function/sqrt_power.go @@ -0,0 +1,127 @@ +package function + +import ( + "fmt" + "math" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" +) + +// Sqrt is a function that returns the square value of the number provided. +type Sqrt struct { + expression.UnaryExpression +} + +// NewSqrt creates a new Sqrt expression. +func NewSqrt(e sql.Expression) sql.Expression { + return &Sqrt{expression.UnaryExpression{Child: e}} +} + +func (s *Sqrt) String() string { + return fmt.Sprintf("sqrt(%s)", s.Child.String()) +} + +// Type implements the Expression interface. +func (s *Sqrt) Type() sql.Type { + return sql.Float64 +} + +// IsNullable implements the Expression interface. +func (s *Sqrt) IsNullable() bool { + return s.Child.IsNullable() +} + +// WithChildren implements the Expression interface. +func (s *Sqrt) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 1) + } + return NewSqrt(children[0]), nil +} + +// Eval implements the Expression interface. +func (s *Sqrt) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + child, err := s.Child.Eval(ctx, row) + + if err != nil { + return nil, err + } + + if child == nil { + return nil, nil + } + + child, err = sql.Float64.Convert(child) + if err != nil { + return nil, err + } + + return math.Sqrt(child.(float64)), nil +} + +// Power is a function that returns value of X raised to the power of Y. +type Power struct { + expression.BinaryExpression +} + +// NewPower creates a new Power expression. +func NewPower(e1, e2 sql.Expression) sql.Expression { + return &Power{ + expression.BinaryExpression{ + Left: e1, + Right: e2, + }, + } +} + +// Type implements the Expression interface. +func (p *Power) Type() sql.Type { return sql.Float64 } + +// IsNullable implements the Expression interface. +func (p *Power) IsNullable() bool { return p.Left.IsNullable() || p.Right.IsNullable() } + +func (p *Power) String() string { + return fmt.Sprintf("power(%s, %s)", p.Left, p.Right) +} + +// WithChildren implements the Expression interface. +func (p *Power) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 2) + } + return NewPower(children[0], children[0]), nil +} + +// Eval implements the Expression interface. +func (p *Power) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + left, err := p.Left.Eval(ctx, row) + if err != nil { + return nil, err + } + + if left == nil { + return nil, nil + } + + left, err = sql.Float64.Convert(left) + if err != nil { + return nil, err + } + + right, err := p.Right.Eval(ctx, row) + if err != nil { + return nil, err + } + + if right == nil { + return nil, nil + } + + right, err = sql.Float64.Convert(right) + if err != nil { + return nil, err + } + + return math.Pow(left.(float64), right.(float64)), nil +} diff --git a/sql/expression/function/sqrt_power_test.go b/sql/expression/function/sqrt_power_test.go new file mode 100644 index 000000000..c148a445b --- /dev/null +++ b/sql/expression/function/sqrt_power_test.go @@ -0,0 +1,110 @@ +package function + +import ( + "math" + "testing" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" +) + +func TestSqrt(t *testing.T) { + f := NewSqrt( + expression.NewGetField(0, sql.Float64, "n", false), + ) + testCases := []struct { + name string + row sql.Row + expected interface{} + err bool + }{ + {"null input", sql.NewRow(nil), nil, false}, + {"invalid string", sql.NewRow("foo"), nil, true}, + {"valid string", sql.NewRow("9"), float64(3), false}, + {"number is zero", sql.NewRow(0), float64(0), false}, + {"positive number", sql.NewRow(8), float64(2.8284271247461903), false}, + } + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + t.Helper() + require := require.New(t) + ctx := sql.NewEmptyContext() + + v, err := f.Eval(ctx, tt.row) + if tt.err { + require.Error(err) + } else { + require.NoError(err) + require.Equal(tt.expected, v) + } + }) + } + + // Test negative number + f = NewSqrt( + expression.NewGetField(0, sql.Float64, "n", false), + ) + require := require.New(t) + v, err := f.Eval(sql.NewEmptyContext(), []interface{}{float64(-4)}) + require.NoError(err) + require.IsType(float64(0), v) + require.True(math.IsNaN(v.(float64))) +} + +func TestPower(t *testing.T) { + testCases := []struct { + name string + rowType sql.Type + row sql.Row + expected interface{} + err bool + }{ + {"Base and exp are nil", sql.Float64, sql.NewRow(nil, nil), nil, false}, + {"Base is nil", sql.Float64, sql.NewRow(2, nil), nil, false}, + {"Exp is nil", sql.Float64, sql.NewRow(nil, 2), nil, false}, + + {"Base is 0", sql.Float64, sql.NewRow(0, 2), float64(0), false}, + {"Base and exp is 0", sql.Float64, sql.NewRow(0, 0), float64(1), false}, + {"Exp is 0", sql.Float64, sql.NewRow(2, 0), float64(1), false}, + {"Base is negative", sql.Float64, sql.NewRow(-2, 2), float64(4), false}, + {"Exp is negative", sql.Float64, sql.NewRow(2, -2), float64(0.25), false}, + {"Base and exp are invalid strings", sql.Float64, sql.NewRow("a", "b"), nil, true}, + {"Base and exp are valid strings", sql.Float64, sql.NewRow("2", "2"), float64(4), false}, + } + for _, tt := range testCases { + f := NewPower( + expression.NewGetField(0, tt.rowType, "", false), + expression.NewGetField(1, tt.rowType, "", false), + ) + t.Run(tt.name, func(t *testing.T) { + t.Helper() + require := require.New(t) + ctx := sql.NewEmptyContext() + + v, err := f.Eval(ctx, tt.row) + if tt.err { + require.Error(err) + } else { + require.NoError(err) + require.Equal(tt.expected, v) + } + }) + } + + // Test inf numbers + f := NewPower( + expression.NewGetField(0, sql.Float64, "", false), + expression.NewGetField(1, sql.Float64, "", false), + ) + require := require.New(t) + v, err := f.Eval(sql.NewEmptyContext(), sql.NewRow(2, math.Inf(1))) + require.NoError(err) + require.IsType(float64(0), v) + require.True(math.IsInf(v.(float64), 1)) + + v, err = f.Eval(sql.NewEmptyContext(), sql.NewRow(math.Inf(1), 2)) + require.NoError(err) + require.IsType(float64(0), v) + require.True(math.IsInf(v.(float64), 1)) +} diff --git a/sql/expression/function/substring.go b/sql/expression/function/substring.go index a9ddfc7c7..c5227b9bc 100644 --- a/sql/expression/function/substring.go +++ b/sql/expression/function/substring.go @@ -3,8 +3,9 @@ package function import ( "fmt" "reflect" + "strings" - "gopkg.in/src-d/go-mysql-server.v0/sql" + "github.com/src-d/go-mysql-server/sql" ) // Substring is a function to return a part of a string. @@ -32,7 +33,7 @@ func NewSubstring(args ...sql.Expression) (sql.Expression, error) { start = args[1] ln = args[2] default: - return nil, sql.ErrInvalidArgumentNumber.New("2 or 3", len(args)) + return nil, sql.ErrInvalidArgumentNumber.New("SUBSTRING", "2 or 3", len(args)) } return &Substring{str, start, ln}, nil } @@ -50,9 +51,6 @@ func (s *Substring) Eval( ctx *sql.Context, row sql.Row, ) (interface{}, error) { - span, ctx := ctx.Span("function.Substring") - defer span.Finish() - str, err := s.str.Eval(ctx, row) if err != nil { return nil, err @@ -144,30 +142,118 @@ func (s *Substring) Resolved() bool { // Type implements the Expression interface. func (*Substring) Type() sql.Type { return sql.Text } -// TransformUp implements the Expression interface. -func (s *Substring) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - str, err := s.str.TransformUp(f) +/// WithChildren implements the Expression interface. +func (*Substring) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewSubstring(children...) +} + +// SubstringIndex returns the substring from string str before count occurrences of the delimiter delim. +// If count is positive, everything to the left of the final delimiter (counting from the left) is returned. +// If count is negative, everything to the right of the final delimiter (counting from the right) is returned. +// SUBSTRING_INDEX() performs a case-sensitive match when searching for delim. +type SubstringIndex struct { + str sql.Expression + delim sql.Expression + count sql.Expression +} + +// NewSubstringIndex creates a new SubstringIndex UDF. +func NewSubstringIndex(str, delim, count sql.Expression) sql.Expression { + return &SubstringIndex{str, delim, count} +} + +// Children implements the Expression interface. +func (s *SubstringIndex) Children() []sql.Expression { + return []sql.Expression{s.str, s.delim, s.count} +} + +// Eval implements the Expression interface. +func (s *SubstringIndex) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + ex, err := s.str.Eval(ctx, row) + if ex == nil || err != nil { + return nil, err + } + ex, err = sql.Text.Convert(ex) if err != nil { return nil, err } + str, ok := ex.(string) + if !ok { + return nil, sql.ErrInvalidType.New(reflect.TypeOf(ex).String()) + } - start, err := s.start.TransformUp(f) + ex, err = s.delim.Eval(ctx, row) + if ex == nil || err != nil { + return nil, err + } + ex, err = sql.Text.Convert(ex) if err != nil { return nil, err } + delim, ok := ex.(string) + if !ok { + return nil, sql.ErrInvalidType.New(reflect.TypeOf(ex).String()) + } - // It is safe to omit the errors of NewSubstring here because to be able to call - // this method, you need a valid instance of Substring, so the arity must be correct - // and that's the only error NewSubstring can return. - var sub sql.Expression - if s.len != nil { - len, err := s.len.TransformUp(f) - if err != nil { - return nil, err + ex, err = s.count.Eval(ctx, row) + if ex == nil || err != nil { + return nil, err + } + ex, err = sql.Int64.Convert(ex) + if err != nil { + return nil, err + } + count, ok := ex.(int64) + if !ok { + return nil, sql.ErrInvalidType.New(reflect.TypeOf(ex).String()) + } + + // Implementation taken from pingcap/tidb + // https://github.com/pingcap/tidb/blob/37c128b64f3ad2f08d52bc767b6e3320ecf429d8/expression/builtin_string.go#L1229 + strs := strings.Split(str, delim) + start, end := int64(0), int64(len(strs)) + if count > 0 { + // If count is positive, everything to the left of the final delimiter (counting from the left) is returned. + if count < end { + end = count } - sub, _ = NewSubstring(str, start, len) } else { - sub, _ = NewSubstring(str, start) + // If count is negative, everything to the right of the final delimiter (counting from the right) is returned. + count = -count + if count < 0 { + // -count overflows max int64, returns an empty string. + return "", nil + } + + if count < end { + start = end - count + } + } + + return strings.Join(strs[start:end], delim), nil +} + +// IsNullable implements the Expression interface. +func (s *SubstringIndex) IsNullable() bool { + return s.str.IsNullable() || s.delim.IsNullable() || s.count.IsNullable() +} + +func (s *SubstringIndex) String() string { + return fmt.Sprintf("SUBSTRING_INDEX(%s, %s, %d)", s.str, s.delim, s.count) +} + +// Resolved implements the Expression interface. +func (s *SubstringIndex) Resolved() bool { + return s.str.Resolved() && s.delim.Resolved() && s.count.Resolved() +} + +// Type implements the Expression interface. +func (*SubstringIndex) Type() sql.Type { return sql.Text } + +// WithChildren implements the Expression interface. +func (s *SubstringIndex) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 3 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 3) } - return f(sub) + return NewSubstringIndex(children[0], children[1], children[2]), nil } diff --git a/sql/expression/function/substring_test.go b/sql/expression/function/substring_test.go index 826594155..e24cbb13f 100644 --- a/sql/expression/function/substring_test.go +++ b/sql/expression/function/substring_test.go @@ -3,9 +3,9 @@ package function import ( "testing" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" ) func TestSubstring(t *testing.T) { @@ -51,3 +51,45 @@ func TestSubstring(t *testing.T) { }) } } + +func TestSubstringIndex(t *testing.T) { + f := NewSubstringIndex( + expression.NewGetField(0, sql.Text, "str", true), + expression.NewGetField(1, sql.Text, "delim", true), + expression.NewGetField(2, sql.Int64, "count", false), + ) + testCases := []struct { + name string + row sql.Row + expected interface{} + err bool + }{ + {"null string", sql.NewRow(nil, ".", 1), nil, false}, + {"null delim", sql.NewRow("foo", nil, 1), nil, false}, + {"null count", sql.NewRow("foo", 1, nil), nil, false}, + {"positive count", sql.NewRow("a.b.c.d.e.f", ".", 2), "a.b", false}, + {"negative count", sql.NewRow("a.b.c.d.e.f", ".", -2), "e.f", false}, + {"count 0", sql.NewRow("a.b.c", ".", 0), "", false}, + {"long delim", sql.NewRow("a.b.c.d.e.f", "..", 5), "a.b.c.d.e.f", false}, + {"count > len", sql.NewRow("a.b.c", ".", 10), "a.b.c", false}, + {"-count > -len", sql.NewRow("a.b.c", ".", -10), "a.b.c", false}, + {"remove suffix", sql.NewRow("source{d}", "{d}", 1), "source", false}, + {"remove suffix with negtive count", sql.NewRow("source{d}", "{d}", -1), "", false}, + {"wrong count type", sql.NewRow("", "", "foo"), "", true}, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + ctx := sql.NewEmptyContext() + + v, err := f.Eval(ctx, tt.row) + if tt.err { + require.Error(err) + } else { + require.NoError(err) + require.Equal(tt.expected, v) + } + }) + } +} diff --git a/sql/expression/function/time.go b/sql/expression/function/time.go index f374e71c3..c0dcaf3d4 100644 --- a/sql/expression/function/time.go +++ b/sql/expression/function/time.go @@ -1,19 +1,18 @@ package function import ( + "errors" "fmt" "time" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" ) -func getDatePart( - ctx *sql.Context, +func getDate(ctx *sql.Context, u expression.UnaryExpression, - row sql.Row, - f func(time.Time) int, -) (interface{}, error) { + row sql.Row) (interface{}, error) { + val, err := u.Child.Eval(ctx, row) if err != nil { return nil, err @@ -27,11 +26,24 @@ func getDatePart( if err != nil { date, err = sql.Date.Convert(val) if err != nil { - return nil, err + date = nil } } - return int32(f(date.(time.Time))), nil + return date, nil +} + +func getDatePart(ctx *sql.Context, + u expression.UnaryExpression, + row sql.Row, + f func(interface{}) interface{}) (interface{}, error) { + + date, err := getDate(ctx, u, row) + if err != nil { + return nil, err + } + + return f(date), nil } // Year is a function that returns the year of a date. @@ -51,19 +63,15 @@ func (y *Year) Type() sql.Type { return sql.Int32 } // Eval implements the Expression interface. func (y *Year) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - span, ctx := ctx.Span("function.Year") - defer span.Finish() - return getDatePart(ctx, y.UnaryExpression, row, (time.Time).Year) + return getDatePart(ctx, y.UnaryExpression, row, year) } -// TransformUp implements the Expression interface. -func (y *Year) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := y.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (y *Year) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(y, len(children), 1) } - - return f(NewYear(child)) + return NewYear(children[0]), nil } // Month is a function that returns the month of a date. @@ -83,24 +91,15 @@ func (m *Month) Type() sql.Type { return sql.Int32 } // Eval implements the Expression interface. func (m *Month) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - span, ctx := ctx.Span("function.Month") - defer span.Finish() - - monthFunc := func(t time.Time) int { - return int(t.Month()) - } - - return getDatePart(ctx, m.UnaryExpression, row, monthFunc) + return getDatePart(ctx, m.UnaryExpression, row, month) } -// TransformUp implements the Expression interface. -func (m *Month) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := m.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (m *Month) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(m, len(children), 1) } - - return f(NewMonth(child)) + return NewMonth(children[0]), nil } // Day is a function that returns the day of a date. @@ -120,19 +119,44 @@ func (d *Day) Type() sql.Type { return sql.Int32 } // Eval implements the Expression interface. func (d *Day) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - span, ctx := ctx.Span("function.Day") - defer span.Finish() - return getDatePart(ctx, d.UnaryExpression, row, (time.Time).Day) + return getDatePart(ctx, d.UnaryExpression, row, day) } -// TransformUp implements the Expression interface. -func (d *Day) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := d.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (d *Day) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 1) } + return NewDay(children[0]), nil +} + +// Weekday is a function that returns the weekday of a date where 0 = Monday, +// ..., 6 = Sunday. +type Weekday struct { + expression.UnaryExpression +} - return f(NewDay(child)) +// NewWeekday creates a new Weekday UDF. +func NewWeekday(date sql.Expression) sql.Expression { + return &Weekday{expression.UnaryExpression{Child: date}} +} + +func (d *Weekday) String() string { return fmt.Sprintf("WEEKDAY(%s)", d.Child) } + +// Type implements the Expression interface. +func (d *Weekday) Type() sql.Type { return sql.Int32 } + +// Eval implements the Expression interface. +func (d *Weekday) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + return getDatePart(ctx, d.UnaryExpression, row, weekday) +} + +// WithChildren implements the Expression interface. +func (d *Weekday) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 1) + } + return NewWeekday(children[0]), nil } // Hour is a function that returns the hour of a date. @@ -152,19 +176,15 @@ func (h *Hour) Type() sql.Type { return sql.Int32 } // Eval implements the Expression interface. func (h *Hour) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - span, ctx := ctx.Span("function.Hour") - defer span.Finish() - return getDatePart(ctx, h.UnaryExpression, row, (time.Time).Hour) + return getDatePart(ctx, h.UnaryExpression, row, hour) } -// TransformUp implements the Expression interface. -func (h *Hour) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := h.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (h *Hour) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(h, len(children), 1) } - - return f(NewHour(child)) + return NewHour(children[0]), nil } // Minute is a function that returns the minute of a date. @@ -184,19 +204,15 @@ func (m *Minute) Type() sql.Type { return sql.Int32 } // Eval implements the Expression interface. func (m *Minute) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - span, ctx := ctx.Span("function.Minute") - defer span.Finish() - return getDatePart(ctx, m.UnaryExpression, row, (time.Time).Minute) + return getDatePart(ctx, m.UnaryExpression, row, minute) } -// TransformUp implements the Expression interface. -func (m *Minute) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := m.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (m *Minute) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(m, len(children), 1) } - - return f(NewMinute(child)) + return NewMinute(children[0]), nil } // Second is a function that returns the second of a date. @@ -216,19 +232,44 @@ func (s *Second) Type() sql.Type { return sql.Int32 } // Eval implements the Expression interface. func (s *Second) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - span, ctx := ctx.Span("function.Second") - defer span.Finish() - return getDatePart(ctx, s.UnaryExpression, row, (time.Time).Second) + return getDatePart(ctx, s.UnaryExpression, row, second) } -// TransformUp implements the Expression interface. -func (s *Second) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := s.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (s *Second) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 1) } + return NewSecond(children[0]), nil +} - return f(NewSecond(child)) +// DayOfWeek is a function that returns the day of the week from a date where +// 1 = Sunday, ..., 7 = Saturday. +type DayOfWeek struct { + expression.UnaryExpression +} + +// NewDayOfWeek creates a new DayOfWeek UDF. +func NewDayOfWeek(date sql.Expression) sql.Expression { + return &DayOfWeek{expression.UnaryExpression{Child: date}} +} + +func (d *DayOfWeek) String() string { return fmt.Sprintf("DAYOFWEEK(%s)", d.Child) } + +// Type implements the Expression interface. +func (d *DayOfWeek) Type() sql.Type { return sql.Int32 } + +// Eval implements the Expression interface. +func (d *DayOfWeek) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + return getDatePart(ctx, d.UnaryExpression, row, dayOfWeek) +} + +// WithChildren implements the Expression interface. +func (d *DayOfWeek) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 1) + } + return NewDayOfWeek(children[0]), nil } // DayOfYear is a function that returns the day of the year from a date. @@ -248,17 +289,292 @@ func (d *DayOfYear) Type() sql.Type { return sql.Int32 } // Eval implements the Expression interface. func (d *DayOfYear) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - span, ctx := ctx.Span("function.DayOfYear") - defer span.Finish() - return getDatePart(ctx, d.UnaryExpression, row, (time.Time).YearDay) + return getDatePart(ctx, d.UnaryExpression, row, dayOfYear) +} + +// WithChildren implements the Expression interface. +func (d *DayOfYear) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 1) + } + return NewDayOfYear(children[0]), nil +} + +func datePartFunc(fn func(time.Time) int) func(interface{}) interface{} { + return func(v interface{}) interface{} { + if v == nil { + return nil + } + + return int32(fn(v.(time.Time))) + } } -// TransformUp implements the Expression interface. -func (d *DayOfYear) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := d.Child.TransformUp(f) +// YearWeek is a function that returns year and week for a date. +// The year in the result may be different from the year in the date argument for the first and the last week of the year. +// Details: https://dev.mysql.com/doc/refman/5.5/en/date-and-time-functions.html#function_yearweek +type YearWeek struct { + date sql.Expression + mode sql.Expression +} + +// NewYearWeek creates a new YearWeek UDF +func NewYearWeek(args ...sql.Expression) (sql.Expression, error) { + if len(args) == 0 { + return nil, sql.ErrInvalidArgumentNumber.New("YEARWEEK", "1 or more", 0) + } + + yw := &YearWeek{date: args[0]} + if len(args) > 1 && args[1].Resolved() && sql.IsInteger(args[1].Type()) { + yw.mode = args[1] + } else { + yw.mode = expression.NewLiteral(0, sql.Int64) + } + return yw, nil +} + +func (d *YearWeek) String() string { return fmt.Sprintf("YEARWEEK(%s, %d)", d.date, d.mode) } + +// Type implements the Expression interface. +func (d *YearWeek) Type() sql.Type { return sql.Int32 } + +// Eval implements the Expression interface. +func (d *YearWeek) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + date, err := getDate(ctx, expression.UnaryExpression{Child: d.date}, row) + if err != nil { + return nil, err + } + yyyy, ok := year(date).(int32) + if !ok { + return nil, errors.New("YEARWEEK: invalid year") + } + mm, ok := month(date).(int32) + if !ok { + return nil, errors.New("YEARWEEK: invalid month") + } + dd, ok := day(date).(int32) + if !ok { + return nil, errors.New("YEARWEEK: invalid day") + } + + mode := int64(0) + val, err := d.mode.Eval(ctx, row) if err != nil { return nil, err } + if val != nil { + if i64, err := sql.Int64.Convert(val); err == nil { + if mode, ok = i64.(int64); ok { + mode %= 8 // mode in [0, 7] + } + } + } + yyyy, week := calcWeek(yyyy, mm, dd, weekMode(mode)|weekBehaviourYear) + + return (yyyy * 100) + week, nil +} + +// Resolved implements the Expression interface. +func (d *YearWeek) Resolved() bool { + return d.date.Resolved() && d.mode.Resolved() +} + +// Children implements the Expression interface. +func (d *YearWeek) Children() []sql.Expression { return []sql.Expression{d.date, d.mode} } + +// IsNullable implements the Expression interface. +func (d *YearWeek) IsNullable() bool { + return d.date.IsNullable() +} + +// WithChildren implements the Expression interface. +func (*YearWeek) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewYearWeek(children...) +} + +// Following solution of YearWeek was taken from tidb: https://github.com/pingcap/tidb/blob/master/types/mytime.go +type weekBehaviour int64 + +const ( + // weekBehaviourMondayFirst set Monday as first day of week; otherwise Sunday is first day of week + weekBehaviourMondayFirst weekBehaviour = 1 << iota + // If set, Week is in range 1-53, otherwise Week is in range 0-53. + // Note that this flag is only relevant if WEEK_JANUARY is not set. + weekBehaviourYear + // If not set, Weeks are numbered according to ISO 8601:1988. + // If set, the week that contains the first 'first-day-of-week' is week 1. + weekBehaviourFirstWeekday +) + +func (v weekBehaviour) test(flag weekBehaviour) bool { + return (v & flag) != 0 +} + +func weekMode(mode int64) weekBehaviour { + weekFormat := weekBehaviour(mode & 7) + if (weekFormat & weekBehaviourMondayFirst) == 0 { + weekFormat ^= weekBehaviourFirstWeekday + } + return weekFormat +} + +// calcWeekday calculates weekday from daynr, returns 0 for Monday, 1 for Tuesday ... +func calcWeekday(daynr int32, sundayFirstDayOfWeek bool) int32 { + daynr += 5 + if sundayFirstDayOfWeek { + daynr++ + } + return daynr % 7 +} + +// calcWeek calculates week and year for the time. +func calcWeek(yyyy, mm, dd int32, wb weekBehaviour) (int32, int32) { + daynr := calcDaynr(yyyy, mm, dd) + firstDaynr := calcDaynr(yyyy, 1, 1) + mondayFirst := wb.test(weekBehaviourMondayFirst) + weekYear := wb.test(weekBehaviourYear) + firstWeekday := wb.test(weekBehaviourFirstWeekday) + weekday := calcWeekday(firstDaynr, !mondayFirst) + + week, days := int32(0), int32(0) + if mm == 1 && dd <= 7-weekday { + if !weekYear && + ((firstWeekday && weekday != 0) || (!firstWeekday && weekday >= 4)) { + return yyyy, week + } + weekYear = true + yyyy-- + days = calcDaysInYear(yyyy) + firstDaynr -= days + weekday = (weekday + 53*7 - days) % 7 + } + + if (firstWeekday && weekday != 0) || + (!firstWeekday && weekday >= 4) { + days = daynr - (firstDaynr + 7 - weekday) + } else { + days = daynr - (firstDaynr - weekday) + } + + if weekYear && days >= 52*7 { + weekday = (weekday + calcDaysInYear(yyyy)) % 7 + if (!firstWeekday && weekday < 4) || + (firstWeekday && weekday == 0) { + yyyy++ + week = 1 + return yyyy, week + } + } + week = days/7 + 1 + return yyyy, week +} + +// calcDaysInYear calculates days in one year, it works with 0 <= yyyy <= 99. +func calcDaysInYear(yyyy int32) int32 { + if (yyyy&3) == 0 && (yyyy%100 != 0 || (yyyy%400 == 0 && (yyyy != 0))) { + return 366 + } + return 365 +} + +// calcDaynr calculates days since 0000-00-00. +func calcDaynr(yyyy, mm, dd int32) int32 { + if yyyy == 0 && mm == 0 { + return 0 + } + + delsum := 365*yyyy + 31*(mm-1) + dd + if mm <= 2 { + yyyy-- + } else { + delsum -= (mm*4 + 23) / 10 + } + return delsum + yyyy/4 - ((yyyy/100+1)*3)/4 +} + +var ( + year = datePartFunc((time.Time).Year) + month = datePartFunc(func(t time.Time) int { return int(t.Month()) }) + day = datePartFunc((time.Time).Day) + weekday = datePartFunc(func(t time.Time) int { return (int(t.Weekday()) + 6) % 7 }) + hour = datePartFunc((time.Time).Hour) + minute = datePartFunc((time.Time).Minute) + second = datePartFunc((time.Time).Second) + dayOfWeek = datePartFunc(func(t time.Time) int { return int(t.Weekday()) + 1 }) + dayOfYear = datePartFunc((time.Time).YearDay) +) + +type clock func() time.Time + +var defaultClock = time.Now + +// Now is a function that returns the current time. +type Now struct { + clock +} + +// NewNow returns a new Now node. +func NewNow() sql.Expression { + return &Now{defaultClock} +} + +// Type implements the sql.Expression interface. +func (*Now) Type() sql.Type { return sql.Timestamp } + +func (*Now) String() string { return "NOW()" } + +// IsNullable implements the sql.Expression interface. +func (*Now) IsNullable() bool { return false } + +// Resolved implements the sql.Expression interface. +func (*Now) Resolved() bool { return true } + +// Children implements the sql.Expression interface. +func (*Now) Children() []sql.Expression { return nil } - return f(NewDayOfYear(child)) +// Eval implements the sql.Expression interface. +func (n *Now) Eval(*sql.Context, sql.Row) (interface{}, error) { + return n.clock(), nil +} + +// WithChildren implements the Expression interface. +func (n *Now) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(n, len(children), 0) + } + return n, nil +} + +// Date a function takes the DATE part out from a datetime expression. +type Date struct { + expression.UnaryExpression +} + +// NewDate returns a new Date node. +func NewDate(date sql.Expression) sql.Expression { + return &Date{expression.UnaryExpression{Child: date}} +} + +func (d *Date) String() string { return fmt.Sprintf("DATE(%s)", d.Child) } + +// Type implements the Expression interface. +func (d *Date) Type() sql.Type { return sql.Text } + +// Eval implements the Expression interface. +func (d *Date) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + return getDatePart(ctx, d.UnaryExpression, row, func(v interface{}) interface{} { + if v == nil { + return nil + } + + return v.(time.Time).Format("2006-01-02") + }) +} + +// WithChildren implements the Expression interface. +func (d *Date) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 1) + } + return NewDate(children[0]), nil } diff --git a/sql/expression/function/time_test.go b/sql/expression/function/time_test.go index 9f84c72e1..b9762ec1f 100644 --- a/sql/expression/function/time_test.go +++ b/sql/expression/function/time_test.go @@ -4,9 +4,9 @@ import ( "testing" "time" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" ) const ( @@ -25,7 +25,7 @@ func TestTime_Year(t *testing.T) { err bool }{ {"null date", sql.NewRow(nil), nil, false}, - {"invalid type", sql.NewRow([]byte{0, 1, 2}), nil, true}, + {"invalid type", sql.NewRow([]byte{0, 1, 2}), nil, false}, {"date as string", sql.NewRow(stringDate), int32(2007), false}, {"date as time", sql.NewRow(time.Now()), int32(time.Now().UTC().Year()), false}, {"date as unix timestamp", sql.NewRow(int64(tsDate)), int32(2009), false}, @@ -56,7 +56,7 @@ func TestTime_Month(t *testing.T) { err bool }{ {"null date", sql.NewRow(nil), nil, false}, - {"invalid type", sql.NewRow([]byte{0, 1, 2}), nil, true}, + {"invalid type", sql.NewRow([]byte{0, 1, 2}), nil, false}, {"date as string", sql.NewRow(stringDate), int32(1), false}, {"date as time", sql.NewRow(time.Now()), int32(time.Now().UTC().Month()), false}, {"date as unix timestamp", sql.NewRow(int64(tsDate)), int32(11), false}, @@ -87,7 +87,7 @@ func TestTime_Day(t *testing.T) { err bool }{ {"null date", sql.NewRow(nil), nil, false}, - {"invalid type", sql.NewRow([]byte{0, 1, 2}), nil, true}, + {"invalid type", sql.NewRow([]byte{0, 1, 2}), nil, false}, {"date as string", sql.NewRow(stringDate), int32(2), false}, {"date as time", sql.NewRow(time.Now()), int32(time.Now().UTC().Day()), false}, {"date as unix timestamp", sql.NewRow(int64(tsDate)), int32(22), false}, @@ -107,6 +107,37 @@ func TestTime_Day(t *testing.T) { } } +func TestTime_Weekday(t *testing.T) { + f := NewWeekday(expression.NewGetField(0, sql.Text, "foo", false)) + ctx := sql.NewEmptyContext() + + testCases := []struct { + name string + row sql.Row + expected interface{} + err bool + }{ + {"null date", sql.NewRow(nil), nil, false}, + {"invalid type", sql.NewRow([]byte{0, 1, 2}), nil, false}, + {"date as string", sql.NewRow(stringDate), int32(1), false}, + {"date as time", sql.NewRow(time.Now()), int32(time.Now().UTC().Weekday()+6) % 7, false}, + {"date as unix timestamp", sql.NewRow(int64(tsDate)), int32(6), false}, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + val, err := f.Eval(ctx, tt.row) + if tt.err { + require.Error(err) + } else { + require.NoError(err) + require.Equal(tt.expected, val) + } + }) + } +} + func TestTime_Hour(t *testing.T) { f := NewHour(expression.NewGetField(0, sql.Text, "foo", false)) ctx := sql.NewEmptyContext() @@ -118,7 +149,7 @@ func TestTime_Hour(t *testing.T) { err bool }{ {"null date", sql.NewRow(nil), nil, false}, - {"invalid type", sql.NewRow([]byte{0, 1, 2}), nil, true}, + {"invalid type", sql.NewRow([]byte{0, 1, 2}), nil, false}, {"date as string", sql.NewRow(stringDate), int32(14), false}, {"date as time", sql.NewRow(time.Now()), int32(time.Now().UTC().Hour()), false}, {"date as unix timestamp", sql.NewRow(int64(tsDate)), int32(9), false}, @@ -149,7 +180,7 @@ func TestTime_Minute(t *testing.T) { err bool }{ {"null date", sql.NewRow(nil), nil, false}, - {"invalid type", sql.NewRow([]byte{0, 1, 2}), nil, true}, + {"invalid type", sql.NewRow([]byte{0, 1, 2}), nil, false}, {"date as string", sql.NewRow(stringDate), int32(15), false}, {"date as time", sql.NewRow(time.Now()), int32(time.Now().UTC().Minute()), false}, {"date as unix timestamp", sql.NewRow(int64(tsDate)), int32(35), false}, @@ -179,7 +210,7 @@ func TestTime_Second(t *testing.T) { err bool }{ {"null date", sql.NewRow(nil), nil, false}, - {"invalid type", sql.NewRow([]byte{0, 1, 2}), nil, true}, + {"invalid type", sql.NewRow([]byte{0, 1, 2}), nil, false}, {"date as string", sql.NewRow(stringDate), int32(16), false}, {"date as time", sql.NewRow(time.Now()), int32(time.Now().UTC().Second()), false}, {"date as unix timestamp", sql.NewRow(int64(tsDate)), int32(45), false}, @@ -199,6 +230,37 @@ func TestTime_Second(t *testing.T) { } } +func TestTime_DayOfWeek(t *testing.T) { + f := NewDayOfWeek(expression.NewGetField(0, sql.Text, "foo", false)) + ctx := sql.NewEmptyContext() + + testCases := []struct { + name string + row sql.Row + expected interface{} + err bool + }{ + {"null date", sql.NewRow(nil), nil, false}, + {"invalid type", sql.NewRow([]byte{0, 1, 2}), nil, false}, + {"date as string", sql.NewRow(stringDate), int32(3), false}, + {"date as time", sql.NewRow(time.Now()), int32(time.Now().UTC().Weekday() + 1), false}, + {"date as unix timestamp", sql.NewRow(int64(tsDate)), int32(1), false}, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + val, err := f.Eval(ctx, tt.row) + if tt.err { + require.Error(err) + } else { + require.NoError(err) + require.Equal(tt.expected, val) + } + }) + } +} + func TestTime_DayOfYear(t *testing.T) { f := NewDayOfYear(expression.NewGetField(0, sql.Text, "foo", false)) ctx := sql.NewEmptyContext() @@ -210,7 +272,7 @@ func TestTime_DayOfYear(t *testing.T) { err bool }{ {"null date", sql.NewRow(nil), nil, false}, - {"invalid type", sql.NewRow([]byte{0, 1, 2}), nil, true}, + {"invalid type", sql.NewRow([]byte{0, 1, 2}), nil, false}, {"date as string", sql.NewRow(stringDate), int32(2), false}, {"date as time", sql.NewRow(time.Now()), int32(time.Now().UTC().YearDay()), false}, {"date as unix timestamp", sql.NewRow(int64(tsDate)), int32(326), false}, @@ -229,3 +291,97 @@ func TestTime_DayOfYear(t *testing.T) { }) } } + +func TestYearWeek(t *testing.T) { + f, err := NewYearWeek(expression.NewGetField(0, sql.Text, "foo", false)) + require.NoError(t, err) + ctx := sql.NewEmptyContext() + + testCases := []struct { + name string + row sql.Row + expected interface{} + err bool + }{ + {"null date", sql.NewRow(nil), nil, true}, + {"invalid type", sql.NewRow([]byte{0, 1, 2}), nil, true}, + {"date as string", sql.NewRow(stringDate), int32(200653), false}, + {"date as unix timestamp", sql.NewRow(int64(tsDate)), int32(200947), false}, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + val, err := f.Eval(ctx, tt.row) + if tt.err { + require.Error(err) + } else { + require.NoError(err) + require.Equal(tt.expected, val) + } + }) + } +} + +func TestCalcDaynr(t *testing.T) { + require.EqualValues(t, calcDaynr(0, 0, 0), 0) + require.EqualValues(t, calcDaynr(9999, 12, 31), 3652424) + require.EqualValues(t, calcDaynr(1970, 1, 1), 719528) + require.EqualValues(t, calcDaynr(2006, 12, 16), 733026) + require.EqualValues(t, calcDaynr(10, 1, 2), 3654) + require.EqualValues(t, calcDaynr(2008, 2, 20), 733457) +} + +func TestCalcWeek(t *testing.T) { + _, w := calcWeek(2008, 2, 20, weekMode(0)) + + _, w = calcWeek(2008, 2, 20, weekMode(1)) + require.EqualValues(t, w, 8) + + _, w = calcWeek(2008, 12, 31, weekMode(1)) + require.EqualValues(t, w, 53) +} + +func TestNow(t *testing.T) { + require := require.New(t) + date := time.Date(2018, time.December, 2, 16, 25, 0, 0, time.Local) + clk := clock(func() time.Time { + return date + }) + f := &Now{clk} + + result, err := f.Eval(nil, nil) + require.NoError(err) + require.Equal(date, result) +} + +func TestDate(t *testing.T) { + f := NewDate(expression.NewGetField(0, sql.Text, "foo", false)) + ctx := sql.NewEmptyContext() + + testCases := []struct { + name string + row sql.Row + expected interface{} + err bool + }{ + {"null date", sql.NewRow(nil), nil, false}, + {"invalid type", sql.NewRow([]byte{0, 1, 2}), nil, false}, + {"date as string", sql.NewRow(stringDate), "2007-01-02", false}, + {"date as time", sql.NewRow(time.Now()), time.Now().Format("2006-01-02"), false}, + {"date as unix timestamp", sql.NewRow(int64(tsDate)), "2009-11-22", false}, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + val, err := f.Eval(ctx, tt.row) + if tt.err { + require.Error(err) + } else { + require.NoError(err) + require.Equal(tt.expected, val) + } + }) + } +} diff --git a/sql/expression/function/tobase64_frombase64.go b/sql/expression/function/tobase64_frombase64.go new file mode 100644 index 000000000..f3c638983 --- /dev/null +++ b/sql/expression/function/tobase64_frombase64.go @@ -0,0 +1,145 @@ +package function + +import ( + "encoding/base64" + "fmt" + "reflect" + "strings" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" +) + +// ToBase64 is a function to encode a string to the Base64 format +// using the same dialect that MySQL's TO_BASE64 uses +type ToBase64 struct { + expression.UnaryExpression +} + +// NewToBase64 creates a new ToBase64 expression. +func NewToBase64(e sql.Expression) sql.Expression { + return &ToBase64{expression.UnaryExpression{Child: e}} +} + +// Eval implements the Expression interface. +func (t *ToBase64) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + str, err := t.Child.Eval(ctx, row) + + if err != nil { + return nil, err + } + + if str == nil { + return nil, nil + } + + str, err = sql.Text.Convert(str) + if err != nil { + return nil, sql.ErrInvalidType.New(reflect.TypeOf(str)) + } + + encoded := base64.StdEncoding.EncodeToString([]byte(str.(string))) + + lenEncoded := len(encoded) + if lenEncoded <= 76 { + return encoded, nil + } + + // Split into max 76 chars lines + var out strings.Builder + start := 0 + end := 76 + for { + out.WriteString(encoded[start:end] + "\n") + start += 76 + end += 76 + if end >= lenEncoded { + out.WriteString(encoded[start:lenEncoded]) + break + } + } + + return out.String(), nil +} + +// String implements the Stringer interface. +func (t *ToBase64) String() string { + return fmt.Sprintf("TO_BASE64(%s)", t.Child) +} + +// IsNullable implements the Expression interface. +func (t *ToBase64) IsNullable() bool { + return t.Child.IsNullable() +} + +// WithChildren implements the Expression interface. +func (t *ToBase64) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), 1) + } + return NewToBase64(children[0]), nil +} + +// Type implements the Expression interface. +func (t *ToBase64) Type() sql.Type { + return sql.Text +} + +// FromBase64 is a function to decode a Base64-formatted string +// using the same dialect that MySQL's FROM_BASE64 uses +type FromBase64 struct { + expression.UnaryExpression +} + +// NewFromBase64 creates a new FromBase64 expression. +func NewFromBase64(e sql.Expression) sql.Expression { + return &FromBase64{expression.UnaryExpression{Child: e}} +} + +// Eval implements the Expression interface. +func (t *FromBase64) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + str, err := t.Child.Eval(ctx, row) + + if err != nil { + return nil, err + } + + if str == nil { + return nil, nil + } + + str, err = sql.Text.Convert(str) + if err != nil { + return nil, sql.ErrInvalidType.New(reflect.TypeOf(str)) + } + + decoded, err := base64.StdEncoding.DecodeString(str.(string)) + if err != nil { + return nil, err + } + + return string(decoded), nil +} + +// String implements the Stringer interface. +func (t *FromBase64) String() string { + return fmt.Sprintf("FROM_BASE64(%s)", t.Child) +} + +// IsNullable implements the Expression interface. +func (t *FromBase64) IsNullable() bool { + return t.Child.IsNullable() +} + +// WithChildren implements the Expression interface. +func (t *FromBase64) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), 1) + } + return NewFromBase64(children[0]), nil +} + +// Type implements the Expression interface. +func (t *FromBase64) Type() sql.Type { + return sql.Text +} diff --git a/sql/expression/function/tobase64_frombase64_test.go b/sql/expression/function/tobase64_frombase64_test.go new file mode 100644 index 000000000..f234728d2 --- /dev/null +++ b/sql/expression/function/tobase64_frombase64_test.go @@ -0,0 +1,56 @@ +package function + +import ( + "testing" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" +) + +func TestBase64(t *testing.T) { + fTo := NewToBase64(expression.NewGetField(0, sql.Text, "", false)) + fFrom := NewFromBase64(expression.NewGetField(0, sql.Text, "", false)) + + testCases := []struct { + name string + row sql.Row + expected interface{} + err bool + }{ + // Use a MySQL server to get expected values if updating/adding to this! + {"null input", sql.NewRow(nil), nil, false}, + {"single_line", sql.NewRow("foo"), string("Zm9v"), false}, + {"multi_line", sql.NewRow( + "Gallia est omnis divisa in partes tres, quarum unam " + + "incolunt Belgae, aliam Aquitani, tertiam qui ipsorum lingua Celtae, " + + "nostra Galli appellantur"), + "R2FsbGlhIGVzdCBvbW5pcyBkaXZpc2EgaW4gcGFydGVzIHRyZXMsIHF1YXJ1bSB1bmFtIGluY29s\n" + + "dW50IEJlbGdhZSwgYWxpYW0gQXF1aXRhbmksIHRlcnRpYW0gcXVpIGlwc29ydW0gbGluZ3VhIENl\n" + + "bHRhZSwgbm9zdHJhIEdhbGxpIGFwcGVsbGFudHVy", false}, + {"empty_input", sql.NewRow(""), string(""), false}, + {"symbols", sql.NewRow("!@#$% %^&*()_+\r\n\t{};"), string("IUAjJCUgJV4mKigpXysNCgl7fTs="), + false}, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + t.Helper() + require := require.New(t) + ctx := sql.NewEmptyContext() + v, err := fTo.Eval(ctx, tt.row) + + if tt.err { + require.Error(err) + } else { + require.NoError(err) + require.Equal(tt.expected, v) + + ctx = sql.NewEmptyContext() + v2, err := fFrom.Eval(ctx, sql.NewRow(v)) + require.NoError(err) + require.Equal(sql.NewRow(v2), tt.row) + } + }) + } +} diff --git a/sql/expression/function/trim_ltrim_rtrim.go b/sql/expression/function/trim_ltrim_rtrim.go new file mode 100644 index 000000000..b08704dfb --- /dev/null +++ b/sql/expression/function/trim_ltrim_rtrim.go @@ -0,0 +1,93 @@ +package function + +import ( + "fmt" + "reflect" + "strings" + "unicode" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" +) + +type trimType rune + +const ( + lTrimType trimType = 'l' + rTrimType trimType = 'r' + bTrimType trimType = 'b' +) + +// NewTrimFunc returns a Trim creator function with a specific trimType. +func NewTrimFunc(tType trimType) func(e sql.Expression) sql.Expression { + return func(e sql.Expression) sql.Expression { + return NewTrim(tType, e) + } +} + +// NewTrim creates a new Trim expression. +func NewTrim(tType trimType, str sql.Expression) sql.Expression { + return &Trim{expression.UnaryExpression{Child: str}, tType} +} + +// Trim is a function that returns the string with prefix or suffix spaces removed based on the trimType +type Trim struct { + expression.UnaryExpression + trimType +} + +// Type implements the Expression interface. +func (t *Trim) Type() sql.Type { return sql.Text } + +func (t *Trim) String() string { + switch t.trimType { + case lTrimType: + return fmt.Sprintf("ltrim(%s)", t.Child) + case rTrimType: + return fmt.Sprintf("rtrim(%s)", t.Child) + default: + return fmt.Sprintf("trim(%s)", t.Child) + } +} + +// IsNullable implements the Expression interface. +func (t *Trim) IsNullable() bool { + return t.Child.IsNullable() +} + +// WithChildren implements the Expression interface. +func (t *Trim) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), 1) + } + return NewTrim(t.trimType, children[0]), nil +} + +// Eval implements the Expression interface. +func (t *Trim) Eval( + ctx *sql.Context, + row sql.Row, +) (interface{}, error) { + str, err := t.Child.Eval(ctx, row) + if err != nil { + return nil, err + } + + if str == nil { + return nil, nil + } + + str, err = sql.Text.Convert(str) + if err != nil { + return nil, sql.ErrInvalidType.New(reflect.TypeOf(str)) + } + + switch t.trimType { + case lTrimType: + return strings.TrimLeftFunc(str.(string), unicode.IsSpace), nil + case rTrimType: + return strings.TrimRightFunc(str.(string), unicode.IsSpace), nil + default: + return strings.TrimFunc(str.(string), unicode.IsSpace), nil + } +} diff --git a/sql/expression/function/trim_ltrim_rtrim_test.go b/sql/expression/function/trim_ltrim_rtrim_test.go new file mode 100644 index 000000000..abfaa9f4e --- /dev/null +++ b/sql/expression/function/trim_ltrim_rtrim_test.go @@ -0,0 +1,108 @@ +package function + +import ( + "testing" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" +) + +func TestTrim(t *testing.T) { + f := NewTrimFunc(bTrimType)(expression.NewGetField(0, sql.Text, "", false)) + testCases := []struct { + name string + row sql.Row + expected interface{} + err bool + }{ + {"null input", sql.NewRow(nil), nil, false}, + {"trimmed string", sql.NewRow("foo"), "foo", false}, + {"spaces in both sides", sql.NewRow(" foo "), "foo", false}, + {"spaces in left side", sql.NewRow(" foo"), "foo", false}, + {"spaces in right side", sql.NewRow("foo "), "foo", false}, + {"two words with spaces", sql.NewRow(" foo bar "), "foo bar", false}, + {"different kinds of spaces", sql.NewRow("\r\tfoo bar \v"), "foo bar", false}, + } + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + t.Helper() + require := require.New(t) + ctx := sql.NewEmptyContext() + + v, err := f.Eval(ctx, tt.row) + if tt.err { + require.Error(err) + } else { + require.NoError(err) + require.Equal(tt.expected, v) + } + }) + } +} + +func TestLTrim(t *testing.T) { + f := NewTrimFunc(lTrimType)(expression.NewGetField(0, sql.Text, "", false)) + testCases := []struct { + name string + row sql.Row + expected interface{} + err bool + }{ + {"null input", sql.NewRow(nil), nil, false}, + {"trimmed string", sql.NewRow("foo"), "foo", false}, + {"spaces in both sides", sql.NewRow(" foo "), "foo ", false}, + {"spaces in left side", sql.NewRow(" foo"), "foo", false}, + {"spaces in right side", sql.NewRow("foo "), "foo ", false}, + {"two words with spaces", sql.NewRow(" foo bar "), "foo bar ", false}, + {"different kinds of spaces", sql.NewRow("\r\tfoo bar \v"), "foo bar \v", false}, + } + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + t.Helper() + require := require.New(t) + ctx := sql.NewEmptyContext() + + v, err := f.Eval(ctx, tt.row) + if tt.err { + require.Error(err) + } else { + require.NoError(err) + require.Equal(tt.expected, v) + } + }) + } +} + +func TestRTrim(t *testing.T) { + f := NewTrimFunc(rTrimType)(expression.NewGetField(0, sql.Text, "", false)) + testCases := []struct { + name string + row sql.Row + expected interface{} + err bool + }{ + {"null input", sql.NewRow(nil), nil, false}, + {"trimmed string", sql.NewRow("foo"), "foo", false}, + {"spaces in both sides", sql.NewRow(" foo "), " foo", false}, + {"spaces in left side", sql.NewRow(" foo"), " foo", false}, + {"spaces in right side", sql.NewRow("foo "), "foo", false}, + {"two words with spaces", sql.NewRow(" foo bar "), " foo bar", false}, + {"different kinds of spaces", sql.NewRow("\r\tfoo bar \v"), "\r\tfoo bar", false}, + } + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + t.Helper() + require := require.New(t) + ctx := sql.NewEmptyContext() + + v, err := f.Eval(ctx, tt.row) + if tt.err { + require.Error(err) + } else { + require.NoError(err) + require.Equal(tt.expected, v) + } + }) + } +} diff --git a/sql/expression/function/version.go b/sql/expression/function/version.go new file mode 100644 index 000000000..eafeb4454 --- /dev/null +++ b/sql/expression/function/version.go @@ -0,0 +1,56 @@ +package function + +import ( + "fmt" + + "github.com/src-d/go-mysql-server/sql" +) + +const mysqlVersion = "8.0.11" + +// Version is a function that returns server version. +type Version string + +// NewVersion creates a new Version UDF. +func NewVersion(versionPostfix string) func(...sql.Expression) (sql.Expression, error) { + return func(...sql.Expression) (sql.Expression, error) { + return Version(versionPostfix), nil + } +} + +// Type implements the Expression interface. +func (f Version) Type() sql.Type { return sql.Text } + +// IsNullable implements the Expression interface. +func (f Version) IsNullable() bool { + return false +} + +func (f Version) String() string { + return "VERSION()" +} + +// WithChildren implements the Expression interface. +func (f Version) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(f, len(children), 0) + } + return f, nil +} + +// Resolved implements the Expression interface. +func (f Version) Resolved() bool { + return true +} + +// Children implements the Expression interface. +func (f Version) Children() []sql.Expression { return nil } + +// Eval implements the Expression interface. +func (f Version) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + if f == "" { + return mysqlVersion, nil + } + + return fmt.Sprintf("%s-%s", mysqlVersion, string(f)), nil +} diff --git a/sql/expression/function/version_test.go b/sql/expression/function/version_test.go new file mode 100644 index 000000000..0ec10cfdb --- /dev/null +++ b/sql/expression/function/version_test.go @@ -0,0 +1,27 @@ +package function + +import ( + "testing" + + "github.com/src-d/go-mysql-server/sql" + "github.com/stretchr/testify/require" +) + +const versionPostfix = "test" + +func TestNewVersion(t *testing.T) { + require := require.New(t) + + f, _ := NewVersion(versionPostfix)() + ctx := sql.NewEmptyContext() + + val, err := f.Eval(ctx, nil) + require.NoError(err) + require.Equal("8.0.11-"+versionPostfix, val) + + f, _ = NewVersion("")() + + val, err = f.Eval(ctx, nil) + require.NoError(err) + require.Equal("8.0.11", val) +} diff --git a/sql/expression/get_field.go b/sql/expression/get_field.go index 4860e8fcb..e6a5884dc 100644 --- a/sql/expression/get_field.go +++ b/sql/expression/get_field.go @@ -3,8 +3,8 @@ package expression import ( "fmt" + "github.com/src-d/go-mysql-server/sql" errors "gopkg.in/src-d/go-errors.v1" - "gopkg.in/src-d/go-mysql-server.v0/sql" ) // GetField is an expression to get the field of a table. @@ -32,29 +32,32 @@ func NewGetFieldWithTable(index int, fieldType sql.Type, table, fieldName string } } +// Index returns the index where the GetField will look for the value from a sql.Row. +func (p *GetField) Index() int { return p.fieldIndex } + // Children implements the Expression interface. -func (GetField) Children() []sql.Expression { +func (*GetField) Children() []sql.Expression { return nil } // Table returns the name of the field table. -func (p GetField) Table() string { return p.table } +func (p *GetField) Table() string { return p.table } // Resolved implements the Expression interface. -func (p GetField) Resolved() bool { +func (p *GetField) Resolved() bool { return true } // Name implements the Nameable interface. -func (p GetField) Name() string { return p.name } +func (p *GetField) Name() string { return p.name } // IsNullable returns whether the field is nullable or not. -func (p GetField) IsNullable() bool { +func (p *GetField) IsNullable() bool { return p.nullable } // Type returns the type of the field. -func (p GetField) Type() sql.Type { +func (p *GetField) Type() sql.Type { return p.fieldType } @@ -62,22 +65,71 @@ func (p GetField) Type() sql.Type { var ErrIndexOutOfBounds = errors.NewKind("unable to find field with index %d in row of %d columns") // Eval implements the Expression interface. -func (p GetField) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { +func (p *GetField) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { if p.fieldIndex < 0 || p.fieldIndex >= len(row) { return nil, ErrIndexOutOfBounds.New(p.fieldIndex, len(row)) } return row[p.fieldIndex], nil } -// TransformUp implements the Expression interface. -func (p *GetField) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - n := *p - return f(&n) +// WithChildren implements the Expression interface. +func (p *GetField) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 0) + } + return p, nil } -func (p GetField) String() string { +func (p *GetField) String() string { if p.table == "" { return p.name } return fmt.Sprintf("%s.%s", p.table, p.name) } + +// WithIndex returns this same GetField with a new index. +func (p *GetField) WithIndex(n int) sql.Expression { + p2 := *p + p2.fieldIndex = n + return &p2 +} + +// GetSessionField is an expression that returns the value of a session configuration. +type GetSessionField struct { + name string + typ sql.Type + value interface{} +} + +// NewGetSessionField creates a new GetSessionField expression. +func NewGetSessionField(name string, typ sql.Type, value interface{}) *GetSessionField { + return &GetSessionField{name, typ, value} +} + +// Children implements the sql.Expression interface. +func (f *GetSessionField) Children() []sql.Expression { return nil } + +// Eval implements the sql.Expression interface. +func (f *GetSessionField) Eval(*sql.Context, sql.Row) (interface{}, error) { + return f.value, nil +} + +// Type implements the sql.Expression interface. +func (f *GetSessionField) Type() sql.Type { return f.typ } + +// IsNullable implements the sql.Expression interface. +func (f *GetSessionField) IsNullable() bool { return f.value == nil } + +// Resolved implements the sql.Expression interface. +func (f *GetSessionField) Resolved() bool { return true } + +// String implements the sql.Expression interface. +func (f *GetSessionField) String() string { return "@@" + f.name } + +// WithChildren implements the Expression interface. +func (f *GetSessionField) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(f, len(children), 0) + } + return f, nil +} diff --git a/sql/expression/interval.go b/sql/expression/interval.go new file mode 100644 index 000000000..a175d55b5 --- /dev/null +++ b/sql/expression/interval.go @@ -0,0 +1,284 @@ +package expression + +import ( + "fmt" + "regexp" + "strconv" + "strings" + "time" + + "github.com/src-d/go-mysql-server/sql" + errors "gopkg.in/src-d/go-errors.v1" +) + +// Interval defines a time duration. +type Interval struct { + UnaryExpression + Unit string +} + +// NewInterval creates a new interval expression. +func NewInterval(child sql.Expression, unit string) *Interval { + return &Interval{UnaryExpression{Child: child}, strings.ToUpper(unit)} +} + +// Type implements the sql.Expression interface. +func (i *Interval) Type() sql.Type { return sql.Uint64 } + +// IsNullable implements the sql.Expression interface. +func (i *Interval) IsNullable() bool { return i.Child.IsNullable() } + +// Eval implements the sql.Expression interface. +func (i *Interval) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + panic("Interval.Eval is just a placeholder method and should not be called directly") +} + +var ( + errInvalidIntervalUnit = errors.NewKind("invalid interval unit: %s") + errInvalidIntervalFormat = errors.NewKind("invalid interval format for %q: %s") +) + +// EvalDelta evaluates the expression returning a TimeDelta. This method should +// be used instead of Eval, as this expression returns a TimeDelta, which is not +// a valid value that can be returned in Eval. +func (i *Interval) EvalDelta(ctx *sql.Context, row sql.Row) (*TimeDelta, error) { + val, err := i.Child.Eval(ctx, row) + if err != nil { + return nil, err + } + + if val == nil { + return nil, nil + } + + var td TimeDelta + + if r, ok := unitTextFormats[i.Unit]; ok { + val, err = sql.Text.Convert(val) + if err != nil { + return nil, err + } + + text := val.(string) + if !r.MatchString(text) { + return nil, errInvalidIntervalFormat.New(i.Unit, text) + } + + parts := textFormatParts(text, r) + + switch i.Unit { + case "DAY_HOUR": + td.Days = parts[0] + td.Hours = parts[1] + case "DAY_MICROSECOND": + td.Days = parts[0] + td.Hours = parts[1] + td.Minutes = parts[2] + td.Seconds = parts[3] + td.Microseconds = parts[4] + case "DAY_MINUTE": + td.Days = parts[0] + td.Hours = parts[1] + td.Minutes = parts[2] + case "DAY_SECOND": + td.Days = parts[0] + td.Hours = parts[1] + td.Minutes = parts[2] + td.Seconds = parts[3] + case "HOUR_MICROSECOND": + td.Hours = parts[0] + td.Minutes = parts[1] + td.Seconds = parts[2] + td.Microseconds = parts[3] + case "HOUR_SECOND": + td.Hours = parts[0] + td.Minutes = parts[1] + td.Seconds = parts[2] + case "HOUR_MINUTE": + td.Hours = parts[0] + td.Minutes = parts[1] + case "MINUTE_MICROSECOND": + td.Minutes = parts[0] + td.Seconds = parts[1] + td.Microseconds = parts[2] + case "MINUTE_SECOND": + td.Minutes = parts[0] + td.Seconds = parts[1] + case "SECOND_MICROSECOND": + td.Seconds = parts[0] + td.Microseconds = parts[1] + case "YEAR_MONTH": + td.Years = parts[0] + td.Months = parts[1] + default: + return nil, errInvalidIntervalUnit.New(i.Unit) + } + } else { + val, err = sql.Int64.Convert(val) + if err != nil { + return nil, err + } + + num := val.(int64) + + switch i.Unit { + case "DAY": + td.Days = num + case "HOUR": + td.Hours = num + case "MINUTE": + td.Minutes = num + case "SECOND": + td.Seconds = num + case "MICROSECOND": + td.Microseconds = num + case "QUARTER": + td.Months = num * 3 + case "MONTH": + td.Months = num + case "WEEK": + td.Days = num * 7 + case "YEAR": + td.Years = num + default: + return nil, errInvalidIntervalUnit.New(i.Unit) + } + } + + return &td, nil +} + +// WithChildren implements the Expression interface. +func (i *Interval) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(i, len(children), 1) + } + return NewInterval(children[0], i.Unit), nil +} + +func (i *Interval) String() string { + return fmt.Sprintf("INTERVAL %s %s", i.Child, i.Unit) +} + +var unitTextFormats = map[string]*regexp.Regexp{ + "DAY_HOUR": regexp.MustCompile(`^(\d+)\s+(\d+)$`), + "DAY_MICROSECOND": regexp.MustCompile(`^(\d+)\s+(\d+):(\d+):(\d+).(\d+)$`), + "DAY_MINUTE": regexp.MustCompile(`^(\d+)\s+(\d+):(\d+)$`), + "DAY_SECOND": regexp.MustCompile(`^(\d+)\s+(\d+):(\d+):(\d+)$`), + "HOUR_MICROSECOND": regexp.MustCompile(`^(\d+):(\d+):(\d+).(\d+)$`), + "HOUR_SECOND": regexp.MustCompile(`^(\d+):(\d+):(\d+)$`), + "HOUR_MINUTE": regexp.MustCompile(`^(\d+):(\d+)$`), + "MINUTE_MICROSECOND": regexp.MustCompile(`^(\d+):(\d+).(\d+)$`), + "MINUTE_SECOND": regexp.MustCompile(`^(\d+):(\d+)$`), + "SECOND_MICROSECOND": regexp.MustCompile(`^(\d+).(\d+)$`), + "YEAR_MONTH": regexp.MustCompile(`^(\d+)-(\d+)$`), +} + +func textFormatParts(text string, r *regexp.Regexp) []int64 { + parts := r.FindStringSubmatch(text) + var result []int64 + for _, p := range parts[1:] { + // It is safe to igore the error here, because at this point we know + // the string matches the regexp, and that means it can't be an + // invalid number. + n, _ := strconv.ParseInt(p, 10, 64) + result = append(result, n) + } + return result +} + +// TimeDelta is the difference between a time and another time. +type TimeDelta struct { + Years int64 + Months int64 + Days int64 + Hours int64 + Minutes int64 + Seconds int64 + Microseconds int64 +} + +// Add returns the given time plus the time delta. +func (td TimeDelta) Add(t time.Time) time.Time { + return td.apply(t, 1) +} + +// Sub returns the given time minus the time delta. +func (td TimeDelta) Sub(t time.Time) time.Time { + return td.apply(t, -1) +} + +const ( + day = 24 * time.Hour + week = 7 * day +) + +func (td TimeDelta) apply(t time.Time, sign int64) time.Time { + y := int64(t.Year()) + mo := int64(t.Month()) + d := t.Day() + h := t.Hour() + min := t.Minute() + s := t.Second() + ns := t.Nanosecond() + + if td.Years != 0 { + y += td.Years * sign + } + + if td.Months != 0 { + m := mo + td.Months*sign + if m < 1 { + mo = 12 + (m % 12) + y += m/12 - 1 + } else if m > 12 { + mo = m % 12 + y += m / 12 + } else { + mo = m + } + + // Due to the operations done before, month may be zero, which means it's + // december. + if mo == 0 { + mo = 12 + } + } + + if days := daysInMonth(time.Month(mo), int(y)); days < d { + d = days + } + + date := time.Date(int(y), time.Month(mo), d, h, min, s, ns, t.Location()) + + if td.Days != 0 { + date = date.Add(time.Duration(td.Days) * day * time.Duration(sign)) + } + + if td.Hours != 0 { + date = date.Add(time.Duration(td.Hours) * time.Hour * time.Duration(sign)) + } + + if td.Minutes != 0 { + date = date.Add(time.Duration(td.Minutes) * time.Minute * time.Duration(sign)) + } + + if td.Seconds != 0 { + date = date.Add(time.Duration(td.Seconds) * time.Second * time.Duration(sign)) + } + + if td.Microseconds != 0 { + date = date.Add(time.Duration(td.Microseconds) * time.Microsecond * time.Duration(sign)) + } + + return date +} + +func daysInMonth(month time.Month, year int) int { + if month == time.December { + return 31 + } + + date := time.Date(year, month+time.Month(1), 1, 0, 0, 0, 0, time.Local) + return date.Add(-1 * day).Day() +} diff --git a/sql/expression/interval_test.go b/sql/expression/interval_test.go new file mode 100644 index 000000000..24be71a5a --- /dev/null +++ b/sql/expression/interval_test.go @@ -0,0 +1,285 @@ +package expression + +import ( + "testing" + "time" + + "github.com/src-d/go-mysql-server/sql" + "github.com/stretchr/testify/require" +) + +func TestTimeDelta(t *testing.T) { + leapYear := date(2004, time.February, 29, 0, 0, 0, 0) + testCases := []struct { + name string + delta TimeDelta + date time.Time + output time.Time + }{ + { + "leap year minus one year", + TimeDelta{Years: -1}, + leapYear, + date(2003, time.February, 28, 0, 0, 0, 0), + }, + { + "leap year plus one year", + TimeDelta{Years: 1}, + leapYear, + date(2005, time.February, 28, 0, 0, 0, 0), + }, + { + "plus overflowing months", + TimeDelta{Months: 13}, + leapYear, + date(2005, time.March, 29, 0, 0, 0, 0), + }, + { + "plus overflowing until december", + TimeDelta{Months: 22}, + leapYear, + date(2006, time.December, 29, 0, 0, 0, 0), + }, + { + "minus overflowing months", + TimeDelta{Months: -13}, + leapYear, + date(2003, time.January, 29, 0, 0, 0, 0), + }, + { + "minus overflowing until december", + TimeDelta{Months: -14}, + leapYear, + date(2002, time.December, 29, 0, 0, 0, 0), + }, + { + "minus months", + TimeDelta{Months: -1}, + leapYear, + date(2004, time.January, 29, 0, 0, 0, 0), + }, + { + "plus months", + TimeDelta{Months: 1}, + leapYear, + date(2004, time.March, 29, 0, 0, 0, 0), + }, + { + "minus days", + TimeDelta{Days: -2}, + leapYear, + date(2004, time.February, 27, 0, 0, 0, 0), + }, + { + "plus days", + TimeDelta{Days: 1}, + leapYear, + date(2004, time.March, 1, 0, 0, 0, 0), + }, + { + "minus hours", + TimeDelta{Hours: -2}, + leapYear, + date(2004, time.February, 28, 22, 0, 0, 0), + }, + { + "plus hours", + TimeDelta{Hours: 26}, + leapYear, + date(2004, time.March, 1, 2, 0, 0, 0), + }, + { + "minus minutes", + TimeDelta{Minutes: -2}, + leapYear, + date(2004, time.February, 28, 23, 58, 0, 0), + }, + { + "plus minutes", + TimeDelta{Minutes: 26}, + leapYear, + date(2004, time.February, 29, 0, 26, 0, 0), + }, + { + "minus seconds", + TimeDelta{Seconds: -2}, + leapYear, + date(2004, time.February, 28, 23, 59, 58, 0), + }, + { + "plus seconds", + TimeDelta{Seconds: 26}, + leapYear, + date(2004, time.February, 29, 0, 0, 26, 0), + }, + { + "minus microseconds", + TimeDelta{Microseconds: -2}, + leapYear, + date(2004, time.February, 28, 23, 59, 59, 999998), + }, + { + "plus microseconds", + TimeDelta{Microseconds: 26}, + leapYear, + date(2004, time.February, 29, 0, 0, 0, 26), + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + result := tt.delta.Add(tt.date) + require.Equal(t, tt.output, result) + }) + } +} + +func TestIntervalEvalDelta(t *testing.T) { + testCases := []struct { + expr sql.Expression + unit string + row sql.Row + expected TimeDelta + }{ + { + NewGetField(0, sql.Int64, "foo", false), + "DAY", + sql.Row{int64(2)}, + TimeDelta{Days: 2}, + }, + { + NewLiteral(int64(2), sql.Int64), + "DAY", + nil, + TimeDelta{Days: 2}, + }, + { + NewLiteral(int64(2), sql.Int64), + "MONTH", + nil, + TimeDelta{Months: 2}, + }, + { + NewLiteral(int64(2), sql.Int64), + "YEAR", + nil, + TimeDelta{Years: 2}, + }, + { + NewLiteral(int64(2), sql.Int64), + "QUARTER", + nil, + TimeDelta{Months: 6}, + }, + { + NewLiteral(int64(2), sql.Int64), + "WEEK", + nil, + TimeDelta{Days: 14}, + }, + { + NewLiteral(int64(2), sql.Int64), + "HOUR", + nil, + TimeDelta{Hours: 2}, + }, + { + NewLiteral(int64(2), sql.Int64), + "MINUTE", + nil, + TimeDelta{Minutes: 2}, + }, + { + NewLiteral(int64(2), sql.Int64), + "SECOND", + nil, + TimeDelta{Seconds: 2}, + }, + { + NewLiteral(int64(2), sql.Int64), + "MICROSECOND", + nil, + TimeDelta{Microseconds: 2}, + }, + { + NewLiteral("2 3", sql.Text), + "DAY_HOUR", + nil, + TimeDelta{Days: 2, Hours: 3}, + }, + { + NewLiteral("2 3:04:05.06", sql.Text), + "DAY_MICROSECOND", + nil, + TimeDelta{Days: 2, Hours: 3, Minutes: 4, Seconds: 5, Microseconds: 6}, + }, + { + NewLiteral("2 3:04:05", sql.Text), + "DAY_SECOND", + nil, + TimeDelta{Days: 2, Hours: 3, Minutes: 4, Seconds: 5}, + }, + { + NewLiteral("2 3:04", sql.Text), + "DAY_MINUTE", + nil, + TimeDelta{Days: 2, Hours: 3, Minutes: 4}, + }, + { + NewLiteral("3:04:05.06", sql.Text), + "HOUR_MICROSECOND", + nil, + TimeDelta{Hours: 3, Minutes: 4, Seconds: 5, Microseconds: 6}, + }, + { + NewLiteral("3:04:05", sql.Text), + "HOUR_SECOND", + nil, + TimeDelta{Hours: 3, Minutes: 4, Seconds: 5}, + }, + { + NewLiteral("3:04", sql.Text), + "HOUR_MINUTE", + nil, + TimeDelta{Hours: 3, Minutes: 4}, + }, + { + NewLiteral("04:05.06", sql.Text), + "MINUTE_MICROSECOND", + nil, + TimeDelta{Minutes: 4, Seconds: 5, Microseconds: 6}, + }, + { + NewLiteral("04:05", sql.Text), + "MINUTE_SECOND", + nil, + TimeDelta{Minutes: 4, Seconds: 5}, + }, + { + NewLiteral("04.05", sql.Text), + "SECOND_MICROSECOND", + nil, + TimeDelta{Seconds: 4, Microseconds: 5}, + }, + { + NewLiteral("1-5", sql.Text), + "YEAR_MONTH", + nil, + TimeDelta{Years: 1, Months: 5}, + }, + } + + for _, tt := range testCases { + interval := NewInterval(tt.expr, tt.unit) + t.Run(interval.String(), func(t *testing.T) { + require := require.New(t) + result, err := interval.EvalDelta(sql.NewEmptyContext(), tt.row) + require.NoError(err) + require.Equal(tt.expected, *result) + }) + } +} + +func date(year int, month time.Month, day, hour, min, sec, micro int) time.Time { + return time.Date(year, month, day, hour, min, sec, micro*int(time.Microsecond), time.Local) +} diff --git a/sql/expression/isnull.go b/sql/expression/isnull.go index 6b6cdecb7..a9ae575d5 100644 --- a/sql/expression/isnull.go +++ b/sql/expression/isnull.go @@ -1,6 +1,6 @@ package expression -import "gopkg.in/src-d/go-mysql-server.v0/sql" +import "github.com/src-d/go-mysql-server/sql" // IsNull is an expression that checks if an expression is null. type IsNull struct { @@ -24,9 +24,6 @@ func (e *IsNull) IsNullable() bool { // Eval implements the Expression interface. func (e *IsNull) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - span, ctx := ctx.Span("expression.IsNull") - defer span.Finish() - v, err := e.Child.Eval(ctx, row) if err != nil { return nil, err @@ -39,11 +36,10 @@ func (e IsNull) String() string { return e.Child.String() + " IS NULL" } -// TransformUp implements the Expression interface. -func (e *IsNull) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - child, err := e.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (e *IsNull) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(e, len(children), 1) } - return f(NewIsNull(child)) + return NewIsNull(children[0]), nil } diff --git a/sql/expression/isnull_test.go b/sql/expression/isnull_test.go index f158455cf..5e638ba6c 100644 --- a/sql/expression/isnull_test.go +++ b/sql/expression/isnull_test.go @@ -3,7 +3,7 @@ package expression import ( "testing" - "gopkg.in/src-d/go-mysql-server.v0/sql" + "github.com/src-d/go-mysql-server/sql" "github.com/stretchr/testify/require" ) diff --git a/sql/expression/istrue.go b/sql/expression/istrue.go new file mode 100644 index 000000000..ba3e0c75f --- /dev/null +++ b/sql/expression/istrue.go @@ -0,0 +1,78 @@ +package expression + +import ( + "errors" + "github.com/src-d/go-mysql-server/sql" +) + +// IsTrue is an expression that checks if an expression is true. +type IsTrue struct { + UnaryExpression + invert bool +} + +const IsTrueStr = "IS TRUE" +const IsFalseStr = "IS FALSE" + +// NewIsTrue creates a new IsTrue expression. +func NewIsTrue(child sql.Expression) *IsTrue { + return &IsTrue{UnaryExpression: UnaryExpression{child}} +} + +// NewIsFalse creates a new IsTrue expression with its boolean sense inverted (IsFalse, effectively). +func NewIsFalse(child sql.Expression) *IsTrue { + return &IsTrue{UnaryExpression: UnaryExpression{child}, invert: true} +} + +// Type implements the Expression interface. +func (*IsTrue) Type() sql.Type { + return sql.Boolean +} + +// IsNullable implements the Expression interface. +func (*IsTrue) IsNullable() bool { + return false +} + +// Eval implements the Expression interface. +func (e *IsTrue) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + v, err := e.Child.Eval(ctx, row) + if err != nil { + return nil, err + } + + var boolVal interface{} + if v == nil { + return false, nil + } else { + boolVal, err = sql.Boolean.Convert(v) + if err != nil { + return nil, err + } + } + + if e.invert { + return !boolVal.(bool), nil + } + return boolVal, nil +} + +func (e *IsTrue) String() string { + isStr := IsTrueStr + if e.invert { + isStr = IsFalseStr + } + return e.Child.String() + " " + isStr +} + +// WithChildren implements the Expression interface. +func (e *IsTrue) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, errors.New("incorrect number of children") + } + + if e.invert { + return NewIsFalse(children[0]), nil + } + return NewIsTrue(children[0]), nil +} diff --git a/sql/expression/istrue_test.go b/sql/expression/istrue_test.go new file mode 100644 index 000000000..0cc4e41d1 --- /dev/null +++ b/sql/expression/istrue_test.go @@ -0,0 +1,85 @@ +package expression + +import ( + "github.com/src-d/go-mysql-server/sql" + "github.com/stretchr/testify/require" + "testing" +) + +func TestIsTrue(t *testing.T) { + require := require.New(t) + + boolF := NewGetField(0, sql.Boolean, "col1", true) + e := NewIsTrue(boolF) + require.Equal(sql.Boolean, e.Type()) + require.False(e.IsNullable()) + require.Equal(false, eval(t, e, sql.NewRow(nil))) + require.Equal(true, eval(t, e, sql.NewRow(true))) + require.Equal(false, eval(t, e, sql.NewRow(false))) + + intF := NewGetField(0, sql.Int64, "col1", true) + e = NewIsTrue(intF) + require.Equal(sql.Boolean, e.Type()) + require.False(e.IsNullable()) + require.Equal(false, eval(t, e, sql.NewRow(nil))) + require.Equal(true, eval(t, e, sql.NewRow(100))) + require.Equal(true, eval(t, e, sql.NewRow(-1))) + require.Equal(false, eval(t, e, sql.NewRow(0))) + + floatF := NewGetField(0, sql.Float64, "col1", true) + e = NewIsTrue(floatF) + require.Equal(sql.Boolean, e.Type()) + require.False(e.IsNullable()) + require.Equal(false, eval(t, e, sql.NewRow(nil))) + require.Equal(true, eval(t, e, sql.NewRow(1.5))) + require.Equal(true, eval(t, e, sql.NewRow(-1.5))) + require.Equal(false, eval(t, e, sql.NewRow(0))) + + stringF := NewGetField(0, sql.Text, "col1", true) + e = NewIsTrue(stringF) + require.Equal(sql.Boolean, e.Type()) + require.False(e.IsNullable()) + require.Equal(false, eval(t, e, sql.NewRow(nil))) + require.Equal(false, eval(t, e, sql.NewRow(""))) + require.Equal(false, eval(t, e, sql.NewRow("false"))) + require.Equal(false, eval(t, e, sql.NewRow("true"))) +} + +func TestIsFalse(t *testing.T) { + require := require.New(t) + + boolF := NewGetField(0, sql.Boolean, "col1", true) + e := NewIsFalse(boolF) + require.Equal(sql.Boolean, e.Type()) + require.False(e.IsNullable()) + require.Equal(false, eval(t, e, sql.NewRow(nil))) + require.Equal(false, eval(t, e, sql.NewRow(true))) + require.Equal(true, eval(t, e, sql.NewRow(false))) + + intF := NewGetField(0, sql.Int64, "col1", true) + e = NewIsFalse(intF) + require.Equal(sql.Boolean, e.Type()) + require.False(e.IsNullable()) + require.Equal(false, eval(t, e, sql.NewRow(nil))) + require.Equal(false, eval(t, e, sql.NewRow(100))) + require.Equal(false, eval(t, e, sql.NewRow(-1))) + require.Equal(true, eval(t, e, sql.NewRow(0))) + + floatF := NewGetField(0, sql.Float64, "col1", true) + e = NewIsFalse(floatF) + require.Equal(sql.Boolean, e.Type()) + require.False(e.IsNullable()) + require.Equal(false, eval(t, e, sql.NewRow(nil))) + require.Equal(false, eval(t, e, sql.NewRow(1.5))) + require.Equal(false, eval(t, e, sql.NewRow(-1.5))) + require.Equal(true, eval(t, e, sql.NewRow(0))) + + stringF := NewGetField(0, sql.Text, "col1", true) + e = NewIsFalse(stringF) + require.Equal(sql.Boolean, e.Type()) + require.False(e.IsNullable()) + require.Equal(false, eval(t, e, sql.NewRow(nil))) + require.Equal(true, eval(t, e, sql.NewRow(""))) + require.Equal(true, eval(t, e, sql.NewRow("false"))) + require.Equal(true, eval(t, e, sql.NewRow("true"))) +} diff --git a/sql/expression/like.go b/sql/expression/like.go new file mode 100644 index 000000000..7913eb22c --- /dev/null +++ b/sql/expression/like.go @@ -0,0 +1,158 @@ +package expression + +import ( + "bytes" + "fmt" + "regexp" + "strings" + "sync" + + "github.com/src-d/go-mysql-server/internal/regex" + "github.com/src-d/go-mysql-server/sql" +) + +// Like performs pattern matching against two strings. +type Like struct { + BinaryExpression + pool *sync.Pool + cached bool +} + +// NewLike creates a new LIKE expression. +func NewLike(left, right sql.Expression) sql.Expression { + var cached = true + Inspect(right, func(e sql.Expression) bool { + if _, ok := e.(*GetField); ok { + cached = false + } + return true + }) + + return &Like{ + BinaryExpression: BinaryExpression{left, right}, + pool: nil, + cached: cached, + } +} + +// Type implements the sql.Expression interface. +func (l *Like) Type() sql.Type { return sql.Boolean } + +// Eval implements the sql.Expression interface. +func (l *Like) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + span, ctx := ctx.Span("expression.Like") + defer span.Finish() + + left, err := l.Left.Eval(ctx, row) + if err != nil || left == nil { + return nil, err + } + left, err = sql.Text.Convert(left) + if err != nil { + return nil, err + } + + var ( + matcher regex.Matcher + disposer regex.Disposer + right string + ) + // eval right and convert to text + if !l.cached || l.pool == nil { + var v interface{} + v, err = l.Right.Eval(ctx, row) + if err != nil || v == nil { + return nil, err + } + v, err = sql.Text.Convert(v) + if err != nil { + return nil, err + } + right = patternToGoRegex(v.(string)) + } + // for non-cached regex every time create a new matcher + if !l.cached { + matcher, disposer, err = regex.New("go", right) + } else { + if l.pool == nil { + l.pool = &sync.Pool{ + New: func() interface{} { + r, _, e := regex.New(regex.Default(), right) + if e != nil { + err = e + return nil + } + return r + }, + } + } + matcher = l.pool.Get().(regex.Matcher) + } + if matcher == nil { + return nil, err + } + + ok := matcher.Match(left.(string)) + if !l.cached { + disposer.Dispose() + } else if l.pool != nil { + l.pool.Put(matcher) + + } + + return ok, nil +} + +func (l *Like) String() string { + return fmt.Sprintf("%s LIKE %s", l.Left, l.Right) +} + +// WithChildren implements the Expression interface. +func (l *Like) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(l, len(children), 2) + } + return NewLike(children[0], children[1]), nil +} + +func patternToGoRegex(pattern string) string { + var buf bytes.Buffer + buf.WriteString("(?s)") + buf.WriteRune('^') + var escaped bool + for _, r := range strings.Replace(regexp.QuoteMeta(pattern), `\\`, `\`, -1) { + switch r { + case '_': + if escaped { + buf.WriteRune(r) + } else { + buf.WriteRune('.') + } + case '%': + if !escaped { + buf.WriteString(".*") + } else { + buf.WriteRune(r) + } + case '\\': + if escaped { + buf.WriteString(`\\`) + } else { + escaped = true + continue + } + default: + if escaped { + buf.WriteString(`\`) + } + buf.WriteRune(r) + } + + if escaped { + escaped = false + } + } + + buf.WriteRune('$') + return buf.String() +} diff --git a/sql/expression/like_test.go b/sql/expression/like_test.go new file mode 100644 index 000000000..78b768bec --- /dev/null +++ b/sql/expression/like_test.go @@ -0,0 +1,65 @@ +package expression + +import ( + "fmt" + "testing" + + "github.com/src-d/go-mysql-server/sql" + "github.com/stretchr/testify/require" +) + +func TestPatternToRegex(t *testing.T) { + testCases := []struct { + in, out string + }{ + {`__`, `(?s)^..$`}, + {`_%_`, `(?s)^..*.$`}, + {`%_`, `(?s)^.*.$`}, + {`_%`, `(?s)^..*$`}, + {`a_b`, `(?s)^a.b$`}, + {`a%b`, `(?s)^a.*b$`}, + {`a.%b`, `(?s)^a\..*b$`}, + {`a\%b`, `(?s)^a%b$`}, + {`a\_b`, `(?s)^a_b$`}, + {`a\\b`, `(?s)^a\\b$`}, + {`a\\\_b`, `(?s)^a\\_b$`}, + {`(ab)`, `(?s)^\(ab\)$`}, + } + + for _, tt := range testCases { + t.Run(tt.in, func(t *testing.T) { + require.Equal(t, tt.out, patternToGoRegex(tt.in)) + }) + } +} + +func TestLike(t *testing.T) { + f := NewLike( + NewGetField(0, sql.Text, "", false), + NewGetField(1, sql.Text, "", false), + ) + + testCases := []struct { + pattern, value string + ok bool + }{ + {"a__", "abc", true}, + {"a__", "abcd", false}, + {"a%b", "acb", true}, + {"a%b", "acdkeflskjfdklb", true}, + {"a%b", "ab", true}, + {"a%b", "a", false}, + {"a_b", "ab", false}, + } + + for _, tt := range testCases { + t.Run(fmt.Sprintf("%q LIKE %q", tt.value, tt.pattern), func(t *testing.T) { + value, err := f.Eval(sql.NewEmptyContext(), sql.NewRow( + tt.value, + tt.pattern, + )) + require.NoError(t, err) + require.Equal(t, tt.ok, value) + }) + } +} diff --git a/sql/expression/literal.go b/sql/expression/literal.go index e26cc844c..02e88e0a9 100644 --- a/sql/expression/literal.go +++ b/sql/expression/literal.go @@ -3,7 +3,7 @@ package expression import ( "fmt" - "gopkg.in/src-d/go-mysql-server.v0/sql" + "github.com/src-d/go-mysql-server/sql" ) // Literal represents a literal expression (string, number, bool, ...). @@ -14,6 +14,8 @@ type Literal struct { // NewLiteral creates a new Literal expression. func NewLiteral(value interface{}, fieldType sql.Type) *Literal { + // TODO(juanjux): we should probably check here if the type is sql.VarChar and the + // Capacity of the Type and the length of the value, but this can't return an error return &Literal{ value: value, fieldType: fieldType, @@ -51,13 +53,20 @@ func (p *Literal) String() string { } } -// TransformUp implements the Expression interface. -func (p *Literal) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - n := *p - return f(&n) +// WithChildren implements the Expression interface. +func (p *Literal) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 0) + } + return p, nil } // Children implements the Expression interface. func (*Literal) Children() []sql.Expression { return nil } + +// Value returns the literal value. +func (p *Literal) Value() interface{} { + return p.value +} diff --git a/sql/expression/logic.go b/sql/expression/logic.go index d5e032f0c..08a31c087 100644 --- a/sql/expression/logic.go +++ b/sql/expression/logic.go @@ -3,7 +3,7 @@ package expression import ( "fmt" - "gopkg.in/src-d/go-mysql-server.v0/sql" + "github.com/src-d/go-mysql-server/sql" ) // And checks whether two expressions are true. @@ -43,9 +43,6 @@ func (*And) Type() sql.Type { // Eval implements the Expression interface. func (a *And) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - span, ctx := ctx.Span("expression.And") - defer span.Finish() - lval, err := a.Left.Eval(ctx, row) if err != nil { return nil, err @@ -71,19 +68,12 @@ func (a *And) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return true, nil } -// TransformUp implements the Expression interface. -func (a *And) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - left, err := a.Left.TransformUp(f) - if err != nil { - return nil, err - } - - right, err := a.Right.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (a *And) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(a, len(children), 2) } - - return f(NewAnd(left, right)) + return NewAnd(children[0], children[1]), nil } // Or checks whether one of the two given expressions is true. @@ -107,9 +97,6 @@ func (*Or) Type() sql.Type { // Eval implements the Expression interface. func (o *Or) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - span, ctx := ctx.Span("expression.Or") - defer span.Finish() - lval, err := o.Left.Eval(ctx, row) if err != nil { return nil, err @@ -131,17 +118,10 @@ func (o *Or) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return rval == true, nil } -// TransformUp implements the Expression interface. -func (o *Or) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - left, err := o.Left.TransformUp(f) - if err != nil { - return nil, err - } - - right, err := o.Right.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Expression interface. +func (o *Or) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(o, len(children), 2) } - - return f(NewOr(left, right)) + return NewOr(children[0], children[1]), nil } diff --git a/sql/expression/logic_test.go b/sql/expression/logic_test.go index 399da8fff..b71c71b7e 100644 --- a/sql/expression/logic_test.go +++ b/sql/expression/logic_test.go @@ -3,8 +3,8 @@ package expression import ( "testing" + "github.com/src-d/go-mysql-server/sql" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/sql" ) func TestAnd(t *testing.T) { diff --git a/sql/expression/set.go b/sql/expression/set.go new file mode 100644 index 000000000..d18bde374 --- /dev/null +++ b/sql/expression/set.go @@ -0,0 +1,61 @@ +package expression + +import ( + "fmt" + "github.com/src-d/go-mysql-server/sql" + "gopkg.in/src-d/go-errors.v1" +) + +var errCannotSetField = errors.NewKind("Expected GetField expression on left but got %T") + +// SetField updates the value of a field from a row. +type SetField struct { + BinaryExpression +} + +// NewSetField creates a new SetField expression. +func NewSetField(colName, expr sql.Expression) sql.Expression { + return &SetField{BinaryExpression{Left: colName, Right: expr}} +} + +func (s *SetField) String() string { + return fmt.Sprintf("SETFIELD %s = %s", s.Left, s.Right) +} + +// Type implements the Expression interface. +func (s *SetField) Type() sql.Type { + return s.Left.Type() +} + +// Eval implements the Expression interface. +// Returns a copy of the given row with an updated value. +func (s *SetField) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + getField, ok := s.Left.(*GetField) + if !ok { + return nil, errCannotSetField.New(s.Left) + } + if getField.fieldIndex < 0 || getField.fieldIndex >= len(row) { + return nil, ErrIndexOutOfBounds.New(getField.fieldIndex, len(row)) + } + val, err := s.Right.Eval(ctx, row) + if err != nil { + return nil, err + } + if val != nil { + val, err = getField.fieldType.Convert(val) + if err != nil { + return nil, err + } + } + updatedRow := row.Copy() + updatedRow[getField.fieldIndex] = val + return updatedRow, nil +} + +// WithChildren implements the Expression interface. +func (s *SetField) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 2) + } + return NewSetField(children[0], children[1]), nil +} \ No newline at end of file diff --git a/sql/expression/star.go b/sql/expression/star.go index 0fbd6d77b..5d6603be1 100644 --- a/sql/expression/star.go +++ b/sql/expression/star.go @@ -3,7 +3,7 @@ package expression import ( "fmt" - "gopkg.in/src-d/go-mysql-server.v0/sql" + "github.com/src-d/go-mysql-server/sql" ) // Star represents the selection of all available fields. @@ -55,8 +55,10 @@ func (*Star) Eval(ctx *sql.Context, r sql.Row) (interface{}, error) { panic("star is just a placeholder node, but Eval was called") } -// TransformUp implements the Expression interface. -func (s *Star) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - n := *s - return f(&n) +// WithChildren implements the Expression interface. +func (s *Star) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 0) + } + return s, nil } diff --git a/sql/expression/subquery.go b/sql/expression/subquery.go new file mode 100644 index 000000000..faae15aa1 --- /dev/null +++ b/sql/expression/subquery.go @@ -0,0 +1,125 @@ +package expression + +import ( + "fmt" + + "github.com/src-d/go-mysql-server/sql" + errors "gopkg.in/src-d/go-errors.v1" +) + +var errExpectedSingleRow = errors.NewKind("the subquery returned more than 1 row") + +// Subquery that is executed as an expression. +type Subquery struct { + Query sql.Node + value interface{} +} + +// NewSubquery returns a new subquery node. +func NewSubquery(node sql.Node) *Subquery { + return &Subquery{node, nil} +} + +// Eval implements the Expression interface. +func (s *Subquery) Eval(ctx *sql.Context, _ sql.Row) (interface{}, error) { + if s.value != nil { + if elems, ok := s.value.([]interface{}); ok { + if len(elems) > 1 { + return nil, errExpectedSingleRow.New() + } + return elems[0], nil + } + return s.value, nil + } + + iter, err := s.Query.RowIter(ctx) + if err != nil { + return nil, err + } + + rows, err := sql.RowIterToRows(iter) + if err != nil { + return nil, err + } + + if len(rows) == 0 { + s.value = nil + return nil, nil + } + + if len(rows) > 1 { + return nil, errExpectedSingleRow.New() + } + + s.value = rows[0][0] + return s.value, nil +} + +// EvalMultiple returns all rows returned by a subquery. +func (s *Subquery) EvalMultiple(ctx *sql.Context) ([]interface{}, error) { + if s.value != nil { + return s.value.([]interface{}), nil + } + + iter, err := s.Query.RowIter(ctx) + if err != nil { + return nil, err + } + + rows, err := sql.RowIterToRows(iter) + if err != nil { + return nil, err + } + + if len(rows) == 0 { + s.value = []interface{}{} + return nil, nil + } + + var result = make([]interface{}, len(rows)) + for i, row := range rows { + result[i] = row[0] + } + s.value = result + + return result, nil +} + +// IsNullable implements the Expression interface. +func (s *Subquery) IsNullable() bool { + return s.Query.Schema()[0].Nullable +} + +func (s *Subquery) String() string { + return fmt.Sprintf("(%s)", s.Query) +} + +// Resolved implements the Expression interface. +func (s *Subquery) Resolved() bool { + return s.Query.Resolved() +} + +// Type implements the Expression interface. +func (s *Subquery) Type() sql.Type { + return s.Query.Schema()[0].Type +} + +// WithChildren implements the Expression interface. +func (s *Subquery) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 0) + } + return s, nil +} + +// Children implements the Expression interface. +func (s *Subquery) Children() []sql.Expression { + return nil +} + +// WithQuery returns the subquery with the query node changed. +func (s *Subquery) WithQuery(node sql.Node) *Subquery { + ns := *s + ns.Query = node + return &ns +} diff --git a/sql/expression/subquery_test.go b/sql/expression/subquery_test.go new file mode 100644 index 000000000..0d3f353be --- /dev/null +++ b/sql/expression/subquery_test.go @@ -0,0 +1,69 @@ +package expression_test + +import ( + "testing" + + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/plan" + "github.com/stretchr/testify/require" +) + +func TestSubquery(t *testing.T) { + require := require.New(t) + table := memory.NewTable("", nil) + require.NoError(table.Insert(sql.NewEmptyContext(), nil)) + + subquery := expression.NewSubquery(plan.NewProject( + []sql.Expression{ + expression.NewLiteral("one", sql.Text), + }, + plan.NewResolvedTable(table), + )) + + value, err := subquery.Eval(sql.NewEmptyContext(), nil) + require.NoError(err) + require.Equal(value, "one") +} + +func TestSubqueryTooManyRows(t *testing.T) { + require := require.New(t) + table := memory.NewTable("", nil) + require.NoError(table.Insert(sql.NewEmptyContext(), nil)) + require.NoError(table.Insert(sql.NewEmptyContext(), nil)) + + subquery := expression.NewSubquery(plan.NewProject( + []sql.Expression{ + expression.NewLiteral("one", sql.Text), + }, + plan.NewResolvedTable(table), + )) + + _, err := subquery.Eval(sql.NewEmptyContext(), nil) + require.Error(err) +} + +func TestSubqueryMultipleRows(t *testing.T) { + require := require.New(t) + + ctx := sql.NewEmptyContext() + table := memory.NewTable("foo", sql.Schema{ + {Name: "t", Source: "foo", Type: sql.Text}, + }) + + require.NoError(table.Insert(ctx, sql.Row{"one"})) + require.NoError(table.Insert(ctx, sql.Row{"two"})) + require.NoError(table.Insert(ctx, sql.Row{"three"})) + + subquery := expression.NewSubquery(plan.NewProject( + []sql.Expression{ + expression.NewGetField(0, sql.Text, "t", false), + }, + plan.NewResolvedTable(table), + )) + + values, err := subquery.EvalMultiple(ctx) + require.NoError(err) + require.Equal(values, []interface{}{"one", "two", "three"}) +} diff --git a/sql/expression/transform.go b/sql/expression/transform.go new file mode 100644 index 000000000..05195c7fd --- /dev/null +++ b/sql/expression/transform.go @@ -0,0 +1,30 @@ +package expression + +import ( + "github.com/src-d/go-mysql-server/sql" +) + +// TransformUp applies a transformation function to the given expression from the +// bottom up. +func TransformUp(e sql.Expression, f sql.TransformExprFunc) (sql.Expression, error) { + children := e.Children() + if len(children) == 0 { + return f(e) + } + + newChildren := make([]sql.Expression, len(children)) + for i, c := range children { + c, err := TransformUp(c, f) + if err != nil { + return nil, err + } + newChildren[i] = c + } + + e, err := e.WithChildren(newChildren...) + if err != nil { + return nil, err + } + + return f(e) +} diff --git a/sql/expression/tuple.go b/sql/expression/tuple.go index 799103013..11d35e1f5 100644 --- a/sql/expression/tuple.go +++ b/sql/expression/tuple.go @@ -4,8 +4,7 @@ import ( "fmt" "strings" - opentracing "github.com/opentracing/opentracing-go" - "gopkg.in/src-d/go-mysql-server.v0/sql" + "github.com/src-d/go-mysql-server/sql" ) // Tuple is a fixed-size collection of expressions. @@ -19,9 +18,6 @@ func NewTuple(exprs ...sql.Expression) Tuple { // Eval implements the Expression interface. func (t Tuple) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - span, ctx := ctx.Span("expression.Tuple", opentracing.Tag{Key: "elems", Value: len(t)}) - defer span.Finish() - if len(t) == 1 { return t[0].Eval(ctx, row) } @@ -81,18 +77,12 @@ func (t Tuple) Type() sql.Type { return sql.Tuple(types...) } -// TransformUp implements the Expression interface. -func (t Tuple) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - var exprs = make([]sql.Expression, len(t)) - for i, e := range t { - var err error - exprs[i], err = f(e) - if err != nil { - return nil, err - } +// WithChildren implements the Expression interface. +func (t Tuple) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != len(t) { + return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), len(t)) } - - return f(Tuple(exprs)) + return NewTuple(children...), nil } // Children implements the Expression interface. diff --git a/sql/expression/tuple_test.go b/sql/expression/tuple_test.go index 947463d89..25fc411a7 100644 --- a/sql/expression/tuple_test.go +++ b/sql/expression/tuple_test.go @@ -3,8 +3,8 @@ package expression import ( "testing" + "github.com/src-d/go-mysql-server/sql" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/sql" ) func TestTuple(t *testing.T) { diff --git a/sql/expression/unresolved.go b/sql/expression/unresolved.go index 4587edd48..6580655d2 100644 --- a/sql/expression/unresolved.go +++ b/sql/expression/unresolved.go @@ -4,7 +4,7 @@ import ( "fmt" "strings" - "gopkg.in/src-d/go-mysql-server.v0/sql" + "github.com/src-d/go-mysql-server/sql" ) // UnresolvedColumn is an expression of a column that is not yet resolved. @@ -64,10 +64,12 @@ func (*UnresolvedColumn) Eval(ctx *sql.Context, r sql.Row) (interface{}, error) panic("unresolved column is a placeholder node, but Eval was called") } -// TransformUp implements the Expression interface. -func (uc *UnresolvedColumn) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - n := *uc - return f(&n) +// WithChildren implements the Expression interface. +func (uc *UnresolvedColumn) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(uc, len(children), 0) + } + return uc, nil } // UnresolvedFunction represents a function that is not yet resolved. @@ -126,16 +128,10 @@ func (*UnresolvedFunction) Eval(ctx *sql.Context, r sql.Row) (interface{}, error panic("unresolved function is a placeholder node, but Eval was called") } -// TransformUp implements the Expression interface. -func (uf *UnresolvedFunction) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { - var rc []sql.Expression - for _, c := range uf.Arguments { - c, err := c.TransformUp(f) - if err != nil { - return nil, err - } - rc = append(rc, c) +// WithChildren implements the Expression interface. +func (uf *UnresolvedFunction) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != len(uf.Arguments) { + return nil, sql.ErrInvalidChildrenNumber.New(uf, len(children), len(uf.Arguments)) } - - return f(NewUnresolvedFunction(uf.name, uf.IsAggregate, rc...)) + return NewUnresolvedFunction(uf.name, uf.IsAggregate, children...), nil } diff --git a/sql/expression/unresolved_test.go b/sql/expression/unresolved_test.go index 59ceb3551..f4b451b3e 100644 --- a/sql/expression/unresolved_test.go +++ b/sql/expression/unresolved_test.go @@ -3,8 +3,8 @@ package expression import ( "testing" + "github.com/src-d/go-mysql-server/sql" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/sql" ) func TestUnresolvedExpression(t *testing.T) { diff --git a/sql/expression/walk.go b/sql/expression/walk.go index 36debed92..f04086781 100644 --- a/sql/expression/walk.go +++ b/sql/expression/walk.go @@ -1,6 +1,6 @@ package expression -import "gopkg.in/src-d/go-mysql-server.v0/sql" +import "github.com/src-d/go-mysql-server/sql" // Visitor visits exprs in the plan. type Visitor interface { diff --git a/sql/expression/walk_test.go b/sql/expression/walk_test.go index f093c4c6b..afbf4d7ad 100644 --- a/sql/expression/walk_test.go +++ b/sql/expression/walk_test.go @@ -3,8 +3,8 @@ package expression import ( "testing" + "github.com/src-d/go-mysql-server/sql" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/sql" ) func TestWalk(t *testing.T) { diff --git a/sql/functionregistry.go b/sql/functionregistry.go index e3966ac04..0d9d6cfb9 100644 --- a/sql/functionregistry.go +++ b/sql/functionregistry.go @@ -1,114 +1,168 @@ package sql import ( + "github.com/src-d/go-mysql-server/internal/similartext" "gopkg.in/src-d/go-errors.v1" ) +// ErrFunctionAlreadyRegistered is thrown when a function is already registered +var ErrFunctionAlreadyRegistered = errors.NewKind("function '%s' is already registered") + // ErrFunctionNotFound is thrown when a function is not found -var ErrFunctionNotFound = errors.NewKind("function not found: %s") +var ErrFunctionNotFound = errors.NewKind("function: '%s' not found") // ErrInvalidArgumentNumber is returned when the number of arguments to call a // function is different from the function arity. -var ErrInvalidArgumentNumber = errors.NewKind("expecting %v arguments for calling this function, %d received") +var ErrInvalidArgumentNumber = errors.NewKind("function '%s' expected %v arguments, %v received") -// Function is a function defined by the user that can be applied in a SQL -// query. +// Function is a function defined by the user that can be applied in a SQL query. type Function interface { // Call invokes the function. Call(...Expression) (Expression, error) + // Function name + name() string // isFunction will restrict implementations of Function isFunction() } type ( + // Function0 is a function with 0 arguments. + Function0 struct { + Name string + Fn func() Expression + } // Function1 is a function with 1 argument. - Function1 func(e Expression) Expression + Function1 struct { + Name string + Fn func(e Expression) Expression + } // Function2 is a function with 2 arguments. - Function2 func(e1, e2 Expression) Expression + Function2 struct { + Name string + Fn func(e1, e2 Expression) Expression + } // Function3 is a function with 3 arguments. - Function3 func(e1, e2, e3 Expression) Expression + Function3 struct { + Name string + Fn func(e1, e2, e3 Expression) Expression + } // Function4 is a function with 4 arguments. - Function4 func(e1, e2, e3, e4 Expression) Expression + Function4 struct { + Name string + Fn func(e1, e2, e3, e4 Expression) Expression + } // Function5 is a function with 5 arguments. - Function5 func(e1, e2, e3, e4, e5 Expression) Expression + Function5 struct { + Name string + Fn func(e1, e2, e3, e4, e5 Expression) Expression + } // Function6 is a function with 6 arguments. - Function6 func(e1, e2, e3, e4, e5, e6 Expression) Expression + Function6 struct { + Name string + Fn func(e1, e2, e3, e4, e5, e6 Expression) Expression + } // Function7 is a function with 7 arguments. - Function7 func(e1, e2, e3, e4, e5, e6, e7 Expression) Expression + Function7 struct { + Name string + Fn func(e1, e2, e3, e4, e5, e6, e7 Expression) Expression + } // FunctionN is a function with variable number of arguments. This function // is expected to return ErrInvalidArgumentNumber if the arity does not // match, since the check has to be done in the implementation. - FunctionN func(...Expression) (Expression, error) + FunctionN struct { + Name string + Fn func(...Expression) (Expression, error) + } ) +// Call implements the Function interface. +func (fn Function0) Call(args ...Expression) (Expression, error) { + if len(args) != 0 { + return nil, ErrInvalidArgumentNumber.New(fn.Name, 0, len(args)) + } + + return fn.Fn(), nil +} + // Call implements the Function interface. func (fn Function1) Call(args ...Expression) (Expression, error) { if len(args) != 1 { - return nil, ErrInvalidArgumentNumber.New(1, len(args)) + return nil, ErrInvalidArgumentNumber.New(fn.Name, 1, len(args)) } - return fn(args[0]), nil + return fn.Fn(args[0]), nil } // Call implements the Function interface. func (fn Function2) Call(args ...Expression) (Expression, error) { if len(args) != 2 { - return nil, ErrInvalidArgumentNumber.New(2, len(args)) + return nil, ErrInvalidArgumentNumber.New(fn.Name, 2, len(args)) } - return fn(args[0], args[1]), nil + return fn.Fn(args[0], args[1]), nil } // Call implements the Function interface. func (fn Function3) Call(args ...Expression) (Expression, error) { if len(args) != 3 { - return nil, ErrInvalidArgumentNumber.New(3, len(args)) + return nil, ErrInvalidArgumentNumber.New(fn.Name, 3, len(args)) } - return fn(args[0], args[1], args[2]), nil + return fn.Fn(args[0], args[1], args[2]), nil } // Call implements the Function interface. func (fn Function4) Call(args ...Expression) (Expression, error) { if len(args) != 4 { - return nil, ErrInvalidArgumentNumber.New(4, len(args)) + return nil, ErrInvalidArgumentNumber.New(fn.Name, 4, len(args)) } - return fn(args[0], args[1], args[2], args[3]), nil + return fn.Fn(args[0], args[1], args[2], args[3]), nil } // Call implements the Function interface. func (fn Function5) Call(args ...Expression) (Expression, error) { if len(args) != 5 { - return nil, ErrInvalidArgumentNumber.New(5, len(args)) + return nil, ErrInvalidArgumentNumber.New(fn.Name, 5, len(args)) } - return fn(args[0], args[1], args[2], args[3], args[4]), nil + return fn.Fn(args[0], args[1], args[2], args[3], args[4]), nil } // Call implements the Function interface. func (fn Function6) Call(args ...Expression) (Expression, error) { if len(args) != 6 { - return nil, ErrInvalidArgumentNumber.New(6, len(args)) + return nil, ErrInvalidArgumentNumber.New(fn.Name, 6, len(args)) } - return fn(args[0], args[1], args[2], args[3], args[4], args[5]), nil + return fn.Fn(args[0], args[1], args[2], args[3], args[4], args[5]), nil } // Call implements the Function interface. func (fn Function7) Call(args ...Expression) (Expression, error) { if len(args) != 7 { - return nil, ErrInvalidArgumentNumber.New(7, len(args)) + return nil, ErrInvalidArgumentNumber.New(fn.Name, 7, len(args)) } - return fn(args[0], args[1], args[2], args[3], args[4], args[5], args[6]), nil + return fn.Fn(args[0], args[1], args[2], args[3], args[4], args[5], args[6]), nil } // Call implements the Function interface. func (fn FunctionN) Call(args ...Expression) (Expression, error) { - return fn(args...) + return fn.Fn(args...) } +func (fn Function0) name() string { return fn.Name } +func (fn Function1) name() string { return fn.Name } +func (fn Function2) name() string { return fn.Name } +func (fn Function3) name() string { return fn.Name } +func (fn Function4) name() string { return fn.Name } +func (fn Function5) name() string { return fn.Name } +func (fn Function6) name() string { return fn.Name } +func (fn Function7) name() string { return fn.Name } +func (fn FunctionN) name() string { return fn.Name } + +func (Function0) isFunction() {} func (Function1) isFunction() {} func (Function2) isFunction() {} func (Function3) isFunction() {} @@ -122,32 +176,41 @@ func (FunctionN) isFunction() {} // and User-Defined Functions. type FunctionRegistry map[string]Function -// Functions is a map of functions identified by their name. -type Functions map[string]Function - // NewFunctionRegistry creates a new FunctionRegistry. func NewFunctionRegistry() FunctionRegistry { return make(FunctionRegistry) } -// RegisterFunction registers a function with the given name. -func (r FunctionRegistry) RegisterFunction(name string, f Function) { - r[name] = f +// Register registers functions. +// If function with that name is already registered, +// the ErrFunctionAlreadyRegistered will be returned +func (r FunctionRegistry) Register(fn ...Function) error { + for _, f := range fn { + if _, ok := r[f.name()]; ok { + return ErrFunctionAlreadyRegistered.New(f.name()) + } + r[f.name()] = f + } + return nil } -// RegisterFunctions registers a map of functions. -func (r FunctionRegistry) RegisterFunctions(funcs Functions) { - for name, f := range funcs { - r[name] = f +// MustRegister registers functions. +// If function with that name is already registered, it will panic! +func (r FunctionRegistry) MustRegister(fn ...Function) { + if err := r.Register(fn...); err != nil { + panic(err) } } // Function returns a function with the given name. func (r FunctionRegistry) Function(name string) (Function, error) { - e, ok := r[name] - if !ok { + if len(r) == 0 { return nil, ErrFunctionNotFound.New(name) } - return e, nil + if fn, ok := r[name]; ok { + return fn, nil + } + similar := similartext.FindFromMap(r, name) + return nil, ErrFunctionNotFound.New(name + similar) } diff --git a/sql/functionregistry_test.go b/sql/functionregistry_test.go index 45da4adc9..d60042331 100644 --- a/sql/functionregistry_test.go +++ b/sql/functionregistry_test.go @@ -3,9 +3,9 @@ package sql_test import ( "testing" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" ) func TestFunctionRegistry(t *testing.T) { @@ -14,12 +14,13 @@ func TestFunctionRegistry(t *testing.T) { c := sql.NewCatalog() name := "func" var expected sql.Expression = expression.NewStar() - c.RegisterFunction(name, sql.Function1(func(arg sql.Expression) sql.Expression { - return expected - })) + c.MustRegister(sql.Function1{ + Name: name, + Fn: func(arg sql.Expression) sql.Expression { return expected }, + }) f, err := c.Function(name) - require.Nil(err) + require.NoError(err) e, err := f.Call() require.Error(err) diff --git a/sql/generator.go b/sql/generator.go new file mode 100644 index 000000000..218b1db66 --- /dev/null +++ b/sql/generator.go @@ -0,0 +1,57 @@ +package sql + +import ( + "io" + + "gopkg.in/src-d/go-errors.v1" +) + +// Generator will generate a set of values for a given row. +type Generator interface { + // Next value in the generator. + Next() (interface{}, error) + // Close the generator and dispose resources. + Close() error +} + +// ErrNotGenerator is returned when the value cannot be converted to a +// generator. +var ErrNotGenerator = errors.NewKind("cannot convert value of type %T to a generator") + +// ToGenerator converts a value to a generator if possible. +func ToGenerator(v interface{}) (Generator, error) { + switch v := v.(type) { + case Generator: + return v, nil + case []interface{}: + return NewArrayGenerator(v), nil + case nil: + return NewArrayGenerator(nil), nil + default: + return nil, ErrNotGenerator.New(v) + } +} + +// NewArrayGenerator creates a generator for a given array. +func NewArrayGenerator(array []interface{}) Generator { + return &arrayGenerator{array, 0} +} + +type arrayGenerator struct { + array []interface{} + pos int +} + +func (g *arrayGenerator) Next() (interface{}, error) { + if g.pos >= len(g.array) { + return nil, io.EOF + } + + g.pos++ + return g.array[g.pos-1], nil +} + +func (g *arrayGenerator) Close() error { + g.pos = len(g.array) + return nil +} diff --git a/sql/generator_test.go b/sql/generator_test.go new file mode 100644 index 000000000..145411333 --- /dev/null +++ b/sql/generator_test.go @@ -0,0 +1,54 @@ +package sql + +import ( + "fmt" + "io" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestArrayGenerator(t *testing.T) { + require := require.New(t) + + expected := []interface{}{"a", "b", "c"} + gen := NewArrayGenerator(expected) + + var values []interface{} + for { + v, err := gen.Next() + if err != nil { + if err == io.EOF { + break + } + require.NoError(err) + } + values = append(values, v) + } + + require.Equal(expected, values) +} + +func TestToGenerator(t *testing.T) { + require := require.New(t) + + gen, err := ToGenerator([]interface{}{1, 2, 3}) + require.NoError(err) + require.Equal(NewArrayGenerator([]interface{}{1, 2, 3}), gen) + + gen, err = ToGenerator(new(fakeGen)) + require.NoError(err) + require.Equal(new(fakeGen), gen) + + gen, err = ToGenerator(nil) + require.NoError(err) + require.Equal(NewArrayGenerator(nil), gen) + + _, err = ToGenerator("foo") + require.Error(err) +} + +type fakeGen struct{} + +func (fakeGen) Next() (interface{}, error) { return nil, fmt.Errorf("not implemented") } +func (fakeGen) Close() error { return nil } diff --git a/sql/index.go b/sql/index.go index 3da4a3155..c673341ca 100644 --- a/sql/index.go +++ b/sql/index.go @@ -1,29 +1,52 @@ package sql import ( - "context" "io" - "reflect" "strings" "sync" + "github.com/src-d/go-mysql-server/internal/similartext" + + "github.com/sirupsen/logrus" "gopkg.in/src-d/go-errors.v1" ) +// IndexBatchSize is the number of rows to save at a time when creating indexes. +const IndexBatchSize = uint64(10000) + +// ChecksumKey is the key in an index config to store the checksum. +const ChecksumKey = "checksum" + +// Checksumable provides the checksum of some data. +type Checksumable interface { + // Checksum returns a checksum and an error if there was any problem + // computing or obtaining the checksum. + Checksum() (string, error) +} + +// PartitionIndexKeyValueIter is an iterator of partitions that will return +// the partition and the IndexKeyValueIter of that partition. +type PartitionIndexKeyValueIter interface { + // Next returns the next partition and the IndexKeyValueIter for that + // partition. + Next() (Partition, IndexKeyValueIter, error) + io.Closer +} + // IndexKeyValueIter is an iterator of index key values, that is, a tuple of // the values that will be index keys. type IndexKeyValueIter interface { // Next returns the next tuple of index key values. The length of the // returned slice will be the same as the number of columns used to - // create this iterator. - Next() ([]interface{}, error) + // create this iterator. The second returned parameter is a repo's location. + Next() ([]interface{}, []byte, error) io.Closer } // IndexValueIter is an iterator of index values. type IndexValueIter interface { - // Next returns the next index value. - Next() (interface{}, error) + // Next returns the next value (repo's location) - see IndexKeyValueIter. + Next() ([]byte, error) io.Closer } @@ -31,9 +54,9 @@ type IndexValueIter interface { // more functionality by implementing more specific interfaces. type Index interface { // Get returns an IndexLookup for the given key in the index. - Get(key interface{}) (IndexLookup, error) + Get(key ...interface{}) (IndexLookup, error) // Has checks if the given key is present in the index. - Has(key interface{}) (bool, error) + Has(partition Partition, key ...interface{}) (bool, error) // ID returns the identifier of the index. ID() string // Database returns the database name this index belongs to. @@ -43,40 +66,52 @@ type Index interface { // Expressions returns the indexed expressions. If the result is more than // one expression, it means the index has multiple columns indexed. If it's // just one, it means it may be an expression or a column. - Expressions() []Expression + Expressions() []string + // Driver ID of the index. + Driver() string } // AscendIndex is an index that is sorted in ascending order. type AscendIndex interface { // AscendGreaterOrEqual returns an IndexLookup for keys that are greater - // or equal to the given key. - AscendGreaterOrEqual(key interface{}) (IndexLookup, error) + // or equal to the given keys. + AscendGreaterOrEqual(keys ...interface{}) (IndexLookup, error) // AscendLessThan returns an IndexLookup for keys that are less than the - // given key. - AscendLessThan(key interface{}) (IndexLookup, error) + // given keys. + AscendLessThan(keys ...interface{}) (IndexLookup, error) // AscendRange returns an IndexLookup for keys that are within the given // range. - AscendRange(greaterOrEqual, lessThan interface{}) (IndexLookup, error) + AscendRange(greaterOrEqual, lessThan []interface{}) (IndexLookup, error) } // DescendIndex is an index that is sorted in descending order. type DescendIndex interface { // DescendGreater returns an IndexLookup for keys that are greater - // than the given key. - DescendGreater(key interface{}) (IndexLookup, error) + // than the given keys. + DescendGreater(keys ...interface{}) (IndexLookup, error) // DescendLessOrEqual returns an IndexLookup for keys that are less than or - // equal to the given key. - DescendLessOrEqual(key interface{}) (IndexLookup, error) + // equal to the given keys. + DescendLessOrEqual(keys ...interface{}) (IndexLookup, error) // DescendRange returns an IndexLookup for keys that are within the given // range. - DescendRange(lessOrEqual, greaterThan interface{}) (IndexLookup, error) + DescendRange(lessOrEqual, greaterThan []interface{}) (IndexLookup, error) +} + +// NegateIndex is an index that supports retrieving negated values. +type NegateIndex interface { + // Not returns an IndexLookup for keys that are not equal + // to the given keys. + Not(keys ...interface{}) (IndexLookup, error) } // IndexLookup is a subset of an index. More specific interfaces can be // implemented to grant more capabilities to the index lookup. type IndexLookup interface { // Values returns the values in the subset of the index. - Values() IndexValueIter + Values(Partition) (IndexValueIter, error) + + // Indexes returns the IDs of all indexes involved in this lookup. + Indexes() []string } // SetOperations is a specialization of IndexLookup that enables set operations @@ -102,20 +137,20 @@ type Mergeable interface { } // IndexDriver manages the coordination between the indexes and their -// representation in disk. +// representation on disk. type IndexDriver interface { // ID returns the unique name of the driver. ID() string // Create a new index. If exprs is more than one expression, it means the // index has multiple columns indexed. If it's just one, it means it may // be an expression or a column. - Create(path, table, db, id string, exprs []Expression, config map[string]string) (Index, error) - // Load the index at the given path. - Load(path string) (Index, error) - // Save the given index at the given path. - Save(ctx context.Context, path string, index Index, iter IndexKeyValueIter) error - // Delete the index with the given path. - Delete(path string, index Index) error + Create(db, table, id string, expressions []Expression, config map[string]string) (Index, error) + // LoadAll loads all indexes for given db and table. + LoadAll(db, table string) ([]Index, error) + // Save the given index for all partitions. + Save(*Context, Index, PartitionIndexKeyValueIter) error + // Delete the given index for all partitions in the iterator. + Delete(Index, PartitionIter) error } type indexKey struct { @@ -127,9 +162,10 @@ type IndexRegistry struct { // Root path where all the data of the indexes is stored on disk. Root string - mut sync.RWMutex - indexes map[indexKey]Index - statuses map[indexKey]IndexStatus + mut sync.RWMutex + indexes map[indexKey]Index + indexOrder []indexKey + statuses map[indexKey]IndexStatus driversMut sync.RWMutex drivers map[string]IndexDriver @@ -157,6 +193,21 @@ func (r *IndexRegistry) IndexDriver(id string) IndexDriver { return r.drivers[id] } +// DefaultIndexDriver returns the default index driver, which is the only +// driver when there is 1 driver in the registry. If there are more than +// 1 drivers in the registry, this will return the empty string, as there +// is no clear default driver. +func (r *IndexRegistry) DefaultIndexDriver() IndexDriver { + r.driversMut.RLock() + defer r.driversMut.RUnlock() + if len(r.drivers) == 1 { + for _, d := range r.drivers { + return d + } + } + return nil +} + // RegisterIndexDriver registers a new index driver. func (r *IndexRegistry) RegisterIndexDriver(driver IndexDriver) { r.driversMut.Lock() @@ -164,18 +215,98 @@ func (r *IndexRegistry) RegisterIndexDriver(driver IndexDriver) { r.drivers[driver.ID()] = driver } +// LoadIndexes loads all indexes for all dbs, tables and drivers. +func (r *IndexRegistry) LoadIndexes(dbs Databases) error { + r.driversMut.RLock() + defer r.driversMut.RUnlock() + r.mut.Lock() + defer r.mut.Unlock() + + for _, driver := range r.drivers { + for _, db := range dbs { + for _, t := range db.Tables() { + indexes, err := driver.LoadAll(db.Name(), t.Name()) + if err != nil { + return err + } + + var checksum string + if c, ok := t.(Checksumable); ok && len(indexes) != 0 { + checksum, err = c.Checksum() + if err != nil { + return err + } + } + + for _, idx := range indexes { + k := indexKey{db.Name(), idx.ID()} + r.indexes[k] = idx + r.indexOrder = append(r.indexOrder, k) + + var idxChecksum string + if c, ok := idx.(Checksumable); ok { + idxChecksum, err = c.Checksum() + if err != nil { + return err + } + } + + if checksum == "" || checksum == idxChecksum { + r.statuses[k] = IndexReady + } else { + logrus.Warnf( + "index %q is outdated and will not be used, you can remove it using `DROP INDEX %s ON %s`", + idx.ID(), + idx.ID(), + idx.Table(), + ) + r.MarkOutdated(idx) + } + } + } + } + } + + return nil +} + +// MarkOutdated sets the index status as outdated. This method is not thread +// safe and should not be used directly except for testing. +func (r *IndexRegistry) MarkOutdated(idx Index) { + r.statuses[indexKey{idx.Database(), idx.ID()}] = IndexOutdated +} + func (r *IndexRegistry) retainIndex(db, id string) { r.rcmut.Lock() defer r.rcmut.Unlock() key := indexKey{db, id} - r.refCounts[key] = r.refCounts[key] + 1 + r.refCounts[key]++ } // CanUseIndex returns whether the given index is ready to use or not. func (r *IndexRegistry) CanUseIndex(idx Index) bool { r.mut.RLock() defer r.mut.RUnlock() - return bool(r.statuses[indexKey{idx.Database(), idx.ID()}]) + return r.canUseIndex(idx) +} + +// CanRemoveIndex returns whether the given index is ready to be removed. +func (r *IndexRegistry) CanRemoveIndex(idx Index) bool { + if idx == nil { + return false + } + + r.mut.RLock() + defer r.mut.RUnlock() + status := r.statuses[indexKey{idx.Database(), idx.ID()}] + return status == IndexReady || status == IndexOutdated +} + +func (r *IndexRegistry) canUseIndex(idx Index) bool { + if idx == nil { + return false + } + return r.statuses[indexKey{idx.Database(), idx.ID()}].IsUsable() } // setStatus is not thread-safe, it should be guarded using mut. @@ -188,8 +319,7 @@ func (r *IndexRegistry) ReleaseIndex(idx Index) { r.rcmut.Lock() defer r.rcmut.Unlock() key := indexKey{idx.Database(), idx.ID()} - r.refCounts[key] = r.refCounts[key] - 1 - + r.refCounts[key]-- if r.refCounts[key] > 0 { return } @@ -205,19 +335,48 @@ func (r *IndexRegistry) ReleaseIndex(idx Index) { func (r *IndexRegistry) Index(db, id string) Index { r.mut.RLock() defer r.mut.RUnlock() + + r.retainIndex(db, id) return r.indexes[indexKey{db, strings.ToLower(id)}] } +// IndexesByTable returns a slice of all the indexes existing on the given table. +func (r *IndexRegistry) IndexesByTable(db, table string) []Index { + r.mut.RLock() + defer r.mut.RUnlock() + + var indexes []Index + for _, key := range r.indexOrder { + idx := r.indexes[key] + if idx.Database() == db && idx.Table() == table { + indexes = append(indexes, idx) + r.retainIndex(db, idx.ID()) + } + } + + return indexes +} + // IndexByExpression returns an index by the given expression. It will return // nil it the index is not found. If more than one expression is given, all // of them must match for the index to be matched. -func (r *IndexRegistry) IndexByExpression(db string, exprs ...Expression) Index { +func (r *IndexRegistry) IndexByExpression(db string, expr ...Expression) Index { r.mut.RLock() defer r.mut.RUnlock() - for _, idx := range r.indexes { + expressions := make([]string, len(expr)) + for i, e := range expr { + expressions[i] = e.String() + } + + for _, k := range r.indexOrder { + idx := r.indexes[k] + if !r.canUseIndex(idx) { + continue + } + if idx.Database() == db { - if exprListsEqual(idx.Expressions(), exprs) { + if exprListsMatch(idx.Expressions(), expressions) { r.retainIndex(db, idx.ID()) return idx } @@ -227,6 +386,52 @@ func (r *IndexRegistry) IndexByExpression(db string, exprs ...Expression) Index return nil } +// ExpressionsWithIndexes finds all the combinations of expressions with +// matching indexes. This only matches multi-column indexes. +func (r *IndexRegistry) ExpressionsWithIndexes( + db string, + exprs ...Expression, +) [][]Expression { + r.mut.RLock() + defer r.mut.RUnlock() + + var results [][]Expression +Indexes: + for _, idx := range r.indexes { + if !r.canUseIndex(idx) { + continue + } + + if ln := len(idx.Expressions()); ln <= len(exprs) && ln > 1 { + var used = make(map[int]struct{}) + var matched []Expression + for _, ie := range idx.Expressions() { + var found bool + for i, e := range exprs { + if _, ok := used[i]; ok { + continue + } + + if ie == e.String() { + used[i] = struct{}{} + found = true + matched = append(matched, e) + break + } + } + + if !found { + continue Indexes + } + } + + results = append(results, matched) + } + } + + return results +} + var ( // ErrIndexIDAlreadyRegistered is the error returned when there is already // an index with the same ID. @@ -240,8 +445,8 @@ var ( ErrIndexNotFound = errors.NewKind("index %q was not found") // ErrIndexDeleteInvalidStatus is returned when the index trying to delete - // does not have a ready state. - ErrIndexDeleteInvalidStatus = errors.NewKind("can't delete index %q because it's not ready for usage") + // does not have a ready or outdated state. + ErrIndexDeleteInvalidStatus = errors.NewKind("can't delete index %q because it's not ready for removal") ) func (r *IndexRegistry) validateIndexToAdd(idx Index) error { @@ -258,27 +463,28 @@ func (r *IndexRegistry) validateIndexToAdd(idx Index) error { } if exprListsEqual(i.Expressions(), idx.Expressions()) { - var exprs = make([]string, len(idx.Expressions())) - for i, e := range idx.Expressions() { - exprs[i] = e.String() - } - return ErrIndexExpressionAlreadyRegistered.New(strings.Join(exprs, ", ")) + return ErrIndexExpressionAlreadyRegistered.New( + strings.Join(idx.Expressions(), ", "), + ) } } return nil } -func exprListsEqual(a, b []Expression) bool { +// exprListsMatch returns whether any subset of a is the entirety of b. +func exprListsMatch(a, b []string) bool { var visited = make([]bool, len(b)) + for _, va := range a { found := false + for j, vb := range b { if visited[j] { continue } - if reflect.DeepEqual(va, vb) { + if va == vb { visited[j] = true found = true break @@ -293,42 +499,67 @@ func exprListsEqual(a, b []Expression) bool { return true } +// exprListsEqual returns whether a and b have the same items. +func exprListsEqual(a, b []string) bool { + if len(a) != len(b) { + return false + } + + return exprListsMatch(a, b) +} + // AddIndex adds the given index to the registry. The added index will be // marked as creating, so nobody can't register two indexes with the same // expression or id while the other is still being created. // When something is sent through the returned channel, it means the index has // finished it's creation and will be marked as ready. -func (r *IndexRegistry) AddIndex(idx Index) (chan<- struct{}, error) { +// Another channel is returned to notify the user when the index is ready. +func (r *IndexRegistry) AddIndex( + idx Index, +) (created chan<- struct{}, ready <-chan struct{}, err error) { if err := r.validateIndexToAdd(idx); err != nil { - return nil, err + return nil, nil, err } r.mut.Lock() r.setStatus(idx, IndexNotReady) - r.indexes[indexKey{idx.Database(), idx.ID()}] = idx + key := indexKey{idx.Database(), idx.ID()} + r.indexes[key] = idx + r.indexOrder = append(r.indexOrder, key) r.mut.Unlock() - var created = make(chan struct{}) + var _created = make(chan struct{}) + var _ready = make(chan struct{}) go func() { - <-created + <-_created r.mut.Lock() defer r.mut.Unlock() r.setStatus(idx, IndexReady) + close(_ready) }() - return created, nil + return _created, _ready, nil } // DeleteIndex deletes an index from the registry by its id. First, it marks // the index for deletion but does not remove it, so queries that are using it // may still do so. The returned channel will send a message when the index can // be deleted from disk. -func (r *IndexRegistry) DeleteIndex(db, id string) (<-chan struct{}, error) { +// If force is true, it will delete the index even if it's not ready for usage. +// Only use that parameter if you know what you're doing. +func (r *IndexRegistry) DeleteIndex(db, id string, force bool) (<-chan struct{}, error) { r.mut.RLock() var key indexKey + + if len(r.indexes) == 0 { + return nil, ErrIndexNotFound.New(id) + } + + var indexNames []string + for k, idx := range r.indexes { if strings.ToLower(id) == idx.ID() { - if !r.CanUseIndex(idx) { + if !force && !r.CanRemoveIndex(idx) { r.mut.RUnlock() return nil, ErrIndexDeleteInvalidStatus.New(id) } @@ -336,23 +567,35 @@ func (r *IndexRegistry) DeleteIndex(db, id string) (<-chan struct{}, error) { key = k break } + indexNames = append(indexNames, idx.ID()) } r.mut.RUnlock() if key.id == "" { - return nil, ErrIndexNotFound.New(id) + similar := similartext.Find(indexNames, id) + return nil, ErrIndexNotFound.New(id + similar) } var done = make(chan struct{}, 1) r.rcmut.Lock() // If no query is using this index just delete it right away - if r.refCounts[key] == 0 { + if force || r.refCounts[key] <= 0 { r.mut.Lock() defer r.mut.Unlock() defer r.rcmut.Unlock() delete(r.indexes, key) + var pos = -1 + for i, k := range r.indexOrder { + if k == key { + pos = i + break + } + } + if pos >= 0 { + r.indexOrder = append(r.indexOrder[:pos], r.indexOrder[pos+1:]...) + } close(done) return done, nil } @@ -374,13 +617,16 @@ func (r *IndexRegistry) DeleteIndex(db, id string) (<-chan struct{}, error) { } // IndexStatus represents the current status in which the index is. -type IndexStatus bool +type IndexStatus byte const ( // IndexNotReady means the index is not ready to be used. - IndexNotReady IndexStatus = false + IndexNotReady IndexStatus = iota // IndexReady means the index can be used. - IndexReady IndexStatus = true + IndexReady + // IndexOutdated means the index is loaded but will not be used because the + // contents in it are outdated. + IndexOutdated ) // IsUsable returns whether the index can be used or not based on the status. diff --git a/sql/index/config.go b/sql/index/config.go new file mode 100644 index 000000000..7d5b7c9fd --- /dev/null +++ b/sql/index/config.go @@ -0,0 +1,124 @@ +package index + +import ( + "io" + "io/ioutil" + "os" + + yaml "gopkg.in/yaml.v2" +) + +// Config represents index configuration +type Config struct { + DB string + Table string + ID string + Expressions []string + Drivers map[string]map[string]string +} + +// NewConfig creates a new Config instance for given driver's configuration +func NewConfig( + db, table, id string, + expressions []string, + driverID string, + driverConfig map[string]string, +) *Config { + cfg := &Config{ + DB: db, + Table: table, + ID: id, + Expressions: expressions, + Drivers: make(map[string]map[string]string), + } + cfg.Drivers[driverID] = driverConfig + + return cfg +} + +// Driver returns an configuration for the particular driverID. +func (cfg *Config) Driver(driverID string) map[string]string { + return cfg.Drivers[driverID] +} + +// WriteConfig writes the configuration to the passed writer (w). +func WriteConfig(w io.Writer, cfg *Config) error { + data, err := yaml.Marshal(cfg) + + if err != nil { + return err + } + + _, err = w.Write(data) + return err +} + +// WriteConfigFile writes the configuration to file. +func WriteConfigFile(path string, cfg *Config) error { + f, err := os.Create(path) + if err != nil { + return err + } + defer f.Close() + + return WriteConfig(f, cfg) +} + +// ReadConfig reads an configuration from the passed reader (r). +func ReadConfig(r io.Reader) (*Config, error) { + data, err := ioutil.ReadAll(r) + if err != nil { + return nil, err + } + + var cfg Config + err = yaml.Unmarshal(data, &cfg) + return &cfg, err +} + +// ReadConfigFile reads an configuration from file. +func ReadConfigFile(path string) (*Config, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + return ReadConfig(f) +} + +// CreateProcessingFile creates a file saying whether the index is being created. +func CreateProcessingFile(path string) error { + f, err := os.Create(path) + if err != nil { + return err + } + + // we don't care about errors closing here + _ = f.Close() + return nil +} + +// WriteProcessingFile write data to the processing file either truncating it +// before or creating it if it doesn't exist. +func WriteProcessingFile(path string, data []byte) error { + return ioutil.WriteFile(path, data, 0666) +} + +// RemoveProcessingFile removes the file that says whether the index is still being created. +func RemoveProcessingFile(path string) error { + return os.Remove(path) +} + +// ExistsProcessingFile returns whether the processing file exists. +func ExistsProcessingFile(path string) (bool, error) { + _, err := os.Stat(path) + if err != nil { + if os.IsNotExist(err) { + return false, nil + } + return false, err + } + + return true, nil +} diff --git a/sql/index/config_test.go b/sql/index/config_test.go new file mode 100644 index 000000000..c5004d7b7 --- /dev/null +++ b/sql/index/config_test.go @@ -0,0 +1,75 @@ +package index + +import ( + "io/ioutil" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestConfig(t *testing.T) { + require := require.New(t) + tmpDir, err := ioutil.TempDir("", "index") + require.NoError(err) + defer func() { require.NoError(os.RemoveAll(tmpDir)) }() + + driver := "driver" + db, table, id := "db_name", "table_name", "index_id" + dir := filepath.Join(tmpDir, driver) + subdir := filepath.Join(dir, db, table) + err = os.MkdirAll(subdir, 0750) + require.NoError(err) + file := filepath.Join(subdir, id+".cfg") + + cfg1 := NewConfig( + db, + table, + id, + []string{"h1", "h2"}, + "DriverID", + map[string]string{ + "port": "10101", + "host": "localhost", + }, + ) + + err = WriteConfigFile(file, cfg1) + require.NoError(err) + + cfg2, err := ReadConfigFile(file) + require.NoError(err) + require.Equal(cfg1, cfg2) +} + +func TestProcessingFile(t *testing.T) { + require := require.New(t) + tmpDir, err := ioutil.TempDir("", "index") + require.NoError(err) + defer func() { require.NoError(os.RemoveAll(tmpDir)) }() + + file := filepath.Join(tmpDir, ".processing") + + ok, err := ExistsProcessingFile(file) + require.NoError(err) + require.False(ok) + + require.NoError(CreateProcessingFile(file)) + + ok, err = ExistsProcessingFile(file) + require.NoError(err) + require.True(ok) + + require.NoError(WriteProcessingFile(file, []byte("test"))) + + ok, err = ExistsProcessingFile(file) + require.NoError(err) + require.True(ok) + + require.NoError(RemoveProcessingFile(file)) + + ok, err = ExistsProcessingFile(file) + require.NoError(err) + require.False(ok) +} diff --git a/sql/index/pilosa/driver.go b/sql/index/pilosa/driver.go new file mode 100644 index 000000000..4f9146b21 --- /dev/null +++ b/sql/index/pilosa/driver.go @@ -0,0 +1,690 @@ +// +build !windows + +package pilosa + +import ( + "context" + "crypto/sha1" + "fmt" + "io" + "io/ioutil" + "os" + "path/filepath" + "runtime" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/go-kit/kit/metrics/discard" + opentracing "github.com/opentracing/opentracing-go" + pilosa "github.com/pilosa/pilosa" + "github.com/pilosa/pilosa/syswrap" + "github.com/sirupsen/logrus" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/index" + errors "gopkg.in/src-d/go-errors.v1" +) + +const ( + // DriverID the unique name of the pilosa driver. + DriverID = "pilosa" + + // IndexNamePrefix the pilosa's indexes prefix + IndexNamePrefix = "idx" + + // FieldNamePrefix the pilosa's field prefix + FieldNamePrefix = "fld" + + // ConfigFileName is the extension of an index config file. + ConfigFileName = "config.yml" + + // ProcessingFileName is the extension of the lock/processing index file. + ProcessingFileName = ".processing" + + // MappingFileNamePrefix is the prefix in mapping file - + MappingFileNamePrefix = "map" + // MappingFileNameExtension is the extension in mapping file - + MappingFileNameExtension = ".db" +) + +const ( + processingFileOnCreate = 'C' + processingFileOnSave = 'S' +) + +var ( + errCorruptedIndex = errors.NewKind("the index db: %s, table: %s, id: %s is corrupted") + errInvalidIndexType = errors.NewKind("expecting a pilosa index, instead got %T") +) + +const ( + pilosaIndexThreadsKey = "PILOSA_INDEX_THREADS" + pilosaIndexThreadsVar = "pilosa_index_threads" +) + +type ( + bitBatch struct { + size uint64 + rows []uint64 + cols []uint64 + pos uint64 + } + + // used for saving + batch struct { + bitBatches []*bitBatch + fields []*pilosa.Field + timePilosa time.Duration + timeMapping time.Duration + } + + // Driver implements sql.IndexDriver interface. + Driver struct { + root string + } +) + +var ( + // RowsGauge describes a metric that takes number of indexes rows over time. + RowsGauge = discard.NewGauge() + // TotalHistogram describes a metric that takes repeated observations of the total time to index values. + TotalHistogram = discard.NewHistogram() + // MappingHistogram describes a metric that takes repeated observations of the total time to map values. + MappingHistogram = discard.NewHistogram() + // BitmapHistogram describes a metric that takes repeated observations of the total time to store values in bitmaps + BitmapHistogram = discard.NewHistogram() +) + +// NewDriver returns a new instance of pilosa.Driver +// which satisfies sql.IndexDriver interface +func NewDriver(root string) *Driver { + return &Driver{ + root: root, + } +} + +// ID returns the unique name of the driver. +func (*Driver) ID() string { + return DriverID +} + +var errWriteConfigFile = errors.NewKind("unable to write indexes configuration file") + +// Create a new index. +func (d *Driver) Create( + db, table, id string, + expressions []sql.Expression, + config map[string]string, +) (sql.Index, error) { + _, err := mkdir(d.root, db, table, id) + if err != nil { + return nil, err + } + + if config == nil { + config = make(map[string]string) + } + + exprs := make([]string, len(expressions)) + for i, e := range expressions { + exprs[i] = e.String() + } + + cfg := index.NewConfig(db, table, id, exprs, d.ID(), config) + err = index.WriteConfigFile(d.configFilePath(db, table, id), cfg) + if err != nil { + return nil, errWriteConfigFile.Wrap(err) + } + + idx, err := d.newPilosaIndex(db, table) + if err != nil { + return nil, err + } + + processingFile := d.processingFilePath(db, table, id) + if err := index.WriteProcessingFile( + processingFile, + []byte{processingFileOnCreate}, + ); err != nil { + return nil, errWriteConfigFile.Wrap(err) + } + + return newPilosaIndex(idx, cfg), nil +} + +var errReadIndexes = errors.NewKind("error loading all indexes for table %s of database %s: %s") + +// LoadAll loads all indexes for given db and table +func (d *Driver) LoadAll(db, table string) ([]sql.Index, error) { + var ( + indexes []sql.Index + errors []string + root = filepath.Join(d.root, db, table) + ) + + dirs, err := ioutil.ReadDir(root) + if err != nil { + if os.IsNotExist(err) { + return indexes, nil + } + return nil, errReadIndexes.New(table, db, err) + } + for _, info := range dirs { + if info.IsDir() && !strings.HasPrefix(info.Name(), ".") { + idx, err := d.loadIndex(db, table, info.Name()) + if err != nil { + if !errCorruptedIndex.Is(err) { + errors = append(errors, err.Error()) + } + continue + } + + indexes = append(indexes, idx) + } + } + if len(errors) > 0 { + return nil, fmt.Errorf(strings.Join(errors, "\n")) + } + + return indexes, nil +} + +var ( + errLoadingIndexConfig = errors.NewKind("unable to load index configuration") + errReadIndexConfig = errors.NewKind("unable to read index configuration") +) + +func (d *Driver) loadIndex(db, table, id string) (*pilosaIndex, error) { + idx, err := d.newPilosaIndex(db, table) + if err != nil { + return nil, err + } + if err := idx.Open(); err != nil { + return nil, err + } + defer idx.Close() + + dir := filepath.Join(d.root, db, table, id) + config := d.configFilePath(db, table, id) + if _, err = os.Stat(config); err != nil { + return nil, errCorruptedIndex.New(dir) + } + + processing := d.processingFilePath(db, table, id) + ok, err := index.ExistsProcessingFile(processing) + if err != nil { + return nil, errLoadingIndexConfig.Wrap(err) + } + if ok { + log := logrus.WithFields(logrus.Fields{ + "err": err, + "db": db, + "table": table, + "id": id, + "dir": dir, + }) + log.Warn("could not read index file, index is corrupt and will be deleted") + if err = os.RemoveAll(dir); err != nil { + log.Warn("unable to remove corrupted index: " + dir) + } + + return nil, errCorruptedIndex.New(dir) + } + + cfg, err := index.ReadConfigFile(config) + if err != nil { + return nil, errReadIndexConfig.Wrap(err) + } + cfgDriver := cfg.Driver(DriverID) + if cfgDriver == nil { + return nil, errCorruptedIndex.New(dir) + } + + pilosaIndex := newPilosaIndex(idx, cfg) + for k, v := range cfgDriver { + if strings.HasPrefix(v, MappingFileNamePrefix) && strings.HasSuffix(v, MappingFileNameExtension) { + path := d.mappingFilePath(db, table, id, k) + if _, err := os.Stat(path); err != nil { + continue + } + pilosaIndex.mapping[k] = newMapping(path) + } + } + + return pilosaIndex, nil +} + +func (d *Driver) savePartition( + ctx *sql.Context, + p sql.Partition, + kviter sql.IndexKeyValueIter, + idx *pilosaIndex, + pilosaIndex *concurrentPilosaIndex, + b *batch, +) (uint64, error) { + var ( + colID uint64 + err error + ) + + for i, e := range idx.Expressions() { + name := fieldName(idx.ID(), e, p) + pilosaIndex.DeleteField(name) + field, err := pilosaIndex.CreateField(name, pilosa.OptFieldTypeDefault()) + if err != nil { + return 0, err + } + b.fields[i] = field + b.bitBatches[i] = newBitBatch(sql.IndexBatchSize) + } + + rollback := true + mk := mappingKey(p) + mapping, ok := idx.mapping[mk] + if !ok { + return 0, errMappingNotFound.New(mk) + } + if err := mapping.openCreate(true); err != nil { + return 0, err + } + + defer func() { + if rollback { + mapping.rollback() + } else { + e := d.saveMapping(ctx, mapping, colID, false, b) + if e != nil && err == nil { + err = e + } + } + + mapping.close() + kviter.Close() + }() + + for colID = 0; err == nil; colID++ { + // commit each batch of objects (pilosa and boltdb) + if colID%sql.IndexBatchSize == 0 && colID != 0 { + if err = d.saveBatch(ctx, mapping, colID, b); err != nil { + return 0, err + } + } + + select { + case <-ctx.Context.Done(): + return 0, ctx.Context.Err() + default: + } + + values, location, err := kviter.Next() + if err != nil { + break + } + + for i, field := range b.fields { + if values[i] == nil { + continue + } + + var rowID uint64 + rowID, err = mapping.getRowID(field.Name(), values[i]) + if err != nil { + return 0, err + } + + b.bitBatches[i].Add(rowID, colID) + } + + err = mapping.putLocation(pilosaIndex.Name(), colID, location) + if err != nil { + return 0, err + } + } + + if err != nil && err != io.EOF { + return 0, err + } + + rollback = false + + err = d.savePilosa(ctx, colID, b) + if err != nil { + return 0, err + } + + for _, f := range b.fields { + if err := f.Close(); err != nil { + return 0, err + } + } + + return colID, err +} + +// Save the given index (mapping and bitmap) +func (d *Driver) Save( + ctx *sql.Context, + i sql.Index, + iter sql.PartitionIndexKeyValueIter, +) (err error) { + start := time.Now() + + idx, ok := i.(*pilosaIndex) + if !ok { + return errInvalidIndexType.New(i) + } + + if err := idx.index.Open(); err != nil { + return err + } + defer idx.index.Close() + + idx.wg.Add(1) + defer idx.wg.Done() + + ctx.Context, idx.cancel = context.WithCancel(ctx.Context) + processingFile := d.processingFilePath(i.Database(), i.Table(), i.ID()) + err = index.WriteProcessingFile( + processingFile, + []byte{processingFileOnSave}, + ) + if err != nil { + return errWriteConfigFile.Wrap(err) + } + + cfgPath := d.configFilePath(i.Database(), i.Table(), i.ID()) + cfg, err := index.ReadConfigFile(cfgPath) + if err != nil { + return errReadIndexConfig.Wrap(err) + } + driverCfg := cfg.Driver(DriverID) + + defer iter.Close() + pilosaIndex := idx.index + + var ( + rows, timePilosa, timeMapping uint64 + + wg sync.WaitGroup + tokens = make(chan struct{}, indexThreads(ctx)) + + errors []error + errmut sync.Mutex + ) + + for { + select { + case <-ctx.Done(): + return + default: + } + + p, kviter, err := iter.Next() + if err != nil { + if err == io.EOF { + break + } + + idx.cancel() + wg.Wait() + return err + } + mk := mappingKey(p) + driverCfg[mk] = mappingFileName(mk) + mapping := newMapping(d.mappingFilePath(idx.Database(), idx.Table(), idx.ID(), mk)) + idx.mapping[mk] = mapping + + wg.Add(1) + + go func() { + defer func() { + wg.Done() + <-tokens + }() + + tokens <- struct{}{} + + var b = &batch{ + fields: make([]*pilosa.Field, len(idx.Expressions())), + bitBatches: make([]*bitBatch, len(idx.Expressions())), + } + + numRows, err := d.savePartition(ctx, p, kviter, idx, pilosaIndex, b) + if err != nil { + errmut.Lock() + errors = append(errors, err) + idx.cancel() + errmut.Unlock() + return + } + + atomic.AddUint64(&timeMapping, uint64(b.timeMapping)) + atomic.AddUint64(&timePilosa, uint64(b.timePilosa)) + atomic.AddUint64(&rows, numRows) + }() + } + + wg.Wait() + if len(errors) > 0 { + return errors[0] + } + if err = index.WriteConfigFile(cfgPath, cfg); err != nil { + return errWriteConfigFile.Wrap(err) + } + + observeIndex(time.Since(start), timePilosa, timeMapping, rows) + + return index.RemoveProcessingFile(processingFile) +} + +// Delete the given index for all partitions in the iterator. +func (d *Driver) Delete(i sql.Index, partitions sql.PartitionIter) error { + idx, ok := i.(*pilosaIndex) + if !ok { + return errInvalidIndexType.New(i) + } + if idx.cancel != nil { + idx.cancel() + idx.wg.Wait() + } + + if err := idx.index.Open(); err != nil { + return err + } + defer idx.index.Close() + + if err := os.RemoveAll(filepath.Join(d.root, i.Database(), i.Table(), i.ID())); err != nil { + return err + } + + for { + p, err := partitions.Next() + if err != nil { + if err == io.EOF { + break + } + return err + } + + for _, ex := range idx.Expressions() { + name := fieldName(idx.ID(), ex, p) + field := idx.index.Field(name) + if field == nil { + continue + } + + if err = idx.index.DeleteField(name); err != nil { + return err + } + } + mk := mappingKey(p) + delete(idx.mapping, mk) + } + + return partitions.Close() +} + +func (d *Driver) saveBatch(ctx *sql.Context, m *mapping, cols uint64, b *batch) error { + err := d.savePilosa(ctx, cols, b) + if err != nil { + return err + } + + return d.saveMapping(ctx, m, cols, true, b) +} + +func (d *Driver) savePilosa(ctx *sql.Context, cols uint64, b *batch) error { + span, _ := ctx.Span("pilosa.Save.bitBatch", + opentracing.Tag{Key: "cols", Value: cols}, + opentracing.Tag{Key: "fields", Value: len(b.fields)}, + ) + defer span.Finish() + + start := time.Now() + + for i, fld := range b.fields { + err := fld.Import(b.bitBatches[i].rows, b.bitBatches[i].cols, nil) + if err != nil { + span.LogKV("error", err) + return err + } + + b.bitBatches[i].Clean() + } + + b.timePilosa += time.Since(start) + + return nil +} + +func (d *Driver) saveMapping( + ctx *sql.Context, + m *mapping, + cols uint64, + cont bool, + b *batch, +) error { + span, _ := ctx.Span("pilosa.Save.mapping", + opentracing.Tag{Key: "cols", Value: cols}, + opentracing.Tag{Key: "continues", Value: cont}, + ) + defer span.Finish() + + start := time.Now() + + err := m.commit(cont) + if err != nil { + span.LogKV("error", err) + return err + } + + b.timeMapping += time.Since(start) + + return nil +} + +func newBitBatch(size uint64) *bitBatch { + b := &bitBatch{size: size} + b.Clean() + + return b +} + +func (b *bitBatch) Clean() { + b.rows = make([]uint64, 0, b.size) + b.cols = make([]uint64, 0, b.size) + b.pos = 0 +} + +func (b *bitBatch) Add(row, col uint64) { + b.rows = append(b.rows, row) + b.cols = append(b.cols, col) +} + +func indexName(db, table string) string { + h := sha1.New() + io.WriteString(h, db) + io.WriteString(h, table) + + return fmt.Sprintf("%s-%x", IndexNamePrefix, h.Sum(nil)) +} + +func fieldName(id, ex string, p sql.Partition) string { + h := sha1.New() + io.WriteString(h, id) + io.WriteString(h, ex) + h.Write(p.Key()) + return fmt.Sprintf("%s-%x", FieldNamePrefix, h.Sum(nil)) +} + +// mkdir makes an empty index directory (if doesn't exist) and returns a path. +func mkdir(elem ...string) (string, error) { + path := filepath.Join(elem...) + return path, os.MkdirAll(path, 0750) +} + +func (d *Driver) configFilePath(db, table, id string) string { + return filepath.Join(d.root, db, table, id, ConfigFileName) +} + +func (d *Driver) processingFilePath(db, table, id string) string { + return filepath.Join(d.root, db, table, id, ProcessingFileName) +} + +func mappingFileName(key string) string { + h := sha1.New() + io.WriteString(h, key) + return fmt.Sprintf("%s-%x%s", MappingFileNamePrefix, h.Sum(nil), MappingFileNameExtension) +} +func (d *Driver) mappingFilePath(db, table, id string, key string) string { + return filepath.Join(d.root, db, table, id, mappingFileName(key)) +} + +func (d *Driver) newPilosaIndex(db, table string) (*pilosa.Index, error) { + name := indexName(db, table) + path := filepath.Join(d.root, "."+DriverID, name) + idx, err := pilosa.NewIndex(path, name) + if err != nil { + return nil, err + } + return idx, nil +} + +func indexThreads(ctx *sql.Context) int { + typ, val := ctx.Session.Get(pilosaIndexThreadsVar) + if val != nil && typ == sql.Int64 { + return int(val.(int64)) + } + + var value int + if v, ok := os.LookupEnv(pilosaIndexThreadsKey); ok { + value, _ = strconv.Atoi(v) + } + + if value <= 0 { + value = runtime.NumCPU() + } + + return value +} + +func observeIndex(timeTotal time.Duration, timePilosa, timeMapping, rows uint64) { + logrus.WithFields(logrus.Fields{ + "duration": timeTotal, + "pilosa": timePilosa, + "mapping": timeMapping, + "rows": rows, + "id": DriverID, + }).Debugf("finished pilosa indexing") + + TotalHistogram.With("driver", DriverID, "duration", "seconds").Observe(timeTotal.Seconds()) + BitmapHistogram.With("driver", DriverID, "duration", "seconds").Observe(float64(timePilosa)) + MappingHistogram.With("driver", DriverID, "duration", "seconds").Observe(float64(timeMapping)) + RowsGauge.With("driver", DriverID).Set(float64(rows)) +} + +func init() { + syswrap.SetMaxMapCount(0) +} diff --git a/sql/index/pilosa/driver_test.go b/sql/index/pilosa/driver_test.go new file mode 100644 index 000000000..25bae3d7c --- /dev/null +++ b/sql/index/pilosa/driver_test.go @@ -0,0 +1,1498 @@ +// +build !windows + +package pilosa + +import ( + "context" + "crypto/rand" + "fmt" + "io" + "io/ioutil" + "os" + "testing" + "time" + + "github.com/pilosa/pilosa" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/test" + "github.com/stretchr/testify/require" +) + +var tmpDir string + +func setup(t *testing.T) { + var err error + + tmpDir, err = ioutil.TempDir("", "pilosa") + if err != nil { + t.Fatal(err) + } +} + +func cleanup(t *testing.T) { + err := os.RemoveAll(tmpDir) + if err != nil { + t.Fatal(err) + } +} + +func TestID(t *testing.T) { + d := &Driver{} + + require := require.New(t) + require.Equal(DriverID, d.ID()) +} + +func TestLoadAll(t *testing.T) { + require := require.New(t) + setup(t) + defer cleanup(t) + + d := NewDriver(tmpDir) + idx1, err := d.Create("db", "table", "id1", makeExpressions("table", "hash1"), nil) + require.NoError(err) + it1 := &partitionKeyValueIter{ + partitions: 2, + offset: 0, + total: 64, + expressions: idx1.Expressions(), + location: randLocation, + } + require.NoError(d.Save(sql.NewEmptyContext(), idx1, it1)) + + idx2, err := d.Create("db", "table", "id2", makeExpressions("table", "hash1"), nil) + require.NoError(err) + it2 := &partitionKeyValueIter{ + partitions: 2, + offset: 0, + total: 64, + expressions: idx2.Expressions(), + location: randLocation, + } + require.NoError(d.Save(sql.NewEmptyContext(), idx2, it2)) + + indexes, err := d.LoadAll("db", "table") + require.NoError(err) + + require.Equal(2, len(indexes)) + i1, ok := idx1.(*pilosaIndex) + require.True(ok) + i2, ok := idx2.(*pilosaIndex) + require.True(ok) + + require.Equal(i1.index.Name(), i2.index.Name()) + + // Load index from another table. Previously this panicked as the same + // pilosa.Holder was used for all indexes. + + idx3, err := d.Create("db", "table2", "id1", makeExpressions("table2", "hash1"), nil) + require.NoError(err) + it3 := &partitionKeyValueIter{ + partitions: 2, + offset: 0, + total: 64, + expressions: idx3.Expressions(), + location: randLocation, + } + require.NoError(d.Save(sql.NewEmptyContext(), idx3, it3)) + + _, err = d.LoadAll("db", "table2") + require.NoError(err) +} + +func TestLoadAllWithMultipleDrivers(t *testing.T) { + require := require.New(t) + setup(t) + defer cleanup(t) + + d1 := NewDriver(tmpDir) + idx1, err := d1.Create("db", "table", "id1", makeExpressions("table", "hash1"), nil) + require.NoError(err) + it1 := &partitionKeyValueIter{ + partitions: 2, + offset: 0, + total: 64, + expressions: idx1.Expressions(), + location: randLocation, + } + require.NoError(d1.Save(sql.NewEmptyContext(), idx1, it1)) + + d2 := NewDriver(tmpDir) + idx2, err := d2.Create("db", "table", "id2", makeExpressions("table", "hash1"), nil) + require.NoError(err) + it2 := &partitionKeyValueIter{ + partitions: 2, + offset: 0, + total: 64, + expressions: idx2.Expressions(), + location: randLocation, + } + require.NoError(d2.Save(sql.NewEmptyContext(), idx2, it2)) + + d := NewDriver(tmpDir) + indexes, err := d.LoadAll("db", "table") + require.NoError(err) + + require.Equal(2, len(indexes)) + i1, ok := idx1.(*pilosaIndex) + require.True(ok) + i2, ok := idx2.(*pilosaIndex) + require.True(ok) + + require.Equal(i1.index.Name(), i2.index.Name()) + + // Load index from another table. Previously this panicked as the same + // pilosa.Holder was used for all indexes. + + d3 := NewDriver(tmpDir) + idx3, err := d3.Create("db", "table2", "id1", makeExpressions("table2", "hash1"), nil) + require.NoError(err) + it3 := &partitionKeyValueIter{ + partitions: 2, + offset: 0, + total: 64, + expressions: idx3.Expressions(), + location: randLocation, + } + require.NoError(d3.Save(sql.NewEmptyContext(), idx3, it3)) + + _, err = d.LoadAll("db", "table2") + require.NoError(err) +} + +func TestSaveAndLoad(t *testing.T) { + require := require.New(t) + setup(t) + defer cleanup(t) + + db, table, id := "db_name", "table_name", "index_id" + expressions := makeExpressions(table, "lang", "hash") + + d := NewDriver(tmpDir) + sqlIdx, err := d.Create(db, table, id, expressions, nil) + require.NoError(err) + + it := &partitionKeyValueIter{ + partitions: 2, + offset: 0, + total: 2, + expressions: sqlIdx.Expressions(), + location: offsetLocation, + } + + tracer := new(test.MemTracer) + ctx := sql.NewContext(context.Background(), sql.WithTracer(tracer)) + err = d.Save(ctx, sqlIdx, it) + require.NoError(err) + + indexes, err := d.LoadAll(db, table) + require.NoError(err) + require.Equal(1, len(indexes)) + + var locations = make([][]string, len(it.records)) + + for partition, records := range it.records { + for _, r := range records { + var lookup sql.IndexLookup + lookup, err = sqlIdx.Get(r.values...) + require.NoError(err) + + var lit sql.IndexValueIter + lit, err = lookup.Values(testPartition(partition)) + require.NoError(err) + + for { + var loc []byte + loc, err = lit.Next() + if err == io.EOF { + break + } + require.NoError(err) + + locations[partition] = append(locations[partition], string(loc)) + } + err = lit.Close() + require.NoError(err) + } + } + + expectedLocations := [][]string{ + {"0-0", "0-1"}, + {"1-0", "1-1"}, + } + + require.ElementsMatch(expectedLocations, locations) + + // test that not found values do not cause error + lookup, err := sqlIdx.Get("do not exist", "none") + require.NoError(err) + lit, err := lookup.Values(testPartition(0)) + require.NoError(err) + _, err = lit.Next() + require.Equal(io.EOF, err) + + found := false + for _, span := range tracer.Spans { + if span == "pilosa.Save.bitBatch" { + found = true + break + } + } + + require.True(found) +} + +func TestSaveAndGetAll(t *testing.T) { + require := require.New(t) + setup(t) + defer cleanup(t) + + db, table, id := "db_name", "table_name", "index_id" + expressions := makeExpressions(table, "lang", "hash") + + d := NewDriver(tmpDir) + sqlIdx, err := d.Create(db, table, id, expressions, nil) + require.NoError(err) + + it := &partitionKeyValueIter{ + partitions: 2, + offset: 0, + total: 64, + expressions: sqlIdx.Expressions(), + location: randLocation, + } + + err = d.Save(sql.NewEmptyContext(), sqlIdx, it) + require.NoError(err) + + indexes, err := d.LoadAll(db, table) + require.NoError(err) + require.Equal(1, len(indexes)) + + _, err = sqlIdx.Get() + require.Error(err) + require.True(errInvalidKeys.Is(err)) +} + +func TestSaveAndGetAllWithMultipleDrivers(t *testing.T) { + require := require.New(t) + setup(t) + defer cleanup(t) + + db, table, id := "db_name", "table_name", "index_id" + expressions := makeExpressions(table, "lang", "hash") + + d1 := NewDriver(tmpDir) + sqlIdx, err := d1.Create(db, table, id, expressions, nil) + require.NoError(err) + + it := &partitionKeyValueIter{ + partitions: 2, + offset: 0, + total: 64, + expressions: sqlIdx.Expressions(), + location: randLocation, + } + + err = d1.Save(sql.NewEmptyContext(), sqlIdx, it) + require.NoError(err) + + d2 := NewDriver(tmpDir) + indexes, err := d2.LoadAll(db, table) + require.NoError(err) + require.Equal(1, len(indexes)) + + _, err = sqlIdx.Get() + require.Error(err) + require.True(errInvalidKeys.Is(err)) +} + +func TestLoadCorruptedIndex(t *testing.T) { + require := require.New(t) + setup(t) + defer cleanup(t) + + d := NewDriver(tmpDir) + processingFile := d.processingFilePath("db", "table", "id") + + _, err := d.Create("db", "table", "id", nil, nil) + require.NoError(err) + + _, err = d.loadIndex("db", "table", "id") + require.Error(err) + require.True(errCorruptedIndex.Is(err)) + + _, err = os.Stat(processingFile) + require.Error(err) + require.True(os.IsNotExist(err)) +} + +func TestDelete(t *testing.T) { + require := require.New(t) + setup(t) + defer cleanup(t) + + db, table, id := "db_name", "table_name", "index_id" + + expressions := []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, table, "lang", true), + expression.NewGetFieldWithTable(1, sql.Int64, table, "field", true), + } + + d := NewDriver(tmpDir) + sqlIdx, err := d.Create(db, table, id, expressions, nil) + require.NoError(err) + + err = d.Delete(sqlIdx, new(partitionIter)) + require.NoError(err) +} + +func TestDeleteWithMultipleDrivers(t *testing.T) { + require := require.New(t) + setup(t) + defer cleanup(t) + + db, table, id := "db_name", "table_name", "index_id" + + expressions := []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, table, "lang", true), + expression.NewGetFieldWithTable(1, sql.Int64, table, "field", true), + } + + d := NewDriver(tmpDir) + sqlIdx, err := d.Create(db, table, id, expressions, nil) + require.NoError(err) + + d = NewDriver(tmpDir) + err = d.Delete(sqlIdx, new(partitionIter)) + require.NoError(err) +} + +func TestDeleteAndLoadAll(t *testing.T) { + require := require.New(t) + setup(t) + defer cleanup(t) + + db, table, id := "db_name", "table_name", "index_id" + expressions := makeExpressions(table, "lang", "hash") + + d := NewDriver(tmpDir) + sqlIdx, err := d.Create(db, table, id, expressions, nil) + require.NoError(err) + + it := &partitionKeyValueIter{ + partitions: 2, + offset: 0, + total: 64, + expressions: sqlIdx.Expressions(), + location: randLocation, + } + + err = d.Save(sql.NewEmptyContext(), sqlIdx, it) + require.NoError(err) + + err = d.Delete(sqlIdx, new(partitionIter)) + require.NoError(err) + + indexes, err := d.LoadAll(db, table) + require.NoError(err) + require.Equal(0, len(indexes)) +} + +func TestDeleteInProgress(t *testing.T) { + require := require.New(t) + setup(t) + defer cleanup(t) + + db, table, id := "db_name", "table_name", "index_id" + + expressions := []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, table, "lang", true), + expression.NewGetFieldWithTable(1, sql.Int64, table, "hash", true), + } + + d := NewDriver(tmpDir) + sqlIdx, err := d.Create(db, table, id, expressions, nil) + require.NoError(err) + + it := &partitionKeyValueIter{ + partitions: 2, + offset: 0, + total: 1024, + expressions: sqlIdx.Expressions(), + location: slowRandLocation, + } + + go func() { + if e := d.Save(sql.NewEmptyContext(), sqlIdx, it); e != nil { + t.Log(e) + } + }() + + time.Sleep(time.Second) + err = d.Delete(sqlIdx, new(partitionIter)) + require.NoError(err) +} + +func TestLoadAllDirectoryDoesNotExist(t *testing.T) { + require := require.New(t) + setup(t) + defer cleanup(t) + + driver := NewDriver(tmpDir) + indexes, err := driver.LoadAll("foo", "bar") + require.NoError(err) + require.Len(indexes, 0) +} + +func TestAscendDescendIndex(t *testing.T) { + idx, cleanup := setupAscendDescend(t) + defer cleanup() + + must := func(lookup sql.IndexLookup, err error) sql.IndexLookup { + require.NoError(t, err) + return lookup + } + + testCases := []struct { + name string + lookup sql.IndexLookup + expected []string + }{ + { + "ascend range", + must(idx.AscendRange( + []interface{}{int64(1), int64(1)}, + []interface{}{int64(7), int64(10)}, + )), + []string{"1", "5", "6", "7", "8", "9"}, + }, + { + "ascend greater or equal", + must(idx.AscendGreaterOrEqual(int64(7), int64(6))), + []string{"2", "4"}, + }, + { + "ascend less than", + must(idx.AscendLessThan(int64(5), int64(3))), + []string{"1", "10"}, + }, + { + "descend range", + must(idx.DescendRange( + []interface{}{int64(6), int64(9)}, + []interface{}{int64(0), int64(0)}, + )), + []string{"9", "8", "7", "6", "5", "1"}, + }, + { + "descend less or equal", + must(idx.DescendLessOrEqual(int64(4), int64(2))), + []string{"10", "1"}, + }, + { + "descend greater", + must(idx.DescendGreater(int64(6), int64(5))), + []string{"4", "2"}, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + iter, err := tt.lookup.Values(testPartition(0)) + require.NoError(err) + + var result []string + for { + k, err := iter.Next() + if err == io.EOF { + break + } + require.NoError(err) + + result = append(result, string(k)) + } + + require.Equal(tt.expected, result) + }) + } +} + +func TestIntersection(t *testing.T) { + ctx := sql.NewContext(context.Background()) + require := require.New(t) + setup(t) + defer cleanup(t) + + db, table := "db_name", "table_name" + idxLang, expLang := "idx_lang", makeExpressions(table, "lang") + idxPath, expPath := "idx_path", makeExpressions(table, "path") + + d := NewDriver(tmpDir) + sqlIdxLang, err := d.Create(db, table, idxLang, expLang, nil) + require.NoError(err) + + sqlIdxPath, err := d.Create(db, table, idxPath, expPath, nil) + require.NoError(err) + + itLang := &partitionKeyValueIter{ + partitions: 2, + offset: 0, + total: 10, + expressions: sqlIdxLang.Expressions(), + location: offsetLocation, + } + + itPath := &partitionKeyValueIter{ + partitions: 2, + offset: 0, + total: 10, + expressions: sqlIdxPath.Expressions(), + location: offsetLocation, + } + + err = d.Save(ctx, sqlIdxLang, itLang) + require.NoError(err) + + err = d.Save(ctx, sqlIdxPath, itPath) + require.NoError(err) + + lookupLang, err := sqlIdxLang.Get(itLang.records[0][0].values...) + require.NoError(err) + + lookupPath, err := sqlIdxPath.Get(itPath.records[0][itPath.total-1].values...) + require.NoError(err) + + m, ok := lookupLang.(sql.Mergeable) + require.True(ok) + require.True(m.IsMergeable(lookupPath)) + + interLookup, ok := lookupLang.(sql.SetOperations) + require.True(ok) + interIt, err := interLookup.Intersection(lookupPath).Values(testPartition(0)) + require.NoError(err) + _, err = interIt.Next() + + require.True(err == io.EOF) + require.NoError(interIt.Close()) + + lookupLang, err = sqlIdxLang.Get(itLang.records[0][0].values...) + require.NoError(err) + lookupPath, err = sqlIdxPath.Get(itPath.records[0][0].values...) + require.NoError(err) + + interLookup, ok = lookupPath.(sql.SetOperations) + require.True(ok) + interIt, err = interLookup.Intersection(lookupLang).Values(testPartition(0)) + require.NoError(err) + loc, err := interIt.Next() + require.NoError(err) + require.Equal(loc, itPath.records[0][0].location) + _, err = interIt.Next() + require.True(err == io.EOF) + + require.NoError(interIt.Close()) +} + +func TestIntersectionWithMultipleDrivers(t *testing.T) { + ctx := sql.NewContext(context.Background()) + require := require.New(t) + setup(t) + defer cleanup(t) + + db, table := "db_name", "table_name" + idxLang, expLang := "idx_lang", makeExpressions(table, "lang") + idxPath, expPath := "idx_path", makeExpressions(table, "path") + + d1 := NewDriver(tmpDir) + sqlIdxLang, err := d1.Create(db, table, idxLang, expLang, nil) + require.NoError(err) + + d2 := NewDriver(tmpDir) + sqlIdxPath, err := d2.Create(db, table, idxPath, expPath, nil) + require.NoError(err) + + itLang := &partitionKeyValueIter{ + partitions: 2, + offset: 0, + total: 10, + expressions: sqlIdxLang.Expressions(), + location: offsetLocation, + } + + itPath := &partitionKeyValueIter{ + partitions: 2, + offset: 0, + total: 10, + expressions: sqlIdxPath.Expressions(), + location: offsetLocation, + } + + err = d1.Save(ctx, sqlIdxLang, itLang) + require.NoError(err) + + err = d2.Save(ctx, sqlIdxPath, itPath) + require.NoError(err) + + lookupLang, err := sqlIdxLang.Get(itLang.records[0][0].values...) + require.NoError(err) + lookupPath, err := sqlIdxPath.Get(itPath.records[0][itPath.total-1].values...) + require.NoError(err) + + m, ok := lookupLang.(sql.Mergeable) + require.True(ok) + require.True(m.IsMergeable(lookupPath)) + + interLookup, ok := lookupLang.(sql.SetOperations) + require.True(ok) + interIt, err := interLookup.Intersection(lookupPath).Values(testPartition(0)) + require.NoError(err) + _, err = interIt.Next() + + require.True(err == io.EOF) + require.NoError(interIt.Close()) + + lookupLang, err = sqlIdxLang.Get(itLang.records[0][0].values...) + require.NoError(err) + lookupPath, err = sqlIdxPath.Get(itPath.records[0][0].values...) + require.NoError(err) + + interLookup, ok = lookupPath.(sql.SetOperations) + require.True(ok) + interIt, err = interLookup.Intersection(lookupLang).Values(testPartition(0)) + require.NoError(err) + loc, err := interIt.Next() + require.NoError(err) + require.Equal(loc, itPath.records[0][0].location) + _, err = interIt.Next() + require.True(err == io.EOF) + + require.NoError(interIt.Close()) +} + +func TestUnion(t *testing.T) { + require := require.New(t) + setup(t) + defer cleanup(t) + + db, table := "db_name", "table_name" + idxLang, expLang := "idx_lang", makeExpressions(table, "lang") + idxPath, expPath := "idx_path", makeExpressions(table, "path") + + d := NewDriver(tmpDir) + sqlIdxLang, err := d.Create(db, table, idxLang, expLang, nil) + require.NoError(err) + + sqlIdxPath, err := d.Create(db, table, idxPath, expPath, nil) + require.NoError(err) + + itLang := &partitionKeyValueIter{ + partitions: 2, + offset: 0, + total: 10, + expressions: sqlIdxLang.Expressions(), + location: offsetLocation, + } + + itPath := &partitionKeyValueIter{ + partitions: 2, + offset: 0, + total: 10, + expressions: sqlIdxPath.Expressions(), + location: offsetLocation, + } + + ctx := sql.NewContext(context.Background()) + + err = d.Save(ctx, sqlIdxLang, itLang) + require.NoError(err) + + err = d.Save(ctx, sqlIdxPath, itPath) + require.NoError(err) + + lookupLang, err := sqlIdxLang.Get(itLang.records[0][0].values...) + require.NoError(err) + litLang, err := lookupLang.Values(testPartition(0)) + require.NoError(err) + + loc, err := litLang.Next() + require.NoError(err) + require.Equal(itLang.records[0][0].location, loc) + _, err = litLang.Next() + require.True(err == io.EOF) + err = litLang.Close() + require.NoError(err) + + lookupPath, err := sqlIdxPath.Get(itPath.records[0][itPath.total-1].values...) + require.NoError(err) + litPath, err := lookupPath.Values(testPartition(0)) + require.NoError(err) + + loc, err = litPath.Next() + require.NoError(err) + require.Equal(itPath.records[0][itPath.total-1].location, loc) + _, err = litPath.Next() + require.True(err == io.EOF) + err = litLang.Close() + require.NoError(err) + + m, ok := lookupLang.(sql.Mergeable) + require.True(ok) + require.True(m.IsMergeable(lookupPath)) + + unionLookup, ok := lookupLang.(sql.SetOperations) + require.True(ok) + + lookupNonExisting, err := sqlIdxPath.Get(itPath.total) + require.NoError(err) + + unionLookup, ok = unionLookup.Union(lookupNonExisting).(sql.SetOperations) + require.True(ok) + + unionIt, err := unionLookup.Union(lookupPath).Values(testPartition(0)) + require.NoError(err) + // 0 + loc, err = unionIt.Next() + require.NoError(err) + require.Equal(itLang.records[0][0].location, loc) + + // total-1 + loc, err = unionIt.Next() + require.NoError(err) + require.Equal(itPath.records[0][itPath.total-1].location, loc) + + _, err = unionIt.Next() + require.True(err == io.EOF) + + require.NoError(unionIt.Close()) +} + +func TestDifference(t *testing.T) { + require := require.New(t) + setup(t) + defer cleanup(t) + + db, table := "db_name", "table_name" + idxLang, expLang := "idx_lang", makeExpressions(table, "lang") + idxPath, expPath := "idx_path", makeExpressions(table, "path") + + d := NewDriver(tmpDir) + sqlIdxLang, err := d.Create(db, table, idxLang, expLang, nil) + require.NoError(err) + + sqlIdxPath, err := d.Create(db, table, idxPath, expPath, nil) + require.NoError(err) + + itLang := &partitionKeyValueIter{ + partitions: 2, + offset: 0, + total: 10, + expressions: sqlIdxLang.Expressions(), + location: offsetLocation, + } + + itPath := &partitionKeyValueIter{ + partitions: 2, + offset: 0, + total: 10, + expressions: sqlIdxPath.Expressions(), + location: offsetLocation, + } + + ctx := sql.NewContext(context.Background()) + + err = d.Save(ctx, sqlIdxLang, itLang) + require.NoError(err) + + err = d.Save(ctx, sqlIdxPath, itPath) + require.NoError(err) + + lookupLang, err := sqlIdxLang.Get(itLang.records[0][0].values...) + require.NoError(err) + + lookupPath, err := sqlIdxPath.Get(itPath.records[0][itPath.total-1].values...) + require.NoError(err) + + m, ok := lookupLang.(sql.Mergeable) + require.True(ok) + require.True(m.IsMergeable(lookupPath)) + + unionOp, ok := lookupLang.(sql.SetOperations) + require.True(ok) + unionLookup, ok := unionOp.Union(lookupPath).(sql.SetOperations) + require.True(ok) + + diffLookup := unionLookup.Difference(lookupLang) + diffIt, err := diffLookup.Values(testPartition(0)) + require.NoError(err) + + // total-1 + loc, err := diffIt.Next() + require.NoError(err) + require.Equal(itPath.records[0][itPath.total-1].location, loc) + + _, err = diffIt.Next() + require.True(err == io.EOF) + + require.NoError(diffIt.Close()) +} + +func TestUnionDiffAsc(t *testing.T) { + require := require.New(t) + setup(t) + defer cleanup(t) + + db, table := "db_name", "table_name" + idx, exp := "idx_lang", makeExpressions(table, "lang") + + d := NewDriver(tmpDir) + sqlIdx, err := d.Create(db, table, idx, exp, nil) + require.NoError(err) + pilosaIdx, ok := sqlIdx.(*pilosaIndex) + require.True(ok) + it := &partitionKeyValueIter{ + partitions: 2, + offset: 0, + total: 10, + expressions: sqlIdx.Expressions(), + location: offsetLocation, + } + + ctx := sql.NewContext(context.Background()) + + err = d.Save(ctx, pilosaIdx, it) + require.NoError(err) + + sqlLookup, err := pilosaIdx.AscendLessThan(it.records[0][it.total-1].values...) + require.NoError(err) + ascLookup, ok := sqlLookup.(*ascendLookup) + require.True(ok) + + ls := make([][]*indexLookup, it.partitions) + for partition, records := range it.records { + ls[partition] = make([]*indexLookup, it.total) + for i, r := range records { + var l sql.IndexLookup + l, err = pilosaIdx.Get(r.values...) + require.NoError(err) + ls[partition][i], _ = l.(*indexLookup) + } + } + + unionLookup := ls[0][0].Union(ls[0][2], ls[0][4], ls[0][6], ls[0][8]) + + diffLookup := ascLookup.Difference(unionLookup) + diffIt, err := diffLookup.Values(testPartition(0)) + require.NoError(err) + + for i := 1; i < it.total-1; i += 2 { + var loc []byte + loc, err = diffIt.Next() + require.NoError(err) + + require.Equal(it.records[0][i].location, loc) + } + + _, err = diffIt.Next() + require.True(err == io.EOF) + require.NoError(diffIt.Close()) +} + +func TestInterRanges(t *testing.T) { + require := require.New(t) + setup(t) + defer cleanup(t) + + db, table := "db_name", "table_name" + idx, exp := "idx_lang", makeExpressions(table, "lang") + + d := NewDriver(tmpDir) + sqlIdx, err := d.Create(db, table, idx, exp, nil) + require.NoError(err) + pilosaIdx, ok := sqlIdx.(*pilosaIndex) + require.True(ok) + it := &partitionKeyValueIter{ + partitions: 2, + offset: 0, + total: 10, + expressions: sqlIdx.Expressions(), + location: offsetLocation, + } + + ctx := sql.NewContext(context.Background()) + + err = d.Save(ctx, pilosaIdx, it) + require.NoError(err) + + ranges := [2]int{3, 9} + sqlLookup, err := pilosaIdx.AscendLessThan(it.records[0][ranges[1]].values...) + require.NoError(err) + lessLookup, ok := sqlLookup.(*ascendLookup) + require.True(ok) + + sqlLookup, err = pilosaIdx.AscendGreaterOrEqual(it.records[0][ranges[0]].values...) + require.NoError(err) + greaterLookup, ok := sqlLookup.(*ascendLookup) + require.True(ok) + + interLookup := lessLookup.Intersection(greaterLookup) + require.NotNil(interLookup) + interIt, err := interLookup.Values(testPartition(0)) + require.NoError(err) + + for i := ranges[0]; i < ranges[1]; i++ { + var loc []byte + loc, err = interIt.Next() + require.NoError(err) + require.Equal(it.records[0][i].location, loc) + } + + _, err = interIt.Next() + require.True(err == io.EOF) + require.NoError(interIt.Close()) +} + +func TestNegateIndex(t *testing.T) { + require := require.New(t) + setup(t) + defer cleanup(t) + + db, table := "db_name", "table_name" + + d := NewDriver(tmpDir) + idx, err := d.Create(db, table, "index_id", makeExpressions(table, "a"), nil) + require.NoError(err) + + multiIdx, err := d.Create( + db, table, "multi_index_id", + makeExpressions(table, "a", "b"), + nil, + ) + require.NoError(err) + + it := &fixturePartitionKeyValueIter{ + fixtures: []partitionKeyValueFixture{ + { + testPartition(0), + []kvfixture{ + {"1", []interface{}{int64(2)}}, + {"2", []interface{}{int64(7)}}, + {"3", []interface{}{int64(1)}}, + {"4", []interface{}{int64(1)}}, + {"5", []interface{}{int64(7)}}, + }, + }, + { + testPartition(1), + []kvfixture{ + {"1", []interface{}{int64(2)}}, + {"2", []interface{}{int64(7)}}, + }, + }, + }, + } + + err = d.Save(sql.NewEmptyContext(), idx, it) + require.NoError(err) + + fixtures := []kvfixture{ + {"1", []interface{}{int64(2), int64(6)}}, + {"2", []interface{}{int64(7), int64(5)}}, + {"3", []interface{}{int64(1), int64(2)}}, + {"4", []interface{}{int64(1), int64(3)}}, + {"5", []interface{}{int64(7), int64(6)}}, + {"6", []interface{}{int64(10), int64(6)}}, + {"7", []interface{}{int64(5), int64(1)}}, + {"8", []interface{}{int64(6), int64(2)}}, + {"9", []interface{}{int64(4), int64(0)}}, + {"10", []interface{}{int64(3), int64(5)}}, + } + + multiIt := &fixturePartitionKeyValueIter{ + fixtures: []partitionKeyValueFixture{ + {testPartition(0), fixtures}, + {testPartition(1), fixtures[4:]}, + }, + } + + err = d.Save(sql.NewEmptyContext(), multiIdx, multiIt) + require.NoError(err) + + lookup, err := idx.(sql.NegateIndex).Not(int64(1)) + require.NoError(err) + + values, err := lookupValues(lookup) + require.NoError(err) + + expected := []string{"1", "2", "5"} + require.Equal(expected, values) + + // test non existing values + lookup, err = idx.(sql.NegateIndex).Not(int64(12739487)) + require.NoError(err) + + values, err = lookupValues(lookup) + require.NoError(err) + + expected = []string{"1", "2", "3", "4", "5"} + require.Equal(expected, values) + + lookup, err = multiIdx.(sql.NegateIndex).Not(int64(1), int64(6)) + require.NoError(err) + + values, err = lookupValues(lookup) + require.NoError(err) + + expected = []string{"2", "7", "8", "9", "10"} + require.Equal(expected, values) +} + +func TestEqualAndLessIndex(t *testing.T) { + require := require.New(t) + setup(t) + defer cleanup(t) + + ctx := sql.NewContext(context.Background()) + db, table := "db_name", "table_name" + d := NewDriver(tmpDir) + + idxEqA, err := d.Create(db, table, "idx_eq_a", makeExpressions(table, "a"), nil) + require.NoError(err) + pilosaIdxEqA, ok := idxEqA.(*pilosaIndex) + require.True(ok) + itA := &fixturePartitionKeyValueIter{ + fixtures: []partitionKeyValueFixture{ + { + testPartition(0), + []kvfixture{ + {"1", []interface{}{int64(2)}}, + {"2", []interface{}{int64(7)}}, + {"3", []interface{}{int64(1)}}, + {"4", []interface{}{int64(1)}}, + {"5", []interface{}{int64(1)}}, + {"6", []interface{}{int64(10)}}, + {"7", []interface{}{int64(5)}}, + {"8", []interface{}{int64(6)}}, + {"9", []interface{}{int64(4)}}, + {"10", []interface{}{int64(1)}}, + }, + }, + }, + } + err = d.Save(ctx, pilosaIdxEqA, itA) + require.NoError(err) + eqALookup, err := pilosaIdxEqA.Get(int64(1)) + require.NoError(err) + + values, err := lookupValues(eqALookup) + require.NoError(err) + expected := []string{"3", "4", "5", "10"} + require.Equal(expected, values) + + idxLessB, err := d.Create(db, table, "idx_less_b", makeExpressions(table, "b"), nil) + require.NoError(err) + pilosaIdxLessB, ok := idxLessB.(*pilosaIndex) + require.True(ok) + itB := &fixturePartitionKeyValueIter{ + fixtures: []partitionKeyValueFixture{ + { + testPartition(0), + []kvfixture{ + {"1", []interface{}{int64(1)}}, + {"2", []interface{}{int64(2)}}, + {"3", []interface{}{int64(3)}}, + {"4", []interface{}{int64(4)}}, + {"5", []interface{}{int64(5)}}, + {"6", []interface{}{int64(6)}}, + {"7", []interface{}{int64(7)}}, + {"8", []interface{}{int64(8)}}, + {"9", []interface{}{int64(9)}}, + {"10", []interface{}{int64(10)}}, + }, + }, + }, + } + err = d.Save(ctx, pilosaIdxLessB, itB) + require.NoError(err) + lessB, err := pilosaIdxLessB.AscendLessThan(int64(5)) + require.NoError(err) + lessBLookup := lessB.(*ascendLookup) + + values, err = lookupValues(lessBLookup) + require.NoError(err) + expected = []string{"1", "2", "3", "4"} + require.Equal(expected, values) + + interLookup := eqALookup.(sql.SetOperations).Intersection(lessBLookup) + values, err = lookupValues(interLookup) + require.NoError(err) + expected = []string{"3", "4"} + require.Equal(expected, values) +} +func TestPilosaHolder(t *testing.T) { + require := require.New(t) + setup(t) + defer cleanup(t) + + h := pilosa.NewHolder() + h.Path = tmpDir + err := h.Open() + require.NoError(err) + + idx1, err := h.CreateIndexIfNotExists("idx", pilosa.IndexOptions{}) + require.NoError(err) + err = idx1.Open() + require.NoError(err) + + f1, err := idx1.CreateFieldIfNotExists("f1", pilosa.OptFieldTypeDefault()) + require.NoError(err) + + _, err = f1.SetBit(0, 0, nil) + require.NoError(err) + _, err = f1.SetBit(0, 2, nil) + require.NoError(err) + r0, err := f1.Row(0) + require.NoError(err) + + _, err = f1.SetBit(1, 0, nil) + require.NoError(err) + _, err = f1.SetBit(1, 1, nil) + require.NoError(err) + r1, err := f1.Row(1) + require.NoError(err) + + _, err = f1.SetBit(2, 2, nil) + require.NoError(err) + _, err = f1.SetBit(2, 3, nil) + require.NoError(err) + r2, err := f1.Row(2) + require.NoError(err) + + row := r0.Intersect(r1).Union(r2) + cols := row.Columns() + require.Equal(3, len(cols)) + require.Equal(uint64(0), cols[0]) + require.Equal(uint64(2), cols[1]) + require.Equal(uint64(3), cols[2]) + + f2, err := idx1.CreateFieldIfNotExists("f2", pilosa.OptFieldTypeDefault()) + require.NoError(err) + + rowIDs := []uint64{0, 0, 1, 1} + colIDs := []uint64{1, 2, 0, 3} + err = f2.Import(rowIDs, colIDs, nil) + require.NoError(err) + + r0, err = f2.Row(0) + require.NoError(err) + + r1, err = f2.Row(1) + require.NoError(err) + + row = r0.Union(r1) + cols = row.Columns() + require.Equal(4, len(cols)) + require.Equal(uint64(0), cols[0]) + require.Equal(uint64(1), cols[1]) + require.Equal(uint64(2), cols[2]) + require.Equal(uint64(3), cols[3]) + + r1, err = f1.Row(1) + require.NoError(err) + r0, err = f2.Row(0) + require.NoError(err) + + row = r1.Intersect(r0) + cols = row.Columns() + require.Equal(1, len(cols)) + require.Equal(uint64(1), cols[0]) + + err = idx1.Close() + require.NoError(err) + // ------------------------------------------------------------------------- + + idx2, err := h.CreateIndexIfNotExists("idx", pilosa.IndexOptions{}) + require.NoError(err) + err = idx2.Open() + require.NoError(err) + + f1 = idx2.Field("f1") + + _, err = f1.Row(2) + require.NoError(err) + + f2 = idx2.Field("f2") + + r0, err = f2.Row(0) + require.NoError(err) + + r1, err = f2.Row(1) + require.NoError(err) + + row = r0.Union(r1) + cols = row.Columns() + require.Equal(4, len(cols)) + require.Equal(uint64(0), cols[0]) + require.Equal(uint64(1), cols[1]) + require.Equal(uint64(2), cols[2]) + require.Equal(uint64(3), cols[3]) + + err = idx2.Close() + require.NoError(err) + + err = h.Close() + require.NoError(err) +} + +func makeExpressions(table string, names ...string) []sql.Expression { + var expressions []sql.Expression + + for i, n := range names { + expressions = append(expressions, + expression.NewGetFieldWithTable(i, sql.Int64, table, n, true)) + } + + return expressions +} + +func randLocation(partition sql.Partition, offset int) string { + b := make([]byte, 1) + rand.Read(b) + return string(partition.Key()) + "-" + string(b) +} + +func slowRandLocation(partition sql.Partition, offset int) string { + defer time.Sleep(200 * time.Millisecond) + + return randLocation(partition, offset) +} + +func offsetLocation(partition sql.Partition, offset int) string { + return string(partition.Key()) + "-" + fmt.Sprint(offset) +} + +type testRecord struct { + values []interface{} + location []byte +} + +// test implementation of sql.IndexKeyValueIter interface +type testIndexKeyValueIter struct { + offset int + total int + expressions []string + location func(sql.Partition, int) string + partition sql.Partition + + records *[]testRecord +} + +func (it *testIndexKeyValueIter) Next() ([]interface{}, []byte, error) { + if it.offset >= it.total { + return nil, nil, io.EOF + } + + loc := it.location(it.partition, it.offset) + + values := make([]interface{}, len(it.expressions)) + for i, e := range it.expressions { + values[i] = e + "-" + loc + "-" + string(it.partition.Key()) + } + + (*it.records)[it.offset] = testRecord{ + values, + []byte(loc), + } + it.offset++ + + return values, []byte(loc), nil +} + +func (it *testIndexKeyValueIter) Close() error { + it.offset = 0 + it.records = nil + return nil +} + +func setupAscendDescend(t *testing.T) (*pilosaIndex, func()) { + t.Helper() + require := require.New(t) + setup(t) + + db, table, id := "db_name", "table_name", "index_id" + expressions := makeExpressions(table, "a", "b") + + d := NewDriver(tmpDir) + sqlIdx, err := d.Create(db, table, id, expressions, nil) + require.NoError(err) + + fixtures := []kvfixture{ + {"9", []interface{}{int64(2), int64(6)}}, + {"3", []interface{}{int64(7), int64(5)}}, + {"1", []interface{}{int64(1), int64(2)}}, + {"7", []interface{}{int64(1), int64(3)}}, + {"4", []interface{}{int64(7), int64(6)}}, + {"2", []interface{}{int64(10), int64(6)}}, + {"5", []interface{}{int64(5), int64(1)}}, + {"6", []interface{}{int64(6), int64(2)}}, + {"10", []interface{}{int64(4), int64(0)}}, + {"8", []interface{}{int64(3), int64(5)}}, + } + + it := &fixturePartitionKeyValueIter{ + fixtures: []partitionKeyValueFixture{ + {testPartition(0), fixtures}, + {testPartition(1), fixtures[4:]}, + }, + } + + err = d.Save(sql.NewEmptyContext(), sqlIdx, it) + require.NoError(err) + + return sqlIdx.(*pilosaIndex), func() { + cleanup(t) + } +} + +func lookupValues(lookup sql.IndexLookup) ([]string, error) { + iter, err := lookup.Values(testPartition(0)) + if err != nil { + return nil, err + } + + var result []string + for { + k, err := iter.Next() + if err == io.EOF { + break + } + + if err != nil { + return nil, err + } + + result = append(result, string(k)) + } + + return result, nil +} + +type partitionKeyValueFixture struct { + partition sql.Partition + kv []kvfixture +} + +type fixturePartitionKeyValueIter struct { + fixtures []partitionKeyValueFixture + pos int +} + +func (i *fixturePartitionKeyValueIter) Next() (sql.Partition, sql.IndexKeyValueIter, error) { + if i.pos >= len(i.fixtures) { + return nil, nil, io.EOF + } + + f := i.fixtures[i.pos] + i.pos++ + return f.partition, &fixtureKeyValueIter{ + fixtures: f.kv, + }, nil +} + +func (i *fixturePartitionKeyValueIter) Close() error { + i.pos = len(i.fixtures) + return nil +} + +type kvfixture struct { + key string + values []interface{} +} + +type fixtureKeyValueIter struct { + fixtures []kvfixture + pos int +} + +func (i *fixtureKeyValueIter) Next() ([]interface{}, []byte, error) { + if i.pos >= len(i.fixtures) { + return nil, nil, io.EOF + } + + f := i.fixtures[i.pos] + i.pos++ + return f.values, []byte(f.key), nil +} + +func (i *fixtureKeyValueIter) Close() error { return nil } + +type partitionKeyValueIter struct { + partitions int + + offset int + total int + expressions []string + location func(sql.Partition, int) string + + pos int + records [][]testRecord +} + +func (i *partitionKeyValueIter) init() { + i.records = make([][]testRecord, i.partitions) + for j := 0; j < i.partitions; j++ { + i.records[j] = make([]testRecord, i.total) + } +} + +func (i *partitionKeyValueIter) Next() (sql.Partition, sql.IndexKeyValueIter, error) { + if i.pos >= i.partitions { + return nil, nil, io.EOF + } + + if i.pos == 0 { + i.init() + } + + i.pos++ + return testPartition(i.pos - 1), &testIndexKeyValueIter{ + offset: i.offset, + total: i.total, + expressions: i.expressions, + location: i.location, + partition: testPartition(i.pos - 1), + records: &i.records[i.pos-1], + }, nil +} + +func (i *partitionKeyValueIter) Close() error { + i.pos = i.partitions + return nil +} + +type testPartition int + +func (p testPartition) Key() []byte { + return []byte(fmt.Sprint(p)) +} + +type partitionIter struct { + partitions int + pos int +} + +func (i *partitionIter) Next() (sql.Partition, error) { + if i.pos >= i.partitions { + return nil, io.EOF + } + + i.pos++ + return testPartition(i.pos), nil +} + +func (i *partitionIter) Close() error { + i.pos = i.partitions + return nil +} diff --git a/sql/index/pilosa/index.go b/sql/index/pilosa/index.go new file mode 100644 index 000000000..d29312b86 --- /dev/null +++ b/sql/index/pilosa/index.go @@ -0,0 +1,388 @@ +// +build !windows + +package pilosa + +import ( + "context" + "sync" + + "github.com/pilosa/pilosa" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/index" + errors "gopkg.in/src-d/go-errors.v1" +) + +// concurrentPilosaIndex is a wrapper of pilosa.Index that can be opened and closed +// concurrently. +type concurrentPilosaIndex struct { + *pilosa.Index + m sync.Mutex + rc int +} + +func newConcurrentPilosaIndex(idx *pilosa.Index) *concurrentPilosaIndex { + return &concurrentPilosaIndex{Index: idx} +} + +func (i *concurrentPilosaIndex) Open() error { + i.m.Lock() + defer i.m.Unlock() + + if i.rc == 0 { + if err := i.Index.Open(); err != nil { + return err + } + } + + i.rc++ + return nil +} + +func (i *concurrentPilosaIndex) Close() error { + i.m.Lock() + defer i.m.Unlock() + + i.rc-- + if i.rc < 0 { + i.rc = 0 + } + + if i.rc == 0 { + return i.Index.Close() + } + + return nil +} + +var ( + errInvalidKeys = errors.NewKind("expecting %d keys for index %q, got %d") +) + +// pilosaIndex is an pilosa implementation of sql.Index interface +type pilosaIndex struct { + index *concurrentPilosaIndex + mapping map[string]*mapping + cancel context.CancelFunc + wg sync.WaitGroup + + db string + table string + id string + expressions []string + checksum string +} + +func newPilosaIndex(idx *pilosa.Index, cfg *index.Config) *pilosaIndex { + var checksum string + for _, c := range cfg.Drivers { + if ch, ok := c[sql.ChecksumKey]; ok { + checksum = ch + } + break + } + + return &pilosaIndex{ + index: newConcurrentPilosaIndex(idx), + db: cfg.DB, + table: cfg.Table, + id: cfg.ID, + expressions: cfg.Expressions, + mapping: make(map[string]*mapping), + checksum: checksum, + } +} + +func (idx *pilosaIndex) Checksum() (string, error) { + return idx.checksum, nil +} + +// Get returns an IndexLookup for the given key in the index. +// If key parameter is not present then the returned iterator +// will go through all the locations on the index. +func (idx *pilosaIndex) Get(keys ...interface{}) (sql.IndexLookup, error) { + if len(keys) != len(idx.expressions) { + return nil, errInvalidKeys.New(len(idx.expressions), idx.ID(), len(keys)) + } + + return &indexLookup{ + id: idx.ID(), + index: idx.index, + mapping: idx.mapping, + keys: keys, + expressions: idx.expressions, + indexes: map[string]struct{}{ + idx.ID(): struct{}{}, + }, + }, nil +} + +// Has checks if the given key is present in the index mapping +func (idx *pilosaIndex) Has(p sql.Partition, key ...interface{}) (bool, error) { + mk := mappingKey(p) + m, ok := idx.mapping[mk] + if !ok { + return false, errMappingNotFound.New(mk) + } + + if err := m.open(); err != nil { + return false, err + } + defer m.close() + + for i, expr := range idx.expressions { + name := fieldName(idx.ID(), expr, p) + + val, err := m.get(name, key[i]) + if err != nil || val == nil { + return false, err + } + } + + return true, nil +} + +// Database returns the database name this index belongs to. +func (idx *pilosaIndex) Database() string { + return idx.db +} + +// Table returns the table name this index belongs to. +func (idx *pilosaIndex) Table() string { + return idx.table +} + +// ID returns the identifier of the index. +func (idx *pilosaIndex) ID() string { + return idx.id +} + +// Expressions returns the indexed expressions. If the result is more than +// one expression, it means the index has multiple columns indexed. If it's +// just one, it means it may be an expression or a column. +func (idx *pilosaIndex) Expressions() []string { + return idx.expressions +} + +func (*pilosaIndex) Driver() string { return DriverID } + +func (idx *pilosaIndex) AscendGreaterOrEqual(keys ...interface{}) (sql.IndexLookup, error) { + if len(keys) != len(idx.expressions) { + return nil, errInvalidKeys.New(len(idx.expressions), idx.ID(), len(keys)) + } + + return newAscendLookup(&filteredLookup{ + id: idx.ID(), + index: idx.index, + mapping: idx.mapping, + keys: keys, + expressions: idx.expressions, + indexes: map[string]struct{}{ + idx.ID(): struct{}{}, + }, + }, keys, nil), nil +} + +func (idx *pilosaIndex) AscendLessThan(keys ...interface{}) (sql.IndexLookup, error) { + if len(keys) != len(idx.expressions) { + return nil, errInvalidKeys.New(len(idx.expressions), idx.ID(), len(keys)) + } + + return newAscendLookup(&filteredLookup{ + id: idx.ID(), + index: idx.index, + mapping: idx.mapping, + keys: keys, + expressions: idx.expressions, + indexes: map[string]struct{}{ + idx.ID(): struct{}{}, + }, + }, nil, keys), nil +} + +func (idx *pilosaIndex) AscendRange(greaterOrEqual, lessThan []interface{}) (sql.IndexLookup, error) { + if len(greaterOrEqual) != len(idx.expressions) { + return nil, errInvalidKeys.New(len(idx.expressions), idx.ID(), len(greaterOrEqual)) + } + + if len(lessThan) != len(idx.expressions) { + return nil, errInvalidKeys.New(len(idx.expressions), idx.ID(), len(lessThan)) + } + + return newAscendLookup(&filteredLookup{ + id: idx.ID(), + index: idx.index, + mapping: idx.mapping, + expressions: idx.expressions, + indexes: map[string]struct{}{ + idx.ID(): struct{}{}, + }, + }, greaterOrEqual, lessThan), nil +} + +func (idx *pilosaIndex) DescendGreater(keys ...interface{}) (sql.IndexLookup, error) { + if len(keys) != len(idx.expressions) { + return nil, errInvalidKeys.New(len(idx.expressions), idx.ID(), len(keys)) + } + + return newDescendLookup(&filteredLookup{ + id: idx.ID(), + index: idx.index, + mapping: idx.mapping, + keys: keys, + expressions: idx.expressions, + reverse: true, + indexes: map[string]struct{}{ + idx.ID(): struct{}{}, + }, + }, keys, nil), nil +} + +func (idx *pilosaIndex) DescendLessOrEqual(keys ...interface{}) (sql.IndexLookup, error) { + if len(keys) != len(idx.expressions) { + return nil, errInvalidKeys.New(len(idx.expressions), idx.ID(), len(keys)) + } + + return newDescendLookup(&filteredLookup{ + id: idx.ID(), + index: idx.index, + mapping: idx.mapping, + keys: keys, + expressions: idx.expressions, + reverse: true, + indexes: map[string]struct{}{ + idx.ID(): struct{}{}, + }, + }, nil, keys), nil +} + +func (idx *pilosaIndex) DescendRange(lessOrEqual, greaterThan []interface{}) (sql.IndexLookup, error) { + if len(lessOrEqual) != len(idx.expressions) { + return nil, errInvalidKeys.New(len(idx.expressions), idx.ID(), len(lessOrEqual)) + } + + if len(greaterThan) != len(idx.expressions) { + return nil, errInvalidKeys.New(len(idx.expressions), idx.ID(), len(greaterThan)) + } + + return newDescendLookup(&filteredLookup{ + id: idx.ID(), + index: idx.index, + mapping: idx.mapping, + expressions: idx.expressions, + reverse: true, + indexes: map[string]struct{}{ + idx.ID(): struct{}{}, + }, + }, greaterThan, lessOrEqual), nil +} + +func (idx *pilosaIndex) Not(keys ...interface{}) (sql.IndexLookup, error) { + if len(keys) != len(idx.expressions) { + return nil, errInvalidKeys.New(len(idx.expressions), idx.ID(), len(keys)) + } + + return &negateLookup{ + id: idx.ID(), + index: idx.index, + mapping: idx.mapping, + keys: keys, + expressions: idx.expressions, + indexes: map[string]struct{}{ + idx.ID(): struct{}{}, + }, + }, nil +} + +func newAscendLookup(f *filteredLookup, gte []interface{}, lt []interface{}) *ascendLookup { + l := &ascendLookup{filteredLookup: f, gte: gte, lt: lt} + if l.filter == nil { + l.filter = func(i int, value []byte) (bool, error) { + var v interface{} + var err error + if len(l.gte) > 0 { + v, err = decodeGob(value, l.gte[i]) + if err != nil { + return false, err + } + + var cmp int + cmp, err = compare(v, l.gte[i]) + if err != nil { + return false, err + } + + if cmp < 0 { + return false, nil + } + } + + if len(l.lt) > 0 { + if v == nil { + v, err = decodeGob(value, l.lt[i]) + if err != nil { + return false, err + } + } + + cmp, err := compare(v, l.lt[i]) + if err != nil { + return false, err + } + + if cmp >= 0 { + return false, nil + } + } + + return true, nil + } + } + return l +} + +func newDescendLookup(f *filteredLookup, gt []interface{}, lte []interface{}) *descendLookup { + l := &descendLookup{filteredLookup: f, gt: gt, lte: lte} + if l.filter == nil { + l.filter = func(i int, value []byte) (bool, error) { + var v interface{} + var err error + if len(l.gt) > 0 { + v, err = decodeGob(value, l.gt[i]) + if err != nil { + return false, err + } + + var cmp int + cmp, err = compare(v, l.gt[i]) + if err != nil { + return false, err + } + + if cmp <= 0 { + return false, nil + } + } + + if len(l.lte) > 0 { + if v == nil { + v, err = decodeGob(value, l.lte[i]) + if err != nil { + return false, err + } + } + + cmp, err := compare(v, l.lte[i]) + if err != nil { + return false, err + } + + if cmp > 0 { + return false, nil + } + } + + return true, nil + } + } + return l +} diff --git a/sql/index/pilosa/iterator.go b/sql/index/pilosa/iterator.go new file mode 100644 index 000000000..d5b8eed41 --- /dev/null +++ b/sql/index/pilosa/iterator.go @@ -0,0 +1,100 @@ +// +build !windows + +package pilosa + +import ( + "io" + + "github.com/sirupsen/logrus" + bolt "go.etcd.io/bbolt" +) + +type locationValueIter struct { + locations [][]byte + pos int +} + +func (i *locationValueIter) Next() ([]byte, error) { + if i.pos >= len(i.locations) { + return nil, io.EOF + } + + i.pos++ + return i.locations[i.pos-1], nil +} + +func (i *locationValueIter) Close() error { + i.locations = nil + return nil +} + +type indexValueIter struct { + offset uint64 + total uint64 + bits []uint64 + mapping *mapping + indexName string + + // share transaction and bucket on all getLocation calls + bucket *bolt.Bucket + tx *bolt.Tx + closed bool +} + +func (it *indexValueIter) Next() ([]byte, error) { + if it.bucket == nil { + if err := it.mapping.open(); err != nil { + return nil, err + } + + bucket, err := it.mapping.getBucket(it.indexName, false) + if err != nil { + _ = it.Close() + return nil, err + } + + it.bucket = bucket + it.tx = bucket.Tx() + } + + if it.offset >= it.total { + if err := it.Close(); err != nil { + logrus.WithField("err", err.Error()). + Error("unable to close the pilosa index value iterator") + } + + if it.tx != nil { + _ = it.tx.Rollback() + } + + return nil, io.EOF + } + + var colID uint64 + if it.bits == nil { + colID = it.offset + } else { + colID = it.bits[it.offset] + } + + it.offset++ + + return it.mapping.getLocationFromBucket(it.bucket, colID) +} + +func (it *indexValueIter) Close() error { + if it.closed { + return nil + } + + it.closed = true + if it.tx != nil { + _ = it.tx.Rollback() + } + + if it.bucket != nil { + return it.mapping.close() + } + + return nil +} diff --git a/sql/index/pilosa/lookup.go b/sql/index/pilosa/lookup.go new file mode 100644 index 000000000..29173921b --- /dev/null +++ b/sql/index/pilosa/lookup.go @@ -0,0 +1,902 @@ +// +build !windows + +package pilosa + +import ( + "bytes" + "encoding/gob" + "io" + "sort" + "strings" + "time" + + "github.com/pilosa/pilosa" + "github.com/src-d/go-mysql-server/sql" + errors "gopkg.in/src-d/go-errors.v1" +) + +var ( + errUnknownType = errors.NewKind("unknown type %T received as value") + errTypeMismatch = errors.NewKind("cannot compare type %T with type %T") + errUnmergeableType = errors.NewKind("unmergeable type %T") + errMappingNotFound = errors.NewKind("mapping not found for partition: %s") + + // operation functors + // r1 AND r2 + intersect = func(r1, r2 *pilosa.Row) *pilosa.Row { + if r1 == nil { + return r2 + } + if r2 == nil { + return nil + } + return r1.Intersect(r2) + } + // r1 OR r2 + union = func(r1, r2 *pilosa.Row) *pilosa.Row { + if r1 == nil { + return r2 + } + if r2 == nil { + return r1 + } + + return r1.Union(r2) + } + // r1 AND NOT r2 + difference = func(r1, r2 *pilosa.Row) *pilosa.Row { + if r1 == nil { + return r2 + } + if r2 == nil { + return r1 + } + + return r1.Difference(r2) + } +) + +type ( + + // indexLookup implement following interfaces: + // sql.IndexLookup, sql.Mergeable, sql.SetOperations + indexLookup struct { + id string + index *concurrentPilosaIndex + mapping map[string]*mapping + keys []interface{} + expressions []string + operations []*lookupOperation + indexes map[string]struct{} + } + + lookupOperation struct { + lookup sql.IndexLookup + operation func(*pilosa.Row, *pilosa.Row) *pilosa.Row + } + + pilosaLookup interface { + indexName() string + values(sql.Partition) (*pilosa.Row, error) + } +) + +func (l *indexLookup) indexName() string { + return l.index.Name() +} + +func (l *indexLookup) intersectExpressions(p sql.Partition, m *mapping) (*pilosa.Row, error) { + var row *pilosa.Row + + for i, expr := range l.expressions { + field := l.index.Field(fieldName(l.id, expr, p)) + rowID, err := m.rowID(field.Name(), l.keys[i]) + if err == io.EOF { + continue + } + if err != nil { + return nil, err + } + + r, err := field.Row(rowID) + if err != nil { + return nil, err + } + + row = intersect(row, r) + } + return row, nil +} + +func (l *indexLookup) values(p sql.Partition) (*pilosa.Row, error) { + mk := mappingKey(p) + m, ok := l.mapping[mk] + if !ok { + return nil, errMappingNotFound.New(mk) + } + + if err := m.open(); err != nil { + return nil, err + } + defer m.close() + + if err := l.index.Open(); err != nil { + return nil, err + } + row, err := l.intersectExpressions(p, m) + if e := l.index.Close(); e != nil { + if err == nil { + err = e + } + } + if err != nil { + return nil, err + } + + // evaluate composition of operations + for _, op := range l.operations { + var ( + r *pilosa.Row + e error + ) + + il, ok := op.lookup.(pilosaLookup) + if !ok { + return nil, errUnmergeableType.New(op.lookup) + } + + r, e = il.values(p) + if e != nil { + return nil, e + } + + row = op.operation(row, r) + } + + return row, nil +} + +// Values implements sql.IndexLookup.Values +func (l *indexLookup) Values(p sql.Partition) (sql.IndexValueIter, error) { + mk := mappingKey(p) + m, ok := l.mapping[mk] + if !ok { + return nil, errMappingNotFound.New(mk) + } + + row, err := l.values(p) + if err != nil { + return nil, err + } + + if row == nil { + return &indexValueIter{ + mapping: m, + indexName: l.index.Name(), + }, nil + } + + bits := row.Columns() + return &indexValueIter{ + total: uint64(len(bits)), + bits: bits, + mapping: m, + indexName: l.index.Name(), + }, nil +} + +func (l *indexLookup) Indexes() []string { + return sortedIndexes(l.indexes) +} + +// IsMergeable implements sql.Mergeable interface. +func (l *indexLookup) IsMergeable(lookup sql.IndexLookup) bool { + if il, ok := lookup.(pilosaLookup); ok { + return il.indexName() == l.indexName() + } + + return false +} + +// Intersection implements sql.SetOperations interface +func (l *indexLookup) Intersection(lookups ...sql.IndexLookup) sql.IndexLookup { + lookup := *l + for _, li := range lookups { + for _, idx := range li.Indexes() { + lookup.indexes[idx] = struct{}{} + } + lookup.operations = append(lookup.operations, &lookupOperation{li, intersect}) + } + + return &lookup +} + +// Union implements sql.SetOperations interface +func (l *indexLookup) Union(lookups ...sql.IndexLookup) sql.IndexLookup { + lookup := *l + for _, li := range lookups { + for _, idx := range li.Indexes() { + lookup.indexes[idx] = struct{}{} + } + lookup.operations = append(lookup.operations, &lookupOperation{li, union}) + } + + return &lookup +} + +// Difference implements sql.SetOperations interface +func (l *indexLookup) Difference(lookups ...sql.IndexLookup) sql.IndexLookup { + lookup := *l + for _, li := range lookups { + for _, idx := range li.Indexes() { + lookup.indexes[idx] = struct{}{} + } + lookup.operations = append(lookup.operations, &lookupOperation{li, difference}) + } + + return &lookup +} + +type filteredLookup struct { + id string + index *concurrentPilosaIndex + mapping map[string]*mapping + keys []interface{} + expressions []string + operations []*lookupOperation + indexes map[string]struct{} + + reverse bool + filter func(int, []byte) (bool, error) +} + +func (l *filteredLookup) indexName() string { + return l.index.Name() +} + +// evaluate Intersection of bitmaps +func (l *filteredLookup) intersectExpressions(p sql.Partition, m *mapping) (*pilosa.Row, error) { + var row *pilosa.Row + + for i, expr := range l.expressions { + field := l.index.Field(fieldName(l.id, expr, p)) + rows, err := m.filter(field.Name(), func(b []byte) (bool, error) { + return l.filter(i, b) + }) + if err != nil { + return nil, err + } + + var r *pilosa.Row + for _, ri := range rows { + rr, err := field.Row(ri) + if err != nil { + return nil, err + } + r = union(r, rr) + } + + row = intersect(row, r) + } + + return row, nil +} + +func (l *filteredLookup) values(p sql.Partition) (*pilosa.Row, error) { + mk := mappingKey(p) + m, ok := l.mapping[mk] + if !ok { + return nil, errMappingNotFound.New(mk) + } + + if err := m.open(); err != nil { + return nil, err + } + defer m.close() + + if err := l.index.Open(); err != nil { + return nil, err + } + row, err := l.intersectExpressions(p, m) + if e := l.index.Close(); e != nil { + if err == nil { + err = e + } + } + if err != nil { + return nil, err + } + + // evaluate composition of operations + for _, op := range l.operations { + var ( + r *pilosa.Row + e error + ) + + il, ok := op.lookup.(pilosaLookup) + if !ok { + return nil, errUnmergeableType.New(op.lookup) + } + + r, e = il.values(p) + if e != nil { + return nil, e + } + if r == nil { + continue + } + + row = op.operation(row, r) + } + + return row, nil +} + +func (l *filteredLookup) Values(p sql.Partition) (sql.IndexValueIter, error) { + mk := mappingKey(p) + m, ok := l.mapping[mk] + if !ok { + return nil, errMappingNotFound.New(mk) + } + row, err := l.values(p) + if err != nil { + return nil, err + } + + if row == nil { + return &indexValueIter{ + mapping: m, + indexName: l.index.Name(), + }, nil + } + + bits := row.Columns() + if err = m.open(); err != nil { + return nil, err + } + + defer m.close() + locations, err := m.sortedLocations(l.index.Name(), bits, l.reverse) + if err != nil { + return nil, err + } + + return &locationValueIter{locations: locations}, nil +} + +func (l *filteredLookup) Indexes() []string { + return sortedIndexes(l.indexes) +} + +// IsMergeable implements sql.Mergeable interface. +func (l *filteredLookup) IsMergeable(lookup sql.IndexLookup) bool { + if il, ok := lookup.(pilosaLookup); ok { + return il.indexName() == l.indexName() + } + return false +} + +// Intersection implements sql.SetOperations interface +func (l *filteredLookup) Intersection(lookups ...sql.IndexLookup) sql.IndexLookup { + lookup := *l + for _, li := range lookups { + for _, idx := range li.Indexes() { + lookup.indexes[idx] = struct{}{} + } + lookup.operations = append(lookup.operations, &lookupOperation{li, intersect}) + } + + return &lookup +} + +// Union implements sql.SetOperations interface +func (l *filteredLookup) Union(lookups ...sql.IndexLookup) sql.IndexLookup { + lookup := *l + for _, li := range lookups { + for _, idx := range li.Indexes() { + lookup.indexes[idx] = struct{}{} + } + lookup.operations = append(lookup.operations, &lookupOperation{li, union}) + } + + return &lookup +} + +// Difference implements sql.SetOperations interface +func (l *filteredLookup) Difference(lookups ...sql.IndexLookup) sql.IndexLookup { + lookup := *l + for _, li := range lookups { + for _, idx := range li.Indexes() { + lookup.indexes[idx] = struct{}{} + } + lookup.operations = append(lookup.operations, &lookupOperation{li, difference}) + } + + return &lookup +} + +type ascendLookup struct { + *filteredLookup + gte []interface{} + lt []interface{} +} + +type descendLookup struct { + *filteredLookup + gt []interface{} + lte []interface{} +} + +type negateLookup struct { + id string + index *concurrentPilosaIndex + mapping map[string]*mapping + keys []interface{} + expressions []string + indexes map[string]struct{} + operations []*lookupOperation +} + +func (l *negateLookup) indexName() string { return l.index.Name() } + +func (l *negateLookup) intersectExpressions(p sql.Partition, m *mapping) (*pilosa.Row, error) { + var row *pilosa.Row + for i, expr := range l.expressions { + field := l.index.Field(fieldName(l.id, expr, p)) + + maxRowID, err := m.getMaxRowID(field.Name()) + if err != nil { + return nil, err + } + + // Since Pilosa does not have a negation in PQL (see: + // https://github.com/pilosa/pilosa/issues/807), we have to get all the + // ones in all the rows and join them, and then make difference between + // them and the ones in the row of the given value. + var r *pilosa.Row + // rowIDs start with 1 + for ri := uint64(1); ri <= maxRowID; ri++ { + var rr *pilosa.Row + rr, err = field.Row(ri) + if err != nil { + return nil, err + } + r = union(r, rr) + } + + rowID, err := m.rowID(field.Name(), l.keys[i]) + if err != nil && err != io.EOF { + return nil, err + } + + rr, err := field.Row(rowID) + if err != nil { + return nil, err + } + r = difference(r, rr) + + row = intersect(row, r) + } + return row, nil +} + +func (l *negateLookup) values(p sql.Partition) (*pilosa.Row, error) { + mk := mappingKey(p) + m, ok := l.mapping[mk] + if !ok { + return nil, errMappingNotFound.New(mk) + } + + if err := m.open(); err != nil { + return nil, err + } + defer m.close() + + if err := l.index.Open(); err != nil { + return nil, err + } + row, err := l.intersectExpressions(p, m) + if e := l.index.Close(); e != nil { + if err == nil { + err = e + } + } + if err != nil { + return nil, err + } + + // evaluate composition of operations + for _, op := range l.operations { + var ( + r *pilosa.Row + e error + ) + + il, ok := op.lookup.(pilosaLookup) + if !ok { + return nil, errUnmergeableType.New(op.lookup) + } + + r, e = il.values(p) + if e != nil { + return nil, e + } + + if r == nil { + continue + } + + row = op.operation(row, r) + } + + return row, nil +} + +// Values implements sql.IndexLookup.Values +func (l *negateLookup) Values(p sql.Partition) (sql.IndexValueIter, error) { + mk := mappingKey(p) + m, ok := l.mapping[mk] + if !ok { + return nil, errMappingNotFound.New(mk) + } + row, err := l.values(p) + if err != nil { + return nil, err + } + + if row == nil { + return &indexValueIter{ + mapping: m, + indexName: l.index.Name(), + }, nil + } + + bits := row.Columns() + return &indexValueIter{ + total: uint64(len(bits)), + bits: bits, + mapping: m, + indexName: l.index.Name(), + }, nil +} + +func (l *negateLookup) Indexes() []string { + return sortedIndexes(l.indexes) +} + +// IsMergeable implements sql.Mergeable interface. +func (l *negateLookup) IsMergeable(lookup sql.IndexLookup) bool { + if il, ok := lookup.(pilosaLookup); ok { + return il.indexName() == l.indexName() + } + + return false +} + +// Intersection implements sql.SetOperations interface +func (l *negateLookup) Intersection(lookups ...sql.IndexLookup) sql.IndexLookup { + lookup := *l + for _, li := range lookups { + for _, idx := range li.Indexes() { + lookup.indexes[idx] = struct{}{} + } + lookup.operations = append(lookup.operations, &lookupOperation{li, intersect}) + } + + return &lookup +} + +// Union implements sql.SetOperations interface +func (l *negateLookup) Union(lookups ...sql.IndexLookup) sql.IndexLookup { + lookup := *l + for _, li := range lookups { + for _, idx := range li.Indexes() { + lookup.indexes[idx] = struct{}{} + } + lookup.operations = append(lookup.operations, &lookupOperation{li, union}) + } + + return &lookup +} + +// Difference implements sql.SetOperations interface +func (l *negateLookup) Difference(lookups ...sql.IndexLookup) sql.IndexLookup { + lookup := *l + for _, li := range lookups { + for _, idx := range li.Indexes() { + lookup.indexes[idx] = struct{}{} + } + lookup.operations = append(lookup.operations, &lookupOperation{li, difference}) + } + + return &lookup +} + +func decodeGob(k []byte, value interface{}) (interface{}, error) { + decoder := gob.NewDecoder(bytes.NewBuffer(k)) + + switch value.(type) { + case string: + var v string + err := decoder.Decode(&v) + return v, err + case int8: + var v int8 + err := decoder.Decode(&v) + return v, err + case int16: + var v int16 + err := decoder.Decode(&v) + return v, err + case int32: + var v int32 + err := decoder.Decode(&v) + return v, err + case int64: + var v int64 + err := decoder.Decode(&v) + return v, err + case uint8: + var v uint8 + err := decoder.Decode(&v) + return v, err + case uint16: + var v uint16 + err := decoder.Decode(&v) + return v, err + case uint32: + var v uint32 + err := decoder.Decode(&v) + return v, err + case uint64: + var v uint64 + err := decoder.Decode(&v) + return v, err + case float64: + var v float64 + err := decoder.Decode(&v) + return v, err + case time.Time: + var v time.Time + err := decoder.Decode(&v) + return v, err + case []byte: + var v []byte + err := decoder.Decode(&v) + return v, err + case bool: + var v bool + err := decoder.Decode(&v) + return v, err + case []interface{}: + var v []interface{} + err := decoder.Decode(&v) + return v, err + default: + return nil, errUnknownType.New(value) + } +} + +// compare two values of the same underlying type. The values MUST be of the +// same type. +func compare(a, b interface{}) (int, error) { + switch a := a.(type) { + case bool: + v, ok := b.(bool) + if !ok { + return 0, errTypeMismatch.New(a, b) + } + + if a == v { + return 0, nil + } + + if a == false { + return -1, nil + } + + return 1, nil + case string: + v, ok := b.(string) + if !ok { + return 0, errTypeMismatch.New(a, b) + } + + return strings.Compare(a, v), nil + case int8: + v, ok := b.(int8) + if !ok { + return 0, errTypeMismatch.New(a, b) + } + + if a == v { + return 0, nil + } + + if a < v { + return -1, nil + } + + return 1, nil + case int16: + v, ok := b.(int16) + if !ok { + return 0, errTypeMismatch.New(a, b) + } + + if a == v { + return 0, nil + } + + if a < v { + return -1, nil + } + + return 1, nil + case int32: + v, ok := b.(int32) + if !ok { + return 0, errTypeMismatch.New(a, b) + } + + if a == v { + return 0, nil + } + + if a < v { + return -1, nil + } + + return 1, nil + case int64: + v, ok := b.(int64) + if !ok { + return 0, errTypeMismatch.New(a, b) + } + + if a == v { + return 0, nil + } + + if a < v { + return -1, nil + } + + return 1, nil + case uint8: + v, ok := b.(uint8) + if !ok { + return 0, errTypeMismatch.New(a, b) + } + + if a == v { + return 0, nil + } + + if a < v { + return -1, nil + } + + return 1, nil + case uint16: + v, ok := b.(uint16) + if !ok { + return 0, errTypeMismatch.New(a, b) + } + + if a == v { + return 0, nil + } + + if a < v { + return -1, nil + } + + return 1, nil + case uint32: + v, ok := b.(uint32) + if !ok { + return 0, errTypeMismatch.New(a, b) + } + + if a == v { + return 0, nil + } + + if a < v { + return -1, nil + } + + return 1, nil + case uint64: + v, ok := b.(uint64) + if !ok { + return 0, errTypeMismatch.New(a, b) + } + + if a == v { + return 0, nil + } + + if a < v { + return -1, nil + } + + return 1, nil + case float64: + v, ok := b.(float64) + if !ok { + return 0, errTypeMismatch.New(a, b) + } + + if a == v { + return 0, nil + } + + if a < v { + return -1, nil + } + + return 1, nil + case []byte: + v, ok := b.([]byte) + if !ok { + return 0, errTypeMismatch.New(a, b) + } + return bytes.Compare(a, v), nil + case []interface{}: + v, ok := b.([]interface{}) + if !ok { + return 0, errTypeMismatch.New(a, b) + } + + if len(a) < len(v) { + return -1, nil + } + + if len(a) > len(v) { + return 1, nil + } + + for i := range a { + cmp, err := compare(a[i], v[i]) + if err != nil { + return 0, err + } + + if cmp != 0 { + return cmp, nil + } + } + + return 0, nil + case time.Time: + v, ok := b.(time.Time) + if !ok { + return 0, errTypeMismatch.New(a, b) + } + + if a.Equal(v) { + return 0, nil + } + + if a.Before(v) { + return -1, nil + } + + return 1, nil + default: + return 0, errUnknownType.New(a) + } +} + +func sortedIndexes(indexes map[string]struct{}) []string { + var result = make([]string, 0, len(indexes)) + for idx := range indexes { + result = append(result, idx) + } + + sort.Strings(result) + return result +} diff --git a/sql/index/pilosa/lookup_test.go b/sql/index/pilosa/lookup_test.go new file mode 100644 index 000000000..b93da3ca2 --- /dev/null +++ b/sql/index/pilosa/lookup_test.go @@ -0,0 +1,281 @@ +// +build !windows + +package pilosa + +import ( + "bytes" + "encoding/gob" + "fmt" + "os" + "testing" + "time" + + "github.com/pilosa/pilosa" + "github.com/src-d/go-mysql-server/sql" + "github.com/stretchr/testify/require" + errors "gopkg.in/src-d/go-errors.v1" +) + +func TestCompare(t *testing.T) { + now := time.Now() + testCases := []struct { + a, b interface{} + err *errors.Kind + expected int + }{ + {true, true, nil, 0}, + {false, true, nil, -1}, + {true, false, nil, 1}, + {false, false, nil, 0}, + {true, 0, errTypeMismatch, 0}, + + {"a", "b", nil, -1}, + {"b", "a", nil, 1}, + {"a", "a", nil, 0}, + {"a", 1, errTypeMismatch, 0}, + + {int32(1), int32(2), nil, -1}, + {int32(2), int32(1), nil, 1}, + {int32(2), int32(2), nil, 0}, + {int32(1), "", errTypeMismatch, 0}, + + {int64(1), int64(2), nil, -1}, + {int64(2), int64(1), nil, 1}, + {int64(2), int64(2), nil, 0}, + {int64(1), "", errTypeMismatch, 0}, + + {uint32(1), uint32(2), nil, -1}, + {uint32(2), uint32(1), nil, 1}, + {uint32(2), uint32(2), nil, 0}, + {uint32(1), "", errTypeMismatch, 0}, + + {uint64(1), uint64(2), nil, -1}, + {uint64(2), uint64(1), nil, 1}, + {uint64(2), uint64(2), nil, 0}, + {uint64(1), "", errTypeMismatch, 0}, + + {float64(1), float64(2), nil, -1}, + {float64(2), float64(1), nil, 1}, + {float64(2), float64(2), nil, 0}, + {float64(1), "", errTypeMismatch, 0}, + + {now.Add(-1 * time.Hour), now, nil, -1}, + {now, now.Add(-1 * time.Hour), nil, 1}, + {now, now, nil, 0}, + {now, 1, errTypeMismatch, -1}, + + {[]interface{}{"a", "a"}, []interface{}{"a", "b"}, nil, -1}, + {[]interface{}{"a", "b"}, []interface{}{"a", "a"}, nil, 1}, + {[]interface{}{"a", "a"}, []interface{}{"a", "a"}, nil, 0}, + {[]interface{}{"b"}, []interface{}{"a", "b"}, nil, -1}, + {[]interface{}{"b"}, 1, errTypeMismatch, -1}, + + {[]byte{0, 1}, []byte{1, 1}, nil, -1}, + {[]byte{1, 1}, []byte{0, 1}, nil, 1}, + {[]byte{1, 1}, []byte{1, 1}, nil, 0}, + {[]byte{1}, []byte{0, 1}, nil, 1}, + {[]byte{0, 1}, 1, errTypeMismatch, -1}, + + {time.Duration(0), nil, errUnknownType, -1}, + } + + for _, tt := range testCases { + name := fmt.Sprintf("(%T)(%v) and (%T)(%v)", tt.a, tt.a, tt.b, tt.b) + t.Run(name, func(t *testing.T) { + require := require.New(t) + cmp, err := compare(tt.a, tt.b) + if tt.err != nil { + require.Error(err) + require.True(tt.err.Is(err)) + } else { + require.NoError(err) + require.Equal(tt.expected, cmp) + } + }) + } +} + +func TestDecodeGob(t *testing.T) { + testCases := []interface{}{ + "foo", + int8(1), + int16(1), + int32(1), + int64(1), + uint8(1), + uint16(1), + uint32(1), + uint64(1), + float64(1), + true, + time.Date(2018, time.August, 1, 1, 1, 1, 1, time.Local), + []byte("foo"), + []interface{}{1, 3, 3, 7}, + } + + for _, tt := range testCases { + name := fmt.Sprintf("(%T)(%v)", tt, tt) + t.Run(name, func(t *testing.T) { + require := require.New(t) + + var buf bytes.Buffer + require.NoError(gob.NewEncoder(&buf).Encode(tt)) + + result, err := decodeGob(buf.Bytes(), tt) + require.NoError(err) + require.Equal(tt, result) + }) + } +} + +func TestMergeable(t *testing.T) { + require := require.New(t) + h := pilosa.NewHolder() + h.Path = os.TempDir() + + i1, err := h.CreateIndexIfNotExists("i1", pilosa.IndexOptions{}) + require.NoError(err) + i2, err := h.CreateIndexIfNotExists("i2", pilosa.IndexOptions{}) + require.NoError(err) + + testCases := []struct { + i1 sql.IndexLookup + i2 sql.IndexLookup + expected bool + }{ + { + i1: &indexLookup{index: newConcurrentPilosaIndex(i1)}, + i2: &indexLookup{index: newConcurrentPilosaIndex(i1)}, + expected: true, + }, + { + i1: &indexLookup{index: newConcurrentPilosaIndex(i1)}, + i2: &indexLookup{index: newConcurrentPilosaIndex(i2)}, + expected: false, + }, + { + i1: &indexLookup{index: newConcurrentPilosaIndex(i1)}, + i2: &ascendLookup{filteredLookup: &filteredLookup{index: newConcurrentPilosaIndex(i1)}}, + expected: true, + }, + { + i1: &descendLookup{filteredLookup: &filteredLookup{index: newConcurrentPilosaIndex(i1)}}, + i2: &ascendLookup{filteredLookup: &filteredLookup{index: newConcurrentPilosaIndex(i1)}}, + expected: true, + }, + { + i1: &descendLookup{filteredLookup: &filteredLookup{index: newConcurrentPilosaIndex(i1)}}, + i2: &indexLookup{index: newConcurrentPilosaIndex(i2)}, + expected: false, + }, + { + i1: &descendLookup{filteredLookup: &filteredLookup{index: newConcurrentPilosaIndex(i1)}}, + i2: &descendLookup{filteredLookup: &filteredLookup{index: newConcurrentPilosaIndex(i2)}}, + expected: false, + }, + { + i1: &negateLookup{index: newConcurrentPilosaIndex(i1)}, + i2: &negateLookup{index: newConcurrentPilosaIndex(i1)}, + expected: true, + }, + { + i1: &negateLookup{index: newConcurrentPilosaIndex(i1)}, + i2: &negateLookup{index: newConcurrentPilosaIndex(i2)}, + expected: false, + }, + { + i1: &negateLookup{index: newConcurrentPilosaIndex(i1)}, + i2: &indexLookup{index: newConcurrentPilosaIndex(i1)}, + expected: true, + }, + { + i1: &negateLookup{index: newConcurrentPilosaIndex(i1)}, + i2: &descendLookup{filteredLookup: &filteredLookup{index: newConcurrentPilosaIndex(i1)}}, + expected: true, + }, + { + i1: &negateLookup{index: newConcurrentPilosaIndex(i1)}, + i2: &ascendLookup{filteredLookup: &filteredLookup{index: newConcurrentPilosaIndex(i1)}}, + expected: true, + }, + } + + for _, tc := range testCases { + m1, ok := tc.i1.(sql.Mergeable) + require.True(ok) + + require.Equal(tc.expected, m1.IsMergeable(tc.i2)) + } +} + +func TestIndexes(t *testing.T) { + testCases := []sql.IndexLookup{ + &indexLookup{id: "foo", indexes: map[string]struct{}{"foo": struct{}{}}}, + &negateLookup{id: "foo", indexes: map[string]struct{}{"foo": struct{}{}}}, + &ascendLookup{ + filteredLookup: &filteredLookup{ + id: "foo", + indexes: map[string]struct{}{"foo": struct{}{}}, + }, + }, + &descendLookup{ + filteredLookup: &filteredLookup{ + id: "foo", + indexes: map[string]struct{}{"foo": struct{}{}}, + }, + }, + } + + for _, tt := range testCases { + t.Run(fmt.Sprintf("%T", tt), func(t *testing.T) { + require.Equal(t, []string{"foo"}, tt.Indexes()) + }) + } +} + +func TestLookupIndexes(t *testing.T) { + require := require.New(t) + + lookups := []sql.IndexLookup{ + &indexLookup{ + id: "1", + indexes: map[string]struct{}{"1": struct{}{}}, + }, + &negateLookup{ + id: "2", + indexes: map[string]struct{}{"2": struct{}{}}, + }, + &ascendLookup{filteredLookup: &filteredLookup{ + id: "3", + indexes: map[string]struct{}{"3": struct{}{}}, + }}, + &descendLookup{filteredLookup: &filteredLookup{ + id: "4", + indexes: map[string]struct{}{"4": struct{}{}}, + }}, + &filteredLookup{ + id: "5", + indexes: map[string]struct{}{"5": struct{}{}}, + }, + } + + expected := []string{"1", "2", "3", "4", "5"} + + // All possible permutations of operations between all the different kinds + // of lookups are tested. + for i := 0; i < len(lookups); i++ { + var op sql.SetOperations + var others []sql.IndexLookup + for j := 0; j < len(lookups); j++ { + if i == j { + op = lookups[i].(sql.SetOperations) + } else { + others = append(others, lookups[j]) + } + } + + require.Equal(expected, op.Union(others...).Indexes()) + require.Equal(expected, op.Difference(others...).Indexes()) + require.Equal(expected, op.Intersection(others...).Indexes()) + } +} diff --git a/sql/index/pilosa/mapping.go b/sql/index/pilosa/mapping.go new file mode 100644 index 000000000..8e33d3f2f --- /dev/null +++ b/sql/index/pilosa/mapping.go @@ -0,0 +1,393 @@ +// +build !windows + +package pilosa + +import ( + "bytes" + "encoding/binary" + "encoding/gob" + "fmt" + "io" + "sort" + "sync" + + "github.com/src-d/go-mysql-server/sql" + bolt "go.etcd.io/bbolt" +) + +// mapping +// buckets: +// - index name: columndID uint64 -> location []byte +// - field name: value []byte (gob encoding) -> rowID uint64 +type mapping struct { + path string + + mut sync.RWMutex + db *bolt.DB + + // in create mode there's only one transaction closed explicitly by + // commit function + create bool + tx *bolt.Tx + + clientMut sync.Mutex + clients int +} + +func newMapping(path string) *mapping { + return &mapping{path: path} +} + +func (m *mapping) open() error { + return m.openCreate(false) +} + +// openCreate opens and sets creation mode in the database. +func (m *mapping) openCreate(create bool) error { + m.clientMut.Lock() + defer m.clientMut.Unlock() + m.mut.Lock() + defer m.mut.Unlock() + + if m.clients == 0 && m.db == nil { + var err error + m.db, err = bolt.Open(m.path, 0640, nil) + if err != nil { + return err + } + } + + m.clients++ + m.create = create + return nil +} + +func (m *mapping) close() error { + m.clientMut.Lock() + defer m.clientMut.Unlock() + m.mut.Lock() + defer m.mut.Unlock() + + if m.clients > 1 { + m.clients-- + return nil + } + + m.clients = 0 + + if m.db != nil { + if err := m.db.Close(); err != nil { + return err + } + m.db = nil + } + + return nil +} + +func (m *mapping) rowID(fieldName string, value interface{}) (uint64, error) { + val, err := m.get(fieldName, value) + if err != nil { + return 0, err + } + if val == nil { + return 0, io.EOF + } + + return binary.LittleEndian.Uint64(val), err +} + +// commit saves current transaction, if cont is true a new transaction will be +// created again in the next query. Only for create mode. +func (m *mapping) commit(cont bool) error { + m.clientMut.Lock() + defer m.clientMut.Unlock() + + var err error + if m.create && m.tx != nil { + err = m.tx.Commit() + } + + m.create = cont + m.tx = nil + + return err +} + +func (m *mapping) rollback() error { + m.clientMut.Lock() + defer m.clientMut.Unlock() + + var err error + if m.create && m.tx != nil { + err = m.tx.Rollback() + } + + m.create = false + m.tx = nil + + return err +} + +func (m *mapping) transaction(writable bool, f func(*bolt.Tx) error) error { + m.clientMut.Lock() + defer m.clientMut.Unlock() + + var tx *bolt.Tx + var err error + if m.create { + if m.tx == nil { + m.tx, err = m.db.Begin(true) + if err != nil { + return err + } + } + + tx = m.tx + } else { + tx, err = m.db.Begin(writable) + if err != nil { + return err + } + } + + err = f(tx) + if m.create { + return err + } + + if err != nil { + tx.Rollback() + return err + } + + return tx.Commit() +} + +func (m *mapping) getRowID(fieldName string, value interface{}) (uint64, error) { + var id uint64 + var buf bytes.Buffer + enc := gob.NewEncoder(&buf) + err := enc.Encode(value) + if err != nil { + return 0, err + } + + err = m.transaction(true, func(tx *bolt.Tx) error { + b, err := tx.CreateBucketIfNotExists([]byte(fieldName)) + if err != nil { + return err + } + + key := buf.Bytes() + val := b.Get(key) + if val != nil { + id = binary.LittleEndian.Uint64(val) + return nil + } + + // the first NextSequence is 1 so the first id will be 1 + // this can only fail if the transaction is closed + id, _ = b.NextSequence() + + val = make([]byte, 8) + binary.LittleEndian.PutUint64(val, id) + err = b.Put(key, val) + return err + }) + + if err != nil { + return 0, err + } + + return id, err +} + +func (m *mapping) getMaxRowID(fieldName string) (uint64, error) { + var id uint64 + err := m.transaction(true, func(tx *bolt.Tx) error { + b := tx.Bucket([]byte(fieldName)) + if b == nil { + return nil + } + + id = b.Sequence() + return nil + }) + + return id, err +} + +func (m *mapping) putLocation( + indexName string, + colID uint64, + location []byte, +) error { + return m.transaction(true, func(tx *bolt.Tx) error { + b, err := tx.CreateBucketIfNotExists([]byte(indexName)) + if err != nil { + return err + } + + key := make([]byte, 8) + binary.LittleEndian.PutUint64(key, colID) + + return b.Put(key, location) + }) +} + +func (m *mapping) sortedLocations( + indexName string, + cols []uint64, + reverse bool, +) ([][]byte, error) { + var result [][]byte + m.mut.RLock() + defer m.mut.RUnlock() + err := m.db.View(func(tx *bolt.Tx) error { + bucket := []byte(indexName) + b := tx.Bucket(bucket) + if b == nil { + return fmt.Errorf("bucket %s not found", bucket) + } + + for _, col := range cols { + key := make([]byte, 8) + binary.LittleEndian.PutUint64(key, col) + val := b.Get(key) + + // val will point to mmap addresses, so we need to copy the slice + dst := make([]byte, len(val)) + copy(dst, val) + result = append(result, dst) + } + + return nil + }) + + if err != nil { + return nil, err + } + + if reverse { + sort.Stable(sort.Reverse(byBytes(result))) + } else { + sort.Stable(byBytes(result)) + } + + return result, nil +} + +type byBytes [][]byte + +func (b byBytes) Len() int { return len(b) } +func (b byBytes) Swap(i, j int) { b[i], b[j] = b[j], b[i] } +func (b byBytes) Less(i, j int) bool { return bytes.Compare(b[i], b[j]) < 0 } + +func (m *mapping) getLocation( + indexName string, + colID uint64, +) ([]byte, error) { + var location []byte + + err := m.transaction(true, func(tx *bolt.Tx) error { + bucket := []byte(indexName) + b := tx.Bucket(bucket) + if b == nil { + return fmt.Errorf("bucket %s not found", bucket) + } + + key := make([]byte, 8) + binary.LittleEndian.PutUint64(key, colID) + + location = b.Get(key) + return nil + }) + + return location, err +} + +func (m *mapping) getLocationFromBucket( + bucket *bolt.Bucket, + colID uint64, +) ([]byte, error) { + key := make([]byte, 8) + binary.LittleEndian.PutUint64(key, colID) + return bucket.Get(key), nil +} + +func (m *mapping) getBucket( + indexName string, + writable bool, +) (*bolt.Bucket, error) { + var bucket *bolt.Bucket + + tx, err := m.db.Begin(writable) + if err != nil { + return nil, err + } + + bu := []byte(indexName) + bucket = tx.Bucket(bu) + if bucket == nil { + _ = tx.Rollback() + return nil, fmt.Errorf("bucket %s not found", bu) + } + + return bucket, err +} + +func (m *mapping) get(name string, key interface{}) ([]byte, error) { + var value []byte + + var buf bytes.Buffer + enc := gob.NewEncoder(&buf) + err := enc.Encode(key) + if err != nil { + return nil, err + } + + err = m.transaction(true, func(tx *bolt.Tx) error { + b := tx.Bucket([]byte(name)) + if b != nil { + value = b.Get(buf.Bytes()) + return nil + } + + return fmt.Errorf("%s not found", name) + }) + + return value, err +} + +func (m *mapping) filter(name string, fn func([]byte) (bool, error)) ([]uint64, error) { + var result []uint64 + + m.mut.RLock() + defer m.mut.RUnlock() + err := m.db.View(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte(name)) + if b == nil { + return nil + } + + return b.ForEach(func(k, v []byte) error { + ok, err := fn(k) + if err != nil { + return err + } + + if ok { + result = append(result, binary.LittleEndian.Uint64(v)) + } + + return nil + }) + }) + + return result, err +} + +func mappingKey(p sql.Partition) string { + return fmt.Sprintf("%x", p.Key()) +} diff --git a/sql/index/pilosa/mapping_test.go b/sql/index/pilosa/mapping_test.go new file mode 100644 index 000000000..e2a37c035 --- /dev/null +++ b/sql/index/pilosa/mapping_test.go @@ -0,0 +1,91 @@ +// +build !windows + +package pilosa + +import ( + "encoding/binary" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRowID(t *testing.T) { + require := require.New(t) + setup(t) + defer cleanup(t) + m := newMapping(filepath.Join(tmpDir, "id.map")) + require.NoError(m.open()) + defer m.close() + + cases := []int{0, 1, 2, 3, 4, 5, 5, 0, 3, 2, 1, 5} + expected := []uint64{1, 2, 3, 4, 5, 6, 6, 1, 4, 3, 2, 6} + + for i, c := range cases { + rowID, err := m.getRowID("frame name", c) + require.NoError(err) + require.Equal(expected[i], rowID) + } + + maxRowID, err := m.getMaxRowID("frame name") + require.NoError(err) + require.Equal(uint64(6), maxRowID) +} + +func TestLocation(t *testing.T) { + require := require.New(t) + setup(t) + defer cleanup(t) + + m := newMapping(filepath.Join(tmpDir, "id.map")) + require.NoError(m.open()) + defer m.close() + + cases := map[uint64]string{ + 0: "zero", + 1: "one", + 2: "two", + 3: "three", + 4: "four", + } + + for colID, loc := range cases { + err := m.putLocation("index name", colID, []byte(loc)) + require.NoError(err) + } + + for colID, loc := range cases { + b, err := m.getLocation("index name", colID) + require.NoError(err) + require.Equal(loc, string(b)) + } +} + +func TestGet(t *testing.T) { + require := require.New(t) + setup(t) + defer cleanup(t) + + m := newMapping(filepath.Join(tmpDir, "id.map")) + require.NoError(m.open()) + defer m.close() + + cases := []int{0, 1, 2, 3, 4, 5, 5, 0, 3, 2, 1, 5} + expected := []uint64{1, 2, 3, 4, 5, 6, 6, 1, 4, 3, 2, 6} + + for i, c := range cases { + m.getRowID("frame name", c) + + id, err := m.get("frame name", c) + val := binary.LittleEndian.Uint64(id) + + require.NoError(err) + require.Equal(expected[i], val) + } +} + +type mockPartition string + +func (m mockPartition) Key() []byte { + return []byte(m) +} diff --git a/sql/index_test.go b/sql/index_test.go index 687b8a9b4..1e631cba5 100644 --- a/sql/index_test.go +++ b/sql/index_test.go @@ -1,20 +1,93 @@ package sql import ( + "fmt" "testing" - "time" "github.com/stretchr/testify/require" ) +func TestIndexesByTable(t *testing.T) { + var require = require.New(t) + + var r = NewIndexRegistry() + r.indexOrder = []indexKey{ + {"foo", "bar_idx_1"}, + {"foo", "bar_idx_2"}, + {"foo", "bar_idx_3"}, + {"foo", "baz_idx_1"}, + {"oof", "rab_idx_1"}, + } + + r.indexes = map[indexKey]Index{ + indexKey{"foo", "bar_idx_1"}: &dummyIdx{ + database: "foo", + table: "bar", + id: "bar_idx_1", + expr: []Expression{dummyExpr{1, "2"}}, + }, + indexKey{"foo", "bar_idx_2"}: &dummyIdx{ + database: "foo", + table: "bar", + id: "bar_idx_2", + expr: []Expression{dummyExpr{2, "3"}}, + }, + indexKey{"foo", "bar_idx_3"}: &dummyIdx{ + database: "foo", + table: "bar", + id: "bar_idx_3", + expr: []Expression{dummyExpr{3, "4"}}, + }, + indexKey{"foo", "baz_idx_1"}: &dummyIdx{ + database: "foo", + table: "baz", + id: "baz_idx_1", + expr: []Expression{dummyExpr{4, "5"}}, + }, + indexKey{"oof", "rab_idx_1"}: &dummyIdx{ + database: "oof", + table: "rab", + id: "rab_idx_1", + expr: []Expression{dummyExpr{5, "6"}}, + }, + } + + r.statuses[indexKey{"foo", "bar_idx_1"}] = IndexReady + r.statuses[indexKey{"foo", "bar_idx_2"}] = IndexReady + r.statuses[indexKey{"foo", "bar_idx_3"}] = IndexNotReady + r.statuses[indexKey{"foo", "baz_idx_1"}] = IndexReady + r.statuses[indexKey{"oof", "rab_idx_1"}] = IndexReady + + indexes := r.IndexesByTable("foo", "bar") + require.Len(indexes, 3) + + for i, idx := range indexes { + expected := r.indexes[r.indexOrder[i]] + require.Equal(expected, idx) + r.ReleaseIndex(idx) + } +} + func TestIndexByExpression(t *testing.T) { require := require.New(t) r := NewIndexRegistry() - r.indexes[indexKey{"foo", ""}] = &dummyIdx{ - database: "foo", - expr: dummyExpr{1, "2"}, + r.indexOrder = []indexKey{ + {"foo", ""}, + {"foo", "bar"}, + } + r.indexes = map[indexKey]Index{ + indexKey{"foo", ""}: &dummyIdx{ + database: "foo", + expr: []Expression{dummyExpr{1, "2"}}, + }, + indexKey{"foo", "bar"}: &dummyIdx{ + database: "foo", + id: "bar", + expr: []Expression{dummyExpr{2, "3"}}, + }, } + r.statuses[indexKey{"foo", ""}] = IndexReady idx := r.IndexByExpression("bar", dummyExpr{1, "2"}) require.Nil(idx) @@ -22,7 +95,10 @@ func TestIndexByExpression(t *testing.T) { idx = r.IndexByExpression("foo", dummyExpr{1, "2"}) require.NotNil(idx) - idx = r.IndexByExpression("foo", dummyExpr{2, "2"}) + idx = r.IndexByExpression("foo", dummyExpr{2, "3"}) + require.Nil(idx) + + idx = r.IndexByExpression("foo", dummyExpr{3, "4"}) require.Nil(idx) } @@ -31,12 +107,12 @@ func TestAddIndex(t *testing.T) { r := NewIndexRegistry() idx := &dummyIdx{ id: "foo", - expr: new(dummyExpr), + expr: []Expression{new(dummyExpr)}, database: "foo", table: "foo", } - done, err := r.AddIndex(idx) + done, ready, err := r.AddIndex(idx) require.NoError(err) i := r.Index("foo", "foo") @@ -44,17 +120,17 @@ func TestAddIndex(t *testing.T) { done <- struct{}{} - <-time.After(25 * time.Millisecond) + <-ready i = r.Index("foo", "foo") require.True(r.CanUseIndex(i)) - _, err = r.AddIndex(idx) + _, _, err = r.AddIndex(idx) require.Error(err) require.True(ErrIndexIDAlreadyRegistered.Is(err)) - _, err = r.AddIndex(&dummyIdx{ + _, _, err = r.AddIndex(&dummyIdx{ id: "another", - expr: new(dummyExpr), + expr: []Expression{new(dummyExpr)}, database: "foo", table: "foo", }) @@ -67,15 +143,20 @@ func TestDeleteIndex(t *testing.T) { r := NewIndexRegistry() idx := &dummyIdx{"foo", nil, "foo", "foo"} + idx2 := &dummyIdx{"foo", nil, "foo", "bar"} r.indexes[indexKey{"foo", "foo"}] = idx + r.indexes[indexKey{"foo", "bar"}] = idx2 - _, err := r.DeleteIndex("foo", "foo") + _, err := r.DeleteIndex("foo", "foo", false) require.Error(err) require.True(ErrIndexDeleteInvalidStatus.Is(err)) - r.setStatus(idx, IndexReady) + _, err = r.DeleteIndex("foo", "foo", true) + require.NoError(err) + + r.setStatus(idx2, IndexReady) - _, err = r.DeleteIndex("foo", "foo") + _, err = r.DeleteIndex("foo", "foo", false) require.NoError(err) require.Len(r.indexes, 0) @@ -91,7 +172,7 @@ func TestDeleteIndex_InUse(t *testing.T) { r.setStatus(idx, IndexReady) r.retainIndex("foo", "foo") - done, err := r.DeleteIndex("foo", "foo") + done, err := r.DeleteIndex("foo", "foo", false) require.NoError(err) require.Len(r.indexes, 1) @@ -105,33 +186,246 @@ func TestDeleteIndex_InUse(t *testing.T) { require.Len(r.indexes, 0) } +func TestExpressionsWithIndexes(t *testing.T) { + require := require.New(t) + + r := NewIndexRegistry() + + var indexes = []*dummyIdx{ + { + "idx1", + []Expression{ + &dummyExpr{0, "foo"}, + &dummyExpr{1, "bar"}, + }, + "foo", + "foo", + }, + { + "idx2", + []Expression{ + &dummyExpr{0, "foo"}, + &dummyExpr{1, "bar"}, + &dummyExpr{3, "baz"}, + }, + "foo", + "foo", + }, + { + "idx3", + []Expression{ + &dummyExpr{0, "foo"}, + }, + "foo", + "foo", + }, + } + + for _, idx := range indexes { + done, ready, err := r.AddIndex(idx) + require.NoError(err) + close(done) + <-ready + } + + exprs := r.ExpressionsWithIndexes( + "foo", + &dummyExpr{0, "foo"}, + &dummyExpr{1, "bar"}, + &dummyExpr{3, "baz"}, + ) + + expected := [][]Expression{ + { + &dummyExpr{0, "foo"}, + &dummyExpr{1, "bar"}, + &dummyExpr{3, "baz"}, + }, + { + &dummyExpr{0, "foo"}, + &dummyExpr{1, "bar"}, + }, + } + + require.ElementsMatch(expected, exprs) +} + +func TestLoadIndexes(t *testing.T) { + require := require.New(t) + + d1 := &loadDriver{id: "d1", indexes: []Index{ + &dummyIdx{id: "idx1", database: "db1", table: "t1"}, + &dummyIdx{id: "idx2", database: "db2", table: "t3"}, + }} + + d2 := &loadDriver{id: "d2", indexes: []Index{ + &dummyIdx{id: "idx3", database: "db1", table: "t2"}, + &dummyIdx{id: "idx4", database: "db2", table: "t4"}, + }} + + registry := NewIndexRegistry() + registry.RegisterIndexDriver(d1) + registry.RegisterIndexDriver(d2) + + dbs := Databases{ + dummyDB{ + name: "db1", + tables: map[string]Table{ + "t1": &dummyTable{name: "t1"}, + "t2": &dummyTable{name: "t2"}, + }, + }, + dummyDB{ + name: "db2", + tables: map[string]Table{ + "t3": &dummyTable{name: "t3"}, + "t4": &dummyTable{name: "t4"}, + }, + }, + } + + require.NoError(registry.LoadIndexes(dbs)) + + expected := append(d1.indexes[:], d2.indexes...) + var result []Index + for _, idx := range registry.indexes { + result = append(result, idx) + } + + require.ElementsMatch(expected, result) + + for _, idx := range expected { + require.Equal(registry.statuses[indexKey{idx.Database(), idx.ID()}], IndexReady) + } +} + +func TestLoadOutdatedIndexes(t *testing.T) { + require := require.New(t) + + d := &loadDriver{id: "d1", indexes: []Index{ + &checksumIndex{&dummyIdx{id: "idx1", database: "db1", table: "t1"}, "2"}, + &checksumIndex{&dummyIdx{id: "idx2", database: "db1", table: "t2"}, "2"}, + }} + + registry := NewIndexRegistry() + registry.RegisterIndexDriver(d) + + dbs := Databases{ + dummyDB{ + name: "db1", + tables: map[string]Table{ + "t1": &checksumTable{&dummyTable{name: "t1"}, "2"}, + "t2": &checksumTable{&dummyTable{name: "t2"}, "1"}, + }, + }, + } + + require.NoError(registry.LoadIndexes(dbs)) + + var result []Index + for _, idx := range registry.indexes { + result = append(result, idx) + } + + require.ElementsMatch(d.indexes, result) + + require.Equal(registry.statuses[indexKey{"db1", "idx1"}], IndexReady) + require.Equal(registry.statuses[indexKey{"db1", "idx2"}], IndexOutdated) +} + +type dummyDB struct { + name string + tables map[string]Table +} + +func (d dummyDB) Name() string { return d.name } +func (d dummyDB) Tables() map[string]Table { return d.tables } + +type dummyTable struct { + Table + name string +} + +func (t dummyTable) Name() string { return t.name } + +type loadDriver struct { + indexes []Index + id string +} + +func (d loadDriver) ID() string { return d.id } +func (loadDriver) Create(db, table, id string, expressions []Expression, config map[string]string) (Index, error) { + panic("create is a placeholder") +} +func (d loadDriver) LoadAll(db, table string) ([]Index, error) { + var result []Index + for _, i := range d.indexes { + if i.Table() == table && i.Database() == db { + result = append(result, i) + } + } + return result, nil +} +func (loadDriver) Save(ctx *Context, index Index, iter PartitionIndexKeyValueIter) error { return nil } +func (loadDriver) Delete(Index, PartitionIter) error { return nil } + type dummyIdx struct { id string - expr Expression + expr []Expression database string table string } var _ Index = (*dummyIdx)(nil) -func (i dummyIdx) Expressions() []Expression { return []Expression{i.expr} } -func (i dummyIdx) ID() string { return i.id } -func (i dummyIdx) Get(interface{}) (IndexLookup, error) { panic("not implemented") } -func (i dummyIdx) Has(interface{}) (bool, error) { panic("not implemented") } -func (i dummyIdx) Database() string { return i.database } -func (i dummyIdx) Table() string { return i.table } +func (i dummyIdx) Expressions() []string { + var exprs []string + for _, e := range i.expr { + exprs = append(exprs, e.String()) + } + return exprs +} +func (i dummyIdx) ID() string { return i.id } +func (i dummyIdx) Get(...interface{}) (IndexLookup, error) { panic("not implemented") } +func (i dummyIdx) Has(Partition, ...interface{}) (bool, error) { panic("not implemented") } +func (i dummyIdx) Database() string { return i.database } +func (i dummyIdx) Table() string { return i.table } +func (i dummyIdx) Driver() string { return "dummy" } type dummyExpr struct { - foo int - bar string + index int + colName string } var _ Expression = (*dummyExpr)(nil) -func (dummyExpr) Children() []Expression { return nil } -func (dummyExpr) Eval(*Context, Row) (interface{}, error) { panic("not implemented") } -func (dummyExpr) TransformUp(fn TransformExprFunc) (Expression, error) { panic("not implemented") } -func (dummyExpr) String() string { return "dummyExpr" } -func (dummyExpr) IsNullable() bool { return false } -func (dummyExpr) Resolved() bool { return false } -func (dummyExpr) Type() Type { panic("not implemented") } +func (dummyExpr) Children() []Expression { return nil } +func (dummyExpr) Eval(*Context, Row) (interface{}, error) { panic("not implemented") } +func (e dummyExpr) WithChildren(children ...Expression) (Expression, error) { + return e, nil +} +func (e dummyExpr) String() string { return fmt.Sprintf("dummyExpr{%d, %s}", e.index, e.colName) } +func (dummyExpr) IsNullable() bool { return false } +func (dummyExpr) Resolved() bool { return false } +func (dummyExpr) Type() Type { panic("not implemented") } +func (e dummyExpr) WithIndex(idx int) Expression { + return &dummyExpr{idx, e.colName} +} + +type checksumTable struct { + Table + checksum string +} + +func (t *checksumTable) Checksum() (string, error) { + return t.checksum, nil +} + +type checksumIndex struct { + Index + checksum string +} + +func (idx *checksumIndex) Checksum() (string, error) { + return idx.checksum, nil +} diff --git a/sql/information_schema.go b/sql/information_schema.go new file mode 100644 index 000000000..48147c7d5 --- /dev/null +++ b/sql/information_schema.go @@ -0,0 +1,380 @@ +package sql + +import ( + "bytes" + "fmt" + "io" + "strings" +) + +const ( + // InformationSchemaDatabaseName is the name of the information schema database. + InformationSchemaDatabaseName = "information_schema" + // FilesTableName is the name of the files table. + FilesTableName = "files" + // ColumnStatisticsTableName is the name of the column statistics table. + ColumnStatisticsTableName = "column_statistics" + // TablesTableName is the name of tables table. + TablesTableName = "tables" + // ColumnsTableName is the name of columns table. + ColumnsTableName = "columns" + // SchemataTableName is the name of the schemata table. + SchemataTableName = "schemata" +) + +type informationSchemaDatabase struct { + name string + tables map[string]Table +} + +type informationSchemaTable struct { + name string + schema Schema + catalog *Catalog + rowIter func(*Catalog) RowIter +} + +type informationSchemaPartition struct { + key []byte +} + +type informationSchemaPartitionIter struct { + informationSchemaPartition + pos int +} + +var ( + _ Database = (*informationSchemaDatabase)(nil) + _ Table = (*informationSchemaTable)(nil) + _ Partition = (*informationSchemaPartition)(nil) + _ PartitionIter = (*informationSchemaPartitionIter)(nil) +) + +var filesSchema = Schema{ + {Name: "file_id", Type: Int64, Source: FilesTableName, Nullable: true}, + {Name: "file_name", Type: Text, Source: FilesTableName, Nullable: true}, + {Name: "file_type", Type: Text, Source: FilesTableName, Nullable: true}, + {Name: "tablespace_name", Type: Text, Source: FilesTableName}, + {Name: "table_catalog", Type: Text, Source: FilesTableName}, + {Name: "table_schema", Type: Blob, Source: FilesTableName, Nullable: true}, + {Name: "table_name", Type: Blob, Source: FilesTableName, Nullable: true}, + {Name: "logfile_group_name", Type: Text, Source: FilesTableName, Nullable: true}, + {Name: "logfile_group_number", Type: Int64, Source: FilesTableName, Nullable: true}, + {Name: "engine", Type: Text, Source: FilesTableName}, + {Name: "fulltext_keys", Type: Blob, Source: FilesTableName, Nullable: true}, + {Name: "deleted_rows", Type: Blob, Source: FilesTableName, Nullable: true}, + {Name: "update_count", Type: Blob, Source: FilesTableName, Nullable: true}, + {Name: "free_extents", Type: Int64, Source: FilesTableName, Nullable: true}, + {Name: "total_extents", Type: Int64, Source: FilesTableName, Nullable: true}, + {Name: "extent_size", Type: Int64, Source: FilesTableName, Nullable: true}, + {Name: "initial_size", Type: Int64, Source: FilesTableName, Nullable: true}, + {Name: "maximum_size", Type: Int64, Source: FilesTableName, Nullable: true}, + {Name: "autoextend_size", Type: Int64, Source: FilesTableName, Nullable: true}, + {Name: "creation_time", Type: Blob, Source: FilesTableName, Nullable: true}, + {Name: "last_update_time", Type: Blob, Source: FilesTableName, Nullable: true}, + {Name: "last_access_time", Type: Blob, Source: FilesTableName, Nullable: true}, + {Name: "recover_time", Type: Blob, Source: FilesTableName, Nullable: true}, + {Name: "transaction_counter", Type: Blob, Source: FilesTableName, Nullable: true}, + {Name: "version", Type: Int64, Source: FilesTableName, Nullable: true}, + {Name: "row_format", Type: Text, Source: FilesTableName, Nullable: true}, + {Name: "table_rows", Type: Blob, Source: FilesTableName, Nullable: true}, + {Name: "avg_row_length", Type: Blob, Source: FilesTableName, Nullable: true}, + {Name: "data_length", Type: Blob, Source: FilesTableName, Nullable: true}, + {Name: "max_data_length", Type: Blob, Source: FilesTableName, Nullable: true}, + {Name: "index_length", Type: Blob, Source: FilesTableName, Nullable: true}, + {Name: "data_free", Type: Int64, Source: FilesTableName, Nullable: true}, + {Name: "create_time", Type: Blob, Source: FilesTableName, Nullable: true}, + {Name: "update_time", Type: Blob, Source: FilesTableName, Nullable: true}, + {Name: "check_time", Type: Blob, Source: FilesTableName, Nullable: true}, + {Name: "checksum", Type: Blob, Source: FilesTableName, Nullable: true}, + {Name: "status", Type: Text, Source: FilesTableName, Nullable: true}, + {Name: "extra", Type: Blob, Source: FilesTableName, Nullable: true}, +} + +var columnStatisticsSchema = Schema{ + {Name: "schema_name", Type: Text, Source: ColumnStatisticsTableName}, + {Name: "table_name", Type: Text, Source: ColumnStatisticsTableName}, + {Name: "column_name", Type: Text, Source: ColumnStatisticsTableName}, + {Name: "histogram", Type: JSON, Source: ColumnStatisticsTableName}, +} + +var tablesSchema = Schema{ + {Name: "table_catalog", Type: Text, Default: "", Nullable: false, Source: TablesTableName}, + {Name: "table_schema", Type: Text, Default: "", Nullable: false, Source: TablesTableName}, + {Name: "table_name", Type: Text, Default: "", Nullable: false, Source: TablesTableName}, + {Name: "table_type", Type: Text, Default: "", Nullable: false, Source: TablesTableName}, + {Name: "engine", Type: Text, Default: nil, Nullable: true, Source: TablesTableName}, + {Name: "version", Type: Uint64, Default: nil, Nullable: true, Source: TablesTableName}, + {Name: "row_format", Type: Text, Default: nil, Nullable: true, Source: TablesTableName}, + {Name: "table_rows", Type: Uint64, Default: nil, Nullable: true, Source: TablesTableName}, + {Name: "avg_row_length", Type: Uint64, Default: nil, Nullable: true, Source: TablesTableName}, + {Name: "data_length", Type: Uint64, Default: nil, Nullable: true, Source: TablesTableName}, + {Name: "max_data_length", Type: Uint64, Default: nil, Nullable: true, Source: TablesTableName}, + {Name: "index_length", Type: Uint64, Default: nil, Nullable: true, Source: TablesTableName}, + {Name: "data_free", Type: Uint64, Default: nil, Nullable: true, Source: TablesTableName}, + {Name: "auto_increment", Type: Uint64, Default: nil, Nullable: true, Source: TablesTableName}, + {Name: "create_time", Type: Date, Default: nil, Nullable: true, Source: TablesTableName}, + {Name: "update_time", Type: Date, Default: nil, Nullable: true, Source: TablesTableName}, + {Name: "check_time", Type: Date, Default: nil, Nullable: true, Source: TablesTableName}, + {Name: "table_collation", Type: Text, Default: nil, Nullable: true, Source: TablesTableName}, + {Name: "checksum", Type: Uint64, Default: nil, Nullable: true, Source: TablesTableName}, + {Name: "create_options", Type: Text, Default: nil, Nullable: true, Source: TablesTableName}, + {Name: "table_comment", Type: Text, Default: "", Nullable: false, Source: TablesTableName}, +} + +var columnsSchema = Schema{ + {Name: "table_catalog", Type: Text, Default: "", Nullable: false, Source: ColumnsTableName}, + {Name: "table_schema", Type: Text, Default: "", Nullable: false, Source: ColumnsTableName}, + {Name: "table_name", Type: Text, Default: "", Nullable: false, Source: ColumnsTableName}, + {Name: "column_name", Type: Text, Default: "", Nullable: false, Source: ColumnsTableName}, + {Name: "ordinal_position", Type: Uint64, Default: 0, Nullable: false, Source: ColumnsTableName}, + {Name: "column_default", Type: Text, Default: nil, Nullable: true, Source: ColumnsTableName}, + {Name: "is_nullable", Type: Text, Default: "", Nullable: false, Source: ColumnsTableName}, + {Name: "data_type", Type: Text, Default: "", Nullable: false, Source: ColumnsTableName}, + {Name: "character_maximum_length", Type: Uint64, Default: nil, Nullable: true, Source: ColumnsTableName}, + {Name: "character_octet_length", Type: Uint64, Default: nil, Nullable: true, Source: ColumnsTableName}, + {Name: "numeric_precision", Type: Uint64, Default: nil, Nullable: true, Source: ColumnsTableName}, + {Name: "numeric_scale", Type: Uint64, Default: nil, Nullable: true, Source: ColumnsTableName}, + {Name: "datetime_precision", Type: Uint64, Default: nil, Nullable: true, Source: ColumnsTableName}, + {Name: "character_set_name", Type: Text, Default: nil, Nullable: true, Source: ColumnsTableName}, + {Name: "collation_name", Type: Text, Default: nil, Nullable: true, Source: ColumnsTableName}, + {Name: "column_type", Type: Text, Default: "", Nullable: false, Source: ColumnsTableName}, + {Name: "column_key", Type: Text, Default: "", Nullable: false, Source: ColumnsTableName}, + {Name: "extra", Type: Text, Default: "", Nullable: false, Source: ColumnsTableName}, + {Name: "privileges", Type: Text, Default: "", Nullable: false, Source: ColumnsTableName}, + {Name: "column_comment", Type: Text, Default: "", Nullable: false, Source: ColumnsTableName}, + {Name: "generation_expression", Type: Text, Default: "", Nullable: false, Source: ColumnsTableName}, +} + +var schemataSchema = Schema{ + {Name: "catalog_name", Type: Text, Default: nil, Nullable: false, Source: SchemataTableName}, + {Name: "schema_name", Type: Text, Default: nil, Nullable: false, Source: SchemataTableName}, + {Name: "default_character_set_name", Type: Text, Default: nil, Nullable: false, Source: SchemataTableName}, + {Name: "default_collation_name", Type: Text, Default: nil, Nullable: false, Source: SchemataTableName}, + {Name: "sql_path", Type: Text, Default: nil, Nullable: true, Source: SchemataTableName}, +} + +func tablesRowIter(cat *Catalog) RowIter { + var rows []Row + for _, db := range cat.AllDatabases() { + tableType := "BASE TABLE" + engine := "INNODB" + rowFormat := "Dynamic" + if db.Name() == InformationSchemaDatabaseName { + tableType = "SYSTEM VIEW" + engine = "MEMORY" + rowFormat = "Fixed" + } + for _, t := range db.Tables() { + rows = append(rows, Row{ + "def", //table_catalog + db.Name(), // table_schema + t.Name(), // table_name + tableType, // table_type + engine, // engine + 10, //version (protocol, always 10) + rowFormat, //row_format + nil, //table_rows + nil, //avg_row_length + nil, //data_length + nil, //max_data_length + nil, //max_data_length + nil, //data_free + nil, //auto_increment + nil, //create_time + nil, //update_time + nil, //check_time + "utf8_bin", //table_collation + nil, //checksum + nil, //create_options + "", //table_comment + }) + } + } + + return RowsToRowIter(rows...) +} + +func columnsRowIter(cat *Catalog) RowIter { + var rows []Row + for _, db := range cat.AllDatabases() { + for _, t := range db.Tables() { + for i, c := range t.Schema() { + var ( + nullable string + charName interface{} + collName interface{} + ) + if c.Nullable { + nullable = "YES" + } else { + nullable = "NO" + } + if IsText(c.Type) { + charName = "utf8mb4" + collName = "utf8_bin" + } + rows = append(rows, Row{ + "def", // table_catalog + db.Name(), // table_schema + t.Name(), // table_name + c.Name, // column_name + uint64(i), // ordinal_position + c.Default, // column_default + nullable, // is_nullable + strings.ToLower(MySQLTypeName(c.Type)), // data_type + nil, // character_maximum_length + nil, // character_octet_length + nil, // numeric_precision + nil, // numeric_scale + nil, // datetime_precision + charName, // character_set_name + collName, // collation_name + strings.ToLower(MySQLTypeName(c.Type)), // column_type + "", // column_key + "", // extra + "select", // privileges + "", // column_comment + "", // generation_expression + }) + } + } + } + return RowsToRowIter(rows...) +} + +func schemataRowIter(c *Catalog) RowIter { + dbs := c.AllDatabases() + + var rows []Row + for _, db := range dbs { + if db.Name() == InformationSchemaDatabaseName { + continue + } + + rows = append(rows, Row{ + "def", + db.Name(), + "utf8mb4", + "utf8_bin", + nil, + }) + } + + return RowsToRowIter(rows...) +} + +// NewInformationSchemaDatabase creates a new INFORMATION_SCHEMA Database. +func NewInformationSchemaDatabase(cat *Catalog) Database { + return &informationSchemaDatabase{ + name: InformationSchemaDatabaseName, + tables: map[string]Table{ + FilesTableName: &informationSchemaTable{ + name: FilesTableName, + schema: filesSchema, + catalog: cat, + }, + ColumnStatisticsTableName: &informationSchemaTable{ + name: ColumnStatisticsTableName, + schema: columnStatisticsSchema, + catalog: cat, + }, + TablesTableName: &informationSchemaTable{ + name: TablesTableName, + schema: tablesSchema, + catalog: cat, + rowIter: tablesRowIter, + }, + ColumnsTableName: &informationSchemaTable{ + name: ColumnsTableName, + schema: columnsSchema, + catalog: cat, + rowIter: columnsRowIter, + }, + SchemataTableName: &informationSchemaTable{ + name: SchemataTableName, + schema: schemataSchema, + catalog: cat, + rowIter: schemataRowIter, + }, + }, + } +} + +// Name implements the sql.Database interface. +func (db *informationSchemaDatabase) Name() string { return db.name } + +// Tables implements the sql.Database interface. +func (db *informationSchemaDatabase) Tables() map[string]Table { return db.tables } + +// Name implements the sql.Table interface. +func (t *informationSchemaTable) Name() string { + return t.name +} + +// Schema implements the sql.Table interface. +func (t *informationSchemaTable) Schema() Schema { + return t.schema +} + +// Partitions implements the sql.Table interface. +func (t *informationSchemaTable) Partitions(ctx *Context) (PartitionIter, error) { + return &informationSchemaPartitionIter{informationSchemaPartition: informationSchemaPartition{partitionKey(t.Name())}}, nil +} + +// PartitionRows implements the sql.PartitionRows interface. +func (t *informationSchemaTable) PartitionRows(ctx *Context, partition Partition) (RowIter, error) { + if !bytes.Equal(partition.Key(), partitionKey(t.Name())) { + return nil, fmt.Errorf( + "partition not found: %q", partition.Key(), + ) + } + if t.rowIter == nil { + return RowsToRowIter(), nil + } + + return t.rowIter(t.catalog), nil +} + +// PartitionCount implements the sql.PartitionCounter interface. +func (t *informationSchemaTable) String() string { + return printTable(t.Name(), t.Schema()) +} + +// Key implements Partition interface +func (p *informationSchemaPartition) Key() []byte { return p.key } + +// Next implements single PartitionIter interface +func (pit *informationSchemaPartitionIter) Next() (Partition, error) { + if pit.pos == 0 { + pit.pos++ + return pit, nil + } + return nil, io.EOF +} + +// Close implements single PartitionIter interface +func (pit *informationSchemaPartitionIter) Close() error { + pit.pos = 0 + return nil +} + +func printTable(name string, tableSchema Schema) string { + p := NewTreePrinter() + _ = p.WriteNode("Table(%s)", name) + var schema = make([]string, len(tableSchema)) + for i, col := range tableSchema { + schema[i] = fmt.Sprintf( + "Column(%s, %s, nullable=%v)", + col.Name, + col.Type.Type().String(), + col.Nullable, + ) + } + _ = p.WriteChildren(schema...) + return p.String() +} + +func partitionKey(tableName string) []byte { + return []byte(InformationSchemaDatabaseName + "." + tableName) +} diff --git a/sql/memory.go b/sql/memory.go new file mode 100644 index 000000000..35556ec99 --- /dev/null +++ b/sql/memory.go @@ -0,0 +1,194 @@ +package sql + +import ( + "os" + "runtime" + "strconv" + "sync" + + errors "gopkg.in/src-d/go-errors.v1" +) + +// Disposable objects can erase all their content when they're no longer in use. +// They should not be used again after they've been disposed. +type Disposable interface { + // Dispose the contents. + Dispose() +} + +// Freeable objects can free their memory. +type Freeable interface { + // Free the memory. + Free() +} + +// KeyValueCache is a cache of key value pairs. +type KeyValueCache interface { + // Put a new value in the cache. + Put(uint64, interface{}) error + // Get the value with the given key. + Get(uint64) (interface{}, error) +} + +// RowsCache is a cache of rows. +type RowsCache interface { + // Add a new row to the cache. If there is no memory available, it will try to + // free some memory. If after that there is still no memory available, it + // will return an error and erase all the content of the cache. + Add(Row) error + // Get all rows. + Get() []Row +} + +// ErrNoMemoryAvailable is returned when there is no more available memory. +var ErrNoMemoryAvailable = errors.NewKind("no memory available") + +const maxMemoryKey = "MAX_MEMORY" + +const ( + b = 1 + kib = 1024 * b + mib = 1024 * kib +) + +var maxMemory = func() uint64 { + val := os.Getenv(maxMemoryKey) + var v uint64 + if val != "" { + var err error + v, err = strconv.ParseUint(val, 10, 64) + if err != nil { + panic("MAX_MEMORY environment variable must be a number, but got: " + val) + } + } + + return v * uint64(mib) +}() + +// Reporter is a component that gives information about the memory usage. +type Reporter interface { + // MaxMemory returns the maximum number of memory allowed in bytes. + MaxMemory() uint64 + // UsedMemory returns the memory in use in bytes. + UsedMemory() uint64 +} + +// ProcessMemory is a reporter for the memory used by the process and the +// maximum amount of memory allowed controlled by the MAX_MEMORY environment +// variable. +var ProcessMemory Reporter = new(processReporter) + +type processReporter struct{} + +func (processReporter) UsedMemory() uint64 { + var s runtime.MemStats + runtime.ReadMemStats(&s) + return s.HeapInuse + s.StackInuse +} + +func (processReporter) MaxMemory() uint64 { return maxMemory } + +// HasAvailableMemory reports whether more memory is available to the program if +// it hasn't reached the max memory limit. +func HasAvailableMemory(r Reporter) bool { + maxMemory := r.MaxMemory() + if maxMemory == 0 { + return true + } + + return r.UsedMemory() < maxMemory +} + +// MemoryManager is in charge of keeping track and managing all the components that operate +// in memory. There should only be one instance of a memory manager running at the +// same time in each process. +type MemoryManager struct { + mu sync.RWMutex + reporter Reporter + caches map[uint64]Disposable + token uint64 +} + +// NewMemoryManager creates a new manager with the given memory reporter. If nil is given, +// then the Process reporter will be used by default. +func NewMemoryManager(r Reporter) *MemoryManager { + if r == nil { + r = ProcessMemory + } + + return &MemoryManager{ + reporter: r, + caches: make(map[uint64]Disposable), + } +} + +// HasAvailable reports whether the memory manager has any available memory. +func (m *MemoryManager) HasAvailable() bool { + return HasAvailableMemory(m.reporter) +} + +// DisposeFunc is a function to completely erase a cache and remove it from the manager. +type DisposeFunc func() + +// NewLRUCache returns an empty LRU cache and a function to dispose it when it's +// no longer needed. +func (m *MemoryManager) NewLRUCache(size uint) (KeyValueCache, DisposeFunc) { + c := newLRUCache(m, m.reporter, size) + pos := m.addCache(c) + return c, func() { + c.Dispose() + m.removeCache(pos) + } +} + +// NewHistoryCache returns an empty history cache and a function to dispose it when it's +// no longer needed. +func (m *MemoryManager) NewHistoryCache() (KeyValueCache, DisposeFunc) { + c := newHistoryCache(m, m.reporter) + pos := m.addCache(c) + return c, func() { + c.Dispose() + m.removeCache(pos) + } +} + +// NewRowsCache returns an empty rows cache and a function to dispose it when it's +// no longer needed. +func (m *MemoryManager) NewRowsCache() (RowsCache, DisposeFunc) { + c := newRowsCache(m, m.reporter) + pos := m.addCache(c) + return c, func() { + c.Dispose() + m.removeCache(pos) + } +} + +func (m *MemoryManager) addCache(c Disposable) (pos uint64) { + m.mu.Lock() + defer m.mu.Unlock() + m.token++ + m.caches[m.token] = c + return m.token +} + +func (m *MemoryManager) removeCache(pos uint64) { + m.mu.Lock() + defer m.mu.Unlock() + delete(m.caches, pos) + + if len(m.caches) == 0 { + m.token = 0 + } +} + +// Free the memory of all freeable caches. +func (m *MemoryManager) Free() { + m.mu.RLock() + defer m.mu.RUnlock() + + for _, c := range m.caches { + if f, ok := c.(Freeable); ok { + f.Free() + } + } +} diff --git a/sql/memory_test.go b/sql/memory_test.go new file mode 100644 index 000000000..c4dcf7d57 --- /dev/null +++ b/sql/memory_test.go @@ -0,0 +1,79 @@ +package sql + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestManager(t *testing.T) { + require := require.New(t) + m := NewMemoryManager(nil) + + kv, dispose := m.NewLRUCache(5) + _, ok := kv.(*lruCache) + require.True(ok) + require.Len(m.caches, 1) + dispose() + require.Len(m.caches, 0) + + kv, dispose = m.NewHistoryCache() + _, ok = kv.(*historyCache) + require.True(ok) + require.Len(m.caches, 1) + dispose() + require.Len(m.caches, 0) + + rc, dispose := m.NewRowsCache() + _, ok = rc.(*rowsCache) + require.True(ok) + require.Len(m.caches, 1) + dispose() + require.Len(m.caches, 0) + + m.addCache(disposableCache{}) + f := new(freeableCache) + m.addCache(f) + m.Free() + require.True(f.freed) +} + +type disposableCache struct{} + +func (d disposableCache) Dispose() {} + +type freeableCache struct { + disposableCache + freed bool +} + +func (f *freeableCache) Free() { f.freed = true } + +func TestHasAvailable(t *testing.T) { + require.True(t, HasAvailableMemory(fixedReporter(2, 5))) + require.False(t, HasAvailableMemory(fixedReporter(6, 5))) +} + +type mockReporter struct { + f func() uint64 + max uint64 +} + +func (m mockReporter) UsedMemory() uint64 { return m.f() } +func (m mockReporter) MaxMemory() uint64 { return m.max } + +func fixedReporter(v, max uint64) mockReporter { + return mockReporter{func() uint64 { + return v + }, max} +} + +type mockMemory struct { + f func() +} + +func (m mockMemory) Free() { + if m.f != nil { + m.f() + } +} diff --git a/sql/parse/describe.go b/sql/parse/describe.go new file mode 100644 index 000000000..e7993d49b --- /dev/null +++ b/sql/parse/describe.go @@ -0,0 +1,50 @@ +package parse + +import ( + "bufio" + "strings" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/plan" + errors "gopkg.in/src-d/go-errors.v1" +) + +var ( + errInvalidDescribeFormat = errors.NewKind("invalid format %q for DESCRIBE, supported formats: %s") + describeSupportedFormats = []string{"tree"} +) + +func parseDescribeQuery(ctx *sql.Context, s string) (sql.Node, error) { + r := bufio.NewReader(strings.NewReader(s)) + + var format, query string + err := parseFuncs{ + oneOf("describe", "desc", "explain"), + skipSpaces, + expect("format"), + skipSpaces, + expectRune('='), + skipSpaces, + readIdent(&format), + skipSpaces, + readRemaining(&query), + }.exec(r) + + if err != nil { + return nil, err + } + + if format != "tree" { + return nil, errInvalidDescribeFormat.New( + format, + strings.Join(describeSupportedFormats, ", "), + ) + } + + child, err := Parse(ctx, query) + if err != nil { + return nil, err + } + + return plan.NewDescribeQuery(format, child), nil +} diff --git a/sql/parse/describe_test.go b/sql/parse/describe_test.go new file mode 100644 index 000000000..c15b6af00 --- /dev/null +++ b/sql/parse/describe_test.go @@ -0,0 +1,74 @@ +package parse + +import ( + "testing" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/plan" + "github.com/stretchr/testify/require" + errors "gopkg.in/src-d/go-errors.v1" +) + +func TestParseDescribeQuery(t *testing.T) { + testCases := []struct { + query string + result sql.Node + err *errors.Kind + }{ + { + "DESCRIBE TABLE foo", + nil, + errUnexpectedSyntax, + }, + { + "DESCRIBE something", + nil, + errUnexpectedSyntax, + }, + { + "DESCRIBE FORMAT=pretty SELECT * FROM foo", + nil, + errInvalidDescribeFormat, + }, + { + "DESCRIBE FORMAT=tree SELECT * FROM foo", + plan.NewDescribeQuery("tree", plan.NewProject( + []sql.Expression{expression.NewStar()}, + plan.NewUnresolvedTable("foo", "")), + ), + nil, + }, + { + "DESC FORMAT=tree SELECT * FROM foo", + plan.NewDescribeQuery("tree", plan.NewProject( + []sql.Expression{expression.NewStar()}, + plan.NewUnresolvedTable("foo", "")), + ), + nil, + }, + { + "EXPLAIN FORMAT=tree SELECT * FROM foo", + plan.NewDescribeQuery("tree", plan.NewProject( + []sql.Expression{expression.NewStar()}, + plan.NewUnresolvedTable("foo", "")), + ), + nil, + }, + } + + for _, tt := range testCases { + t.Run(tt.query, func(t *testing.T) { + require := require.New(t) + + result, err := parseDescribeQuery(sql.NewEmptyContext(), tt.query) + if tt.err != nil { + require.Error(err) + require.True(tt.err.Is(err)) + } else { + require.NoError(err) + require.Equal(tt.result, result) + } + }) + } +} diff --git a/sql/parse/create_index.go b/sql/parse/indexes.go similarity index 59% rename from sql/parse/create_index.go rename to sql/parse/indexes.go index e3b994e5c..a435f7888 100644 --- a/sql/parse/create_index.go +++ b/sql/parse/indexes.go @@ -7,27 +7,45 @@ import ( "strings" "unicode" - "gopkg.in/src-d/go-mysql-server.v0/sql/plan" + "github.com/src-d/go-mysql-server/sql/plan" - errors "gopkg.in/src-d/go-errors.v1" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-vitess.v0/vt/sqlparser" + "github.com/src-d/go-mysql-server/sql" ) -var ( - errUnexpectedSyntax = errors.NewKind("expecting %q but got %q instead") - errInvalidIndexExpression = errors.NewKind("invalid expression to index: %s") -) +func parseShowIndex(s string) (sql.Node, error) { + r := bufio.NewReader(strings.NewReader(s)) -type parseFunc func(*bufio.Reader) error + var table string + err := parseFuncs{ + expect("show"), + skipSpaces, + oneOf("index", "indexes", "keys"), + skipSpaces, + oneOf("from", "in"), + skipSpaces, + readIdent(&table), + skipSpaces, + checkEOF, + }.exec(r) -func parseCreateIndex(s string) (sql.Node, error) { + if err != nil { + return nil, err + } + + return plan.NewShowIndexes( + sql.UnresolvedDatabase(""), + table, + nil, + ), nil +} + +func parseCreateIndex(ctx *sql.Context, s string) (sql.Node, error) { r := bufio.NewReader(strings.NewReader(s)) var name, table, driver string var exprs []string var config = make(map[string]string) - steps := []parseFunc{ + err := parseFuncs{ expect("create"), skipSpaces, expect("index"), @@ -38,12 +56,10 @@ func parseCreateIndex(s string) (sql.Node, error) { skipSpaces, readIdent(&table), skipSpaces, - optional( - expect("using"), - skipSpaces, - readIdent(&driver), - skipSpaces, - ), + expect("using"), + skipSpaces, + readIdent(&driver), + skipSpaces, readExprs(&exprs), skipSpaces, optional( @@ -52,18 +68,17 @@ func parseCreateIndex(s string) (sql.Node, error) { readKeyValue(config), skipSpaces, ), - } + checkEOF, + }.exec(r) - for _, step := range steps { - if err := step(r); err != nil { - return nil, err - } + if err != nil { + return nil, err } var indexExprs = make([]sql.Expression, len(exprs)) for i, e := range exprs { var err error - indexExprs[i], err = parseIndexExpr(e) + indexExprs[i], err = parseExpr(ctx, e) if err != nil { return nil, err } @@ -71,30 +86,13 @@ func parseCreateIndex(s string) (sql.Node, error) { return plan.NewCreateIndex( name, - plan.NewUnresolvedTable(table), + plan.NewUnresolvedTable(table, ""), indexExprs, driver, config, ), nil } -func optional(steps ...parseFunc) parseFunc { - return func(rd *bufio.Reader) error { - for _, step := range steps { - err := step(rd) - if err == io.EOF || errUnexpectedSyntax.Is(err) { - return nil - } - - if err != nil { - return err - } - } - - return nil - } -} - func readKeyValue(kv map[string]string) parseFunc { return func(rd *bufio.Reader) error { r, _, err := rd.ReadRune() @@ -108,7 +106,7 @@ func readKeyValue(kv map[string]string) parseFunc { for { var key, value string - steps := []parseFunc{ + err := parseFuncs{ skipSpaces, readIdent(&key), skipSpaces, @@ -116,12 +114,10 @@ func readKeyValue(kv map[string]string) parseFunc { skipSpaces, readValue(&value), skipSpaces, - } + }.exec(rd) - for _, step := range steps { - if err := step(rd); err != nil { - return err - } + if err != nil { + return err } r, _, err := rd.ReadRune() @@ -195,55 +191,32 @@ func readValue(val *string) parseFunc { } } -func parseIndexExpr(str string) (sql.Expression, error) { - stmt, err := sqlparser.Parse("SELECT " + str) - if err != nil { - return nil, err - } +func parseDropIndex(str string) (sql.Node, error) { + r := bufio.NewReader(strings.NewReader(str)) - selectStmt, ok := stmt.(*sqlparser.Select) - if !ok { - return nil, errInvalidIndexExpression.New(str) - } - - if len(selectStmt.SelectExprs) != 1 { - return nil, errInvalidIndexExpression.New(str) - } + var name, table string + err := parseFuncs{ + expect("drop"), + skipSpaces, + expect("index"), + skipSpaces, + readIdent(&name), + skipSpaces, + expect("on"), + skipSpaces, + readIdent(&table), + skipSpaces, + checkEOF, + }.exec(r) - selectExpr, ok := selectStmt.SelectExprs[0].(*sqlparser.AliasedExpr) - if !ok { - return nil, errInvalidIndexExpression.New(str) + if err != nil { + return nil, err } - return exprToExpression(selectExpr.Expr) -} - -func readIdent(ident *string) parseFunc { - return func(r *bufio.Reader) error { - var buf bytes.Buffer - for { - ru, _, err := r.ReadRune() - if err == io.EOF { - break - } - - if err != nil { - return err - } - - if !unicode.IsLetter(ru) && ru != '_' { - if err := r.UnreadRune(); err != nil { - return err - } - break - } - - buf.WriteRune(ru) - } - - *ident = strings.ToLower(buf.String()) - return nil - } + return plan.NewDropIndex( + name, + plan.NewUnresolvedTable(table, ""), + ), nil } func readExprs(exprs *[]string) parseFunc { @@ -320,51 +293,3 @@ func readExprs(exprs *[]string) parseFunc { } } } - -func expectRune(expected rune) parseFunc { - return func(rd *bufio.Reader) error { - r, _, err := rd.ReadRune() - if err != nil { - return err - } - - if r != expected { - return errUnexpectedSyntax.New(expected, string(r)) - } - - return nil - } -} - -func expect(expected string) parseFunc { - return func(r *bufio.Reader) error { - var ident string - - if err := readIdent(&ident)(r); err != nil { - return err - } - - if ident == expected { - return nil - } - - return errUnexpectedSyntax.New(expected, ident) - } -} - -func skipSpaces(r *bufio.Reader) error { - for { - ru, _, err := r.ReadRune() - if err == io.EOF { - return nil - } - - if err != nil { - return err - } - - if !unicode.IsSpace(ru) { - return r.UnreadRune() - } - } -} diff --git a/sql/parse/create_index_test.go b/sql/parse/indexes_test.go similarity index 79% rename from sql/parse/create_index_test.go rename to sql/parse/indexes_test.go index 95ac7afad..a9d7df5a4 100644 --- a/sql/parse/create_index_test.go +++ b/sql/parse/indexes_test.go @@ -5,12 +5,12 @@ import ( "strings" "testing" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" - "gopkg.in/src-d/go-mysql-server.v0/sql/plan" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/plan" + "github.com/src-d/go-mysql-server/sql" "github.com/stretchr/testify/require" errors "gopkg.in/src-d/go-errors.v1" - "gopkg.in/src-d/go-mysql-server.v0/sql" ) func TestParseCreateIndex(t *testing.T) { @@ -32,6 +32,11 @@ func TestParseCreateIndex(t *testing.T) { { "CREATE INDEX idx ON foo(*)", nil, + errUnexpectedSyntax, + }, + { + "CREATE INDEX idx ON foo USING foo (*)", + nil, errInvalidIndexExpression, }, { @@ -46,9 +51,14 @@ func TestParseCreateIndex(t *testing.T) { }, { "CREATE INDEX idx ON foo(fn(bar, baz))", + nil, + errUnexpectedSyntax, + }, + { + "CREATE INDEX idx ON foo USING foo (fn(bar, baz))", plan.NewCreateIndex( "idx", - plan.NewUnresolvedTable("foo"), + plan.NewUnresolvedTable("foo", ""), []sql.Expression{ expression.NewUnresolvedFunction( "fn", false, @@ -56,32 +66,32 @@ func TestParseCreateIndex(t *testing.T) { expression.NewUnresolvedColumn("baz"), ), }, - "", + "foo", make(map[string]string), ), nil, }, { - "CREATE INDEX idx ON foo(bar)", + "CREATE INDEX idx ON foo USING foo (bar)", plan.NewCreateIndex( "idx", - plan.NewUnresolvedTable("foo"), + plan.NewUnresolvedTable("foo", ""), []sql.Expression{expression.NewUnresolvedColumn("bar")}, - "", + "foo", make(map[string]string), ), nil, }, { - "CREATE INDEX idx ON foo(bar, baz)", + "CREATE INDEX idx ON foo USING foo (bar, baz)", plan.NewCreateIndex( "idx", - plan.NewUnresolvedTable("foo"), + plan.NewUnresolvedTable("foo", ""), []sql.Expression{ expression.NewUnresolvedColumn("bar"), expression.NewUnresolvedColumn("baz"), }, - "", + "foo", make(map[string]string), ), nil, @@ -90,7 +100,7 @@ func TestParseCreateIndex(t *testing.T) { "CREATE INDEX idx ON foo USING bar (baz)", plan.NewCreateIndex( "idx", - plan.NewUnresolvedTable("foo"), + plan.NewUnresolvedTable("foo", ""), []sql.Expression{ expression.NewUnresolvedColumn("baz"), }, @@ -108,7 +118,7 @@ func TestParseCreateIndex(t *testing.T) { "CREATE INDEX idx ON foo USING bar (baz) WITH (foo = bar)", plan.NewCreateIndex( "idx", - plan.NewUnresolvedTable("foo"), + plan.NewUnresolvedTable("foo", ""), []sql.Expression{ expression.NewUnresolvedColumn("baz"), }, @@ -121,7 +131,7 @@ func TestParseCreateIndex(t *testing.T) { "CREATE INDEX idx ON foo USING bar (baz) WITH (foo = bar, qux = 'mux')", plan.NewCreateIndex( "idx", - plan.NewUnresolvedTable("foo"), + plan.NewUnresolvedTable("foo", ""), []sql.Expression{ expression.NewUnresolvedColumn("baz"), }, @@ -130,13 +140,26 @@ func TestParseCreateIndex(t *testing.T) { ), nil, }, + { + "CREATE INDEX idx_2 ON foo USING bar (baz)", + plan.NewCreateIndex( + "idx_2", + plan.NewUnresolvedTable("foo", ""), + []sql.Expression{ + expression.NewUnresolvedColumn("baz"), + }, + "bar", + make(map[string]string), + ), + nil, + }, } for _, tt := range testCases { t.Run(tt.query, func(t *testing.T) { require := require.New(t) - result, err := parseCreateIndex(strings.ToLower(tt.query)) + result, err := parseCreateIndex(sql.NewEmptyContext(), strings.ToLower(tt.query)) if tt.err != nil { require.Error(err) require.True(tt.err.Is(err)) diff --git a/sql/parse/lock.go b/sql/parse/lock.go new file mode 100644 index 000000000..b4d6d67b6 --- /dev/null +++ b/sql/parse/lock.go @@ -0,0 +1,164 @@ +package parse + +import ( + "bufio" + "io" + "strings" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/plan" +) + +func parseLockTables(ctx *sql.Context, query string) (sql.Node, error) { + var r = bufio.NewReader(strings.NewReader(query)) + var tables []*plan.TableLock + err := parseFuncs{ + expect("lock"), + skipSpaces, + expect("tables"), + skipSpaces, + readTableLocks(&tables), + skipSpaces, + checkEOF, + }.exec(r) + + if err != nil { + return nil, err + } + + return plan.NewLockTables(tables), nil +} + +func readTableLocks(tables *[]*plan.TableLock) parseFunc { + return func(rd *bufio.Reader) error { + for { + t, err := readTableLock(rd) + if err != nil { + return err + } + + *tables = append(*tables, t) + + if err = skipSpaces(rd); err != nil { + return err + } + + var b []byte + b, err = rd.Peek(1) + if err == io.EOF { + return nil + } else if err != nil { + return err + } + + if string(b) != "," { + return nil + } + + if _, err := rd.Discard(1); err != nil { + return err + } + + if err := skipSpaces(rd); err != nil { + return err + } + } + } +} + +func readTableLock(rd *bufio.Reader) (*plan.TableLock, error) { + var tableName string + var write bool + + err := parseFuncs{ + readQuotableIdent(&tableName), + skipSpaces, + maybeReadAlias, + skipSpaces, + readLockType(&write), + }.exec(rd) + if err != nil { + return nil, err + } + + return &plan.TableLock{ + Table: plan.NewUnresolvedTable(tableName, ""), + Write: write, + }, nil +} + +func maybeReadAlias(rd *bufio.Reader) error { + data, err := rd.Peek(2) + if err != nil { + return err + } + + if strings.ToLower(string(data)) == "as" { + _, err := rd.Discard(2) + if err != nil { + return err + } + + if err := skipSpaces(rd); err != nil { + return err + } + + var ignored string + if err := readIdent(&ignored)(rd); err != nil { + return err + } + + return nil + } + + var nextIdent string + if err := readIdent(&nextIdent)(rd); err != nil { + return err + } + + switch strings.ToLower(nextIdent) { + case "read", "low_priority", "write": + unreadString(rd, nextIdent) + } + + return nil +} + +func readLockType(write *bool) parseFunc { + return func(rd *bufio.Reader) error { + var nextIdent string + if err := readIdent(&nextIdent)(rd); err != nil { + return err + } + + switch strings.ToLower(nextIdent) { + case "low_priority": + err := parseFuncs{skipSpaces, expect("write")}.exec(rd) + if err != nil { + return err + } + + fallthrough + case "write": + *write = true + return nil + case "read": + var ident string + if err := skipSpaces(rd); err != nil { + return err + } + + if err := readIdent(&ident)(rd); err != nil { + return err + } + + if ident != "" && ident != "local" { + return errUnexpectedSyntax.New("LOCAL", ident) + } + + return nil + default: + return errUnexpectedSyntax.New("one of: READ, WRITE or LOW_PRIORITY", nextIdent) + } + } +} diff --git a/sql/parse/parse.go b/sql/parse/parse.go index 6dc5e23ee..f4dfe6f0a 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -1,17 +1,22 @@ package parse import ( + "bufio" "fmt" + "io" + "io/ioutil" "regexp" "strconv" "strings" - opentracing "github.com/opentracing/opentracing-go" + "github.com/opentracing/opentracing-go" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/expression/function" + "github.com/src-d/go-mysql-server/sql/expression/function/aggregation" + "github.com/src-d/go-mysql-server/sql/plan" "gopkg.in/src-d/go-errors.v1" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" - "gopkg.in/src-d/go-mysql-server.v0/sql/plan" - "gopkg.in/src-d/go-vitess.v0/vt/sqlparser" + "vitess.io/vitess/go/vt/sqlparser" ) var ( @@ -25,29 +30,83 @@ var ( ErrInvalidSQLValType = errors.NewKind("invalid SQLVal of type: %d") // ErrInvalidSortOrder is returned when a sort order is not valid. - ErrInvalidSortOrder = errors.NewKind("invalod sort order: %s") + ErrInvalidSortOrder = errors.NewKind("invalid sort order: %s") ) var ( - describeTablesRegex = regexp.MustCompile(`^describe\s+table\s+(.*)`) - createIndexRegex = regexp.MustCompile(`^create\s+index\s+`) + describeTablesRegex = regexp.MustCompile(`^(describe|desc)\s+table\s+(.*)`) + createIndexRegex = regexp.MustCompile(`^create\s+index\s+`) + dropIndexRegex = regexp.MustCompile(`^drop\s+index\s+`) + showIndexRegex = regexp.MustCompile(`^show\s+(index|indexes|keys)\s+(from|in)\s+\S+\s*`) + showCreateRegex = regexp.MustCompile(`^show create\s+\S+\s*`) + showVariablesRegex = regexp.MustCompile(`^show\s+(.*)?variables\s*`) + showWarningsRegex = regexp.MustCompile(`^show\s+warnings\s*`) + showCollationRegex = regexp.MustCompile(`^show\s+collation\s*`) + describeRegex = regexp.MustCompile(`^(describe|desc|explain)\s+(.*)\s+`) + fullProcessListRegex = regexp.MustCompile(`^show\s+(full\s+)?processlist$`) + unlockTablesRegex = regexp.MustCompile(`^unlock\s+tables$`) + lockTablesRegex = regexp.MustCompile(`^lock\s+tables\s`) + setRegex = regexp.MustCompile(`^set\s+`) + createViewRegex = regexp.MustCompile(`^create\s+view\s+`) +) + +// These constants aren't exported from vitess for some reason. This could be removed if we changed this. +const ( + colKeyNone sqlparser.ColumnKeyOption = iota + colKeyPrimary + colKeySpatialKey + colKeyUnique + colKeyUniqueKey + colKey ) // Parse parses the given SQL sentence and returns the corresponding node. -func Parse(ctx *sql.Context, s string) (sql.Node, error) { - span, ctx := ctx.Span("parse", opentracing.Tag{Key: "query", Value: s}) +func Parse(ctx *sql.Context, query string) (sql.Node, error) { + span, ctx := ctx.Span("parse", opentracing.Tag{Key: "query", Value: query}) defer span.Finish() + s := strings.TrimSpace(removeComments(query)) if strings.HasSuffix(s, ";") { s = s[:len(s)-1] } + if s == "" { + ctx.Warn(0, "query was empty after trimming comments, so it will be ignored") + return plan.Nothing, nil + } + lowerQuery := strings.ToLower(s) + switch true { case describeTablesRegex.MatchString(lowerQuery): return parseDescribeTables(lowerQuery) case createIndexRegex.MatchString(lowerQuery): - return parseCreateIndex(s) + return parseCreateIndex(ctx, s) + case dropIndexRegex.MatchString(lowerQuery): + return parseDropIndex(s) + case showIndexRegex.MatchString(lowerQuery): + return parseShowIndex(s) + case showCreateRegex.MatchString(lowerQuery): + return parseShowCreate(s) + case showVariablesRegex.MatchString(lowerQuery): + return parseShowVariables(ctx, s) + case showWarningsRegex.MatchString(lowerQuery): + return parseShowWarnings(ctx, s) + case showCollationRegex.MatchString(lowerQuery): + return parseShowCollation(ctx, s) + case describeRegex.MatchString(lowerQuery): + return parseDescribeQuery(ctx, s) + case fullProcessListRegex.MatchString(lowerQuery): + return plan.NewShowProcessList(), nil + case unlockTablesRegex.MatchString(lowerQuery): + return plan.NewUnlockTables(), nil + case lockTablesRegex.MatchString(lowerQuery): + return parseLockTables(ctx, s) + case setRegex.MatchString(lowerQuery): + s = fixSetQuery(s) + case createViewRegex.MatchString(lowerQuery): + // CREATE VIEW parses as a CREATE DDL statement with an empty table spec + return nil, ErrUnsupportedFeature.New("CREATE VIEW") } stmt, err := sqlparser.Parse(s) @@ -60,8 +119,23 @@ func Parse(ctx *sql.Context, s string) (sql.Node, error) { func parseDescribeTables(s string) (sql.Node, error) { t := describeTablesRegex.FindStringSubmatch(s) - if len(t) == 2 && t[1] != "" { - return plan.NewDescribe(plan.NewUnresolvedTable(t[1])), nil + if len(t) == 3 && t[2] != "" { + parts := strings.Split(t[2], ".") + var table, db string + switch len(parts) { + case 1: + table = parts[0] + case 2: + if parts[0] == "" || parts[1] == "" { + return nil, ErrUnsupportedSyntax.New(s) + } + db = parts[0] + table = parts[1] + default: + return nil, ErrUnsupportedSyntax.New(s) + } + + return plan.NewDescribe(plan.NewUnresolvedTable(table, db)), nil } return nil, ErrUnsupportedSyntax.New(s) @@ -72,23 +146,170 @@ func convert(ctx *sql.Context, stmt sqlparser.Statement, query string) (sql.Node default: return nil, ErrUnsupportedSyntax.New(n) case *sqlparser.Show: - return convertShow(n) + // When a query is empty it means it comes from a subquery, as we don't + // have the query itself in a subquery. Hence, a SHOW could not be + // parsed. + if query == "" { + return nil, ErrUnsupportedFeature.New("SHOW in subquery") + } + return convertShow(ctx, n, query) case *sqlparser.Select: return convertSelect(ctx, n) case *sqlparser.Insert: return convertInsert(ctx, n) case *sqlparser.DDL: - return convertDDL(n, query) + // unlike other statements, DDL statements have loose parsing by default + ddl, err := sqlparser.ParseStrictDDL(query) + if err != nil { + return nil, err + } + return convertDDL(ddl.(*sqlparser.DDL)) + case *sqlparser.Set: + return convertSet(ctx, n) + case *sqlparser.Use: + return convertUse(n) + case *sqlparser.Rollback: + return plan.NewRollback(), nil + case *sqlparser.Delete: + return convertDelete(ctx, n) + case *sqlparser.Update: + return convertUpdate(ctx, n) + } +} + +func convertUse(n *sqlparser.Use) (sql.Node, error) { + name := n.DBName.String() + return plan.NewUse(sql.UnresolvedDatabase(name)), nil +} + +func convertSet(ctx *sql.Context, n *sqlparser.Set) (sql.Node, error) { + if n.Scope == sqlparser.GlobalStr { + return nil, ErrUnsupportedFeature.New("SET global variables") + } + + var variables = make([]plan.SetVariable, len(n.Exprs)) + for i, e := range n.Exprs { + expr, err := exprToExpression(ctx, e.Expr) + if err != nil { + return nil, err + } + + name := strings.TrimSpace(e.Name.Lowered()) + if expr, err = expression.TransformUp(expr, func(e sql.Expression) (sql.Expression, error) { + if _, ok := e.(*expression.DefaultColumn); ok { + return e, nil + } + + if !e.Resolved() || e.Type() != sql.Text { + return e, nil + } + + txt, err := e.Eval(ctx, nil) + if err != nil { + return nil, err + } + + val, ok := txt.(string) + if !ok { + return nil, ErrUnsupportedFeature.New("invalid qualifiers in set variable names") + } + + switch strings.ToLower(val) { + case sqlparser.KeywordString(sqlparser.ON): + return expression.NewLiteral(int64(1), sql.Int64), nil + case sqlparser.KeywordString(sqlparser.TRUE): + return expression.NewLiteral(true, sql.Boolean), nil + case sqlparser.KeywordString(sqlparser.OFF): + return expression.NewLiteral(int64(0), sql.Int64), nil + case sqlparser.KeywordString(sqlparser.FALSE): + return expression.NewLiteral(false, sql.Boolean), nil + } + + return e, nil + }); err != nil { + return nil, err + } + + variables[i] = plan.SetVariable{ + Name: name, + Value: expr, + } } + + return plan.NewSet(variables...), nil } -func convertShow(s *sqlparser.Show) (sql.Node, error) { - if s.Type != sqlparser.KeywordString(sqlparser.TABLES) { +func convertShow(ctx *sql.Context, s *sqlparser.Show, query string) (sql.Node, error) { + switch s.Type { + case sqlparser.KeywordString(sqlparser.TABLES): + var dbName string + var filter sql.Expression + var full bool + if s.ShowTablesOpt != nil { + dbName = s.ShowTablesOpt.DbName + full = s.ShowTablesOpt.Full != "" + + if s.ShowTablesOpt.Filter != nil { + if s.ShowTablesOpt.Filter.Filter != nil { + var err error + filter, err = exprToExpression(ctx, s.ShowTablesOpt.Filter.Filter) + if err != nil { + return nil, err + } + } else if s.ShowTablesOpt.Filter.Like != "" { + filter = expression.NewLike( + expression.NewUnresolvedColumn("Table"), + expression.NewLiteral(s.ShowTablesOpt.Filter.Like, sql.Text), + ) + } + } + } + + var node sql.Node = plan.NewShowTables(sql.UnresolvedDatabase(dbName), full) + if filter != nil { + node = plan.NewFilter(filter, node) + } + + return node, nil + case sqlparser.KeywordString(sqlparser.DATABASES), sqlparser.KeywordString(sqlparser.SCHEMAS): + return plan.NewShowDatabases(), nil + case sqlparser.KeywordString(sqlparser.FIELDS), sqlparser.KeywordString(sqlparser.COLUMNS): + // TODO(erizocosmico): vitess parser does not support EXTENDED. + table := plan.NewUnresolvedTable(s.OnTable.Name.String(), s.OnTable.Qualifier.String()) + full := s.ShowTablesOpt.Full != "" + + var node sql.Node = plan.NewShowColumns(full, table) + + if s.ShowTablesOpt != nil && s.ShowTablesOpt.Filter != nil { + if s.ShowTablesOpt.Filter.Like != "" { + pattern := expression.NewLiteral(s.ShowTablesOpt.Filter.Like, sql.Text) + + node = plan.NewFilter( + expression.NewLike( + expression.NewUnresolvedColumn("Field"), + pattern, + ), + node, + ) + } + + if s.ShowTablesOpt.Filter.Filter != nil { + filter, err := exprToExpression(ctx, s.ShowTablesOpt.Filter.Filter) + if err != nil { + return nil, err + } + + node = plan.NewFilter(filter, node) + } + } + + return node, nil + case sqlparser.KeywordString(sqlparser.TABLE): + return parseShowTableStatus(ctx, query) + default: unsupportedShow := fmt.Sprintf("SHOW %s", s.Type) return nil, ErrUnsupportedFeature.New(unsupportedShow) } - - return plan.NewShowTables(&sql.UnresolvedDatabase{}), nil } func convertSelect(ctx *sql.Context, s *sqlparser.Select) (sql.Node, error) { @@ -97,40 +318,37 @@ func convertSelect(ctx *sql.Context, s *sqlparser.Select) (sql.Node, error) { return nil, err } - if s.Having != nil { - return nil, ErrUnsupportedFeature.New("HAVING") - } - if s.Where != nil { - node, err = whereToFilter(s.Where, node) + node, err = whereToFilter(ctx, s.Where, node) if err != nil { return nil, err } } - if len(s.OrderBy) != 0 { - node, err = orderByToSort(s.OrderBy, node) + node, err = selectToProjectOrGroupBy(ctx, s.SelectExprs, s.GroupBy, node) + if err != nil { + return nil, err + } + + if s.Having != nil { + node, err = havingToHaving(ctx, s.Having, node) if err != nil { return nil, err } } - node, err = selectToProjectOrGroupBy(s.SelectExprs, s.GroupBy, node) - if err != nil { - return nil, err - } - if s.Distinct != "" { node = plan.NewDistinct(node) } - if s.Limit != nil { - node, err = limitToLimit(ctx, s.Limit.Rowcount, node) + if len(s.OrderBy) != 0 { + node, err = orderByToSort(ctx, s.OrderBy, node) if err != nil { return nil, err } } + // Limit must wrap offset, and not vice-versa, so that skipped rows don't count toward the returned row count. if s.Limit != nil && s.Limit.Offset != nil { node, err = offsetToOffset(ctx, s.Limit.Offset, node) if err != nil { @@ -138,29 +356,46 @@ func convertSelect(ctx *sql.Context, s *sqlparser.Select) (sql.Node, error) { } } + if s.Limit != nil { + node, err = limitToLimit(ctx, s.Limit.Rowcount, node) + if err != nil { + return nil, err + } + } else if ok, val := sql.HasDefaultValue(ctx.Session, "sql_select_limit"); !ok { + limit := val.(int64) + node = plan.NewLimit(int64(limit), node) + } + return node, nil } -func convertDDL(c *sqlparser.DDL, query string) (sql.Node, error) { +func convertDDL(c *sqlparser.DDL) (sql.Node, error) { switch c.Action { case sqlparser.CreateStr: return convertCreateTable(c) + case sqlparser.DropStr: + return convertDropTable(c) default: return nil, ErrUnsupportedSyntax.New(c) } } +func convertDropTable(c *sqlparser.DDL) (sql.Node, error) { + tableNames := make([]string, len(c.FromTables)) + for i, t := range c.FromTables { + tableNames[i] = t.Name.String() + } + return plan.NewDropTable(sql.UnresolvedDatabase(""), c.IfExists, tableNames...), nil +} + func convertCreateTable(c *sqlparser.DDL) (sql.Node, error) { - schema, err := columnDefinitionToSchema(c.TableSpec.Columns) + schema, err := tableSpecToSchema(c.TableSpec) if err != nil { return nil, err } return plan.NewCreateTable( - &sql.UnresolvedDatabase{}, - c.NewName.Name.String(), - schema, - ), nil + sql.UnresolvedDatabase(""), c.Table.Name.String(), schema), nil } func convertInsert(ctx *sql.Context, i *sqlparser.Insert) (sql.Node, error) { @@ -172,39 +407,153 @@ func convertInsert(ctx *sql.Context, i *sqlparser.Insert) (sql.Node, error) { return nil, ErrUnsupportedSyntax.New(i) } + isReplace := i.Action == sqlparser.ReplaceStr + src, err := insertRowsToNode(ctx, i.Rows) if err != nil { return nil, err } return plan.NewInsertInto( - plan.NewUnresolvedTable(i.Table.Name.String()), + plan.NewUnresolvedTable(i.Table.Name.String(), i.Table.Qualifier.String()), src, + isReplace, columnsToStrings(i.Columns), ), nil } -func columnDefinitionToSchema(colDef []*sqlparser.ColumnDefinition) (sql.Schema, error) { +func convertDelete(ctx *sql.Context, d *sqlparser.Delete) (sql.Node, error) { + node, err := tableExprsToTable(ctx, d.TableExprs) + if err != nil { + return nil, err + } + + if d.Where != nil { + node, err = whereToFilter(ctx, d.Where, node) + if err != nil { + return nil, err + } + } + + if len(d.OrderBy) != 0 { + node, err = orderByToSort(ctx, d.OrderBy, node) + if err != nil { + return nil, err + } + } + + // Limit must wrap offset, and not vice-versa, so that skipped rows don't count toward the returned row count. + if d.Limit != nil && d.Limit.Offset != nil { + node, err = offsetToOffset(ctx, d.Limit.Offset, node) + if err != nil { + return nil, err + } + } + + if d.Limit != nil { + node, err = limitToLimit(ctx, d.Limit.Rowcount, node) + if err != nil { + return nil, err + } + } + + return plan.NewDeleteFrom(node), nil +} + +func convertUpdate(ctx *sql.Context, d *sqlparser.Update) (sql.Node, error) { + node, err := tableExprsToTable(ctx, d.TableExprs) + if err != nil { + return nil, err + } + + updateExprs, err := updateExprsToExpressions(ctx, d.Exprs) + if err != nil { + return nil, err + } + + if d.Where != nil { + node, err = whereToFilter(ctx, d.Where, node) + if err != nil { + return nil, err + } + } + + if len(d.OrderBy) != 0 { + node, err = orderByToSort(ctx, d.OrderBy, node) + if err != nil { + return nil, err + } + } + + // Limit must wrap offset, and not vice-versa, so that skipped rows don't count toward the returned row count. + if d.Limit != nil && d.Limit.Offset != nil { + node, err = offsetToOffset(ctx, d.Limit.Offset, node) + if err != nil { + return nil, err + } + + } + + if d.Limit != nil { + node, err = limitToLimit(ctx, d.Limit.Rowcount, node) + if err != nil { + return nil, err + } + } + + return plan.NewUpdate(node, updateExprs), nil +} + +func tableSpecToSchema(tableSpec *sqlparser.TableSpec) (sql.Schema, error) { var schema sql.Schema - for _, cd := range colDef { - typ := cd.Type - internalTyp, err := sql.MysqlTypeToType(typ.SQLType()) + for _, cd := range tableSpec.Columns { + column, err := getColumn(cd, tableSpec.Indexes) if err != nil { return nil, err } - schema = append(schema, &sql.Column{ - Nullable: !bool(typ.NotNull), - Type: internalTyp, - Name: cd.Name.String(), - // TODO - Default: nil, - }) + schema = append(schema, column) } return schema, nil } +// getColumn returns the sql.Column for the column definition given, as part of a create table statement. +func getColumn(cd *sqlparser.ColumnDefinition, indexes []*sqlparser.IndexDefinition) (*sql.Column, error) { + typ := cd.Type + internalTyp, err := sql.MysqlTypeToType(typ.SQLType()) + if err != nil { + return nil, err + } + + // Primary key info can either be specified in the column's type info (for in-line declarations), or in a slice of + // indexes attached to the table def. We have to check both places to find if a column is part of the primary key + isPkey := cd.Type.KeyOpt == colKeyPrimary + + if !isPkey { + OuterLoop: + for _, index := range indexes { + if index.Info.Primary { + for _, indexCol := range index.Columns { + if indexCol.Column.Equal(cd.Name) { + isPkey = true + break OuterLoop + } + } + } + } + } + + return &sql.Column{ + Nullable: !bool(typ.NotNull), + Type: internalTyp, + Name: cd.Name.String(), + PrimaryKey: isPkey, + // TODO + Default: nil, + }, nil +} + func columnsToStrings(cols sqlparser.Columns) []string { res := make([]string, len(cols)) for i, c := range cols { @@ -221,19 +570,19 @@ func insertRowsToNode(ctx *sql.Context, ir sqlparser.InsertRows) (sql.Node, erro case *sqlparser.Union: return nil, ErrUnsupportedFeature.New("UNION") case sqlparser.Values: - return valuesToValues(v) + return valuesToValues(ctx, v) default: return nil, ErrUnsupportedSyntax.New(ir) } } -func valuesToValues(v sqlparser.Values) (sql.Node, error) { +func valuesToValues(ctx *sql.Context, v sqlparser.Values) (sql.Node, error) { exprTuples := make([][]sql.Expression, len(v)) for i, vt := range v { exprs := make([]sql.Expression, len(vt)) exprTuples[i] = exprs for j, e := range vt { - expr, err := exprToExpression(e) + expr, err := exprToExpression(ctx, e) if err != nil { return nil, err } @@ -286,11 +635,7 @@ func tableExprToTable( // TODO: Add support for qualifier. switch e := t.Expr.(type) { case sqlparser.TableName: - if !e.Qualifier.IsEmpty() { - return nil, ErrUnsupportedFeature.New("table name qualifiers") - } - - node := plan.NewUnresolvedTable(e.Name.String()) + node := plan.NewUnresolvedTable(e.Name.String(), e.Qualifier.String()) if !t.As.IsEmpty() { return plan.NewTableAlias(t.As.String(), node), nil } @@ -311,15 +656,10 @@ func tableExprToTable( return nil, ErrUnsupportedSyntax.New(te) } case *sqlparser.JoinTableExpr: - // TODO: add support for the rest of joins - if t.Join != sqlparser.JoinStr { - return nil, ErrUnsupportedFeature.New(t.Join) - } - // TODO: add support for using, once we have proper table // qualification of fields if len(t.Condition.Using) > 0 { - return nil, ErrUnsupportedFeature.New("using clause on join") + return nil, ErrUnsupportedFeature.New("USING clause on join") } left, err := tableExprToTable(ctx, t.LeftExpr) @@ -332,17 +672,34 @@ func tableExprToTable( return nil, err } - cond, err := exprToExpression(t.Condition.On) + if t.Join == sqlparser.NaturalJoinStr { + return plan.NewNaturalJoin(left, right), nil + } + + if t.Condition.On == nil { + return nil, ErrUnsupportedSyntax.New("missed ON clause for JOIN statement") + } + + cond, err := exprToExpression(ctx, t.Condition.On) if err != nil { return nil, err } - return plan.NewInnerJoin(left, right, cond), nil + switch t.Join { + case sqlparser.JoinStr: + return plan.NewInnerJoin(left, right, cond), nil + case sqlparser.LeftJoinStr: + return plan.NewLeftJoin(left, right, cond), nil + case sqlparser.RightJoinStr: + return plan.NewRightJoin(left, right, cond), nil + default: + return nil, ErrUnsupportedFeature.New(t.Join) + } } } -func whereToFilter(w *sqlparser.Where, child sql.Node) (*plan.Filter, error) { - c, err := exprToExpression(w.Expr) +func whereToFilter(ctx *sql.Context, w *sqlparser.Where, child sql.Node) (*plan.Filter, error) { + c, err := exprToExpression(ctx, w.Expr) if err != nil { return nil, err } @@ -350,10 +707,10 @@ func whereToFilter(w *sqlparser.Where, child sql.Node) (*plan.Filter, error) { return plan.NewFilter(c, child), nil } -func orderByToSort(ob sqlparser.OrderBy, child sql.Node) (*plan.Sort, error) { +func orderByToSort(ctx *sql.Context, ob sqlparser.OrderBy, child sql.Node) (*plan.Sort, error) { var sortFields []plan.SortField for _, o := range ob { - e, err := exprToExpression(o.Expr) + e, err := exprToExpression(ctx, o.Expr) if err != nil { return nil, err } @@ -380,21 +737,25 @@ func limitToLimit( limit sqlparser.Expr, child sql.Node, ) (*plan.Limit, error) { - e, err := exprToExpression(limit) + rowCount, err := getInt64Value(ctx, limit, "LIMIT with non-integer literal") if err != nil { return nil, err } - nl, ok := e.(*expression.Literal) - if !ok || nl.Type() != sql.Int64 { - return nil, ErrUnsupportedFeature.New("LIMIT with non-integer literal") + if rowCount < 0 { + return nil, ErrUnsupportedSyntax.New("LIMIT must be >= 0") } - n, err := nl.Eval(ctx, nil) + return plan.NewLimit(rowCount, child), nil +} + +func havingToHaving(ctx *sql.Context, having *sqlparser.Where, node sql.Node) (sql.Node, error) { + cond, err := exprToExpression(ctx, having.Expr) if err != nil { return nil, err } - return plan.NewLimit(n.(int64), child), nil + + return plan.NewHaving(cond, node), nil } func offsetToOffset( @@ -402,36 +763,78 @@ func offsetToOffset( offset sqlparser.Expr, child sql.Node, ) (*plan.Offset, error) { - e, err := exprToExpression(offset) + o, err := getInt64Value(ctx, offset, "OFFSET with non-integer literal") if err != nil { return nil, err } - nl, ok := e.(*expression.Literal) - if !ok || nl.Type() != sql.Int64 { - return nil, ErrUnsupportedFeature.New("OFFSET with non-integer literal") + if o < 0 { + return nil, ErrUnsupportedSyntax.New("OFFSET must be >= 0") } - n, err := nl.Eval(ctx, nil) + return plan.NewOffset(o, child), nil +} + +// getInt64Literal returns an int64 *expression.Literal for the value given, or an unsupported error with the string +// given if the expression doesn't represent an integer literal. +func getInt64Literal(ctx *sql.Context, expr sqlparser.Expr, errStr string) (*expression.Literal, error) { + e, err := exprToExpression(ctx, expr) if err != nil { return nil, err } - return plan.NewOffset(n.(int64), child), nil + + nl, ok := e.(*expression.Literal) + if !ok || !sql.IsInteger(nl.Type()) { + return nil, ErrUnsupportedFeature.New(errStr) + } else { + i64, err := sql.Int64.Convert(nl.Value()) + if err != nil { + return nil, ErrUnsupportedFeature.New(errStr) + } + return expression.NewLiteral(i64, sql.Int64), nil + } + + return nl, nil } -func isAggregate(e sql.Expression) bool { - switch v := e.(type) { - case *expression.UnresolvedFunction: - return v.IsAggregate - case *expression.Alias: - return isAggregate(v.Child) - default: - return false +// getInt64Value returns the int64 literal value in the expression given, or an error with the errStr given if it +// cannot. +func getInt64Value(ctx *sql.Context, expr sqlparser.Expr, errStr string) (int64, error) { + ie, err := getInt64Literal(ctx, expr, errStr) + if err != nil { + return 0, err } + + i, err := ie.Eval(ctx, nil) + if err != nil { + return 0, err + } + + return i.(int64), nil } -func selectToProjectOrGroupBy(se sqlparser.SelectExprs, g sqlparser.GroupBy, child sql.Node) (sql.Node, error) { - selectExprs, err := selectExprsToExpressions(se) +func isAggregate(e sql.Expression) bool { + var isAgg bool + expression.Inspect(e, func(e sql.Expression) bool { + switch e := e.(type) { + case *expression.UnresolvedFunction: + isAgg = isAgg || e.IsAggregate + case *aggregation.CountDistinct: + isAgg = true + } + + return true + }) + return isAgg +} + +func selectToProjectOrGroupBy( + ctx *sql.Context, + se sqlparser.SelectExprs, + g sqlparser.GroupBy, + child sql.Node, +) (sql.Node, error) { + selectExprs, err := selectExprsToExpressions(ctx, se) if err != nil { return nil, err } @@ -447,21 +850,37 @@ func selectToProjectOrGroupBy(se sqlparser.SelectExprs, g sqlparser.GroupBy, chi } if isAgg { - groupingExprs, err := groupByToExpressions(g) + groupingExprs, err := groupByToExpressions(ctx, g) if err != nil { return nil, err } + agglen := int64(len(selectExprs)) + for i, ge := range groupingExprs { + // if GROUP BY index + if l, ok := ge.(*expression.Literal); ok && sql.IsNumber(l.Type()) { + if i64, err := sql.Int64.Convert(l.Value()); err == nil { + if idx, ok := i64.(int64); ok && idx > 0 && idx <= agglen { + aggexpr := selectExprs[idx-1] + if alias, ok := aggexpr.(*expression.Alias); ok { + aggexpr = expression.NewUnresolvedColumn(alias.Name()) + } + groupingExprs[i] = aggexpr + } + } + } + } + return plan.NewGroupBy(selectExprs, groupingExprs, child), nil } return plan.NewProject(selectExprs, child), nil } -func selectExprsToExpressions(se sqlparser.SelectExprs) ([]sql.Expression, error) { +func selectExprsToExpressions(ctx *sql.Context, se sqlparser.SelectExprs) ([]sql.Expression, error) { var exprs []sql.Expression for _, e := range se { - pe, err := selectExprToExpression(e) + pe, err := selectExprToExpression(ctx, e) if err != nil { return nil, err } @@ -472,16 +891,44 @@ func selectExprsToExpressions(se sqlparser.SelectExprs) ([]sql.Expression, error return exprs, nil } -func exprToExpression(e sqlparser.Expr) (sql.Expression, error) { +func exprToExpression(ctx *sql.Context, e sqlparser.Expr) (sql.Expression, error) { switch v := e.(type) { default: return nil, ErrUnsupportedSyntax.New(e) + case *sqlparser.Default: + return expression.NewDefaultColumn(v.ColName), nil + case *sqlparser.SubstrExpr: + var ( + name sql.Expression + err error + ) + if v.Name != nil { + name, err = exprToExpression(ctx, v.Name) + } else { + name, err = exprToExpression(ctx, v.StrVal) + } + if err != nil { + return nil, err + } + from, err := exprToExpression(ctx, v.From) + if err != nil { + return nil, err + } + + if v.To == nil { + return function.NewSubstring(name, from) + } + to, err := exprToExpression(ctx, v.To) + if err != nil { + return nil, err + } + return function.NewSubstring(name, from, to) case *sqlparser.ComparisonExpr: - return comparisonExprToExpression(v) + return comparisonExprToExpression(ctx, v) case *sqlparser.IsExpr: - return isExprToExpression(v) + return isExprToExpression(ctx, v) case *sqlparser.NotExpr: - c, err := exprToExpression(v.Expr) + c, err := exprToExpression(ctx, v.Expr) if err != nil { return nil, err } @@ -494,67 +941,78 @@ func exprToExpression(e sqlparser.Expr) (sql.Expression, error) { case *sqlparser.NullVal: return expression.NewLiteral(nil, sql.Null), nil case *sqlparser.ColName: - //TODO: add handling of case sensitiveness. if !v.Qualifier.IsEmpty() { return expression.NewUnresolvedQualifiedColumn( v.Qualifier.Name.String(), - v.Name.Lowered(), + v.Name.String(), ), nil } - return expression.NewUnresolvedColumn(v.Name.Lowered()), nil + return expression.NewUnresolvedColumn(v.Name.String()), nil case *sqlparser.FuncExpr: - exprs, err := selectExprsToExpressions(v.Exprs) + exprs, err := selectExprsToExpressions(ctx, v.Exprs) if err != nil { return nil, err } + if v.Distinct { + if v.Name.Lowered() != "count" { + return nil, ErrUnsupportedSyntax.New("DISTINCT on non-COUNT aggregations") + } + + if len(exprs) != 1 { + return nil, ErrUnsupportedSyntax.New("more than one expression in COUNT") + } + + return aggregation.NewCountDistinct(exprs[0]), nil + } + return expression.NewUnresolvedFunction(v.Name.Lowered(), - v.IsAggregate(), exprs...), nil + isAggregateFunc(v), exprs...), nil case *sqlparser.ParenExpr: - return exprToExpression(v.Expr) + return exprToExpression(ctx, v.Expr) case *sqlparser.AndExpr: - lhs, err := exprToExpression(v.Left) + lhs, err := exprToExpression(ctx, v.Left) if err != nil { return nil, err } - rhs, err := exprToExpression(v.Right) + rhs, err := exprToExpression(ctx, v.Right) if err != nil { return nil, err } return expression.NewAnd(lhs, rhs), nil case *sqlparser.OrExpr: - lhs, err := exprToExpression(v.Left) + lhs, err := exprToExpression(ctx, v.Left) if err != nil { return nil, err } - rhs, err := exprToExpression(v.Right) + rhs, err := exprToExpression(ctx, v.Right) if err != nil { return nil, err } return expression.NewOr(lhs, rhs), nil case *sqlparser.ConvertExpr: - expr, err := exprToExpression(v.Expr) + expr, err := exprToExpression(ctx, v.Expr) if err != nil { return nil, err } return expression.NewConvert(expr, v.Type.Type), nil case *sqlparser.RangeCond: - val, err := exprToExpression(v.Left) + val, err := exprToExpression(ctx, v.Left) if err != nil { return nil, err } - lower, err := exprToExpression(v.From) + lower, err := exprToExpression(ctx, v.From) if err != nil { return nil, err } - upper, err := exprToExpression(v.To) + upper, err := exprToExpression(ctx, v.To) if err != nil { return nil, err } @@ -570,7 +1028,7 @@ func exprToExpression(e sqlparser.Expr) (sql.Expression, error) { case sqlparser.ValTuple: var exprs = make([]sql.Expression, len(v)) for i, e := range v { - expr, err := exprToExpression(e) + expr, err := exprToExpression(ctx, e) if err != nil { return nil, err } @@ -579,21 +1037,71 @@ func exprToExpression(e sqlparser.Expr) (sql.Expression, error) { return expression.NewTuple(exprs...), nil case *sqlparser.BinaryExpr: - return binaryExprToExpression(v) + return binaryExprToExpression(ctx, v) + case *sqlparser.UnaryExpr: + return unaryExprToExpression(ctx, v) + case *sqlparser.Subquery: + node, err := convert(ctx, v.Select, "") + if err != nil { + return nil, err + } + return expression.NewSubquery(node), nil + case *sqlparser.CaseExpr: + return caseExprToExpression(ctx, v) + case *sqlparser.IntervalExpr: + return intervalExprToExpression(ctx, v) } } +func isAggregateFunc(v *sqlparser.FuncExpr) bool { + switch v.Name.Lowered() { + case "first", "last": + return true + } + + return v.IsAggregate() +} + +// Convert an integer, represented by the specified string in the specified +// base, to its smallest representation possible, out of: +// int8, uint8, int16, uint16, int32, uint32, int64 and uint64 +func convertInt(value string, base int) (sql.Expression, error) { + if i8, err := strconv.ParseInt(value, base, 8); err == nil { + return expression.NewLiteral(int8(i8), sql.Int8), nil + } + if ui8, err := strconv.ParseUint(value, base, 8); err == nil { + return expression.NewLiteral(uint8(ui8), sql.Uint8), nil + } + if i16, err := strconv.ParseInt(value, base, 16); err == nil { + return expression.NewLiteral(int16(i16), sql.Int16), nil + } + if ui16, err := strconv.ParseUint(value, base, 16); err == nil { + return expression.NewLiteral(uint16(ui16), sql.Uint16), nil + } + if i32, err := strconv.ParseInt(value, base, 32); err == nil { + return expression.NewLiteral(int32(i32), sql.Int32), nil + } + if ui32, err := strconv.ParseUint(value, base, 32); err == nil { + return expression.NewLiteral(uint32(ui32), sql.Uint32), nil + } + if i64, err := strconv.ParseInt(value, base, 64); err == nil { + return expression.NewLiteral(int64(i64), sql.Int64), nil + } + + ui64, err := strconv.ParseUint(value, base, 64) + if err != nil { + return nil, err + } + + return expression.NewLiteral(uint64(ui64), sql.Uint64), nil +} + func convertVal(v *sqlparser.SQLVal) (sql.Expression, error) { switch v.Type { case sqlparser.StrVal: return expression.NewLiteral(string(v.Val), sql.Text), nil case sqlparser.IntVal: - //TODO: Use smallest integer representation and widen later. - val, err := strconv.ParseInt(string(v.Val), 10, 64) - if err != nil { - return nil, err - } - return expression.NewLiteral(val, sql.Int64), nil + return convertInt(string(v.Val), 10) case sqlparser.FloatVal: val, err := strconv.ParseFloat(string(v.Val), 64) if err != nil { @@ -608,11 +1116,7 @@ func convertVal(v *sqlparser.SQLVal) (sql.Expression, error) { v = strings.Trim(v[1:], "'") } - val, err := strconv.ParseInt(v, 16, 64) - if err != nil { - return nil, err - } - return expression.NewLiteral(val, sql.Int64), nil + return convertInt(v, 16) case sqlparser.HexVal: val, err := v.HexDecode() if err != nil { @@ -628,8 +1132,8 @@ func convertVal(v *sqlparser.SQLVal) (sql.Expression, error) { return nil, ErrInvalidSQLValType.New(v.Type) } -func isExprToExpression(c *sqlparser.IsExpr) (sql.Expression, error) { - e, err := exprToExpression(c.Expr) +func isExprToExpression(ctx *sql.Context, c *sqlparser.IsExpr) (sql.Expression, error) { + e, err := exprToExpression(ctx, c.Expr) if err != nil { return nil, err } @@ -639,18 +1143,26 @@ func isExprToExpression(c *sqlparser.IsExpr) (sql.Expression, error) { return expression.NewIsNull(e), nil case sqlparser.IsNotNullStr: return expression.NewNot(expression.NewIsNull(e)), nil + case sqlparser.IsTrueStr: + return expression.NewIsTrue(e), nil + case sqlparser.IsFalseStr: + return expression.NewIsFalse(e), nil + case sqlparser.IsNotTrueStr: + return expression.NewNot(expression.NewIsTrue(e)), nil + case sqlparser.IsNotFalseStr: + return expression.NewNot(expression.NewIsFalse(e)), nil default: return nil, ErrUnsupportedSyntax.New(c) } } -func comparisonExprToExpression(c *sqlparser.ComparisonExpr) (sql.Expression, error) { - left, err := exprToExpression(c.Left) +func comparisonExprToExpression(ctx *sql.Context, c *sqlparser.ComparisonExpr) (sql.Expression, error) { + left, err := exprToExpression(ctx, c.Left) if err != nil { return nil, err } - right, err := exprToExpression(c.Right) + right, err := exprToExpression(ctx, c.Right) if err != nil { return nil, err } @@ -680,13 +1192,17 @@ func comparisonExprToExpression(c *sqlparser.ComparisonExpr) (sql.Expression, er return expression.NewIn(left, right), nil case sqlparser.NotInStr: return expression.NewNotIn(left, right), nil + case sqlparser.LikeStr: + return expression.NewLike(left, right), nil + case sqlparser.NotLikeStr: + return expression.NewNot(expression.NewLike(left, right)), nil } } -func groupByToExpressions(g sqlparser.GroupBy) ([]sql.Expression, error) { +func groupByToExpressions(ctx *sql.Context, g sqlparser.GroupBy) ([]sql.Expression, error) { es := make([]sql.Expression, len(g)) for i, ve := range g { - e, err := exprToExpression(ve) + e, err := exprToExpression(ctx, ve) if err != nil { return nil, err } @@ -697,7 +1213,7 @@ func groupByToExpressions(g sqlparser.GroupBy) ([]sql.Expression, error) { return es, nil } -func selectExprToExpression(se sqlparser.SelectExpr) (sql.Expression, error) { +func selectExprToExpression(ctx *sql.Context, se sqlparser.SelectExpr) (sql.Expression, error) { switch e := se.(type) { default: return nil, ErrUnsupportedSyntax.New(e) @@ -707,7 +1223,7 @@ func selectExprToExpression(se sqlparser.SelectExpr) (sql.Expression, error) { } return expression.NewQualifiedStar(e.TableName.Name.String()), nil case *sqlparser.AliasedExpr: - expr, err := exprToExpression(e.Expr) + expr, err := exprToExpression(ctx, e.Expr) if err != nil { return nil, err } @@ -721,7 +1237,22 @@ func selectExprToExpression(se sqlparser.SelectExpr) (sql.Expression, error) { } } -func binaryExprToExpression(be *sqlparser.BinaryExpr) (sql.Expression, error) { +func unaryExprToExpression(ctx *sql.Context, e *sqlparser.UnaryExpr) (sql.Expression, error) { + switch e.Operator { + case sqlparser.MinusStr: + expr, err := exprToExpression(ctx, e.Expr) + if err != nil { + return nil, err + } + + return expression.NewUnaryMinus(expr), nil + + default: + return nil, ErrUnsupportedFeature.New("unary operator: " + e.Operator) + } +} + +func binaryExprToExpression(ctx *sql.Context, be *sqlparser.BinaryExpr) (sql.Expression, error) { switch be.Operator { case sqlparser.PlusStr, @@ -736,19 +1267,327 @@ func binaryExprToExpression(be *sqlparser.BinaryExpr) (sql.Expression, error) { sqlparser.IntDivStr, sqlparser.ModStr: - l, err := exprToExpression(be.Left) + l, err := exprToExpression(ctx, be.Left) if err != nil { return nil, err } - r, err := exprToExpression(be.Right) + r, err := exprToExpression(ctx, be.Right) if err != nil { return nil, err } + _, lok := l.(*expression.Interval) + _, rok := r.(*expression.Interval) + if lok && be.Operator == "-" { + return nil, ErrUnsupportedSyntax.New("subtracting from an interval") + } else if (lok || rok) && be.Operator != "+" && be.Operator != "-" { + return nil, ErrUnsupportedSyntax.New("only + and - can be used to add of subtract intervals from dates") + } else if lok && rok { + return nil, ErrUnsupportedSyntax.New("intervals cannot be added or subtracted from other intervals") + } + return expression.NewArithmetic(l, r, be.Operator), nil default: return nil, ErrUnsupportedFeature.New(be.Operator) } } + +func caseExprToExpression(ctx *sql.Context, e *sqlparser.CaseExpr) (sql.Expression, error) { + var expr sql.Expression + var err error + + if e.Expr != nil { + expr, err = exprToExpression(ctx, e.Expr) + if err != nil { + return nil, err + } + } + + var branches []expression.CaseBranch + for _, w := range e.Whens { + var cond sql.Expression + cond, err = exprToExpression(ctx, w.Cond) + if err != nil { + return nil, err + } + + var val sql.Expression + val, err = exprToExpression(ctx, w.Val) + if err != nil { + return nil, err + } + + branches = append(branches, expression.CaseBranch{ + Cond: cond, + Value: val, + }) + } + + var elseExpr sql.Expression + if e.Else != nil { + elseExpr, err = exprToExpression(ctx, e.Else) + if err != nil { + return nil, err + } + } + + return expression.NewCase(expr, branches, elseExpr), nil +} + +func intervalExprToExpression(ctx *sql.Context, e *sqlparser.IntervalExpr) (sql.Expression, error) { + expr, err := exprToExpression(ctx, e.Expr) + if err != nil { + return nil, err + } + + return expression.NewInterval(expr, e.Unit), nil +} + +func updateExprsToExpressions(ctx *sql.Context, e sqlparser.UpdateExprs) ([]sql.Expression, error) { + res := make([]sql.Expression, len(e)) + for i, updateExpr := range e { + colName, err := exprToExpression(ctx, updateExpr.Name) + if err != nil { + return nil, err + } + innerExpr, err := exprToExpression(ctx, updateExpr.Expr) + if err != nil { + return nil, err + } + res[i] = expression.NewSetField(colName, innerExpr) + } + return res, nil +} + +func removeComments(s string) string { + r := bufio.NewReader(strings.NewReader(s)) + var result []rune + for { + ru, _, err := r.ReadRune() + if err == io.EOF { + break + } + if err != nil { + continue + } + switch ru { + case '\'', '"': + result = append(result, ru) + result = append(result, readString(r, ru == '\'')...) + case '-': + peeked, err := r.Peek(2) + if err == nil && + len(peeked) == 2 && + rune(peeked[0]) == '-' && + rune(peeked[1]) == ' ' { + discardUntilEOL(r) + } else { + result = append(result, ru) + } + case '/': + peeked, err := r.Peek(1) + if err == nil && + len(peeked) == 1 && + rune(peeked[0]) == '*' { + // read the char we peeked + _, _, _ = r.ReadRune() + discardMultilineComment(r) + } else { + result = append(result, ru) + } + default: + result = append(result, ru) + } + } + return string(result) +} +func discardUntilEOL(r *bufio.Reader) { + for { + ru, _, err := r.ReadRune() + if err == io.EOF { + break + } + if err != nil { + continue + } + if ru == '\n' { + break + } + } +} +func discardMultilineComment(r *bufio.Reader) { + for { + ru, _, err := r.ReadRune() + if err == io.EOF { + break + } + if err != nil { + continue + } + if ru == '*' { + peeked, err := r.Peek(1) + if err == nil && len(peeked) == 1 && rune(peeked[0]) == '/' { + // read the rune we just peeked + _, _, _ = r.ReadRune() + break + } + } + } +} +func readString(r *bufio.Reader, single bool) []rune { + var result []rune + var escaped bool + for { + ru, _, err := r.ReadRune() + if err == io.EOF { + break + } + if err != nil { + continue + } + result = append(result, ru) + if (!single && ru == '"' && !escaped) || + (single && ru == '\'' && !escaped) { + break + } + escaped = false + if ru == '\\' { + escaped = true + } + } + return result +} + +func parseShowTableStatus(ctx *sql.Context, query string) (sql.Node, error) { + buf := bufio.NewReader(strings.NewReader(query)) + err := parseFuncs{ + expect("show"), + skipSpaces, + expect("table"), + skipSpaces, + expect("status"), + skipSpaces, + }.exec(buf) + + if err != nil { + return nil, err + } + + if _, err = buf.Peek(1); err == io.EOF { + return plan.NewShowTableStatus(), nil + } + + var clause string + if err := readIdent(&clause)(buf); err != nil { + return nil, err + } + + if err := skipSpaces(buf); err != nil { + return nil, err + } + + switch strings.ToUpper(clause) { + case "FROM", "IN": + var db string + if err := readQuotableIdent(&db)(buf); err != nil { + return nil, err + } + + return plan.NewShowTableStatus(db), nil + case "WHERE", "LIKE": + bs, err := ioutil.ReadAll(buf) + if err != nil { + return nil, err + } + + expr, err := parseExpr(ctx, string(bs)) + if err != nil { + return nil, err + } + + var filter sql.Expression + if strings.ToUpper(clause) == "LIKE" { + filter = expression.NewLike( + expression.NewUnresolvedColumn("Name"), + expr, + ) + } else { + filter = expr + } + + return plan.NewFilter( + filter, + plan.NewShowTableStatus(), + ), nil + default: + return nil, errUnexpectedSyntax.New("one of: FROM, IN, LIKE or WHERE", clause) + } +} + +func parseShowCollation(ctx *sql.Context, query string) (sql.Node, error) { + buf := bufio.NewReader(strings.NewReader(query)) + err := parseFuncs{ + expect("show"), + skipSpaces, + expect("collation"), + skipSpaces, + }.exec(buf) + + if err != nil { + return nil, err + } + + if _, err = buf.Peek(1); err == io.EOF { + return plan.NewShowCollation(), nil + } + + var clause string + if err := readIdent(&clause)(buf); err != nil { + return nil, err + } + + if err := skipSpaces(buf); err != nil { + return nil, err + } + + switch strings.ToUpper(clause) { + case "WHERE", "LIKE": + bs, err := ioutil.ReadAll(buf) + if err != nil { + return nil, err + } + + expr, err := parseExpr(ctx, string(bs)) + if err != nil { + return nil, err + } + + var filter sql.Expression + if strings.ToUpper(clause) == "LIKE" { + filter = expression.NewLike( + expression.NewUnresolvedColumn("collation"), + expr, + ) + } else { + filter = expr + } + + return plan.NewFilter( + filter, + plan.NewShowCollation(), + ), nil + default: + return nil, errUnexpectedSyntax.New("one of: LIKE or WHERE", clause) + } +} + +var fixSessionRegex = regexp.MustCompile(`(,\s*|(set|SET)\s+)(SESSION|session)\s+([a-zA-Z0-9_]+)\s*=`) +var fixGlobalRegex = regexp.MustCompile(`(,\s*|(set|SET)\s+)(GLOBAL|global)\s+([a-zA-Z0-9_]+)\s*=`) + +func fixSetQuery(s string) string { + s = fixSessionRegex.ReplaceAllString(s, `$1@@session.$4 =`) + s = fixGlobalRegex.ReplaceAllString(s, `$1@@global.$4 =`) + return s +} diff --git a/sql/parse/parse_test.go b/sql/parse/parse_test.go index 960e7857c..b66dca943 100644 --- a/sql/parse/parse_test.go +++ b/sql/parse/parse_test.go @@ -1,17 +1,21 @@ package parse import ( + "math" "testing" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/expression/function/aggregation" + "github.com/src-d/go-mysql-server/sql/plan" + "gopkg.in/src-d/go-errors.v1" + + "github.com/src-d/go-mysql-server/sql" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" - "gopkg.in/src-d/go-mysql-server.v0/sql/plan" ) var fixtures = map[string]sql.Node{ - `CREATE TABLE t1(a INTEGER, b TEXT, c DATE, d TIMESTAMP, e VARCHAR(20), f BLOB NOT NULL)`: plan.NewCreateTable( - &sql.UnresolvedDatabase{}, + `CREATE TABLE t1(a INTEGER, b TEXT, c DATE, d TIMESTAMP, e VARCHAR(20), f BLOB NOT NULL, g DATETIME, h CHAR(40))`: plan.NewCreateTable( + sql.UnresolvedDatabase(""), "t1", sql.Schema{{ Name: "a", @@ -37,24 +41,96 @@ var fixtures = map[string]sql.Node{ Name: "f", Type: sql.Blob, Nullable: false, + }, { + Name: "g", + Type: sql.Datetime, + Nullable: true, + }, { + Name: "h", + Type: sql.Text, + Nullable: true, + }}, + ), + `CREATE TABLE t1(a INTEGER NOT NULL PRIMARY KEY, b TEXT)`: plan.NewCreateTable( + sql.UnresolvedDatabase(""), + "t1", + sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }, { + Name: "b", + Type: sql.Text, + Nullable: true, + PrimaryKey: false, + }}, + ), + `CREATE TABLE t1(a INTEGER, b TEXT, PRIMARY KEY (a))`: plan.NewCreateTable( + sql.UnresolvedDatabase(""), + "t1", + sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: true, + PrimaryKey: true, + }, { + Name: "b", + Type: sql.Text, + Nullable: true, + PrimaryKey: false, + }}, + ), + `CREATE TABLE t1(a INTEGER, b TEXT, PRIMARY KEY (a, b))`: plan.NewCreateTable( + sql.UnresolvedDatabase(""), + "t1", + sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: true, + PrimaryKey: true, + }, { + Name: "b", + Type: sql.Text, + Nullable: true, + PrimaryKey: true, }}, ), + `DROP TABLE foo;`: plan.NewDropTable( + sql.UnresolvedDatabase(""), false, "foo", + ), + `DROP TABLE IF EXISTS foo;`: plan.NewDropTable( + sql.UnresolvedDatabase(""), true, "foo", + ), + `DROP TABLE IF EXISTS foo, bar, baz;`: plan.NewDropTable( + sql.UnresolvedDatabase(""), true, "foo", "bar", "baz", + ), `DESCRIBE TABLE foo;`: plan.NewDescribe( - plan.NewUnresolvedTable("foo"), + plan.NewUnresolvedTable("foo", ""), + ), + `DESC TABLE foo;`: plan.NewDescribe( + plan.NewUnresolvedTable("foo", ""), ), `SELECT foo, bar FROM foo;`: plan.NewProject( []sql.Expression{ expression.NewUnresolvedColumn("foo"), expression.NewUnresolvedColumn("bar"), }, - plan.NewUnresolvedTable("foo"), + plan.NewUnresolvedTable("foo", ""), ), `SELECT foo IS NULL, bar IS NOT NULL FROM foo;`: plan.NewProject( []sql.Expression{ expression.NewIsNull(expression.NewUnresolvedColumn("foo")), expression.NewNot(expression.NewIsNull(expression.NewUnresolvedColumn("bar"))), }, - plan.NewUnresolvedTable("foo"), + plan.NewUnresolvedTable("foo", ""), + ), + `SELECT foo IS TRUE, bar IS NOT FALSE FROM foo;`: plan.NewProject( + []sql.Expression{ + expression.NewIsTrue(expression.NewUnresolvedColumn("foo")), + expression.NewNot(expression.NewIsFalse(expression.NewUnresolvedColumn("bar"))), + }, + plan.NewUnresolvedTable("foo", ""), ), `SELECT foo AS bar FROM foo;`: plan.NewProject( []sql.Expression{ @@ -63,7 +139,7 @@ var fixtures = map[string]sql.Node{ "bar", ), }, - plan.NewUnresolvedTable("foo"), + plan.NewUnresolvedTable("foo", ""), ), `SELECT foo, bar FROM foo WHERE foo = bar;`: plan.NewProject( []sql.Expression{ @@ -75,7 +151,7 @@ var fixtures = map[string]sql.Node{ expression.NewUnresolvedColumn("foo"), expression.NewUnresolvedColumn("bar"), ), - plan.NewUnresolvedTable("foo"), + plan.NewUnresolvedTable("foo", ""), ), ), `SELECT foo, bar FROM foo WHERE foo = 'bar';`: plan.NewProject( @@ -88,7 +164,7 @@ var fixtures = map[string]sql.Node{ expression.NewUnresolvedColumn("foo"), expression.NewLiteral("bar", sql.Text), ), - plan.NewUnresolvedTable("foo"), + plan.NewUnresolvedTable("foo", ""), ), ), `SELECT * FROM foo WHERE foo != 'bar';`: plan.NewProject( @@ -100,7 +176,7 @@ var fixtures = map[string]sql.Node{ expression.NewUnresolvedColumn("foo"), expression.NewLiteral("bar", sql.Text), )), - plan.NewUnresolvedTable("foo"), + plan.NewUnresolvedTable("foo", ""), ), ), `SELECT foo, bar FROM foo LIMIT 10;`: plan.NewLimit(10, @@ -109,17 +185,17 @@ var fixtures = map[string]sql.Node{ expression.NewUnresolvedColumn("foo"), expression.NewUnresolvedColumn("bar"), }, - plan.NewUnresolvedTable("foo"), + plan.NewUnresolvedTable("foo", ""), ), ), - `SELECT foo, bar FROM foo ORDER BY baz DESC;`: plan.NewProject( - []sql.Expression{ - expression.NewUnresolvedColumn("foo"), - expression.NewUnresolvedColumn("bar"), - }, - plan.NewSort( - []plan.SortField{{Column: expression.NewUnresolvedColumn("baz"), Order: plan.Descending, NullOrdering: plan.NullsFirst}}, - plan.NewUnresolvedTable("foo"), + `SELECT foo, bar FROM foo ORDER BY baz DESC;`: plan.NewSort( + []plan.SortField{{Column: expression.NewUnresolvedColumn("baz"), Order: plan.Descending, NullOrdering: plan.NullsFirst}}, + plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedColumn("foo"), + expression.NewUnresolvedColumn("bar"), + }, + plan.NewUnresolvedTable("foo", ""), ), ), `SELECT foo, bar FROM foo WHERE foo = bar LIMIT 10;`: plan.NewLimit(10, @@ -133,36 +209,36 @@ var fixtures = map[string]sql.Node{ expression.NewUnresolvedColumn("foo"), expression.NewUnresolvedColumn("bar"), ), - plan.NewUnresolvedTable("foo"), + plan.NewUnresolvedTable("foo", ""), ), ), ), `SELECT foo, bar FROM foo ORDER BY baz DESC LIMIT 1;`: plan.NewLimit(1, - plan.NewProject( - []sql.Expression{ - expression.NewUnresolvedColumn("foo"), - expression.NewUnresolvedColumn("bar"), - }, - plan.NewSort( - []plan.SortField{{Column: expression.NewUnresolvedColumn("baz"), Order: plan.Descending, NullOrdering: plan.NullsFirst}}, - plan.NewUnresolvedTable("foo"), + plan.NewSort( + []plan.SortField{{Column: expression.NewUnresolvedColumn("baz"), Order: plan.Descending, NullOrdering: plan.NullsFirst}}, + plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedColumn("foo"), + expression.NewUnresolvedColumn("bar"), + }, + plan.NewUnresolvedTable("foo", ""), ), ), ), `SELECT foo, bar FROM foo WHERE qux = 1 ORDER BY baz DESC LIMIT 1;`: plan.NewLimit(1, - plan.NewProject( - []sql.Expression{ - expression.NewUnresolvedColumn("foo"), - expression.NewUnresolvedColumn("bar"), - }, - plan.NewSort( - []plan.SortField{{Column: expression.NewUnresolvedColumn("baz"), Order: plan.Descending, NullOrdering: plan.NullsFirst}}, + plan.NewSort( + []plan.SortField{{Column: expression.NewUnresolvedColumn("baz"), Order: plan.Descending, NullOrdering: plan.NullsFirst}}, + plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedColumn("foo"), + expression.NewUnresolvedColumn("bar"), + }, plan.NewFilter( expression.NewEquals( expression.NewUnresolvedColumn("qux"), - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), ), - plan.NewUnresolvedTable("foo"), + plan.NewUnresolvedTable("foo", ""), ), ), ), @@ -173,8 +249,8 @@ var fixtures = map[string]sql.Node{ expression.NewUnresolvedColumn("bar"), }, plan.NewCrossJoin( - plan.NewUnresolvedTable("t1"), - plan.NewUnresolvedTable("t2"), + plan.NewUnresolvedTable("t1", ""), + plan.NewUnresolvedTable("t2", ""), ), ), `SELECT foo, bar FROM t1 GROUP BY foo, bar;`: plan.NewGroupBy( @@ -186,7 +262,18 @@ var fixtures = map[string]sql.Node{ expression.NewUnresolvedColumn("foo"), expression.NewUnresolvedColumn("bar"), }, - plan.NewUnresolvedTable("t1"), + plan.NewUnresolvedTable("t1", ""), + ), + `SELECT foo, bar FROM t1 GROUP BY 1, 2;`: plan.NewGroupBy( + []sql.Expression{ + expression.NewUnresolvedColumn("foo"), + expression.NewUnresolvedColumn("bar"), + }, + []sql.Expression{ + expression.NewUnresolvedColumn("foo"), + expression.NewUnresolvedColumn("bar"), + }, + plan.NewUnresolvedTable("t1", ""), ), `SELECT COUNT(*) FROM t1;`: plan.NewGroupBy( []sql.Expression{ @@ -194,7 +281,7 @@ var fixtures = map[string]sql.Node{ expression.NewStar()), }, []sql.Expression{}, - plan.NewUnresolvedTable("t1"), + plan.NewUnresolvedTable("t1", ""), ), `SELECT a FROM t1 where a regexp '.*test.*';`: plan.NewProject( []sql.Expression{ @@ -205,7 +292,7 @@ var fixtures = map[string]sql.Node{ expression.NewUnresolvedColumn("a"), expression.NewLiteral(".*test.*", sql.Text), ), - plan.NewUnresolvedTable("t1"), + plan.NewUnresolvedTable("t1", ""), ), ), `SELECT a FROM t1 where a not regexp '.*test.*';`: plan.NewProject( @@ -219,40 +306,106 @@ var fixtures = map[string]sql.Node{ expression.NewLiteral(".*test.*", sql.Text), ), ), - plan.NewUnresolvedTable("t1"), + plan.NewUnresolvedTable("t1", ""), ), ), `INSERT INTO t1 (col1, col2) VALUES ('a', 1)`: plan.NewInsertInto( - plan.NewUnresolvedTable("t1"), + plan.NewUnresolvedTable("t1", ""), + plan.NewValues([][]sql.Expression{{ + expression.NewLiteral("a", sql.Text), + expression.NewLiteral(int8(1), sql.Int8), + }}), + false, + []string{"col1", "col2"}, + ), + `REPLACE INTO t1 (col1, col2) VALUES ('a', 1)`: plan.NewInsertInto( + plan.NewUnresolvedTable("t1", ""), plan.NewValues([][]sql.Expression{{ expression.NewLiteral("a", sql.Text), - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), }}), + true, []string{"col1", "col2"}, ), - `SHOW TABLES`: plan.NewShowTables(&sql.UnresolvedDatabase{}), + `SHOW TABLES`: plan.NewShowTables(sql.UnresolvedDatabase(""), false), + `SHOW FULL TABLES`: plan.NewShowTables(sql.UnresolvedDatabase(""), true), + `SHOW TABLES FROM foo`: plan.NewShowTables(sql.UnresolvedDatabase("foo"), false), + `SHOW TABLES IN foo`: plan.NewShowTables(sql.UnresolvedDatabase("foo"), false), + `SHOW FULL TABLES FROM foo`: plan.NewShowTables(sql.UnresolvedDatabase("foo"), true), + `SHOW FULL TABLES IN foo`: plan.NewShowTables(sql.UnresolvedDatabase("foo"), true), + `SHOW TABLES LIKE 'foo'`: plan.NewFilter( + expression.NewLike( + expression.NewUnresolvedColumn("Table"), + expression.NewLiteral("foo", sql.Text), + ), + plan.NewShowTables(sql.UnresolvedDatabase(""), false), + ), + "SHOW TABLES WHERE `Table` = 'foo'": plan.NewFilter( + expression.NewEquals( + expression.NewUnresolvedColumn("Table"), + expression.NewLiteral("foo", sql.Text), + ), + plan.NewShowTables(sql.UnresolvedDatabase(""), false), + ), + `SHOW FULL TABLES LIKE 'foo'`: plan.NewFilter( + expression.NewLike( + expression.NewUnresolvedColumn("Table"), + expression.NewLiteral("foo", sql.Text), + ), + plan.NewShowTables(sql.UnresolvedDatabase(""), true), + ), + "SHOW FULL TABLES WHERE `Table` = 'foo'": plan.NewFilter( + expression.NewEquals( + expression.NewUnresolvedColumn("Table"), + expression.NewLiteral("foo", sql.Text), + ), + plan.NewShowTables(sql.UnresolvedDatabase(""), true), + ), + `SHOW FULL TABLES FROM bar LIKE 'foo'`: plan.NewFilter( + expression.NewLike( + expression.NewUnresolvedColumn("Table"), + expression.NewLiteral("foo", sql.Text), + ), + plan.NewShowTables(sql.UnresolvedDatabase("bar"), true), + ), + "SHOW FULL TABLES FROM bar WHERE `Table` = 'foo'": plan.NewFilter( + expression.NewEquals( + expression.NewUnresolvedColumn("Table"), + expression.NewLiteral("foo", sql.Text), + ), + plan.NewShowTables(sql.UnresolvedDatabase("bar"), true), + ), `SELECT DISTINCT foo, bar FROM foo;`: plan.NewDistinct( plan.NewProject( []sql.Expression{ expression.NewUnresolvedColumn("foo"), expression.NewUnresolvedColumn("bar"), }, - plan.NewUnresolvedTable("foo"), + plan.NewUnresolvedTable("foo", ""), ), ), `SELECT * FROM foo`: plan.NewProject( []sql.Expression{ expression.NewStar(), }, - plan.NewUnresolvedTable("foo"), + plan.NewUnresolvedTable("foo", ""), + ), + `SELECT foo, bar FROM foo LIMIT 2 OFFSET 5;`: plan.NewLimit(2, + plan.NewOffset(5, plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedColumn("foo"), + expression.NewUnresolvedColumn("bar"), + }, + plan.NewUnresolvedTable("foo", ""), + )), ), - `SELECT foo, bar FROM foo LIMIT 2 OFFSET 5;`: plan.NewOffset(5, - plan.NewLimit(2, plan.NewProject( + `SELECT foo, bar FROM foo LIMIT 5,2;`: plan.NewLimit(2, + plan.NewOffset(5, plan.NewProject( []sql.Expression{ expression.NewUnresolvedColumn("foo"), expression.NewUnresolvedColumn("bar"), }, - plan.NewUnresolvedTable("foo"), + plan.NewUnresolvedTable("foo", ""), )), ), `SELECT * FROM foo WHERE (a = 1)`: plan.NewProject( @@ -262,9 +415,9 @@ var fixtures = map[string]sql.Node{ plan.NewFilter( expression.NewEquals( expression.NewUnresolvedColumn("a"), - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), ), - plan.NewUnresolvedTable("foo"), + plan.NewUnresolvedTable("foo", ""), ), ), `SELECT * FROM foo, bar, baz, qux`: plan.NewProject( @@ -272,12 +425,12 @@ var fixtures = map[string]sql.Node{ plan.NewCrossJoin( plan.NewCrossJoin( plan.NewCrossJoin( - plan.NewUnresolvedTable("foo"), - plan.NewUnresolvedTable("bar"), + plan.NewUnresolvedTable("foo", ""), + plan.NewUnresolvedTable("bar", ""), ), - plan.NewUnresolvedTable("baz"), + plan.NewUnresolvedTable("baz", ""), ), - plan.NewUnresolvedTable("qux"), + plan.NewUnresolvedTable("qux", ""), ), ), `SELECT * FROM foo WHERE a = b AND c = d`: plan.NewProject( @@ -293,7 +446,7 @@ var fixtures = map[string]sql.Node{ expression.NewUnresolvedColumn("d"), ), ), - plan.NewUnresolvedTable("foo"), + plan.NewUnresolvedTable("foo", ""), ), ), `SELECT * FROM foo WHERE a = b OR c = d`: plan.NewProject( @@ -309,14 +462,14 @@ var fixtures = map[string]sql.Node{ expression.NewUnresolvedColumn("d"), ), ), - plan.NewUnresolvedTable("foo"), + plan.NewUnresolvedTable("foo", ""), ), ), `SELECT * FROM foo as bar`: plan.NewProject( []sql.Expression{expression.NewStar()}, plan.NewTableAlias( "bar", - plan.NewUnresolvedTable("foo"), + plan.NewUnresolvedTable("foo", ""), ), ), `SELECT * FROM (SELECT * FROM foo) AS bar`: plan.NewProject( @@ -325,7 +478,7 @@ var fixtures = map[string]sql.Node{ "bar", plan.NewProject( []sql.Expression{expression.NewStar()}, - plan.NewUnresolvedTable("foo"), + plan.NewUnresolvedTable("foo", ""), ), ), ), @@ -334,36 +487,36 @@ var fixtures = map[string]sql.Node{ plan.NewFilter( expression.NewNot( expression.NewBetween( - expression.NewLiteral(int64(1), sql.Int64), - expression.NewLiteral(int64(2), sql.Int64), - expression.NewLiteral(int64(5), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), + expression.NewLiteral(int8(2), sql.Int8), + expression.NewLiteral(int8(5), sql.Int8), ), ), - plan.NewUnresolvedTable("foo"), + plan.NewUnresolvedTable("foo", ""), ), ), `SELECT * FROM foo WHERE 1 BETWEEN 2 AND 5`: plan.NewProject( []sql.Expression{expression.NewStar()}, plan.NewFilter( expression.NewBetween( - expression.NewLiteral(int64(1), sql.Int64), - expression.NewLiteral(int64(2), sql.Int64), - expression.NewLiteral(int64(5), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), + expression.NewLiteral(int8(2), sql.Int8), + expression.NewLiteral(int8(5), sql.Int8), ), - plan.NewUnresolvedTable("foo"), + plan.NewUnresolvedTable("foo", ""), ), ), `SELECT 0x01AF`: plan.NewProject( []sql.Expression{ - expression.NewLiteral(int64(431), sql.Int64), + expression.NewLiteral(int16(431), sql.Int16), }, - plan.NewUnresolvedTable("dual"), + plan.NewUnresolvedTable("dual", ""), ), `SELECT X'41'`: plan.NewProject( []sql.Expression{ expression.NewLiteral([]byte{'A'}, sql.Blob), }, - plan.NewUnresolvedTable("dual"), + plan.NewUnresolvedTable("dual", ""), ), `SELECT * FROM b WHERE SOMEFUNC((1, 2), (3, 4))`: plan.NewProject( []sql.Expression{expression.NewStar()}, @@ -372,15 +525,15 @@ var fixtures = map[string]sql.Node{ "somefunc", false, expression.NewTuple( - expression.NewLiteral(int64(1), sql.Int64), - expression.NewLiteral(int64(2), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), + expression.NewLiteral(int8(2), sql.Int8), ), expression.NewTuple( - expression.NewLiteral(int64(3), sql.Int64), - expression.NewLiteral(int64(4), sql.Int64), + expression.NewLiteral(int8(3), sql.Int8), + expression.NewLiteral(int8(4), sql.Int8), ), ), - plan.NewUnresolvedTable("b"), + plan.NewUnresolvedTable("b", ""), ), ), `SELECT * FROM foo WHERE :foo_id = 2`: plan.NewProject( @@ -388,16 +541,16 @@ var fixtures = map[string]sql.Node{ plan.NewFilter( expression.NewEquals( expression.NewLiteral(":foo_id", sql.Text), - expression.NewLiteral(int64(2), sql.Int64), + expression.NewLiteral(int8(2), sql.Int8), ), - plan.NewUnresolvedTable("foo"), + plan.NewUnresolvedTable("foo", ""), ), ), `SELECT * FROM foo INNER JOIN bar ON a = b`: plan.NewProject( []sql.Expression{expression.NewStar()}, plan.NewInnerJoin( - plan.NewUnresolvedTable("foo"), - plan.NewUnresolvedTable("bar"), + plan.NewUnresolvedTable("foo", ""), + plan.NewUnresolvedTable("bar", ""), expression.NewEquals( expression.NewUnresolvedColumn("a"), expression.NewUnresolvedColumn("b"), @@ -408,40 +561,40 @@ var fixtures = map[string]sql.Node{ []sql.Expression{ expression.NewUnresolvedQualifiedColumn("foo", "a"), }, - plan.NewUnresolvedTable("foo"), + plan.NewUnresolvedTable("foo", ""), ), `SELECT CAST(-3 AS UNSIGNED) FROM foo`: plan.NewProject( []sql.Expression{ - expression.NewConvert(expression.NewLiteral(int64(-3), sql.Int64), expression.ConvertToUnsigned), + expression.NewConvert(expression.NewLiteral(int8(-3), sql.Int8), expression.ConvertToUnsigned), }, - plan.NewUnresolvedTable("foo"), + plan.NewUnresolvedTable("foo", ""), ), `SELECT 2 = 2 FROM foo`: plan.NewProject( []sql.Expression{ - expression.NewEquals(expression.NewLiteral(int64(2), sql.Int64), expression.NewLiteral(int64(2), sql.Int64)), + expression.NewEquals(expression.NewLiteral(int8(2), sql.Int8), expression.NewLiteral(int8(2), sql.Int8)), }, - plan.NewUnresolvedTable("foo"), + plan.NewUnresolvedTable("foo", ""), ), `SELECT *, bar FROM foo`: plan.NewProject( []sql.Expression{ expression.NewStar(), expression.NewUnresolvedColumn("bar"), }, - plan.NewUnresolvedTable("foo"), + plan.NewUnresolvedTable("foo", ""), ), `SELECT *, foo.* FROM foo`: plan.NewProject( []sql.Expression{ expression.NewStar(), expression.NewQualifiedStar("foo"), }, - plan.NewUnresolvedTable("foo"), + plan.NewUnresolvedTable("foo", ""), ), `SELECT bar, foo.* FROM foo`: plan.NewProject( []sql.Expression{ expression.NewUnresolvedColumn("bar"), expression.NewQualifiedStar("foo"), }, - plan.NewUnresolvedTable("foo"), + plan.NewUnresolvedTable("foo", ""), ), `SELECT bar, *, foo.* FROM foo`: plan.NewProject( []sql.Expression{ @@ -449,91 +602,91 @@ var fixtures = map[string]sql.Node{ expression.NewStar(), expression.NewQualifiedStar("foo"), }, - plan.NewUnresolvedTable("foo"), + plan.NewUnresolvedTable("foo", ""), ), `SELECT *, * FROM foo`: plan.NewProject( []sql.Expression{ expression.NewStar(), expression.NewStar(), }, - plan.NewUnresolvedTable("foo"), + plan.NewUnresolvedTable("foo", ""), ), `SELECT * FROM foo WHERE 1 IN ('1', 2)`: plan.NewProject( []sql.Expression{expression.NewStar()}, plan.NewFilter( expression.NewIn( - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), expression.NewTuple( expression.NewLiteral("1", sql.Text), - expression.NewLiteral(int64(2), sql.Int64), + expression.NewLiteral(int8(2), sql.Int8), ), ), - plan.NewUnresolvedTable("foo"), + plan.NewUnresolvedTable("foo", ""), ), ), `SELECT * FROM foo WHERE 1 NOT IN ('1', 2)`: plan.NewProject( []sql.Expression{expression.NewStar()}, plan.NewFilter( expression.NewNotIn( - expression.NewLiteral(int64(1), sql.Int64), + expression.NewLiteral(int8(1), sql.Int8), expression.NewTuple( expression.NewLiteral("1", sql.Text), - expression.NewLiteral(int64(2), sql.Int64), + expression.NewLiteral(int8(2), sql.Int8), ), ), - plan.NewUnresolvedTable("foo"), + plan.NewUnresolvedTable("foo", ""), ), ), - `SELECT a, b FROM t ORDER BY 2, 1`: plan.NewProject( - []sql.Expression{ - expression.NewUnresolvedColumn("a"), - expression.NewUnresolvedColumn("b"), + `SELECT a, b FROM t ORDER BY 2, 1`: plan.NewSort( + []plan.SortField{ + { + Column: expression.NewLiteral(int8(2), sql.Int8), + Order: plan.Ascending, + NullOrdering: plan.NullsFirst, + }, + { + Column: expression.NewLiteral(int8(1), sql.Int8), + Order: plan.Ascending, + NullOrdering: plan.NullsFirst, + }, }, - plan.NewSort( - []plan.SortField{ - { - Column: expression.NewLiteral(int64(2), sql.Int64), - Order: plan.Ascending, - NullOrdering: plan.NullsFirst, - }, - { - Column: expression.NewLiteral(int64(1), sql.Int64), - Order: plan.Ascending, - NullOrdering: plan.NullsFirst, - }, + plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedColumn("a"), + expression.NewUnresolvedColumn("b"), }, - plan.NewUnresolvedTable("t"), + plan.NewUnresolvedTable("t", ""), ), ), `SELECT 1 + 1;`: plan.NewProject( []sql.Expression{ - expression.NewPlus(expression.NewLiteral(int64(1), sql.Int64), expression.NewLiteral(int64(1), sql.Int64)), + expression.NewPlus(expression.NewLiteral(int8(1), sql.Int8), expression.NewLiteral(int8(1), sql.Int8)), }, - plan.NewUnresolvedTable("dual"), + plan.NewUnresolvedTable("dual", ""), ), `SELECT 1 * (2 + 1);`: plan.NewProject( []sql.Expression{ - expression.NewMult(expression.NewLiteral(int64(1), sql.Int64), - expression.NewPlus(expression.NewLiteral(int64(2), sql.Int64), expression.NewLiteral(int64(1), sql.Int64))), + expression.NewMult(expression.NewLiteral(int8(1), sql.Int8), + expression.NewPlus(expression.NewLiteral(int8(2), sql.Int8), expression.NewLiteral(int8(1), sql.Int8))), }, - plan.NewUnresolvedTable("dual"), + plan.NewUnresolvedTable("dual", ""), ), `SELECT (0 - 1) * (1 | 1);`: plan.NewProject( []sql.Expression{ expression.NewMult( - expression.NewMinus(expression.NewLiteral(int64(0), sql.Int64), expression.NewLiteral(int64(1), sql.Int64)), - expression.NewBitOr(expression.NewLiteral(int64(1), sql.Int64), expression.NewLiteral(int64(1), sql.Int64)), + expression.NewMinus(expression.NewLiteral(int8(0), sql.Int8), expression.NewLiteral(int8(1), sql.Int8)), + expression.NewBitOr(expression.NewLiteral(int8(1), sql.Int8), expression.NewLiteral(int8(1), sql.Int8)), ), }, - plan.NewUnresolvedTable("dual"), + plan.NewUnresolvedTable("dual", ""), ), `SELECT (1 << 3) % (2 div 1);`: plan.NewProject( []sql.Expression{ expression.NewMod( - expression.NewShiftLeft(expression.NewLiteral(int64(1), sql.Int64), expression.NewLiteral(int64(3), sql.Int64)), - expression.NewIntDiv(expression.NewLiteral(int64(2), sql.Int64), expression.NewLiteral(int64(1), sql.Int64))), + expression.NewShiftLeft(expression.NewLiteral(int8(1), sql.Int8), expression.NewLiteral(int8(3), sql.Int8)), + expression.NewIntDiv(expression.NewLiteral(int8(2), sql.Int8), expression.NewLiteral(int8(1), sql.Int8))), }, - plan.NewUnresolvedTable("dual"), + plan.NewUnresolvedTable("dual", ""), ), `SELECT 1.0 * a + 2.0 * b FROM t;`: plan.NewProject( []sql.Expression{ @@ -542,15 +695,15 @@ var fixtures = map[string]sql.Node{ expression.NewMult(expression.NewLiteral(float64(2.0), sql.Float64), expression.NewUnresolvedColumn("b")), ), }, - plan.NewUnresolvedTable("t"), + plan.NewUnresolvedTable("t", ""), ), `SELECT '1.0' + 2;`: plan.NewProject( []sql.Expression{ expression.NewPlus( - expression.NewLiteral("1.0", sql.Text), expression.NewLiteral(int64(2), sql.Int64), + expression.NewLiteral("1.0", sql.Text), expression.NewLiteral(int8(2), sql.Int8), ), }, - plan.NewUnresolvedTable("dual"), + plan.NewUnresolvedTable("dual", ""), ), `SELECT '1' + '2';`: plan.NewProject( []sql.Expression{ @@ -558,19 +711,558 @@ var fixtures = map[string]sql.Node{ expression.NewLiteral("1", sql.Text), expression.NewLiteral("2", sql.Text), ), }, - plan.NewUnresolvedTable("dual"), + plan.NewUnresolvedTable("dual", ""), + ), + `CREATE INDEX idx ON foo USING bar (fn(bar, baz))`: plan.NewCreateIndex( + "idx", + plan.NewUnresolvedTable("foo", ""), + []sql.Expression{expression.NewUnresolvedFunction( + "fn", false, + expression.NewUnresolvedColumn("bar"), + expression.NewUnresolvedColumn("baz"), + )}, + "bar", + make(map[string]string), ), - `CREATE INDEX idx ON foo(fn(bar, baz))`: plan.NewCreateIndex( + ` CREATE INDEX idx ON foo USING bar (fn(bar, baz))`: plan.NewCreateIndex( "idx", - plan.NewUnresolvedTable("foo"), + plan.NewUnresolvedTable("foo", ""), []sql.Expression{expression.NewUnresolvedFunction( "fn", false, expression.NewUnresolvedColumn("bar"), expression.NewUnresolvedColumn("baz"), )}, - "", + "bar", make(map[string]string), ), + `SELECT * FROM foo NATURAL JOIN bar`: plan.NewProject( + []sql.Expression{expression.NewStar()}, + plan.NewNaturalJoin( + plan.NewUnresolvedTable("foo", ""), + plan.NewUnresolvedTable("bar", ""), + ), + ), + `SELECT * FROM foo NATURAL JOIN bar NATURAL JOIN baz`: plan.NewProject( + []sql.Expression{expression.NewStar()}, + plan.NewNaturalJoin( + plan.NewNaturalJoin( + plan.NewUnresolvedTable("foo", ""), + plan.NewUnresolvedTable("bar", ""), + ), + plan.NewUnresolvedTable("baz", ""), + ), + ), + `DROP INDEX foo ON bar`: plan.NewDropIndex( + "foo", + plan.NewUnresolvedTable("bar", ""), + ), + `DESCRIBE FORMAT=TREE SELECT * FROM foo`: plan.NewDescribeQuery( + "tree", + plan.NewProject( + []sql.Expression{expression.NewStar()}, + plan.NewUnresolvedTable("foo", ""), + ), + ), + `SELECT MAX(i)/2 FROM foo`: plan.NewGroupBy( + []sql.Expression{ + expression.NewArithmetic( + expression.NewUnresolvedFunction( + "max", true, expression.NewUnresolvedColumn("i"), + ), + expression.NewLiteral(int8(2), sql.Int8), + "/", + ), + }, + []sql.Expression{}, + plan.NewUnresolvedTable("foo", ""), + ), + `SHOW INDEXES FROM foo`: plan.NewShowIndexes(sql.UnresolvedDatabase(""), "foo", nil), + `SHOW INDEX FROM foo`: plan.NewShowIndexes(sql.UnresolvedDatabase(""), "foo", nil), + `SHOW KEYS FROM foo`: plan.NewShowIndexes(sql.UnresolvedDatabase(""), "foo", nil), + `SHOW INDEXES IN foo`: plan.NewShowIndexes(sql.UnresolvedDatabase(""), "foo", nil), + `SHOW INDEX IN foo`: plan.NewShowIndexes(sql.UnresolvedDatabase(""), "foo", nil), + `SHOW KEYS IN foo`: plan.NewShowIndexes(sql.UnresolvedDatabase(""), "foo", nil), + `create index foo on bar using qux (baz)`: plan.NewCreateIndex( + "foo", + plan.NewUnresolvedTable("bar", ""), + []sql.Expression{expression.NewUnresolvedColumn("baz")}, + "qux", + make(map[string]string), + ), + `SHOW FULL PROCESSLIST`: plan.NewShowProcessList(), + `SHOW PROCESSLIST`: plan.NewShowProcessList(), + `SELECT @@allowed_max_packet`: plan.NewProject([]sql.Expression{ + expression.NewUnresolvedColumn("@@allowed_max_packet"), + }, plan.NewUnresolvedTable("dual", "")), + `SET autocommit=1, foo="bar"`: plan.NewSet( + plan.SetVariable{ + Name: "autocommit", + Value: expression.NewLiteral(int8(1), sql.Int8), + }, + plan.SetVariable{ + Name: "foo", + Value: expression.NewLiteral("bar", sql.Text), + }, + ), + `SET @@session.autocommit=1, foo="bar"`: plan.NewSet( + plan.SetVariable{ + Name: "@@session.autocommit", + Value: expression.NewLiteral(int8(1), sql.Int8), + }, + plan.SetVariable{ + Name: "foo", + Value: expression.NewLiteral("bar", sql.Text), + }, + ), + `SET autocommit=ON, on="1"`: plan.NewSet( + plan.SetVariable{ + Name: "autocommit", + Value: expression.NewLiteral(int64(1), sql.Int64), + }, + plan.SetVariable{ + Name: "on", + Value: expression.NewLiteral("1", sql.Text), + }, + ), + `SET @@session.autocommit=OFF, off="0"`: plan.NewSet( + plan.SetVariable{ + Name: "@@session.autocommit", + Value: expression.NewLiteral(int64(0), sql.Int64), + }, + plan.SetVariable{ + Name: "off", + Value: expression.NewLiteral("0", sql.Text), + }, + ), + `SET @@session.autocommit=ON`: plan.NewSet( + plan.SetVariable{ + Name: "@@session.autocommit", + Value: expression.NewLiteral(int64(1), sql.Int64), + }, + ), + `SET autocommit=off`: plan.NewSet( + plan.SetVariable{ + Name: "autocommit", + Value: expression.NewLiteral(int64(0), sql.Int64), + }, + ), + `SET autocommit=true`: plan.NewSet( + plan.SetVariable{ + Name: "autocommit", + Value: expression.NewLiteral(true, sql.Boolean), + }, + ), + `SET autocommit="true"`: plan.NewSet( + plan.SetVariable{ + Name: "autocommit", + Value: expression.NewLiteral(true, sql.Boolean), + }, + ), + `SET autocommit=false`: plan.NewSet( + plan.SetVariable{ + Name: "autocommit", + Value: expression.NewLiteral(false, sql.Boolean), + }, + ), + `SET autocommit="false"`: plan.NewSet( + plan.SetVariable{ + Name: "autocommit", + Value: expression.NewLiteral(false, sql.Boolean), + }, + ), + `SET SESSION NET_READ_TIMEOUT= 700, SESSION NET_WRITE_TIMEOUT= 700`: plan.NewSet( + plan.SetVariable{ + Name: "@@session.net_read_timeout", + Value: expression.NewLiteral(int16(700), sql.Int16), + }, + plan.SetVariable{ + Name: "@@session.net_write_timeout", + Value: expression.NewLiteral(int16(700), sql.Int16), + }, + ), + `SET gtid_mode=DEFAULT`: plan.NewSet( + plan.SetVariable{ + Name: "gtid_mode", + Value: expression.NewDefaultColumn(""), + }, + ), + `SET @@sql_select_limit=default`: plan.NewSet( + plan.SetVariable{ + Name: "@@sql_select_limit", + Value: expression.NewDefaultColumn(""), + }, + ), + `/*!40101 SET NAMES utf8 */`: plan.Nothing, + `SELECT /*!40101 SET NAMES utf8 */ * FROM foo`: plan.NewProject( + []sql.Expression{ + expression.NewStar(), + }, + plan.NewUnresolvedTable("foo", ""), + ), + `SHOW DATABASES`: plan.NewShowDatabases(), + `SELECT * FROM foo WHERE i LIKE 'foo'`: plan.NewProject( + []sql.Expression{expression.NewStar()}, + plan.NewFilter( + expression.NewLike( + expression.NewUnresolvedColumn("i"), + expression.NewLiteral("foo", sql.Text), + ), + plan.NewUnresolvedTable("foo", ""), + ), + ), + `SELECT * FROM foo WHERE i NOT LIKE 'foo'`: plan.NewProject( + []sql.Expression{expression.NewStar()}, + plan.NewFilter( + expression.NewNot(expression.NewLike( + expression.NewUnresolvedColumn("i"), + expression.NewLiteral("foo", sql.Text), + )), + plan.NewUnresolvedTable("foo", ""), + ), + ), + `SHOW FIELDS FROM foo`: plan.NewShowColumns(false, plan.NewUnresolvedTable("foo", "")), + `SHOW FULL COLUMNS FROM foo`: plan.NewShowColumns(true, plan.NewUnresolvedTable("foo", "")), + `SHOW FIELDS FROM foo WHERE Field = 'bar'`: plan.NewFilter( + expression.NewEquals( + expression.NewUnresolvedColumn("Field"), + expression.NewLiteral("bar", sql.Text), + ), + plan.NewShowColumns(false, plan.NewUnresolvedTable("foo", "")), + ), + `SHOW FIELDS FROM foo LIKE 'bar'`: plan.NewFilter( + expression.NewLike( + expression.NewUnresolvedColumn("Field"), + expression.NewLiteral("bar", sql.Text), + ), + plan.NewShowColumns(false, plan.NewUnresolvedTable("foo", "")), + ), + `SHOW TABLE STATUS LIKE 'foo'`: plan.NewFilter( + expression.NewLike( + expression.NewUnresolvedColumn("Name"), + expression.NewLiteral("foo", sql.Text), + ), + plan.NewShowTableStatus(), + ), + `SHOW TABLE STATUS FROM foo`: plan.NewShowTableStatus("foo"), + `SHOW TABLE STATUS IN foo`: plan.NewShowTableStatus("foo"), + `SHOW TABLE STATUS`: plan.NewShowTableStatus(), + `SHOW TABLE STATUS WHERE Name = 'foo'`: plan.NewFilter( + expression.NewEquals( + expression.NewUnresolvedColumn("Name"), + expression.NewLiteral("foo", sql.Text), + ), + plan.NewShowTableStatus(), + ), + `USE foo`: plan.NewUse(sql.UnresolvedDatabase("foo")), + `DESCRIBE TABLE foo.bar`: plan.NewDescribe( + plan.NewUnresolvedTable("bar", "foo"), + ), + `DESC TABLE foo.bar`: plan.NewDescribe( + plan.NewUnresolvedTable("bar", "foo"), + ), + `SELECT * FROM foo.bar`: plan.NewProject( + []sql.Expression{ + expression.NewStar(), + }, + plan.NewUnresolvedTable("bar", "foo"), + ), + `SHOW VARIABLES`: plan.NewShowVariables(sql.NewEmptyContext().GetAll(), ""), + `SHOW GLOBAL VARIABLES`: plan.NewShowVariables(sql.NewEmptyContext().GetAll(), ""), + `SHOW SESSION VARIABLES`: plan.NewShowVariables(sql.NewEmptyContext().GetAll(), ""), + `SHOW VARIABLES LIKE 'gtid_mode'`: plan.NewShowVariables(sql.NewEmptyContext().GetAll(), "gtid_mode"), + `SHOW SESSION VARIABLES LIKE 'autocommit'`: plan.NewShowVariables(sql.NewEmptyContext().GetAll(), "autocommit"), + `UNLOCK TABLES`: plan.NewUnlockTables(), + `LOCK TABLES foo READ`: plan.NewLockTables([]*plan.TableLock{ + {Table: plan.NewUnresolvedTable("foo", "")}, + }), + `LOCK TABLES foo123 READ`: plan.NewLockTables([]*plan.TableLock{ + {Table: plan.NewUnresolvedTable("foo123", "")}, + }), + `LOCK TABLES foo f READ`: plan.NewLockTables([]*plan.TableLock{ + {Table: plan.NewUnresolvedTable("foo", "")}, + }), + `LOCK TABLES foo AS f READ`: plan.NewLockTables([]*plan.TableLock{ + {Table: plan.NewUnresolvedTable("foo", "")}, + }), + `LOCK TABLES foo READ LOCAL`: plan.NewLockTables([]*plan.TableLock{ + {Table: plan.NewUnresolvedTable("foo", "")}, + }), + `LOCK TABLES foo WRITE`: plan.NewLockTables([]*plan.TableLock{ + {Table: plan.NewUnresolvedTable("foo", ""), Write: true}, + }), + `LOCK TABLES foo LOW_PRIORITY WRITE`: plan.NewLockTables([]*plan.TableLock{ + {Table: plan.NewUnresolvedTable("foo", ""), Write: true}, + }), + `LOCK TABLES foo WRITE, bar READ`: plan.NewLockTables([]*plan.TableLock{ + {Table: plan.NewUnresolvedTable("foo", ""), Write: true}, + {Table: plan.NewUnresolvedTable("bar", "")}, + }), + "LOCK TABLES `foo` WRITE, `bar` READ": plan.NewLockTables([]*plan.TableLock{ + {Table: plan.NewUnresolvedTable("foo", ""), Write: true}, + {Table: plan.NewUnresolvedTable("bar", "")}, + }), + `LOCK TABLES foo READ, bar WRITE, baz READ`: plan.NewLockTables([]*plan.TableLock{ + {Table: plan.NewUnresolvedTable("foo", "")}, + {Table: plan.NewUnresolvedTable("bar", ""), Write: true}, + {Table: plan.NewUnresolvedTable("baz", "")}, + }), + `SHOW CREATE DATABASE foo`: plan.NewShowCreateDatabase(sql.UnresolvedDatabase("foo"), false), + `SHOW CREATE SCHEMA foo`: plan.NewShowCreateDatabase(sql.UnresolvedDatabase("foo"), false), + `SHOW CREATE DATABASE IF NOT EXISTS foo`: plan.NewShowCreateDatabase(sql.UnresolvedDatabase("foo"), true), + `SHOW CREATE SCHEMA IF NOT EXISTS foo`: plan.NewShowCreateDatabase(sql.UnresolvedDatabase("foo"), true), + `SELECT -i FROM mytable`: plan.NewProject( + []sql.Expression{ + expression.NewUnaryMinus( + expression.NewUnresolvedColumn("i"), + ), + }, + plan.NewUnresolvedTable("mytable", ""), + ), + `SHOW WARNINGS`: plan.NewOffset(0, plan.ShowWarnings(sql.NewEmptyContext().Warnings())), + `SHOW WARNINGS LIMIT 10`: plan.NewLimit(10, plan.NewOffset(0, plan.ShowWarnings(sql.NewEmptyContext().Warnings()))), + `SHOW WARNINGS LIMIT 5,10`: plan.NewLimit(10, plan.NewOffset(5, plan.ShowWarnings(sql.NewEmptyContext().Warnings()))), + "SHOW CREATE DATABASE `foo`": plan.NewShowCreateDatabase(sql.UnresolvedDatabase("foo"), false), + "SHOW CREATE SCHEMA `foo`": plan.NewShowCreateDatabase(sql.UnresolvedDatabase("foo"), false), + "SHOW CREATE DATABASE IF NOT EXISTS `foo`": plan.NewShowCreateDatabase(sql.UnresolvedDatabase("foo"), true), + "SHOW CREATE SCHEMA IF NOT EXISTS `foo`": plan.NewShowCreateDatabase(sql.UnresolvedDatabase("foo"), true), + "SELECT CASE foo WHEN 1 THEN 'foo' WHEN 2 THEN 'bar' ELSE 'baz' END": plan.NewProject( + []sql.Expression{expression.NewCase( + expression.NewUnresolvedColumn("foo"), + []expression.CaseBranch{ + { + Cond: expression.NewLiteral(int8(1), sql.Int8), + Value: expression.NewLiteral("foo", sql.Text), + }, + { + Cond: expression.NewLiteral(int8(2), sql.Int8), + Value: expression.NewLiteral("bar", sql.Text), + }, + }, + expression.NewLiteral("baz", sql.Text), + )}, + plan.NewUnresolvedTable("dual", ""), + ), + "SELECT CASE foo WHEN 1 THEN 'foo' WHEN 2 THEN 'bar' END": plan.NewProject( + []sql.Expression{expression.NewCase( + expression.NewUnresolvedColumn("foo"), + []expression.CaseBranch{ + { + Cond: expression.NewLiteral(int8(1), sql.Int8), + Value: expression.NewLiteral("foo", sql.Text), + }, + { + Cond: expression.NewLiteral(int8(2), sql.Int8), + Value: expression.NewLiteral("bar", sql.Text), + }, + }, + nil, + )}, + plan.NewUnresolvedTable("dual", ""), + ), + "SELECT CASE WHEN foo = 1 THEN 'foo' WHEN foo = 2 THEN 'bar' ELSE 'baz' END": plan.NewProject( + []sql.Expression{expression.NewCase( + nil, + []expression.CaseBranch{ + { + Cond: expression.NewEquals( + expression.NewUnresolvedColumn("foo"), + expression.NewLiteral(int8(1), sql.Int8), + ), + Value: expression.NewLiteral("foo", sql.Text), + }, + { + Cond: expression.NewEquals( + expression.NewUnresolvedColumn("foo"), + expression.NewLiteral(int8(2), sql.Int8), + ), + Value: expression.NewLiteral("bar", sql.Text), + }, + }, + expression.NewLiteral("baz", sql.Text), + )}, + plan.NewUnresolvedTable("dual", ""), + ), + "SHOW COLLATION": plan.NewShowCollation(), + "SHOW COLLATION LIKE 'foo'": plan.NewFilter( + expression.NewLike( + expression.NewUnresolvedColumn("collation"), + expression.NewLiteral("foo", sql.Text), + ), + plan.NewShowCollation(), + ), + "SHOW COLLATION WHERE Charset = 'foo'": plan.NewFilter( + expression.NewEquals( + expression.NewUnresolvedColumn("charset"), + expression.NewLiteral("foo", sql.Text), + ), + plan.NewShowCollation(), + ), + `ROLLBACK`: plan.NewRollback(), + "SHOW CREATE TABLE `mytable`": plan.NewShowCreateTable("", nil, "mytable"), + "SHOW CREATE TABLE `mydb`.`mytable`": plan.NewShowCreateTable("mydb", nil, "mytable"), + "SHOW CREATE TABLE `my.table`": plan.NewShowCreateTable("", nil, "my.table"), + "SHOW CREATE TABLE `my.db`.`my.table`": plan.NewShowCreateTable("my.db", nil, "my.table"), + "SHOW CREATE TABLE `my``table`": plan.NewShowCreateTable("", nil, "my`table"), + "SHOW CREATE TABLE `my``db`.`my``table`": plan.NewShowCreateTable("my`db", nil, "my`table"), + "SHOW CREATE TABLE ````": plan.NewShowCreateTable("", nil, "`"), + "SHOW CREATE TABLE `.`": plan.NewShowCreateTable("", nil, "."), + `SELECT '2018-05-01' + INTERVAL 1 DAY`: plan.NewProject( + []sql.Expression{expression.NewArithmetic( + expression.NewLiteral("2018-05-01", sql.Text), + expression.NewInterval( + expression.NewLiteral(int8(1), sql.Int8), + "DAY", + ), + "+", + )}, + plan.NewUnresolvedTable("dual", ""), + ), + `SELECT '2018-05-01' - INTERVAL 1 DAY`: plan.NewProject( + []sql.Expression{expression.NewArithmetic( + expression.NewLiteral("2018-05-01", sql.Text), + expression.NewInterval( + expression.NewLiteral(int8(1), sql.Int8), + "DAY", + ), + "-", + )}, + plan.NewUnresolvedTable("dual", ""), + ), + `SELECT INTERVAL 1 DAY + '2018-05-01'`: plan.NewProject( + []sql.Expression{expression.NewArithmetic( + expression.NewInterval( + expression.NewLiteral(int8(1), sql.Int8), + "DAY", + ), + expression.NewLiteral("2018-05-01", sql.Text), + "+", + )}, + plan.NewUnresolvedTable("dual", ""), + ), + `SELECT '2018-05-01' + INTERVAL 1 DAY + INTERVAL 1 DAY`: plan.NewProject( + []sql.Expression{expression.NewArithmetic( + expression.NewArithmetic( + expression.NewLiteral("2018-05-01", sql.Text), + expression.NewInterval( + expression.NewLiteral(int8(1), sql.Int8), + "DAY", + ), + "+", + ), + expression.NewInterval( + expression.NewLiteral(int8(1), sql.Int8), + "DAY", + ), + "+", + )}, + plan.NewUnresolvedTable("dual", ""), + ), + `SELECT COUNT(*) FROM foo GROUP BY a HAVING COUNT(*) > 5`: plan.NewHaving( + expression.NewGreaterThan( + expression.NewUnresolvedFunction("count", true, expression.NewStar()), + expression.NewLiteral(int8(5), sql.Int8), + ), + plan.NewGroupBy( + []sql.Expression{expression.NewUnresolvedFunction("count", true, expression.NewStar())}, + []sql.Expression{expression.NewUnresolvedColumn("a")}, + plan.NewUnresolvedTable("foo", ""), + ), + ), + `SELECT DISTINCT COUNT(*) FROM foo GROUP BY a HAVING COUNT(*) > 5`: plan.NewDistinct( + plan.NewHaving( + expression.NewGreaterThan( + expression.NewUnresolvedFunction("count", true, expression.NewStar()), + expression.NewLiteral(int8(5), sql.Int8), + ), + plan.NewGroupBy( + []sql.Expression{expression.NewUnresolvedFunction("count", true, expression.NewStar())}, + []sql.Expression{expression.NewUnresolvedColumn("a")}, + plan.NewUnresolvedTable("foo", ""), + ), + ), + ), + `SELECT * FROM foo LEFT JOIN bar ON 1=1`: plan.NewProject( + []sql.Expression{expression.NewStar()}, + plan.NewLeftJoin( + plan.NewUnresolvedTable("foo", ""), + plan.NewUnresolvedTable("bar", ""), + expression.NewEquals( + expression.NewLiteral(int8(1), sql.Int8), + expression.NewLiteral(int8(1), sql.Int8), + ), + ), + ), + `SELECT * FROM foo LEFT OUTER JOIN bar ON 1=1`: plan.NewProject( + []sql.Expression{expression.NewStar()}, + plan.NewLeftJoin( + plan.NewUnresolvedTable("foo", ""), + plan.NewUnresolvedTable("bar", ""), + expression.NewEquals( + expression.NewLiteral(int8(1), sql.Int8), + expression.NewLiteral(int8(1), sql.Int8), + ), + ), + ), + `SELECT * FROM foo RIGHT JOIN bar ON 1=1`: plan.NewProject( + []sql.Expression{expression.NewStar()}, + plan.NewRightJoin( + plan.NewUnresolvedTable("foo", ""), + plan.NewUnresolvedTable("bar", ""), + expression.NewEquals( + expression.NewLiteral(int8(1), sql.Int8), + expression.NewLiteral(int8(1), sql.Int8), + ), + ), + ), + `SELECT * FROM foo RIGHT OUTER JOIN bar ON 1=1`: plan.NewProject( + []sql.Expression{expression.NewStar()}, + plan.NewRightJoin( + plan.NewUnresolvedTable("foo", ""), + plan.NewUnresolvedTable("bar", ""), + expression.NewEquals( + expression.NewLiteral(int8(1), sql.Int8), + expression.NewLiteral(int8(1), sql.Int8), + ), + ), + ), + `SELECT FIRST(i) FROM foo`: plan.NewGroupBy( + []sql.Expression{ + expression.NewUnresolvedFunction("first", true, expression.NewUnresolvedColumn("i")), + }, + []sql.Expression{}, + plan.NewUnresolvedTable("foo", ""), + ), + `SELECT LAST(i) FROM foo`: plan.NewGroupBy( + []sql.Expression{ + expression.NewUnresolvedFunction("last", true, expression.NewUnresolvedColumn("i")), + }, + []sql.Expression{}, + plan.NewUnresolvedTable("foo", ""), + ), + `SELECT COUNT(DISTINCT i) FROM foo`: plan.NewGroupBy( + []sql.Expression{ + aggregation.NewCountDistinct(expression.NewUnresolvedColumn("i")), + }, + []sql.Expression{}, + plan.NewUnresolvedTable("foo", ""), + ), + `SELECT -128, 127, 255, -32768, 32767, 65535, -2147483648, 2147483647, 4294967295, -9223372036854775808, 9223372036854775807, 18446744073709551615`: plan.NewProject( + []sql.Expression{ + expression.NewLiteral(int8(math.MinInt8), sql.Int8), + expression.NewLiteral(int8(math.MaxInt8), sql.Int8), + expression.NewLiteral(uint8(math.MaxUint8), sql.Uint8), + expression.NewLiteral(int16(math.MinInt16), sql.Int16), + expression.NewLiteral(int16(math.MaxInt16), sql.Int16), + expression.NewLiteral(uint16(math.MaxUint16), sql.Uint16), + expression.NewLiteral(int32(math.MinInt32), sql.Int32), + expression.NewLiteral(int32(math.MaxInt32), sql.Int32), + expression.NewLiteral(uint32(math.MaxUint32), sql.Uint32), + expression.NewLiteral(int64(math.MinInt64), sql.Int64), + expression.NewLiteral(int64(math.MaxInt64), sql.Int64), + expression.NewLiteral(uint64(math.MaxUint64), sql.Uint64), + }, + plan.NewUnresolvedTable("dual", ""), + ), } func TestParse(t *testing.T) { @@ -579,7 +1271,7 @@ func TestParse(t *testing.T) { require := require.New(t) ctx := sql.NewEmptyContext() p, err := Parse(ctx, query) - require.Nil(err, "error for query '%s'", query) + require.NoError(err) require.Exactly(expectedPlan, p, "plans do not match for query '%s'", query) }) @@ -587,8 +1279,24 @@ func TestParse(t *testing.T) { } } -var fixturesErrors = map[string]error{ - `SHOW METHEMONEY`: ErrUnsupportedFeature.New(`SHOW METHEMONEY`), +var fixturesErrors = map[string]*errors.Kind{ + `SHOW METHEMONEY`: ErrUnsupportedFeature, + `LOCK TABLES foo AS READ`: errUnexpectedSyntax, + `LOCK TABLES foo LOW_PRIORITY READ`: errUnexpectedSyntax, + `SELECT * FROM mytable LIMIT -100`: ErrUnsupportedSyntax, + `SELECT * FROM mytable LIMIT 100 OFFSET -1`: ErrUnsupportedSyntax, + `SELECT * FROM files + JOIN commit_files + JOIN refs + `: ErrUnsupportedSyntax, + `SELECT INTERVAL 1 DAY - '2018-05-01'`: ErrUnsupportedSyntax, + `SELECT INTERVAL 1 DAY * '2018-05-01'`: ErrUnsupportedSyntax, + `SELECT '2018-05-01' * INTERVAL 1 DAY`: ErrUnsupportedSyntax, + `SELECT '2018-05-01' / INTERVAL 1 DAY`: ErrUnsupportedSyntax, + `SELECT INTERVAL 1 DAY + INTERVAL 1 DAY`: ErrUnsupportedSyntax, + `SELECT '2018-05-01' + (INTERVAL 1 DAY + INTERVAL 1 DAY)`: ErrUnsupportedSyntax, + `SELECT AVG(DISTINCT foo) FROM b`: ErrUnsupportedSyntax, + `CREATE VIEW view1 AS SELECT x FROM t1 WHERE x>0`: ErrUnsupportedFeature, } func TestParseErrors(t *testing.T) { @@ -598,7 +1306,72 @@ func TestParseErrors(t *testing.T) { ctx := sql.NewEmptyContext() _, err := Parse(ctx, query) require.Error(err) - require.Equal(expectedError.Error(), err.Error()) + require.True(expectedError.Is(err)) + }) + } +} + +func TestRemoveComments(t *testing.T) { + testCases := []struct { + input string + output string + }{ + { + `/* FOO BAR BAZ */`, + ``, + }, + { + `SELECT 1 -- something`, + `SELECT 1 `, + }, + { + `SELECT 1 --something`, + `SELECT 1 --something`, + }, + { + `SELECT ' -- something'`, + `SELECT ' -- something'`, + }, + { + `SELECT /* FOO */ 1;`, + `SELECT 1;`, + }, + { + `SELECT '/* FOO */ 1';`, + `SELECT '/* FOO */ 1';`, + }, + { + `SELECT "\"/* FOO */ 1\"";`, + `SELECT "\"/* FOO */ 1\"";`, + }, + { + `SELECT '\'/* FOO */ 1\'';`, + `SELECT '\'/* FOO */ 1\'';`, + }, + } + for _, tt := range testCases { + t.Run(tt.input, func(t *testing.T) { + require.Equal( + t, + tt.output, + removeComments(tt.input), + ) + }) + } +} + +func TestFixSetQuery(t *testing.T) { + testCases := []struct { + in, out string + }{ + {"set session foo = 1, session bar = 2", "set @@session.foo = 1, @@session.bar = 2"}, + {"set global foo = 1, session bar = 2", "set @@global.foo = 1, @@session.bar = 2"}, + {"set SESSION foo = 1, GLOBAL bar = 2", "set @@session.foo = 1, @@global.bar = 2"}, + } + + for _, tt := range testCases { + t.Run(tt.in, func(t *testing.T) { + require.Equal(t, tt.out, fixSetQuery(tt.in)) }) } } diff --git a/sql/parse/show_create.go b/sql/parse/show_create.go new file mode 100644 index 000000000..028b6024f --- /dev/null +++ b/sql/parse/show_create.go @@ -0,0 +1,113 @@ +package parse + +import ( + "bufio" + "io" + "strings" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/plan" + "gopkg.in/src-d/go-errors.v1" +) + +var errUnsupportedShowCreateQuery = errors.NewKind("Unsupported query: SHOW CREATE %s") + +func parseShowCreate(s string) (sql.Node, error) { + r := bufio.NewReader(strings.NewReader(s)) + + var thingToShow string + err := parseFuncs{ + expect("show"), + skipSpaces, + expect("create"), + skipSpaces, + readIdent(&thingToShow), + skipSpaces, + }.exec(r) + if err != nil { + return nil, err + } + + switch strings.ToLower(thingToShow) { + case "table": + var db, table string + + if err := readQuotableIdent(&table)(r); err != nil { + return nil, err + } + + ru, _, err := r.ReadRune() + if err != nil && err != io.EOF { + return nil, err + } else if err == nil && ru == '.' { + db = table + + if err = readQuotableIdent(&table)(r); err != nil { + return nil, err + } + } + + err = parseFuncs{ + skipSpaces, + checkEOF, + }.exec(r) + if err != nil { + return nil, err + } + + return plan.NewShowCreateTable( + db, + nil, + table), nil + case "database", "schema": + var ifNotExists bool + var next string + + nextByte, err := r.Peek(1) + if err != nil { + return nil, err + } + + // If ` is the next character, it's a db name. Otherwise it may be + // a table name or IF NOT EXISTS. + if nextByte[0] == '`' { + if err = readQuotableIdent(&next)(r); err != nil { + return nil, err + } + } else { + if err = readIdent(&next)(r); err != nil { + return nil, err + } + + if next == "if" { + ifNotExists = true + err = parseFuncs{ + skipSpaces, + expect("not"), + skipSpaces, + expect("exists"), + skipSpaces, + readQuotableIdent(&next), + }.exec(r) + if err != nil { + return nil, err + } + } + } + + err = parseFuncs{ + skipSpaces, + checkEOF, + }.exec(r) + if err != nil { + return nil, err + } + + return plan.NewShowCreateDatabase( + sql.UnresolvedDatabase(next), + ifNotExists, + ), nil + default: + return nil, errUnsupportedShowCreateQuery.New(thingToShow) + } +} diff --git a/sql/parse/show_create_test.go b/sql/parse/show_create_test.go new file mode 100644 index 000000000..3831fc0fb --- /dev/null +++ b/sql/parse/show_create_test.go @@ -0,0 +1,74 @@ +package parse + +import ( + "testing" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/plan" + "github.com/stretchr/testify/require" + "gopkg.in/src-d/go-errors.v1" +) + +func TestParseShowCreateTableQuery(t *testing.T) { + testCases := []struct { + query string + result sql.Node + err *errors.Kind + }{ + { + "SHOW CREATE", + nil, + errUnsupportedShowCreateQuery, + }, + { + "SHOW CREATE ANYTHING", + nil, + errUnsupportedShowCreateQuery, + }, + { + "SHOW CREATE ASDF foo", + nil, + errUnsupportedShowCreateQuery, + }, + { + "SHOW CREATE TABLE mytable", + plan.NewShowCreateTable("", nil, "mytable"), + nil, + }, + { + "SHOW CREATE TABLE `mytable`", + plan.NewShowCreateTable("", nil, "mytable"), + nil, + }, + { + "SHOW CREATE TABLE mydb.`mytable`", + plan.NewShowCreateTable("mydb", nil, "mytable"), + nil, + }, + { + "SHOW CREATE TABLE `mydb`.mytable", + plan.NewShowCreateTable("mydb", nil, "mytable"), + nil, + }, + { + "SHOW CREATE TABLE `mydb`.`mytable`", + plan.NewShowCreateTable("mydb", nil, "mytable"), + nil, + }, + } + + for _, tt := range testCases { + t.Run(tt.query, func(t *testing.T) { + require := require.New(t) + + result, err := parseShowCreate(tt.query) + if tt.err != nil { + require.Error(err) + require.True(tt.err.Is(err)) + } else { + require.NoError(err) + require.Equal(tt.result, result) + } + }) + } +} diff --git a/sql/parse/util.go b/sql/parse/util.go new file mode 100644 index 000000000..bfb358b1d --- /dev/null +++ b/sql/parse/util.go @@ -0,0 +1,310 @@ +package parse + +import ( + "bufio" + "bytes" + "fmt" + "io" + "io/ioutil" + "strings" + "unicode" + + "github.com/src-d/go-mysql-server/sql" + errors "gopkg.in/src-d/go-errors.v1" + "vitess.io/vitess/go/vt/sqlparser" +) + +var ( + errUnexpectedSyntax = errors.NewKind("expecting %q but got %q instead") + errInvalidIndexExpression = errors.NewKind("invalid expression to index: %s") +) + +type parseFunc func(*bufio.Reader) error + +type parseFuncs []parseFunc + +func (f parseFuncs) exec(r *bufio.Reader) error { + for _, fn := range f { + if err := fn(r); err != nil { + return err + } + } + return nil +} + +func expectRune(expected rune) parseFunc { + return func(rd *bufio.Reader) error { + r, _, err := rd.ReadRune() + if err != nil { + return err + } + + if r != expected { + return errUnexpectedSyntax.New(expected, string(r)) + } + + return nil + } +} + +func expect(expected string) parseFunc { + return func(r *bufio.Reader) error { + var ident string + + if err := readIdent(&ident)(r); err != nil { + return err + } + + if ident == expected { + return nil + } + + return errUnexpectedSyntax.New(expected, ident) + } +} + +func skipSpaces(r *bufio.Reader) error { + for { + ru, _, err := r.ReadRune() + if err == io.EOF { + return nil + } + + if err != nil { + return err + } + + if !unicode.IsSpace(ru) { + return r.UnreadRune() + } + } +} + +func checkEOF(rd *bufio.Reader) error { + r, _, err := rd.ReadRune() + if err == io.EOF { + return nil + } + + return errUnexpectedSyntax.New("EOF", r) +} + +func optional(steps ...parseFunc) parseFunc { + return func(rd *bufio.Reader) error { + for _, step := range steps { + err := step(rd) + if err == io.EOF || errUnexpectedSyntax.Is(err) { + return nil + } + + if err != nil { + return err + } + } + + return nil + } +} + +func readLetter(r *bufio.Reader, buf *bytes.Buffer) error { + ru, _, err := r.ReadRune() + if err != nil { + if err == io.EOF { + return nil + } + + return err + } + + if !unicode.IsLetter(ru) { + if err := r.UnreadRune(); err != nil { + return err + } + return nil + } + + buf.WriteRune(ru) + return nil +} + +func readValidIdentRune(r *bufio.Reader, buf *bytes.Buffer) error { + ru, _, err := r.ReadRune() + if err != nil { + return err + } + + if !unicode.IsLetter(ru) && !unicode.IsDigit(ru) && ru != '_' { + if err := r.UnreadRune(); err != nil { + return err + } + return io.EOF + } + + buf.WriteRune(ru) + return nil +} + +func readValidQuotedIdentRune(r *bufio.Reader, buf *bytes.Buffer) error { + bs, err := r.Peek(2) + if err != nil { + return err + } + + if bs[0] == '`' && bs[1] == '`' { + if _, _, err := r.ReadRune(); err != nil { + return err + } + if _, _, err := r.ReadRune(); err != nil { + return err + } + buf.WriteRune('`') + return nil + } + + if bs[0] == '`' && bs[1] != '`' { + return io.EOF + } + + if _, _, err := r.ReadRune(); err != nil { + return err + } + + buf.WriteByte(bs[0]) + + return nil +} + +func unreadString(r *bufio.Reader, str string) { + nr := *r + r.Reset(io.MultiReader(strings.NewReader(str), &nr)) +} + +func readIdent(ident *string) parseFunc { + return func(r *bufio.Reader) error { + var buf bytes.Buffer + if err := readLetter(r, &buf); err != nil { + return err + } + + for { + if err := readValidIdentRune(r, &buf); err == io.EOF { + break + } else if err != nil { + return err + } + } + + *ident = strings.ToLower(buf.String()) + return nil + } +} + +func readQuotedIdent(ident *string) parseFunc { + return func(r *bufio.Reader) error { + var buf bytes.Buffer + if err := readValidQuotedIdentRune(r, &buf); err != nil { + return err + } + + for { + if err := readValidQuotedIdentRune(r, &buf); err == io.EOF { + break + } else if err != nil { + return err + } + } + + *ident = strings.ToLower(buf.String()) + return nil + } +} + +func oneOf(options ...string) parseFunc { + return func(r *bufio.Reader) error { + var ident string + if err := readIdent(&ident)(r); err != nil { + return err + } + + for _, opt := range options { + if strings.ToLower(opt) == ident { + return nil + } + } + + return errUnexpectedSyntax.New( + fmt.Sprintf("one of: %s", strings.Join(options, ", ")), + ident, + ) + } +} + +func readRemaining(val *string) parseFunc { + return func(r *bufio.Reader) error { + bytes, err := ioutil.ReadAll(r) + if err != nil { + return err + } + + *val = string(bytes) + return nil + } +} + +func parseExpr(ctx *sql.Context, str string) (sql.Expression, error) { + stmt, err := sqlparser.Parse("SELECT " + str) + if err != nil { + return nil, err + } + + selectStmt, ok := stmt.(*sqlparser.Select) + if !ok { + return nil, errInvalidIndexExpression.New(str) + } + + if len(selectStmt.SelectExprs) != 1 { + return nil, errInvalidIndexExpression.New(str) + } + + selectExpr, ok := selectStmt.SelectExprs[0].(*sqlparser.AliasedExpr) + if !ok { + return nil, errInvalidIndexExpression.New(str) + } + + return exprToExpression(ctx, selectExpr.Expr) +} + +func readQuotableIdent(ident *string) parseFunc { + return func(r *bufio.Reader) error { + nextChar, err := r.Peek(1) + if err != nil { + return err + } + + var steps parseFuncs + if nextChar[0] == '`' { + steps = parseFuncs{ + expectQuote, + readQuotedIdent(ident), + expectQuote, + } + } else { + steps = parseFuncs{readIdent(ident)} + } + + return steps.exec(r) + } +} + +func expectQuote(r *bufio.Reader) error { + ru, _, err := r.ReadRune() + if err != nil { + return err + } + + if ru != '`' { + return errUnexpectedSyntax.New("`", string(ru)) + } + + return nil +} diff --git a/sql/parse/variables.go b/sql/parse/variables.go new file mode 100644 index 000000000..df4844341 --- /dev/null +++ b/sql/parse/variables.go @@ -0,0 +1,58 @@ +package parse + +import ( + "bufio" + "strings" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/plan" +) + +func parseShowVariables(ctx *sql.Context, s string) (sql.Node, error) { + var pattern string + + r := bufio.NewReader(strings.NewReader(s)) + for _, fn := range []parseFunc{ + expect("show"), + skipSpaces, + func(in *bufio.Reader) error { + var s string + if err := readIdent(&s)(in); err != nil { + return err + } + + switch s { + case "global", "session": + if err := skipSpaces(in); err != nil { + return err + } + + return expect("variables")(in) + case "variables": + return nil + } + return errUnexpectedSyntax.New("show [global | session] variables", s) + }, + skipSpaces, + func(in *bufio.Reader) error { + if expect("like")(in) == nil { + if err := skipSpaces(in); err != nil { + return err + } + + if err := readValue(&pattern)(in); err != nil { + return err + } + } + return nil + }, + skipSpaces, + checkEOF, + } { + if err := fn(r); err != nil { + return nil, err + } + } + + return plan.NewShowVariables(ctx.Session.GetAll(), pattern), nil +} diff --git a/sql/parse/warnings.go b/sql/parse/warnings.go new file mode 100644 index 000000000..9c88a3fd8 --- /dev/null +++ b/sql/parse/warnings.go @@ -0,0 +1,77 @@ +package parse + +import ( + "bufio" + "strconv" + "strings" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/plan" + errors "gopkg.in/src-d/go-errors.v1" +) + +var errInvalidIndex = errors.NewKind("invalid %s index %d (index must be non-negative)") + +func parseShowWarnings(ctx *sql.Context, s string) (sql.Node, error) { + var ( + offstr string + cntstr string + ) + + r := bufio.NewReader(strings.NewReader(s)) + for _, fn := range []parseFunc{ + expect("show"), + skipSpaces, + expect("warnings"), + skipSpaces, + func(in *bufio.Reader) error { + if expect("limit")(in) == nil { + skipSpaces(in) + readValue(&cntstr)(in) + skipSpaces(in) + if expectRune(',')(in) == nil { + if readValue(&offstr)(in) == nil { + offstr, cntstr = cntstr, offstr + } + } + + } + return nil + }, + skipSpaces, + checkEOF, + } { + if err := fn(r); err != nil { + return nil, err + } + } + + var ( + node sql.Node = plan.ShowWarnings(ctx.Session.Warnings()) + offset int + count int + err error + ) + if offstr != "" { + if offset, err = strconv.Atoi(offstr); err != nil { + return nil, err + } + if offset < 0 { + return nil, errInvalidIndex.New("offset", offset) + } + } + node = plan.NewOffset(int64(offset), node) + if cntstr != "" { + if count, err = strconv.Atoi(cntstr); err != nil { + return nil, err + } + if count < 0 { + return nil, errInvalidIndex.New("count", count) + } + if count > 0 { + node = plan.NewLimit(int64(count), node) + } + } + + return node, nil +} diff --git a/sql/plan/common.go b/sql/plan/common.go index 193272781..28e0dd935 100644 --- a/sql/plan/common.go +++ b/sql/plan/common.go @@ -1,6 +1,16 @@ package plan -import "gopkg.in/src-d/go-mysql-server.v0/sql" +import "github.com/src-d/go-mysql-server/sql" + +// IsUnary returns whether the node is unary or not. +func IsUnary(node sql.Node) bool { + return len(node.Children()) == 1 +} + +// IsBinary returns whether the node is binary or not. +func IsBinary(node sql.Node) bool { + return len(node.Children()) == 2 +} // UnaryNode is a node that has only one children. type UnaryNode struct { @@ -47,20 +57,3 @@ func expressionsResolved(exprs ...sql.Expression) bool { return true } - -func transformExpressionsUp( - f sql.TransformExprFunc, - exprs []sql.Expression, -) ([]sql.Expression, error) { - - var es []sql.Expression - for _, e := range exprs { - te, err := e.TransformUp(f) - if err != nil { - return nil, err - } - es = append(es, te) - } - - return es, nil -} diff --git a/sql/plan/common_test.go b/sql/plan/common_test.go index b7ad07066..483a0324c 100644 --- a/sql/plan/common_test.go +++ b/sql/plan/common_test.go @@ -6,12 +6,12 @@ import ( "io" "testing" + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/mem" - "gopkg.in/src-d/go-mysql-server.v0/sql" ) -var benchtable = func() *mem.Table { +var benchtable = func() *memory.Table { schema := sql.Schema{ {Name: "strfield", Type: sql.Text, Nullable: true}, {Name: "floatfield", Type: sql.Float64, Nullable: true}, @@ -20,31 +20,37 @@ var benchtable = func() *mem.Table { {Name: "bigintfield", Type: sql.Int64, Nullable: false}, {Name: "blobfield", Type: sql.Blob, Nullable: false}, } - t := mem.NewTable("test", schema) + t := memory.NewTable("test", schema) for i := 0; i < 100; i++ { n := fmt.Sprint(i) - err := t.Insert(sql.NewRow( - repeatStr(n, i%10+1), - float64(i), - i%2 == 0, - int32(i), - int64(i), - []byte(repeatStr(n, 100+(i%100))), - )) - if err != nil { - panic(err) - } - - if i%2 == 0 { - err := t.Insert(sql.NewRow( + err := t.Insert( + sql.NewEmptyContext(), + sql.NewRow( repeatStr(n, i%10+1), float64(i), i%2 == 0, int32(i), int64(i), []byte(repeatStr(n, 100+(i%100))), - )) + ), + ) + if err != nil { + panic(err) + } + + if i%2 == 0 { + err := t.Insert( + sql.NewEmptyContext(), + sql.NewRow( + repeatStr(n, i%10+1), + float64(i), + i%2 == 0, + int32(i), + int64(i), + []byte(repeatStr(n, 100+(i%100))), + ), + ) if err != nil { panic(err) } @@ -100,3 +106,25 @@ func collectRows(t *testing.T, node sql.Node) []sql.Row { rows = append(rows, row) } } + +func TestIsUnary(t *testing.T) { + require := require.New(t) + table := memory.NewTable("foo", nil) + + require.True(IsUnary(NewFilter(nil, NewResolvedTable(table)))) + require.False(IsUnary(NewCrossJoin( + NewResolvedTable(table), + NewResolvedTable(table), + ))) +} + +func TestIsBinary(t *testing.T) { + require := require.New(t) + table := memory.NewTable("foo", nil) + + require.False(IsBinary(NewFilter(nil, NewResolvedTable(table)))) + require.True(IsBinary(NewCrossJoin( + NewResolvedTable(table), + NewResolvedTable(table), + ))) +} diff --git a/sql/plan/create_index.go b/sql/plan/create_index.go index 6bd43a374..00455a127 100644 --- a/sql/plan/create_index.go +++ b/sql/plan/create_index.go @@ -3,11 +3,14 @@ package plan import ( "fmt" "strings" + "time" + opentracing "github.com/opentracing/opentracing-go" + otlog "github.com/opentracing/opentracing-go/log" "github.com/sirupsen/logrus" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" errors "gopkg.in/src-d/go-errors.v1" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" ) var ( @@ -17,8 +20,9 @@ var ( // ErrInvalidIndexDriver is returned when the index driver can't be found. ErrInvalidIndexDriver = errors.NewKind("invalid driver index %q") - // ErrTableNotNameable is returned when the table name can't be obtained. - ErrTableNotNameable = errors.NewKind("can't get the name from the table") + // ErrExprTypeNotIndexable is returned when the expression type cannot be + // indexed, such as BLOB or JSON. + ErrExprTypeNotIndexable = errors.NewKind("expression %q with type %s cannot be indexed") ) // CreateIndex is a node to create an index. @@ -30,6 +34,7 @@ type CreateIndex struct { Config map[string]string Catalog *sql.Catalog CurrentDatabase string + Async bool } // NewCreateIndex creates a new CreateIndex node. @@ -40,12 +45,14 @@ func NewCreateIndex( driver string, config map[string]string, ) *CreateIndex { + async, ok := config["async"] return &CreateIndex{ Name: name, Table: table, Exprs: exprs, Driver: driver, Config: config, + Async: async != "false" || !ok, } } @@ -67,19 +74,47 @@ func (c *CreateIndex) Resolved() bool { return true } +func getIndexableTable(t sql.Table) (sql.IndexableTable, error) { + switch t := t.(type) { + case sql.IndexableTable: + return t, nil + case sql.TableWrapper: + return getIndexableTable(t.Underlying()) + default: + return nil, ErrNotIndexable.New() + } +} + +func getChecksumable(t sql.Table) sql.Checksumable { + switch t := t.(type) { + case sql.Checksumable: + return t + case sql.TableWrapper: + return getChecksumable(t.Underlying()) + default: + return nil + } +} + // RowIter implements the Node interface. func (c *CreateIndex) RowIter(ctx *sql.Context) (sql.RowIter, error) { - table, ok := c.Table.(sql.Indexable) + table, ok := c.Table.(*ResolvedTable) if !ok { return nil, ErrNotIndexable.New() } - nameable, ok := c.Table.(sql.Nameable) - if !ok { - return nil, ErrTableNotNameable.New() + indexable, err := getIndexableTable(table.Table) + if err != nil { + return nil, err + } + + var driver sql.IndexDriver + if c.Driver == "" { + driver = c.Catalog.DefaultIndexDriver() + } else { + driver = c.Catalog.IndexDriver(c.Driver) } - driver := c.Catalog.IndexDriver(c.Driver) if driver == nil { return nil, ErrInvalidIndexDriver.New(c.Driver) } @@ -89,10 +124,22 @@ func (c *CreateIndex) RowIter(ctx *sql.Context) (sql.RowIter, error) { return nil, err } + for _, e := range exprs { + if e.Type() == sql.Blob || e.Type() == sql.JSON { + return nil, ErrExprTypeNotIndexable.New(e, e.Type()) + } + } + + if ch := getChecksumable(table.Table); ch != nil { + c.Config[sql.ChecksumKey], err = ch.Checksum() + if err != nil { + return nil, err + } + } + index, err := driver.Create( - c.Catalog.IndexRegistry.Root, - nameable.Name(), c.CurrentDatabase, + table.Name(), c.Name, exprs, c.Config, @@ -101,33 +148,94 @@ func (c *CreateIndex) RowIter(ctx *sql.Context) (sql.RowIter, error) { return nil, err } - iter, err := table.IndexKeyValueIter(columns) + iter, err := indexable.IndexKeyValues(ctx, columns) if err != nil { return nil, err } - done, err := c.Catalog.AddIndex(index) + iter = &evalPartitionKeyValueIter{ + ctx: ctx, + columns: columns, + exprs: exprs, + iter: iter, + } + + created, ready, err := c.Catalog.AddIndex(index) if err != nil { return nil, err } - go func() { - err := driver.Save(ctx, c.Catalog.IndexRegistry.Root, index, iter) - close(done) - if err != nil { - logrus.WithField("err", err).Error("unable to save the index") - deleted, err := c.Catalog.DeleteIndex(index.Database(), index.ID()) - if err != nil { - logrus.WithField("err", err).Error("unable to delete the index") - } else { - <-deleted - } - } - }() + log := logrus.WithFields(logrus.Fields{ + "id": index.ID(), + "driver": index.Driver(), + }) + + createIndex := func() { + c.createIndex(ctx, log, driver, index, iter, created, ready) + c.Catalog.ProcessList.Done(ctx.Pid()) + } + + log.WithField("async", c.Async).Info("starting to save the index") + + if c.Async { + go createIndex() + } else { + createIndex() + } return sql.RowsToRowIter(), nil } +func (c *CreateIndex) createIndex( + ctx *sql.Context, + log *logrus.Entry, + driver sql.IndexDriver, + index sql.Index, + iter sql.PartitionIndexKeyValueIter, + done chan<- struct{}, + ready <-chan struct{}, +) { + span, ctx := ctx.Span("plan.createIndex", + opentracing.Tags{ + "index": index.ID(), + "table": index.Table(), + "driver": index.Driver(), + }) + + l := log.WithField("id", index.ID()) + + err := driver.Save(ctx, index, newLoggingPartitionKeyValueIter(ctx, l, iter)) + close(done) + + if err != nil { + span.FinishWithOptions(opentracing.FinishOptions{ + LogRecords: []opentracing.LogRecord{ + { + Timestamp: time.Now(), + Fields: []otlog.Field{ + otlog.String("error", err.Error()), + }, + }, + }, + }) + + ctx.Error(0, "unable to save the index: %s", err) + logrus.WithField("err", err).Error("unable to save the index") + + deleted, err := c.Catalog.DeleteIndex(index.Database(), index.ID(), true) + if err != nil { + ctx.Error(0, "unable to delete index: %s", err) + logrus.WithField("err", err).Error("unable to delete the index") + } else { + <-deleted + } + } else { + <-ready + span.Finish() + log.Info("index successfully created") + } +} + // Schema implements the Node interface. func (c *CreateIndex) Schema() sql.Schema { return nil } @@ -147,39 +255,36 @@ func (c *CreateIndex) String() string { return pr.String() } -// TransformExpressionsUp implements the Node interface. -func (c *CreateIndex) TransformExpressionsUp(fn sql.TransformExprFunc) (sql.Node, error) { - table, err := c.Table.TransformExpressionsUp(fn) - if err != nil { - return nil, err - } +// Expressions implements the Expressioner interface. +func (c *CreateIndex) Expressions() []sql.Expression { + return c.Exprs +} - var exprs = make([]sql.Expression, len(c.Exprs)) - for i, e := range c.Exprs { - exprs[i], err = e.TransformUp(fn) - if err != nil { - return nil, err - } +// WithExpressions implements the Expressioner interface. +func (c *CreateIndex) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + if len(exprs) != len(c.Exprs) { + return nil, sql.ErrInvalidChildrenNumber.New(c, len(exprs), len(c.Exprs)) } nc := *c - nc.Table = table nc.Exprs = exprs - return &nc, nil } -// TransformUp implements the Node interface. -func (c *CreateIndex) TransformUp(fn sql.TransformNodeFunc) (sql.Node, error) { - table, err := c.Table.TransformUp(fn) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (c *CreateIndex) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 1) } nc := *c - nc.Table = table + nc.Table = children[0] + return &nc, nil +} - return fn(&nc) +// IsAsync implements the AsyncNode interface. +func (c *CreateIndex) IsAsync() bool { + return c.Async } // getColumnsAndPrepareExpressions extracts the unique columns required by all @@ -193,16 +298,15 @@ func getColumnsAndPrepareExpressions( var expressions = make([]sql.Expression, len(exprs)) for i, e := range exprs { - var err error - expressions[i], err = e.TransformUp(func(e sql.Expression) (sql.Expression, error) { + ex, err := expression.TransformUp(e, func(e sql.Expression) (sql.Expression, error) { gf, ok := e.(*expression.GetField) if !ok { return e, nil } var idx int - if i, ok := seen[gf.Name()]; ok { - idx = i + if j, ok := seen[gf.Name()]; ok { + idx = j } else { idx = len(columns) columns = append(columns, gf.Name()) @@ -221,7 +325,162 @@ func getColumnsAndPrepareExpressions( if err != nil { return nil, nil, err } + + expressions[i] = ex } return columns, expressions, nil } + +type evalPartitionKeyValueIter struct { + iter sql.PartitionIndexKeyValueIter + columns []string + exprs []sql.Expression + ctx *sql.Context +} + +func (i *evalPartitionKeyValueIter) Next() (sql.Partition, sql.IndexKeyValueIter, error) { + p, iter, err := i.iter.Next() + if err != nil { + return nil, nil, err + } + + return p, &evalKeyValueIter{ + ctx: i.ctx, + columns: i.columns, + exprs: i.exprs, + iter: iter, + }, nil +} + +func (i *evalPartitionKeyValueIter) Close() error { + return i.iter.Close() +} + +type evalKeyValueIter struct { + ctx *sql.Context + iter sql.IndexKeyValueIter + columns []string + exprs []sql.Expression +} + +func (i *evalKeyValueIter) Next() ([]interface{}, []byte, error) { + vals, loc, err := i.iter.Next() + if err != nil { + return nil, nil, err + } + + row := sql.NewRow(vals...) + evals := make([]interface{}, len(i.exprs)) + for j, ex := range i.exprs { + eval, err := ex.Eval(i.ctx, row) + if err != nil { + return nil, nil, err + } + + evals[j] = eval + } + + return evals, loc, nil +} + +func (i *evalKeyValueIter) Close() error { + return i.iter.Close() +} + +type loggingPartitionKeyValueIter struct { + ctx *sql.Context + log *logrus.Entry + iter sql.PartitionIndexKeyValueIter + rows uint64 +} + +func newLoggingPartitionKeyValueIter( + ctx *sql.Context, + log *logrus.Entry, + iter sql.PartitionIndexKeyValueIter, +) *loggingPartitionKeyValueIter { + return &loggingPartitionKeyValueIter{ + ctx: ctx, + log: log, + iter: iter, + } +} + +func (i *loggingPartitionKeyValueIter) Next() (sql.Partition, sql.IndexKeyValueIter, error) { + p, iter, err := i.iter.Next() + if err != nil { + return nil, nil, err + } + + return p, newLoggingKeyValueIter(i.ctx, i.log, iter, &i.rows), nil +} + +func (i *loggingPartitionKeyValueIter) Close() error { + return i.iter.Close() +} + +type loggingKeyValueIter struct { + ctx *sql.Context + span opentracing.Span + log *logrus.Entry + iter sql.IndexKeyValueIter + rows *uint64 + start time.Time +} + +func newLoggingKeyValueIter( + ctx *sql.Context, + log *logrus.Entry, + iter sql.IndexKeyValueIter, + rows *uint64, +) *loggingKeyValueIter { + return &loggingKeyValueIter{ + ctx: ctx, + log: log, + iter: iter, + start: time.Now(), + rows: rows, + } +} + +func (i *loggingKeyValueIter) Next() ([]interface{}, []byte, error) { + if i.span == nil { + i.span, _ = i.ctx.Span("plan.createIndex.iterator", + opentracing.Tags{ + "start": i.rows, + }, + ) + } + + (*i.rows)++ + if *i.rows%sql.IndexBatchSize == 0 { + duration := time.Since(i.start) + + i.log.WithFields(logrus.Fields{ + "duration": duration, + "rows": *i.rows, + }).Debugf("still creating index") + + if i.span != nil { + i.span.LogKV("duration", duration.String()) + i.span.Finish() + i.span = nil + } + + i.start = time.Now() + } + + val, loc, err := i.iter.Next() + if err != nil { + i.span.LogKV("error", err) + i.span.Finish() + i.span = nil + } + + return val, loc, err +} + +func (i *loggingKeyValueIter) Close() error { + return i.iter.Close() +} diff --git a/sql/plan/create_index_test.go b/sql/plan/create_index_test.go index 6094f750d..0cb509583 100644 --- a/sql/plan/create_index_test.go +++ b/sql/plan/create_index_test.go @@ -2,42 +2,49 @@ package plan import ( "context" + "io" + "math" "testing" "time" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/test" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/mem" - "gopkg.in/src-d/go-mysql-server.v0/sql" ) -func TestCreateIndex(t *testing.T) { +func TestCreateIndexAsync(t *testing.T) { require := require.New(t) - table := &indexableTable{mem.NewTable("foo", sql.Schema{ + table := memory.NewTable("foo", sql.Schema{ {Name: "a", Source: "foo"}, {Name: "b", Source: "foo"}, {Name: "c", Source: "foo"}, - })} + }) driver := new(mockDriver) catalog := sql.NewCatalog() catalog.RegisterIndexDriver(driver) - db := mem.NewDatabase("foo") + db := memory.NewDatabase("foo") db.AddTable("foo", table) - catalog.Databases = append(catalog.Databases, db) + catalog.AddDatabase(db) exprs := []sql.Expression{ expression.NewGetFieldWithTable(2, sql.Int64, "foo", "c", true), expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", true), } - ci := NewCreateIndex("idx", table, exprs, "mock", make(map[string]string)) + ci := NewCreateIndex("idx", NewResolvedTable(table), exprs, "mock", map[string]string{ + "async": "true", + }) ci.Catalog = catalog ci.CurrentDatabase = "foo" - _, err := ci.RowIter(sql.NewEmptyContext()) + tracer := new(test.MemTracer) + ctx := sql.NewContext(context.Background(), sql.WithTracer(tracer)) + _, err := ci.RowIter(ctx) require.NoError(err) time.Sleep(50 * time.Millisecond) @@ -46,33 +53,317 @@ func TestCreateIndex(t *testing.T) { require.Equal([]string{"idx"}, driver.saved) idx := catalog.IndexRegistry.Index("foo", "idx") require.NotNil(idx) - require.Equal(&mockIndex{"idx", "foo", "foo", []sql.Expression{ + require.Equal(&mockIndex{"foo", "foo", "idx", []sql.Expression{ expression.NewGetFieldWithTable(0, sql.Int64, "foo", "c", true), expression.NewGetFieldWithTable(1, sql.Int64, "foo", "a", true), }}, idx) + + found := false + for _, span := range tracer.Spans { + if span == "plan.createIndex" { + found = true + break + } + } + + require.True(found) +} + +func TestCreateIndexNotIndexableExprs(t *testing.T) { + require := require.New(t) + + table := memory.NewTable("foo", sql.Schema{ + {Name: "a", Source: "foo", Type: sql.Blob}, + {Name: "b", Source: "foo", Type: sql.JSON}, + {Name: "c", Source: "foo", Type: sql.Text}, + }) + + driver := new(mockDriver) + catalog := sql.NewCatalog() + catalog.RegisterIndexDriver(driver) + db := memory.NewDatabase("foo") + db.AddTable("foo", table) + catalog.AddDatabase(db) + + ci := NewCreateIndex( + "idx", + NewResolvedTable(table), + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Blob, "foo", "a", true), + }, + "mock", + make(map[string]string), + ) + ci.Catalog = catalog + ci.CurrentDatabase = "foo" + + _, err := ci.RowIter(sql.NewEmptyContext()) + require.Error(err) + require.True(ErrExprTypeNotIndexable.Is(err)) + + ci = NewCreateIndex( + "idx", + NewResolvedTable(table), + []sql.Expression{ + expression.NewGetFieldWithTable(1, sql.JSON, "foo", "a", true), + }, + "mock", + make(map[string]string), + ) + ci.Catalog = catalog + ci.CurrentDatabase = "foo" + + _, err = ci.RowIter(sql.NewEmptyContext()) + require.Error(err) + require.True(ErrExprTypeNotIndexable.Is(err)) +} + +func TestCreateIndexSync(t *testing.T) { + require := require.New(t) + + table := memory.NewTable("foo", sql.Schema{ + {Name: "a", Source: "foo"}, + {Name: "b", Source: "foo"}, + {Name: "c", Source: "foo"}, + }) + + driver := new(mockDriver) + catalog := sql.NewCatalog() + catalog.RegisterIndexDriver(driver) + db := memory.NewDatabase("foo") + db.AddTable("foo", table) + catalog.AddDatabase(db) + + exprs := []sql.Expression{ + expression.NewGetFieldWithTable(2, sql.Int64, "foo", "c", true), + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", true), + } + + ci := NewCreateIndex( + "idx", NewResolvedTable(table), exprs, "mock", + map[string]string{"async": "false"}, + ) + ci.Catalog = catalog + ci.CurrentDatabase = "foo" + + tracer := new(test.MemTracer) + ctx := sql.NewContext(context.Background(), sql.WithTracer(tracer)) + _, err := ci.RowIter(ctx) + require.NoError(err) + + require.Len(driver.deleted, 0) + require.Equal([]string{"idx"}, driver.saved) + idx := catalog.IndexRegistry.Index("foo", "idx") + require.NotNil(idx) + require.Equal(&mockIndex{"foo", "foo", "idx", []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "c", true), + expression.NewGetFieldWithTable(1, sql.Int64, "foo", "a", true), + }}, idx) + + found := false + for _, span := range tracer.Spans { + if span == "plan.createIndex" { + found = true + break + } + } + + require.True(found) +} + +func TestCreateIndexChecksum(t *testing.T) { + require := require.New(t) + + table := &checksumTable{ + memory.NewTable("foo", sql.Schema{ + {Name: "a", Source: "foo"}, + {Name: "b", Source: "foo"}, + {Name: "c", Source: "foo"}, + }), + "1", + } + + driver := new(mockDriver) + catalog := sql.NewCatalog() + catalog.RegisterIndexDriver(driver) + db := memory.NewDatabase("foo") + db.AddTable("foo", table) + catalog.AddDatabase(db) + + exprs := []sql.Expression{ + expression.NewGetFieldWithTable(2, sql.Int64, "foo", "c", true), + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", true), + } + + ci := NewCreateIndex( + "idx", NewResolvedTable(table), exprs, "mock", + map[string]string{"async": "false"}, + ) + ci.Catalog = catalog + ci.CurrentDatabase = "foo" + + _, err := ci.RowIter(sql.NewEmptyContext()) + require.NoError(err) + + require.Equal([]string{"idx"}, driver.saved) + require.Equal("1", driver.config["idx"][sql.ChecksumKey]) +} + +func TestCreateIndexChecksumWithUnderlying(t *testing.T) { + require := require.New(t) + + table := + &underlyingTable{ + &underlyingTable{ + &underlyingTable{ + &checksumTable{ + memory.NewTable("foo", sql.Schema{ + {Name: "a", Source: "foo"}, + {Name: "b", Source: "foo"}, + {Name: "c", Source: "foo"}, + }), + "1", + }, + }, + }, + } + + driver := new(mockDriver) + catalog := sql.NewCatalog() + catalog.RegisterIndexDriver(driver) + + exprs := []sql.Expression{ + expression.NewGetFieldWithTable(2, sql.Int64, "foo", "c", true), + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", true), + } + + ci := NewCreateIndex( + "idx", NewResolvedTable(table), exprs, "mock", + map[string]string{"async": "false"}, + ) + ci.Catalog = catalog + ci.CurrentDatabase = "foo" + + _, err := ci.RowIter(sql.NewEmptyContext()) + require.NoError(err) + + require.Equal([]string{"idx"}, driver.saved) + require.Equal("1", driver.config["idx"][sql.ChecksumKey]) +} + +func TestCreateIndexWithIter(t *testing.T) { + require := require.New(t) + foo := memory.NewPartitionedTable("foo", sql.Schema{ + {Name: "one", Source: "foo", Type: sql.Int64}, + {Name: "two", Source: "foo", Type: sql.Int64}, + }, 2) + + rows := [][2]int64{ + {1, 2}, + {-1, -2}, + {0, 0}, + {math.MaxInt64, math.MinInt64}, + } + for _, r := range rows { + err := foo.Insert(sql.NewEmptyContext(), sql.NewRow(r[0], r[1])) + require.NoError(err) + } + + exprs := []sql.Expression{expression.NewPlus( + expression.NewGetField(0, sql.Int64, "one", false), + expression.NewGetField(0, sql.Int64, "two", false)), + } + + driver := new(mockDriver) + catalog := sql.NewCatalog() + catalog.RegisterIndexDriver(driver) + db := memory.NewDatabase("foo") + db.AddTable("foo", foo) + catalog.AddDatabase(db) + + ci := NewCreateIndex("idx", NewResolvedTable(foo), exprs, "mock", make(map[string]string)) + ci.Catalog = catalog + ci.CurrentDatabase = "foo" + + columns, exprs, err := getColumnsAndPrepareExpressions(ci.Exprs) + require.NoError(err) + + iter, err := foo.IndexKeyValues(sql.NewEmptyContext(), columns) + require.NoError(err) + + iter = &evalPartitionKeyValueIter{ + ctx: sql.NewEmptyContext(), + columns: columns, + exprs: exprs, + iter: iter, + } + + var ( + vals [][]interface{} + i int + ) + + for { + _, kviter, err := iter.Next() + if err == io.EOF { + break + } + require.NoError(err) + + vals = append(vals, nil) + + for { + values, _, err := kviter.Next() + if err == io.EOF { + break + } + require.NoError(err) + + vals[i] = append(vals[i], values...) + } + + require.NoError(kviter.Close()) + + i++ + } + require.NoError(iter.Close()) + + require.Equal([][]interface{}{ + {int64(3), int64(0)}, + {int64(-3), int64(-1)}, + }, vals) } type mockIndex struct { - id string - table string db string + table string + id string exprs []sql.Expression } var _ sql.Index = (*mockIndex)(nil) -func (i *mockIndex) ID() string { return i.id } -func (i *mockIndex) Table() string { return i.table } -func (i *mockIndex) Database() string { return i.db } -func (i *mockIndex) Expressions() []sql.Expression { return i.exprs } -func (i *mockIndex) Get(key interface{}) (sql.IndexLookup, error) { +func (i *mockIndex) ID() string { return i.id } +func (i *mockIndex) Table() string { return i.table } +func (i *mockIndex) Database() string { return i.db } +func (i *mockIndex) Expressions() []string { + exprs := make([]string, len(i.exprs)) + for i, e := range i.exprs { + exprs[i] = e.String() + } + + return exprs +} +func (i *mockIndex) Get(key ...interface{}) (sql.IndexLookup, error) { panic("unimplemented") } -func (i *mockIndex) Has(key interface{}) (bool, error) { +func (i *mockIndex) Has(sql.Partition, ...interface{}) (bool, error) { panic("unimplemented") } +func (*mockIndex) Driver() string { return "mock" } type mockDriver struct { + config map[string]map[string]string deleted []string saved []string } @@ -80,32 +371,40 @@ type mockDriver struct { var _ sql.IndexDriver = (*mockDriver)(nil) func (*mockDriver) ID() string { return "mock" } -func (*mockDriver) Create(path, db, table, id string, exprs []sql.Expression, config map[string]string) (sql.Index, error) { - return &mockIndex{id, table, db, exprs}, nil +func (d *mockDriver) Create(db, table, id string, exprs []sql.Expression, config map[string]string) (sql.Index, error) { + if d.config == nil { + d.config = make(map[string]map[string]string) + } + d.config[id] = config + + return &mockIndex{db, table, id, exprs}, nil } -func (*mockDriver) Load(path string) (sql.Index, error) { +func (*mockDriver) LoadAll(db, table string) ([]sql.Index, error) { panic("not implemented") } -func (d *mockDriver) Save(ctx context.Context, path string, index sql.Index, iter sql.IndexKeyValueIter) error { + +func (d *mockDriver) Save(ctx *sql.Context, index sql.Index, iter sql.PartitionIndexKeyValueIter) error { d.saved = append(d.saved, index.ID()) return nil } -func (d *mockDriver) Delete(path string, index sql.Index) error { +func (d *mockDriver) Delete(index sql.Index, _ sql.PartitionIter) error { d.deleted = append(d.deleted, index.ID()) return nil } -type indexableTable struct { +type checksumTable struct { sql.Table + checksum string } -func (indexableTable) IndexKeyValueIter(colNames []string) (sql.IndexKeyValueIter, error) { - return nil, nil +func (t *checksumTable) Checksum() (string, error) { + return t.checksum, nil } -func (indexableTable) WithProjectFiltersAndIndex( - columns, filters []sql.Expression, - index sql.IndexValueIter, -) (sql.RowIter, error) { - return nil, nil +func (t *checksumTable) Underlying() sql.Table { return t.Table } + +type underlyingTable struct { + sql.Table } + +func (t *underlyingTable) Underlying() sql.Table { return t.Table } diff --git a/sql/plan/cross_join.go b/sql/plan/cross_join.go index 69af567e3..f7dc1d80d 100644 --- a/sql/plan/cross_join.go +++ b/sql/plan/cross_join.go @@ -5,7 +5,7 @@ import ( "reflect" opentracing "github.com/opentracing/opentracing-go" - "gopkg.in/src-d/go-mysql-server.v0/sql" + "github.com/src-d/go-mysql-server/sql" ) // CrossJoin is a cross join between two tables. @@ -66,34 +66,13 @@ func (p *CrossJoin) RowIter(ctx *sql.Context) (sql.RowIter, error) { }), nil } -// TransformUp implements the Transformable interface. -func (p *CrossJoin) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - left, err := p.Left.TransformUp(f) - if err != nil { - return nil, err - } - - right, err := p.Right.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (p *CrossJoin) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 2) } - return f(NewCrossJoin(left, right)) -} - -// TransformExpressionsUp implements the Transformable interface. -func (p *CrossJoin) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - left, err := p.Left.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - - right, err := p.Right.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - - return NewCrossJoin(left, right), nil + return NewCrossJoin(children[0], children[1]), nil } func (p *CrossJoin) String() string { @@ -147,21 +126,26 @@ func (i *crossJoinIterator) Next() (sql.Row, error) { return nil, err } - return append(i.leftRow, rightRow...), nil + var row sql.Row + row = append(row, i.leftRow...) + row = append(row, rightRow...) + + return row, nil } } -func (i *crossJoinIterator) Close() error { - if err := i.l.Close(); err != nil { - if i.r != nil { - _ = i.r.Close() - } - return err +func (i *crossJoinIterator) Close() (err error) { + if i.l != nil { + err = i.l.Close() } if i.r != nil { - return i.r.Close() + if err == nil { + err = i.r.Close() + } else { + i.r.Close() + } } - return nil + return err } diff --git a/sql/plan/cross_join_test.go b/sql/plan/cross_join_test.go index 4783f11b2..f45548e99 100644 --- a/sql/plan/cross_join_test.go +++ b/sql/plan/cross_join_test.go @@ -4,9 +4,9 @@ import ( "io" "testing" + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/mem" - "gopkg.in/src-d/go-mysql-server.v0/sql" ) var lSchema = sql.Schema{ @@ -38,50 +38,53 @@ func TestCrossJoin(t *testing.T) { {Name: "rcol4", Type: sql.Int64}, } - ltable := mem.NewTable("left", lSchema) - rtable := mem.NewTable("right", rSchema) + ltable := memory.NewTable("left", lSchema) + rtable := memory.NewTable("right", rSchema) insertData(t, ltable) insertData(t, rtable) - j := NewCrossJoin(ltable, rtable) + j := NewCrossJoin( + NewResolvedTable(ltable), + NewResolvedTable(rtable), + ) require.Equal(resultSchema, j.Schema()) iter, err := j.RowIter(ctx) - require.Nil(err) + require.NoError(err) require.NotNil(iter) row, err := iter.Next() - require.Nil(err) + require.NoError(err) require.NotNil(row) require.Equal(8, len(row)) require.Equal("col1_1", row[0]) require.Equal("col2_1", row[1]) - require.Equal(int32(1111), row[2]) - require.Equal(int64(2222), row[3]) + require.Equal(int32(1), row[2]) + require.Equal(int64(2), row[3]) require.Equal("col1_1", row[4]) require.Equal("col2_1", row[5]) - require.Equal(int32(1111), row[6]) - require.Equal(int64(2222), row[7]) + require.Equal(int32(1), row[6]) + require.Equal(int64(2), row[7]) row, err = iter.Next() - require.Nil(err) + require.NoError(err) require.NotNil(row) require.Equal("col1_1", row[0]) require.Equal("col2_1", row[1]) - require.Equal(int32(1111), row[2]) - require.Equal(int64(2222), row[3]) + require.Equal(int32(1), row[2]) + require.Equal(int64(2), row[3]) require.Equal("col1_2", row[4]) require.Equal("col2_2", row[5]) - require.Equal(int32(3333), row[6]) - require.Equal(int64(4444), row[7]) + require.Equal(int32(3), row[6]) + require.Equal(int64(4), row[7]) for i := 0; i < 2; i++ { row, err = iter.Next() - require.Nil(err) + require.NoError(err) require.NotNil(row) } @@ -96,28 +99,34 @@ func TestCrossJoin_Empty(t *testing.T) { require := require.New(t) ctx := sql.NewEmptyContext() - ltable := mem.NewTable("left", lSchema) - rtable := mem.NewTable("right", rSchema) + ltable := memory.NewTable("left", lSchema) + rtable := memory.NewTable("right", rSchema) insertData(t, ltable) - j := NewCrossJoin(ltable, rtable) + j := NewCrossJoin( + NewResolvedTable(ltable), + NewResolvedTable(rtable), + ) iter, err := j.RowIter(ctx) - require.Nil(err) + require.NoError(err) require.NotNil(iter) row, err := iter.Next() require.Equal(io.EOF, err) require.Nil(row) - ltable = mem.NewTable("left", lSchema) - rtable = mem.NewTable("right", rSchema) + ltable = memory.NewTable("left", lSchema) + rtable = memory.NewTable("right", rSchema) insertData(t, rtable) - j = NewCrossJoin(ltable, rtable) + j = NewCrossJoin( + NewResolvedTable(ltable), + NewResolvedTable(rtable), + ) iter, err = j.RowIter(ctx) - require.Nil(err) + require.NoError(err) require.NotNil(iter) row, err = iter.Next() @@ -125,11 +134,16 @@ func TestCrossJoin_Empty(t *testing.T) { require.Nil(row) } -func insertData(t *testing.T, table *mem.Table) { +func insertData(t *testing.T, table *memory.Table) { t.Helper() require := require.New(t) - err := table.Insert(sql.NewRow("col1_1", "col2_1", int32(1111), int64(2222))) - require.Nil(err) - err = table.Insert(sql.NewRow("col1_2", "col2_2", int32(3333), int64(4444))) - require.Nil(err) + + rows := []sql.Row{ + sql.NewRow("col1_1", "col2_1", int32(1), int64(2)), + sql.NewRow("col1_2", "col2_2", int32(3), int64(4)), + } + + for _, r := range rows { + require.NoError(table.Insert(sql.NewEmptyContext(), r)) + } } diff --git a/sql/plan/ddl.go b/sql/plan/ddl.go index b60f19353..d9acf8a09 100644 --- a/sql/plan/ddl.go +++ b/sql/plan/ddl.go @@ -1,18 +1,20 @@ package plan import ( + "fmt" + "github.com/src-d/go-mysql-server/sql" "gopkg.in/src-d/go-errors.v1" - "gopkg.in/src-d/go-mysql-server.v0/sql" ) // ErrCreateTable is thrown when the database doesn't support table creation -var ErrCreateTable = errors.NewKind("tables cannot be created on database %s") +var ErrCreateTableNotSupported = errors.NewKind("tables cannot be created on database %s") +var ErrDropTableNotSupported = errors.NewKind("tables cannot be dropped on database %s") // CreateTable is a node describing the creation of some table. type CreateTable struct { - Database sql.Database - name string - schema sql.Schema + db sql.Database + name string + schema sql.Schema } // NewCreateTable creates a new CreateTable node @@ -22,48 +24,139 @@ func NewCreateTable(db sql.Database, name string, schema sql.Schema) *CreateTabl } return &CreateTable{ - Database: db, - name: name, - schema: schema, + db: db, + name: name, + schema: schema, } } +var _ sql.Databaser = (*CreateTable)(nil) + +// Database implements the sql.Databaser interface. +func (c *CreateTable) Database() sql.Database { + return c.db +} + +// WithDatabase implements the sql.Databaser interface. +func (c *CreateTable) WithDatabase(db sql.Database) (sql.Node, error) { + nc := *c + nc.db = db + return &nc, nil +} + // Resolved implements the Resolvable interface. func (c *CreateTable) Resolved() bool { - _, ok := c.Database.(*sql.UnresolvedDatabase) + _, ok := c.db.(sql.UnresolvedDatabase) return !ok } // RowIter implements the Node interface. func (c *CreateTable) RowIter(s *sql.Context) (sql.RowIter, error) { - d, ok := c.Database.(sql.Alterable) - if !ok { - return nil, ErrCreateTable.New(c.Database.Name()) + creatable, ok := c.db.(sql.TableCreator) + if ok { + return sql.RowsToRowIter(), creatable.CreateTable(s, c.name, c.schema) } - return sql.RowsToRowIter(), d.Create(c.name, c.schema) + return nil, ErrCreateTableNotSupported.New(c.db.Name()) } // Schema implements the Node interface. -func (c *CreateTable) Schema() sql.Schema { - return sql.Schema{} -} +func (c *CreateTable) Schema() sql.Schema { return nil } // Children implements the Node interface. -func (c *CreateTable) Children() []sql.Node { - return nil -} +func (c *CreateTable) Children() []sql.Node { return nil } -// TransformUp implements the Transformable interface. -func (c *CreateTable) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(NewCreateTable(c.Database, c.name, c.schema)) -} - -// TransformExpressionsUp implements the Transformable interface. -func (c *CreateTable) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { +// WithChildren implements the Node interface. +func (c *CreateTable) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 0) + } return c, nil } func (c *CreateTable) String() string { return "CreateTable" } + +// DropTable is a node describing dropping one or more tables +type DropTable struct { + db sql.Database + names []string + ifExists bool +} + +// NewDropTable creates a new DropTable node +func NewDropTable(db sql.Database, ifExists bool, tableNames ...string) *DropTable { + return &DropTable{ + db: db, + names: tableNames, + ifExists: ifExists, + } +} + +var _ sql.Databaser = (*DropTable)(nil) + +// Database implements the sql.Databaser interface. +func (d *DropTable) Database() sql.Database { + return d.db +} + +// WithDatabase implements the sql.Databaser interface. +func (d *DropTable) WithDatabase(db sql.Database) (sql.Node, error) { + nc := *d + nc.db = db + return &nc, nil +} + +// Resolved implements the Resolvable interface. +func (d *DropTable) Resolved() bool { + _, ok := d.db.(sql.UnresolvedDatabase) + return !ok +} + +// RowIter implements the Node interface. +func (d *DropTable) RowIter(s *sql.Context) (sql.RowIter, error) { + droppable, ok := d.db.(sql.TableDropper) + if !ok { + return nil, ErrDropTableNotSupported.New(d.db.Name()) + } + + var err error + for _, tableName := range d.names { + _, ok := d.db.Tables()[tableName] + if !ok { + if d.ifExists { + continue + } + return nil, sql.ErrTableNotFound.New(tableName) + } + err = droppable.DropTable(s, tableName) + if err != nil { + break + } + } + + return sql.RowsToRowIter(), err +} + +// Schema implements the Node interface. +func (d *DropTable) Schema() sql.Schema { return nil } + +// Children implements the Node interface. +func (d *DropTable) Children() []sql.Node { return nil } + +// WithChildren implements the Node interface. +func (d *DropTable) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 0) + } + return d, nil +} + +func (d *DropTable) String() string { + ifExists := "" + if d.ifExists { + ifExists = "if exists " + } + return fmt.Sprintf("Drop table %s%s", ifExists, d.names) +} diff --git a/sql/plan/ddl_test.go b/sql/plan/ddl_test.go index 1ede00f3b..a226a631c 100644 --- a/sql/plan/ddl_test.go +++ b/sql/plan/ddl_test.go @@ -4,15 +4,15 @@ import ( "io" "testing" + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/mem" - "gopkg.in/src-d/go-mysql-server.v0/sql" ) func TestCreateTable(t *testing.T) { require := require.New(t) - db := mem.NewDatabase("test") + db := memory.NewDatabase("test") tables := db.Tables() _, ok := tables["testTable"] require.False(ok) @@ -22,24 +22,72 @@ func TestCreateTable(t *testing.T) { {Name: "c2", Type: sql.Int32}, } - c := NewCreateTable(db, "testTable", s) + createTable(t, db, "testTable", s) - rows, err := c.RowIter(sql.NewEmptyContext()) + tables = db.Tables() + + newTable, ok := tables["testTable"] + require.True(ok) + + require.Equal(newTable.Schema(), s) + + for _, s := range newTable.Schema() { + require.Equal("testTable", s.Source) + } +} + +func TestDropTable(t *testing.T) { + require := require.New(t) + + db := memory.NewDatabase("test") + + s := sql.Schema{ + {Name: "c1", Type: sql.Text}, + {Name: "c2", Type: sql.Int32}, + } + createTable(t, db, "testTable1", s) + createTable(t, db, "testTable2", s) + createTable(t, db, "testTable3", s) + + d := NewDropTable(db, false, "testTable1", "testTable2") + rows, err := d.RowIter(sql.NewEmptyContext()) require.NoError(err) r, err := rows.Next() require.Equal(err, io.EOF) require.Nil(r) - tables = db.Tables() - - newTable, ok := tables["testTable"] + _, ok := db.Tables()["testTable1"] + require.False(ok) + _, ok = db.Tables()["testTable2"] + require.False(ok) + _, ok = db.Tables()["testTable3"] require.True(ok) - require.Equal(newTable.Schema(), s) + d = NewDropTable(db, false, "testTable1") + _, err = d.RowIter(sql.NewEmptyContext()) + require.Error(err) - for _, s := range newTable.Schema() { - require.Equal(newTable.Name(), s.Source) - } + d = NewDropTable(db, true, "testTable1") + _, err = d.RowIter(sql.NewEmptyContext()) + require.NoError(err) + + d = NewDropTable(db, true, "testTable1", "testTable2", "testTable3") + _, err = d.RowIter(sql.NewEmptyContext()) + require.NoError(err) + + _, ok = db.Tables()["testTable3"] + require.False(ok) } + +func createTable(t *testing.T, db sql.Database, name string, schema sql.Schema) { + c := NewCreateTable(db, name, schema) + + rows, err := c.RowIter(sql.NewEmptyContext()) + require.NoError(t, err) + + r, err := rows.Next() + require.Equal(t, err, io.EOF) + require.Nil(t, r) +} \ No newline at end of file diff --git a/sql/plan/delete.go b/sql/plan/delete.go new file mode 100644 index 000000000..95498b397 --- /dev/null +++ b/sql/plan/delete.go @@ -0,0 +1,125 @@ +package plan + +import ( + "github.com/src-d/go-mysql-server/sql" + "gopkg.in/src-d/go-errors.v1" + "io" +) + +var ErrDeleteFromNotSupported = errors.NewKind("table doesn't support DELETE FROM") + +// DeleteFrom is a node describing a deletion from some table. +type DeleteFrom struct { + sql.Node +} + +// NewDeleteFrom creates a DeleteFrom node. +func NewDeleteFrom(n sql.Node) *DeleteFrom { + return &DeleteFrom{n} +} + +// Schema implements the Node interface. +func (p *DeleteFrom) Schema() sql.Schema { + return sql.Schema{{ + Name: "updated", + Type: sql.Int64, + Default: int64(0), + Nullable: false, + }} +} + +// Resolved implements the Resolvable interface. +func (p *DeleteFrom) Resolved() bool { + return p.Node.Resolved() +} + +func (p *DeleteFrom) Children() []sql.Node { + return []sql.Node{p.Node} +} + +func getDeletable(node sql.Node) (sql.Deleter, error) { + switch node := node.(type) { + case sql.Deleter: + return node, nil + case *ResolvedTable: + return getDeletableTable(node.Table) + } + for _, child := range node.Children() { + deleter, _ := getDeletable(child) + if deleter != nil { + return deleter, nil + } + } + return nil, ErrDeleteFromNotSupported.New() +} + +func getDeletableTable(t sql.Table) (sql.Deleter, error) { + switch t := t.(type) { + case sql.Deleter: + return t, nil + case sql.TableWrapper: + return getDeletableTable(t.Underlying()) + default: + return nil, ErrDeleteFromNotSupported.New() + } +} + +// Execute deletes the rows in the database. +func (p *DeleteFrom) Execute(ctx *sql.Context) (int, error) { + deletable, err := getDeletable(p.Node) + if err != nil { + return 0, err + } + + iter, err := p.Node.RowIter(ctx) + if err != nil { + return 0, err + } + + i := 0 + for { + row, err := iter.Next() + if err == io.EOF { + break + } + + if err != nil { + _ = iter.Close() + return i, err + } + + if err := deletable.Delete(ctx, row); err != nil { + _ = iter.Close() + return i, err + } + + i++ + } + + return i, nil +} + +// RowIter implements the Node interface. +func (p *DeleteFrom) RowIter(ctx *sql.Context) (sql.RowIter, error) { + n, err := p.Execute(ctx) + if err != nil { + return nil, err + } + + return sql.RowsToRowIter(sql.NewRow(int64(n))), nil +} + +// WithChildren implements the Node interface. +func (p *DeleteFrom) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 1) + } + return NewDeleteFrom(children[0]), nil +} + +func (p DeleteFrom) String() string { + pr := sql.NewTreePrinter() + _ = pr.WriteNode("Delete") + _ = pr.WriteChildren(p.Node.String()) + return pr.String() +} diff --git a/sql/plan/describe.go b/sql/plan/describe.go index d21975389..e84cdc8a0 100644 --- a/sql/plan/describe.go +++ b/sql/plan/describe.go @@ -2,8 +2,9 @@ package plan import ( "io" + "strings" - "gopkg.in/src-d/go-mysql-server.v0/sql" + "github.com/src-d/go-mysql-server/sql" ) // Describe is a node that describes its children. @@ -32,26 +33,20 @@ func (d *Describe) RowIter(ctx *sql.Context) (sql.RowIter, error) { return &describeIter{schema: d.Child.Schema()}, nil } -// TransformUp implements the Transformable interface. -func (d *Describe) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := d.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (d *Describe) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 1) } - return f(NewDescribe(child)) -} -// TransformExpressionsUp implements the Transformable interface. -func (d *Describe) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - child, err := d.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - return NewDescribe(child), nil + return NewDescribe(children[0]), nil } func (d Describe) String() string { - return "Describe" + p := sql.NewTreePrinter() + _ = p.WriteNode("Describe") + _ = p.WriteChildren(d.Child.String()) + return p.String() } type describeIter struct { @@ -66,9 +61,57 @@ func (i *describeIter) Next() (sql.Row, error) { f := i.schema[i.i] i.i++ - return sql.NewRow(f.Name, f.Type.Type().String()), nil + return sql.NewRow(f.Name, sql.MySQLTypeName(f.Type)), nil } func (i *describeIter) Close() error { return nil } + +// DescribeQuery returns the description of the query plan. +type DescribeQuery struct { + UnaryNode + Format string +} + +// DescribeSchema is the schema returned by a DescribeQuery node. +var DescribeSchema = sql.Schema{ + {Name: "plan", Type: sql.Text}, +} + +// NewDescribeQuery creates a new DescribeQuery node. +func NewDescribeQuery(format string, child sql.Node) *DescribeQuery { + return &DescribeQuery{UnaryNode{Child: child}, format} +} + +// Schema implements the Node interface. +func (d *DescribeQuery) Schema() sql.Schema { + return DescribeSchema +} + +// RowIter implements the Node interface. +func (d *DescribeQuery) RowIter(ctx *sql.Context) (sql.RowIter, error) { + var rows []sql.Row + for _, l := range strings.Split(d.Child.String(), "\n") { + if strings.TrimSpace(l) != "" { + rows = append(rows, sql.NewRow(l)) + } + } + return sql.RowsToRowIter(rows...), nil +} + +func (d *DescribeQuery) String() string { + pr := sql.NewTreePrinter() + _ = pr.WriteNode("DescribeQuery(format=%s)", d.Format) + _ = pr.WriteChildren(d.Child.String()) + return pr.String() +} + +// WithChildren implements the Node interface. +func (d *DescribeQuery) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 1) + } + + return NewDescribeQuery(d.Format, children[0]), nil +} diff --git a/sql/plan/describe_test.go b/sql/plan/describe_test.go index 660e84c9d..18c0b9f84 100644 --- a/sql/plan/describe_test.go +++ b/sql/plan/describe_test.go @@ -4,32 +4,33 @@ import ( "io" "testing" + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/mem" - "gopkg.in/src-d/go-mysql-server.v0/sql" ) func TestDescribe(t *testing.T) { require := require.New(t) ctx := sql.NewEmptyContext() - table := mem.NewTable("test", sql.Schema{ + table := memory.NewTable("test", sql.Schema{ {Name: "c1", Type: sql.Text}, {Name: "c2", Type: sql.Int32}, }) - d := NewDescribe(table) + d := NewDescribe(NewResolvedTable(table)) iter, err := d.RowIter(ctx) - require.Nil(err) + require.NoError(err) require.NotNil(iter) n, err := iter.Next() - require.Nil(err) + require.NoError(err) require.Equal(sql.NewRow("c1", "TEXT"), n) n, err = iter.Next() - require.Nil(err) - require.Equal(sql.NewRow("c2", "INT32"), n) + require.NoError(err) + require.Equal(sql.NewRow("c2", "INTEGER"), n) n, err = iter.Next() require.Equal(io.EOF, err) @@ -40,13 +41,52 @@ func TestDescribe_Empty(t *testing.T) { require := require.New(t) ctx := sql.NewEmptyContext() - d := NewDescribe(NewUnresolvedTable("test_table")) + d := NewDescribe(NewUnresolvedTable("test_table", "")) iter, err := d.RowIter(ctx) - require.Nil(err) + require.NoError(err) require.NotNil(iter) n, err := iter.Next() require.Equal(io.EOF, err) require.Nil(n) } + +func TestDescribeQuery(t *testing.T) { + require := require.New(t) + + table := memory.NewTable("foo", sql.Schema{ + {Source: "foo", Name: "a", Type: sql.Text}, + {Source: "foo", Name: "b", Type: sql.Text}, + }) + + node := NewDescribeQuery("tree", NewProject( + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Text, "foo", "a", false), + expression.NewGetFieldWithTable(1, sql.Text, "foo", "b", false), + }, + NewFilter( + expression.NewEquals( + expression.NewGetFieldWithTable(0, sql.Text, "foo", "a", false), + expression.NewLiteral("foo", sql.Text), + ), + NewResolvedTable(table), + ), + )) + + iter, err := node.RowIter(sql.NewEmptyContext()) + require.NoError(err) + + rows, err := sql.RowIterToRows(iter) + require.NoError(err) + + expected := []sql.Row{ + {"Project(foo.a, foo.b)"}, + {" └─ Filter(foo.a = \"foo\")"}, + {" └─ Table(foo)"}, + {" ├─ Column(a, TEXT, nullable=false)"}, + {" └─ Column(b, TEXT, nullable=false)"}, + } + + require.Equal(expected, rows) +} diff --git a/sql/plan/distinct.go b/sql/plan/distinct.go index 345d6c9a0..8a3da753a 100644 --- a/sql/plan/distinct.go +++ b/sql/plan/distinct.go @@ -1,10 +1,9 @@ package plan import ( - "fmt" + "io" - "github.com/mitchellh/hashstructure" - "gopkg.in/src-d/go-mysql-server.v0/sql" + "github.com/src-d/go-mysql-server/sql" ) // Distinct is a node that ensures all rows that come from it are unique. @@ -34,25 +33,16 @@ func (d *Distinct) RowIter(ctx *sql.Context) (sql.RowIter, error) { return nil, err } - return sql.NewSpanIter(span, newDistinctIter(it)), nil + return sql.NewSpanIter(span, newDistinctIter(ctx, it)), nil } -// TransformUp implements the Transformable interface. -func (d *Distinct) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := d.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (d *Distinct) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 1) } - return f(NewDistinct(child)) -} -// TransformExpressionsUp implements the Transformable interface. -func (d *Distinct) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - child, err := d.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - return NewDistinct(child), nil + return NewDistinct(children[0]), nil } func (d Distinct) String() string { @@ -64,18 +54,21 @@ func (d Distinct) String() string { // distinctIter keeps track of the hashes of all rows that have been emitted. // It does not emit any rows whose hashes have been seen already. -// TODO: come up with a way to use less memory than keeping all hashes in mem. +// TODO: come up with a way to use less memory than keeping all hashes in memory. // Even though they are just 64-bit integers, this could be a problem in large // result sets. type distinctIter struct { childIter sql.RowIter - seen map[uint64]struct{} + seen sql.KeyValueCache + dispose sql.DisposeFunc } -func newDistinctIter(child sql.RowIter) *distinctIter { +func newDistinctIter(ctx *sql.Context, child sql.RowIter) *distinctIter { + cache, dispose := ctx.Memory.NewHistoryCache() return &distinctIter{ childIter: child, - seen: make(map[uint64]struct{}), + seen: cache, + dispose: dispose, } } @@ -83,29 +76,38 @@ func (di *distinctIter) Next() (sql.Row, error) { for { row, err := di.childIter.Next() if err != nil { + if err == io.EOF { + di.Dispose() + } return nil, err } - hash, err := hashstructure.Hash(row, nil) - if err != nil { - return nil, fmt.Errorf("unable to hash row: %s", err) + hash := sql.CacheKey(row) + if _, err := di.seen.Get(hash); err == nil { + continue } - if _, ok := di.seen[hash]; ok { - continue + if err := di.seen.Put(hash, struct{}{}); err != nil { + return nil, err } - di.seen[hash] = struct{}{} return row, nil } } func (di *distinctIter) Close() error { + di.Dispose() return di.childIter.Close() } +func (di *distinctIter) Dispose() { + if di.dispose != nil { + di.dispose() + } +} + // OrderedDistinct is a Distinct node optimized for sorted row sets. -// It's 2 orders of magnitude faster and uses 2 orders of magnitude less mem. +// It's 2 orders of magnitude faster and uses 2 orders of magnitude less memory. type OrderedDistinct struct { UnaryNode } @@ -135,22 +137,13 @@ func (d *OrderedDistinct) RowIter(ctx *sql.Context) (sql.RowIter, error) { return sql.NewSpanIter(span, newOrderedDistinctIter(it, d.Child.Schema())), nil } -// TransformUp implements the Transformable interface. -func (d *OrderedDistinct) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := d.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (d *OrderedDistinct) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 1) } - return f(NewOrderedDistinct(child)) -} -// TransformExpressionsUp implements the Transformable interface. -func (d *OrderedDistinct) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - child, err := d.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - return NewOrderedDistinct(child), nil + return NewOrderedDistinct(children[0]), nil } func (d OrderedDistinct) String() string { diff --git a/sql/plan/distinct_test.go b/sql/plan/distinct_test.go index 78bf3e292..098308d47 100644 --- a/sql/plan/distinct_test.go +++ b/sql/plan/distinct_test.go @@ -4,10 +4,10 @@ import ( "io" "testing" + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/mem" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" ) func TestDistinct(t *testing.T) { @@ -18,20 +18,27 @@ func TestDistinct(t *testing.T) { {Name: "name", Type: sql.Text, Nullable: true}, {Name: "email", Type: sql.Text, Nullable: true}, } - child := mem.NewTable("test", childSchema) - require.NoError(child.Insert(sql.NewRow("john", "john@doe.com"))) - require.NoError(child.Insert(sql.NewRow("jane", "jane@doe.com"))) - require.NoError(child.Insert(sql.NewRow("john", "johnx@doe.com"))) - require.NoError(child.Insert(sql.NewRow("martha", "marthax@doe.com"))) - require.NoError(child.Insert(sql.NewRow("martha", "martha@doe.com"))) + child := memory.NewTable("test", childSchema) + + rows := []sql.Row{ + sql.NewRow("john", "john@doe.com"), + sql.NewRow("jane", "jane@doe.com"), + sql.NewRow("john", "johnx@doe.com"), + sql.NewRow("martha", "marthax@doe.com"), + sql.NewRow("martha", "martha@doe.com"), + } + + for _, r := range rows { + require.NoError(child.Insert(sql.NewEmptyContext(), r)) + } p := NewProject([]sql.Expression{ expression.NewGetField(0, sql.Text, "name", true), - }, child) + }, NewResolvedTable(child)) d := NewDistinct(p) iter, err := d.RowIter(ctx) - require.Nil(err) + require.NoError(err) require.NotNil(iter) var results []string @@ -58,20 +65,27 @@ func TestOrderedDistinct(t *testing.T) { {Name: "name", Type: sql.Text, Nullable: true}, {Name: "email", Type: sql.Text, Nullable: true}, } - child := mem.NewTable("test", childSchema) - require.NoError(child.Insert(sql.NewRow("jane", "jane@doe.com"))) - require.NoError(child.Insert(sql.NewRow("john", "john@doe.com"))) - require.NoError(child.Insert(sql.NewRow("john", "johnx@doe.com"))) - require.NoError(child.Insert(sql.NewRow("martha", "martha@doe.com"))) - require.NoError(child.Insert(sql.NewRow("martha", "marthax@doe.com"))) + child := memory.NewTable("test", childSchema) + + rows := []sql.Row{ + sql.NewRow("jane", "jane@doe.com"), + sql.NewRow("john", "john@doe.com"), + sql.NewRow("john", "johnx@doe.com"), + sql.NewRow("martha", "martha@doe.com"), + sql.NewRow("martha", "marthax@doe.com"), + } + + for _, r := range rows { + require.NoError(child.Insert(sql.NewEmptyContext(), r)) + } p := NewProject([]sql.Expression{ expression.NewGetField(0, sql.Text, "name", true), - }, child) + }, NewResolvedTable(child)) d := NewOrderedDistinct(p) iter, err := d.RowIter(ctx) - require.Nil(err) + require.NoError(err) require.NotNil(iter) var results []string @@ -102,11 +116,11 @@ func BenchmarkDistinct(b *testing.B) { expression.NewGetField(3, sql.Int32, "intfield", false), expression.NewGetField(4, sql.Int64, "bigintfield", false), expression.NewGetField(5, sql.Blob, "blobfield", false), - }, benchtable) + }, NewResolvedTable(benchtable)) d := NewDistinct(p) iter, err := d.RowIter(ctx) - require.Nil(err) + require.NoError(err) require.NotNil(iter) var rows int @@ -135,11 +149,11 @@ func BenchmarkOrderedDistinct(b *testing.B) { expression.NewGetField(3, sql.Int32, "intfield", false), expression.NewGetField(4, sql.Int64, "bigintfield", false), expression.NewGetField(5, sql.Blob, "blobfield", false), - }, benchtable) + }, NewResolvedTable(benchtable)) d := NewOrderedDistinct(p) iter, err := d.RowIter(ctx) - require.Nil(err) + require.NoError(err) require.NotNil(iter) var rows int diff --git a/sql/plan/drop_index.go b/sql/plan/drop_index.go new file mode 100644 index 000000000..f08caf711 --- /dev/null +++ b/sql/plan/drop_index.go @@ -0,0 +1,116 @@ +package plan + +import ( + "github.com/src-d/go-mysql-server/internal/similartext" + "github.com/src-d/go-mysql-server/sql" + "gopkg.in/src-d/go-errors.v1" +) + +var ( + // ErrIndexNotFound is returned when the index cannot be found. + ErrIndexNotFound = errors.NewKind("unable to find index %q on table %q of database %q") + // ErrTableNotValid is returned when the table is not valid + ErrTableNotValid = errors.NewKind("table is not valid") + // ErrTableNotNameable is returned when the table is not nameable. + ErrTableNotNameable = errors.NewKind("can't get name from table") + // ErrIndexNotAvailable is returned when trying to delete an index that is + // still not ready for usage. + ErrIndexNotAvailable = errors.NewKind("index %q is still not ready for usage and can't be deleted") +) + +// DropIndex is a node to drop an index. +type DropIndex struct { + Name string + Table sql.Node + Catalog *sql.Catalog + CurrentDatabase string +} + +// NewDropIndex creates a new DropIndex node. +func NewDropIndex(name string, table sql.Node) *DropIndex { + return &DropIndex{name, table, nil, ""} +} + +// Resolved implements the Node interface. +func (d *DropIndex) Resolved() bool { return d.Table.Resolved() } + +// Schema implements the Node interface. +func (d *DropIndex) Schema() sql.Schema { return nil } + +// Children implements the Node interface. +func (d *DropIndex) Children() []sql.Node { return []sql.Node{d.Table} } + +// RowIter implements the Node interface. +func (d *DropIndex) RowIter(ctx *sql.Context) (sql.RowIter, error) { + db, err := d.Catalog.Database(d.CurrentDatabase) + if err != nil { + return nil, err + } + + n, ok := d.Table.(sql.Nameable) + if !ok { + return nil, ErrTableNotNameable.New() + } + + tables := db.Tables() + table, ok := tables[n.Name()] + if !ok { + if len(tables) == 0 { + return nil, sql.ErrTableNotFound.New(n.Name()) + } + + similar := similartext.FindFromMap(tables, n.Name()) + return nil, sql.ErrTableNotFound.New(n.Name() + similar) + } + + index := d.Catalog.Index(db.Name(), d.Name) + if index == nil { + return nil, ErrIndexNotFound.New(d.Name, n.Name(), db.Name()) + } + d.Catalog.ReleaseIndex(index) + + if !d.Catalog.CanRemoveIndex(index) { + return nil, ErrIndexNotAvailable.New(d.Name) + } + + done, err := d.Catalog.DeleteIndex(db.Name(), d.Name, true) + if err != nil { + return nil, err + } + + driver := d.Catalog.IndexDriver(index.Driver()) + if driver == nil { + return nil, ErrInvalidIndexDriver.New(index.Driver()) + } + + <-done + + partitions, err := table.Partitions(ctx) + if err != nil { + return nil, err + } + + if err := driver.Delete(index, partitions); err != nil { + return nil, err + } + + return sql.RowsToRowIter(), nil +} + +func (d *DropIndex) String() string { + pr := sql.NewTreePrinter() + _ = pr.WriteNode("DropIndex(%s)", d.Name) + _ = pr.WriteChildren(d.Table.String()) + return pr.String() +} + +// WithChildren implements the Node interface. +func (d *DropIndex) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 1) + } + + nd := *d + nd.Table = children[0] + return &nd, nil +} diff --git a/sql/plan/drop_index_test.go b/sql/plan/drop_index_test.go new file mode 100644 index 000000000..cd2e214fe --- /dev/null +++ b/sql/plan/drop_index_test.go @@ -0,0 +1,143 @@ +package plan + +import ( + "testing" + "time" + + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" +) + +func TestDeleteIndex(t *testing.T) { + require := require.New(t) + + table := memory.NewTable("foo", sql.Schema{ + {Name: "a", Source: "foo"}, + {Name: "b", Source: "foo"}, + {Name: "c", Source: "foo"}, + }) + + driver := new(mockDriver) + catalog := sql.NewCatalog() + catalog.RegisterIndexDriver(driver) + db := memory.NewDatabase("foo") + db.AddTable("foo", table) + catalog.AddDatabase(db) + + var expressions = []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "c", true), + expression.NewGetFieldWithTable(1, sql.Int64, "foo", "a", true), + } + + done, ready, err := catalog.AddIndex(&mockIndex{id: "idx", db: "foo", table: "foo", exprs: expressions}) + require.NoError(err) + close(done) + <-ready + + idx := catalog.Index("foo", "idx") + require.NotNil(idx) + catalog.ReleaseIndex(idx) + + di := NewDropIndex("idx", NewResolvedTable(table)) + di.Catalog = catalog + di.CurrentDatabase = "foo" + + _, err = di.RowIter(sql.NewEmptyContext()) + require.NoError(err) + + time.Sleep(50 * time.Millisecond) + + require.Equal([]string{"idx"}, driver.deleted) + require.Nil(catalog.Index("foo", "idx")) +} + +func TestDeleteIndexNotReady(t *testing.T) { + require := require.New(t) + + table := memory.NewTable("foo", sql.Schema{ + {Name: "a", Source: "foo"}, + {Name: "b", Source: "foo"}, + {Name: "c", Source: "foo"}, + }) + + driver := new(mockDriver) + catalog := sql.NewCatalog() + catalog.RegisterIndexDriver(driver) + db := memory.NewDatabase("foo") + db.AddTable("foo", table) + catalog.AddDatabase(db) + + var expressions = []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "c", true), + expression.NewGetFieldWithTable(1, sql.Int64, "foo", "a", true), + } + + done, ready, err := catalog.AddIndex(&mockIndex{id: "idx", db: "foo", table: "foo", exprs: expressions}) + require.NoError(err) + + idx := catalog.Index("foo", "idx") + require.NotNil(idx) + catalog.ReleaseIndex(idx) + + di := NewDropIndex("idx", NewResolvedTable(table)) + di.Catalog = catalog + di.CurrentDatabase = "foo" + + _, err = di.RowIter(sql.NewEmptyContext()) + require.Error(err) + require.True(ErrIndexNotAvailable.Is(err)) + + time.Sleep(50 * time.Millisecond) + + require.Equal(([]string)(nil), driver.deleted) + require.NotNil(catalog.Index("foo", "idx")) + + close(done) + <-ready +} + +func TestDeleteIndexOutdated(t *testing.T) { + require := require.New(t) + + table := memory.NewTable("foo", sql.Schema{ + {Name: "a", Source: "foo"}, + {Name: "b", Source: "foo"}, + {Name: "c", Source: "foo"}, + }) + + driver := new(mockDriver) + catalog := sql.NewCatalog() + catalog.RegisterIndexDriver(driver) + db := memory.NewDatabase("foo") + db.AddTable("foo", table) + catalog.AddDatabase(db) + + var expressions = []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "c", true), + expression.NewGetFieldWithTable(1, sql.Int64, "foo", "a", true), + } + + done, ready, err := catalog.AddIndex(&mockIndex{id: "idx", db: "foo", table: "foo", exprs: expressions}) + require.NoError(err) + close(done) + <-ready + + idx := catalog.Index("foo", "idx") + require.NotNil(idx) + catalog.ReleaseIndex(idx) + catalog.MarkOutdated(idx) + + di := NewDropIndex("idx", NewResolvedTable(table)) + di.Catalog = catalog + di.CurrentDatabase = "foo" + + _, err = di.RowIter(sql.NewEmptyContext()) + require.NoError(err) + + time.Sleep(50 * time.Millisecond) + + require.Equal([]string{"idx"}, driver.deleted) + require.Nil(catalog.Index("foo", "idx")) +} diff --git a/sql/plan/empty_table.go b/sql/plan/empty_table.go new file mode 100644 index 000000000..198cef41d --- /dev/null +++ b/sql/plan/empty_table.go @@ -0,0 +1,26 @@ +package plan + +import "github.com/src-d/go-mysql-server/sql" + +// EmptyTable is a node representing an empty table. +var EmptyTable = new(emptyTable) + +type emptyTable struct{} + +func (emptyTable) Schema() sql.Schema { return nil } +func (emptyTable) Children() []sql.Node { return nil } +func (emptyTable) Resolved() bool { return true } +func (e *emptyTable) String() string { return "EmptyTable" } + +func (emptyTable) RowIter(ctx *sql.Context) (sql.RowIter, error) { + return sql.RowsToRowIter(), nil +} + +// WithChildren implements the Node interface. +func (e *emptyTable) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(e, len(children), 0) + } + + return e, nil +} diff --git a/sql/plan/exchange.go b/sql/plan/exchange.go new file mode 100644 index 000000000..3eb59808d --- /dev/null +++ b/sql/plan/exchange.go @@ -0,0 +1,325 @@ +package plan + +import ( + "context" + "fmt" + "io" + "sync" + + "github.com/src-d/go-mysql-server/sql" + errors "gopkg.in/src-d/go-errors.v1" +) + +// ErrNoPartitionable is returned when no Partitionable node is found +// in the Exchange tree. +var ErrNoPartitionable = errors.NewKind("no partitionable node found in exchange tree") + +// Exchange is a node that can parallelize the underlying tree iterating +// partitions concurrently. +type Exchange struct { + UnaryNode + Parallelism int +} + +// NewExchange creates a new Exchange node. +func NewExchange( + parallelism int, + child sql.Node, +) *Exchange { + return &Exchange{ + UnaryNode: UnaryNode{Child: child}, + Parallelism: parallelism, + } +} + +// RowIter implements the sql.Node interface. +func (e *Exchange) RowIter(ctx *sql.Context) (sql.RowIter, error) { + var t sql.Table + Inspect(e.Child, func(n sql.Node) bool { + if table, ok := n.(sql.Table); ok { + t = table + return false + } + return true + }) + if t == nil { + return nil, ErrNoPartitionable.New() + } + + partitions, err := t.Partitions(ctx) + if err != nil { + return nil, err + } + + return newExchangeRowIter(ctx, e.Parallelism, partitions, e.Child), nil +} + +func (e *Exchange) String() string { + p := sql.NewTreePrinter() + _ = p.WriteNode("Exchange(parallelism=%d)", e.Parallelism) + _ = p.WriteChildren(e.Child.String()) + return p.String() +} + +// WithChildren implements the Node interface. +func (e *Exchange) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(e, len(children), 1) + } + + return NewExchange(e.Parallelism, children[0]), nil +} + +type exchangeRowIter struct { + ctx *sql.Context + parallelism int + partitions sql.PartitionIter + tree sql.Node + mut sync.RWMutex + tokensChan chan struct{} + started bool + rows chan sql.Row + err chan error + + quitMut sync.RWMutex + quitChan chan struct{} +} + +func newExchangeRowIter( + ctx *sql.Context, + parallelism int, + iter sql.PartitionIter, + tree sql.Node, +) *exchangeRowIter { + return &exchangeRowIter{ + ctx: ctx, + parallelism: parallelism, + rows: make(chan sql.Row, parallelism), + err: make(chan error, 1), + started: false, + tree: tree, + partitions: iter, + quitChan: make(chan struct{}), + } +} + +func (it *exchangeRowIter) releaseToken() { + it.mut.Lock() + defer it.mut.Unlock() + + if it.tokensChan != nil { + it.tokensChan <- struct{}{} + } +} + +func (it *exchangeRowIter) closeTokens() { + it.mut.Lock() + defer it.mut.Unlock() + + close(it.tokensChan) + it.tokensChan = nil +} + +func (it *exchangeRowIter) tokens() chan struct{} { + it.mut.RLock() + defer it.mut.RUnlock() + return it.tokensChan +} + +func (it *exchangeRowIter) fillTokens() { + it.mut.Lock() + defer it.mut.Unlock() + + it.tokensChan = make(chan struct{}, it.parallelism) + for i := 0; i < it.parallelism; i++ { + it.tokensChan <- struct{}{} + } +} + +func (it *exchangeRowIter) start() { + it.fillTokens() + + var partitions = make(chan sql.Partition) + go it.iterPartitions(partitions) + + var wg sync.WaitGroup + + for { + select { + case <-it.ctx.Done(): + it.err <- context.Canceled + it.closeTokens() + return + case <-it.quit(): + it.closeTokens() + return + case p, ok := <-partitions: + if !ok { + it.closeTokens() + + wg.Wait() + close(it.rows) + return + } + + wg.Add(1) + go func(p sql.Partition) { + it.iterPartition(p) + wg.Done() + + it.releaseToken() + }(p) + } + } +} + +func (it *exchangeRowIter) iterPartitions(ch chan<- sql.Partition) { + defer func() { + if x := recover(); x != nil { + it.err <- fmt.Errorf("mysql_server caught panic:\n%v", x) + } + + close(ch) + }() + + for { + select { + case <-it.ctx.Done(): + it.err <- context.Canceled + return + case <-it.quit(): + return + case <-it.tokens(): + } + + p, err := it.partitions.Next() + if err != nil { + if err != io.EOF { + it.err <- err + } + return + } + + ch <- p + } +} + +func (it *exchangeRowIter) iterPartition(p sql.Partition) { + node, err := TransformUp(it.tree, func(n sql.Node) (sql.Node, error) { + if t, ok := n.(sql.Table); ok { + return &exchangePartition{p, t}, nil + } + + return n, nil + }) + if err != nil { + it.err <- err + return + } + + rows, err := node.RowIter(it.ctx) + if err != nil { + it.err <- err + return + } + + defer func() { + if err := rows.Close(); err != nil { + it.err <- err + } + }() + + for { + select { + case <-it.ctx.Done(): + it.err <- context.Canceled + return + case <-it.quit(): + return + default: + } + + row, err := rows.Next() + if err != nil { + if err == io.EOF { + break + } + + it.err <- err + return + } + + it.rows <- row + } +} + +func (it *exchangeRowIter) Next() (sql.Row, error) { + if !it.started { + it.started = true + go it.start() + } + + select { + case err := <-it.err: + _ = it.Close() + return nil, err + case row, ok := <-it.rows: + if !ok { + return nil, io.EOF + } + return row, nil + } +} + +func (it *exchangeRowIter) quit() chan struct{} { + it.quitMut.RLock() + defer it.quitMut.RUnlock() + return it.quitChan +} + +func (it *exchangeRowIter) Close() error { + it.quitMut.Lock() + if it.quitChan != nil { + close(it.quitChan) + it.quitChan = nil + } + it.quitMut.Unlock() + + if it.partitions != nil { + return it.partitions.Close() + } + + return nil +} + +type exchangePartition struct { + sql.Partition + table sql.Table +} + +var _ sql.Node = (*exchangePartition)(nil) + +func (p *exchangePartition) String() string { + return fmt.Sprintf("Partition(%s)", string(p.Key())) +} + +func (exchangePartition) Children() []sql.Node { return nil } + +func (exchangePartition) Resolved() bool { return true } + +func (p *exchangePartition) RowIter(ctx *sql.Context) (sql.RowIter, error) { + return p.table.PartitionRows(ctx, p.Partition) +} + +func (p *exchangePartition) Schema() sql.Schema { + return p.table.Schema() +} + +// WithChildren implements the Node interface. +func (p *exchangePartition) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 0) + } + + return p, nil +} diff --git a/sql/plan/exchange_test.go b/sql/plan/exchange_test.go new file mode 100644 index 000000000..5a8e8e317 --- /dev/null +++ b/sql/plan/exchange_test.go @@ -0,0 +1,192 @@ +package plan + +import ( + "context" + "fmt" + "io" + "testing" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" +) + +func TestExchange(t *testing.T) { + children := NewProject( + []sql.Expression{ + expression.NewGetField(0, sql.Text, "partition", false), + expression.NewArithmetic( + expression.NewGetField(1, sql.Int64, "val", false), + expression.NewLiteral(int64(1), sql.Int64), + "+", + ), + }, + NewFilter( + expression.NewLessThan( + expression.NewGetField(1, sql.Int64, "val", false), + expression.NewLiteral(int64(4), sql.Int64), + ), + &partitionable{nil, 3, 6}, + ), + ) + + expected := []sql.Row{ + {"1", int64(2)}, + {"1", int64(3)}, + {"1", int64(4)}, + {"2", int64(2)}, + {"2", int64(3)}, + {"2", int64(4)}, + {"3", int64(2)}, + {"3", int64(3)}, + {"3", int64(4)}, + } + + for i := 1; i <= 4; i++ { + t.Run(fmt.Sprint(i), func(t *testing.T) { + require := require.New(t) + + exchange := NewExchange(i, children) + iter, err := exchange.RowIter(sql.NewEmptyContext()) + require.NoError(err) + + rows, err := sql.RowIterToRows(iter) + require.NoError(err) + require.ElementsMatch(expected, rows) + }) + } +} + +func TestExchangeCancelled(t *testing.T) { + children := NewProject( + []sql.Expression{ + expression.NewGetField(0, sql.Text, "partition", false), + expression.NewArithmetic( + expression.NewGetField(1, sql.Int64, "val", false), + expression.NewLiteral(int64(1), sql.Int64), + "+", + ), + }, + NewFilter( + expression.NewLessThan( + expression.NewGetField(1, sql.Int64, "val", false), + expression.NewLiteral(int64(4), sql.Int64), + ), + &partitionable{nil, 3, 6}, + ), + ) + + exchange := NewExchange(3, children) + require := require.New(t) + + c, cancel := context.WithCancel(context.Background()) + ctx := sql.NewContext(c) + cancel() + + iter, err := exchange.RowIter(ctx) + require.NoError(err) + + _, err = iter.Next() + require.Equal(context.Canceled, err) +} + +func TestExchangePanicRecover(t *testing.T) { + ctx := sql.NewContext(context.Background()) + it := &partitionPanic{} + ex := newExchangeRowIter(ctx, 1, it, nil) + ex.start() + it.Close() + + require.True(t, it.closed) +} + +type partitionable struct { + sql.Node + partitions int + rowsPerPartition int +} + +// WithChildren implements the Node interface. +func (p *partitionable) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 0) + } + + return p, nil +} + +func (partitionable) Children() []sql.Node { return nil } + +func (p partitionable) Partitions(*sql.Context) (sql.PartitionIter, error) { + return &exchangePartitionIter{p.partitions}, nil +} + +func (p partitionable) PartitionRows(_ *sql.Context, part sql.Partition) (sql.RowIter, error) { + return &partitionRows{part, p.rowsPerPartition}, nil +} + +func (partitionable) Schema() sql.Schema { + return sql.Schema{ + {Name: "partition", Type: sql.Text, Source: "foo"}, + {Name: "val", Type: sql.Int64, Source: "foo"}, + } +} + +func (partitionable) Name() string { return "partitionable" } + +type Partition string + +func (p Partition) Key() []byte { + return []byte(p) +} + +type exchangePartitionIter struct { + num int +} + +func (i *exchangePartitionIter) Next() (sql.Partition, error) { + if i.num <= 0 { + return nil, io.EOF + } + + i.num-- + return Partition(fmt.Sprint(i.num + 1)), nil +} + +func (i *exchangePartitionIter) Close() error { + i.num = -1 + return nil +} + +type partitionRows struct { + sql.Partition + num int +} + +func (r *partitionRows) Next() (sql.Row, error) { + if r.num <= 0 { + return nil, io.EOF + } + + r.num-- + return sql.NewRow(string(r.Key()), int64(r.num+1)), nil +} + +func (r *partitionRows) Close() error { + r.num = -1 + return nil +} + +type partitionPanic struct { + sql.Partition + closed bool +} + +func (*partitionPanic) Next() (sql.Partition, error) { + panic("partitionPanic.Next") +} + +func (p *partitionPanic) Close() error { + p.closed = true + return nil +} diff --git a/sql/plan/filter.go b/sql/plan/filter.go index 267653785..f160b3f66 100644 --- a/sql/plan/filter.go +++ b/sql/plan/filter.go @@ -1,6 +1,8 @@ package plan -import "gopkg.in/src-d/go-mysql-server.v0/sql" +import ( + "github.com/src-d/go-mysql-server/sql" +) // Filter skips rows that don't match a certain expression. type Filter struct { @@ -34,31 +36,25 @@ func (p *Filter) RowIter(ctx *sql.Context) (sql.RowIter, error) { return sql.NewSpanIter(span, NewFilterIter(ctx, p.Expression, i)), nil } -// TransformUp implements the Transformable interface. -func (p *Filter) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := p.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (p *Filter) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 1) } - return f(NewFilter(p.Expression, child)) -} -// TransformExpressionsUp implements the Transformable interface. -func (p *Filter) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - expr, err := p.Expression.TransformUp(f) - if err != nil { - return nil, err - } + return NewFilter(p.Expression, children[0]), nil +} - child, err := p.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err +// WithExpressions implements the Expressioner interface. +func (p *Filter) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + if len(exprs) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(exprs), 1) } - return NewFilter(expr, child), nil + return NewFilter(exprs[0], p.Child), nil } -func (p Filter) String() string { +func (p *Filter) String() string { pr := sql.NewTreePrinter() _ = pr.WriteNode("Filter(%s)", p.Expression) _ = pr.WriteChildren(p.Child.String()) @@ -66,7 +62,7 @@ func (p Filter) String() string { } // Expressions implements the Expressioner interface. -func (p Filter) Expressions() []sql.Expression { +func (p *Filter) Expressions() []sql.Expression { return []sql.Expression{p.Expression} } @@ -95,12 +91,12 @@ func (i *FilterIter) Next() (sql.Row, error) { return nil, err } - result, err := i.cond.Eval(i.ctx, row) + ok, err := sql.EvaluateCondition(i.ctx, i.cond, row) if err != nil { return nil, err } - if result == true { + if ok { return row, nil } } diff --git a/sql/plan/filter_test.go b/sql/plan/filter_test.go index f6f79062a..de5912fe7 100644 --- a/sql/plan/filter_test.go +++ b/sql/plan/filter_test.go @@ -3,10 +3,10 @@ package plan import ( "testing" + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/mem" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" ) func TestFilter(t *testing.T) { @@ -19,28 +19,32 @@ func TestFilter(t *testing.T) { {Name: "col3", Type: sql.Int32, Nullable: true}, {Name: "col4", Type: sql.Int64, Nullable: true}, } - child := mem.NewTable("test", childSchema) - err := child.Insert(sql.NewRow("col1_1", "col2_1", int32(1111), int64(2222))) - require.Nil(err) - err = child.Insert(sql.NewRow("col1_2", "col2_2", int32(3333), int64(4444))) - require.Nil(err) - err = child.Insert(sql.NewRow("col1_3", "col2_3", nil, int64(4444))) - require.Nil(err) + child := memory.NewTable("test", childSchema) + + rows := []sql.Row{ + sql.NewRow("col1_1", "col2_1", int32(1111), int64(2222)), + sql.NewRow("col1_2", "col2_2", int32(3333), int64(4444)), + sql.NewRow("col1_3", "col2_3", nil, int64(4444)), + } + + for _, r := range rows { + require.NoError(child.Insert(sql.NewEmptyContext(), r)) + } f := NewFilter( expression.NewEquals( expression.NewGetField(0, sql.Text, "col1", true), expression.NewLiteral("col1_1", sql.Text)), - child) + NewResolvedTable(child)) require.Equal(1, len(f.Children())) iter, err := f.RowIter(ctx) - require.Nil(err) + require.NoError(err) require.NotNil(iter) row, err := iter.Next() - require.Nil(err) + require.NoError(err) require.NotNil(row) require.Equal("col1_1", row[0]) @@ -53,14 +57,14 @@ func TestFilter(t *testing.T) { f = NewFilter(expression.NewEquals( expression.NewGetField(2, sql.Int32, "col3", true), expression.NewLiteral(int32(1111), - sql.Int32)), child) + sql.Int32)), NewResolvedTable(child)) iter, err = f.RowIter(ctx) - require.Nil(err) + require.NoError(err) require.NotNil(iter) row, err = iter.Next() - require.Nil(err) + require.NoError(err) require.NotNil(row) require.Equal(int32(1111), row[2]) @@ -69,14 +73,14 @@ func TestFilter(t *testing.T) { f = NewFilter(expression.NewEquals( expression.NewGetField(3, sql.Int64, "col4", true), expression.NewLiteral(int64(4444), sql.Int64)), - child) + NewResolvedTable(child)) iter, err = f.RowIter(ctx) - require.Nil(err) + require.NoError(err) require.NotNil(iter) row, err = iter.Next() - require.Nil(err) + require.NoError(err) require.NotNil(row) require.Equal(int32(3333), row[2]) diff --git a/sql/plan/generate.go b/sql/plan/generate.go new file mode 100644 index 000000000..c259b8d8d --- /dev/null +++ b/sql/plan/generate.go @@ -0,0 +1,134 @@ +package plan + +import ( + "fmt" + "io" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" +) + +// Generate will explode rows using a generator. +type Generate struct { + UnaryNode + Column *expression.GetField +} + +// NewGenerate creates a new generate node. +func NewGenerate(child sql.Node, col *expression.GetField) *Generate { + return &Generate{UnaryNode{child}, col} +} + +// Schema implements the sql.Node interface. +func (g *Generate) Schema() sql.Schema { + s := g.Child.Schema() + col := s[g.Column.Index()] + s[g.Column.Index()] = &sql.Column{ + Name: g.Column.Name(), + Type: sql.UnderlyingType(col.Type), + Nullable: col.Nullable, + } + return s +} + +// RowIter implements the sql.Node interface. +func (g *Generate) RowIter(ctx *sql.Context) (sql.RowIter, error) { + span, ctx := ctx.Span("plan.Generate") + + childIter, err := g.Child.RowIter(ctx) + if err != nil { + return nil, err + } + + return sql.NewSpanIter(span, &generateIter{ + child: childIter, + idx: g.Column.Index(), + }), nil +} + +// Expressions implements the Expressioner interface. +func (g *Generate) Expressions() []sql.Expression { return []sql.Expression{g.Column} } + +// WithChildren implements the Node interface. +func (g *Generate) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(g, len(children), 1) + } + + return NewGenerate(children[0], g.Column), nil +} + +// WithExpressions implements the Expressioner interface. +func (g *Generate) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + if len(exprs) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(g, len(exprs), 1) + } + + gf, ok := exprs[0].(*expression.GetField) + if !ok { + return nil, fmt.Errorf("Generate expects child to be expression.GetField, but is %T", exprs[0]) + } + + return NewGenerate(g.Child, gf), nil +} + +func (g *Generate) String() string { + tp := sql.NewTreePrinter() + _ = tp.WriteNode("Generate(%s)", g.Column) + _ = tp.WriteChildren(g.Child.String()) + return tp.String() +} + +type generateIter struct { + child sql.RowIter + idx int + + gen sql.Generator + row sql.Row +} + +func (i *generateIter) Next() (sql.Row, error) { + for { + if i.gen == nil { + var err error + i.row, err = i.child.Next() + if err != nil { + return nil, err + } + + i.gen, err = sql.ToGenerator(i.row[i.idx]) + if err != nil { + return nil, err + } + } + + val, err := i.gen.Next() + if err != nil { + if err == io.EOF { + if err := i.gen.Close(); err != nil { + return nil, err + } + + i.gen = nil + continue + } + return nil, err + } + + var row = make(sql.Row, len(i.row)) + copy(row, i.row) + row[i.idx] = val + return row, nil + } +} + +func (i *generateIter) Close() error { + if i.gen != nil { + if err := i.gen.Close(); err != nil { + _ = i.child.Close() + return err + } + } + + return i.child.Close() +} diff --git a/sql/plan/generate_test.go b/sql/plan/generate_test.go new file mode 100644 index 000000000..7f32db68c --- /dev/null +++ b/sql/plan/generate_test.go @@ -0,0 +1,85 @@ +package plan + +import ( + "testing" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" +) + +func TestGenerateRowIter(t *testing.T) { + require := require.New(t) + + child := newFakeNode( + sql.Schema{ + {Name: "a", Type: sql.Text, Source: "foo"}, + {Name: "b", Type: sql.Array(sql.Text), Source: "foo"}, + {Name: "c", Type: sql.Int64, Source: "foo"}, + }, + sql.RowsToRowIter( + sql.Row{"first", sql.NewArrayGenerator([]interface{}{"a", "b"}), int64(1)}, + sql.Row{"second", sql.NewArrayGenerator([]interface{}{"c", "d"}), int64(2)}, + ), + ) + + iter, err := NewGenerate( + child, + expression.NewGetFieldWithTable(1, sql.Array(sql.Text), "foo", "b", false), + ).RowIter(sql.NewEmptyContext()) + require.NoError(err) + + rows, err := sql.RowIterToRows(iter) + require.NoError(err) + + expected := []sql.Row{ + {"first", "a", int64(1)}, + {"first", "b", int64(1)}, + {"second", "c", int64(2)}, + {"second", "d", int64(2)}, + } + + require.Equal(expected, rows) +} + +func TestGenerateSchema(t *testing.T) { + require := require.New(t) + + schema := NewGenerate( + newFakeNode( + sql.Schema{ + {Name: "a", Type: sql.Text, Source: "foo"}, + {Name: "b", Type: sql.Array(sql.Text), Source: "foo"}, + {Name: "c", Type: sql.Int64, Source: "foo"}, + }, + nil, + ), + expression.NewGetField(1, sql.Array(sql.Text), "foobar", false), + ).Schema() + + expected := sql.Schema{ + {Name: "a", Type: sql.Text, Source: "foo"}, + {Name: "foobar", Type: sql.Text}, + {Name: "c", Type: sql.Int64, Source: "foo"}, + } + + require.Equal(expected, schema) +} + +type fakeNode struct { + schema sql.Schema + iter sql.RowIter +} + +func newFakeNode(s sql.Schema, iter sql.RowIter) *fakeNode { + return &fakeNode{s, iter} +} + +func (n *fakeNode) Children() []sql.Node { return nil } +func (n *fakeNode) Resolved() bool { return true } +func (n *fakeNode) Schema() sql.Schema { return n.schema } +func (n *fakeNode) RowIter(*sql.Context) (sql.RowIter, error) { return n.iter, nil } +func (n *fakeNode) String() string { return "fakeNode" } +func (*fakeNode) WithChildren(children ...sql.Node) (sql.Node, error) { + panic("placeholder") +} diff --git a/sql/plan/group_by.go b/sql/plan/group_by.go index 70f2992ea..4a9fb76b0 100644 --- a/sql/plan/group_by.go +++ b/sql/plan/group_by.go @@ -2,13 +2,14 @@ package plan import ( "fmt" + "hash/crc64" "io" "strings" opentracing "github.com/opentracing/opentracing-go" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" errors "gopkg.in/src-d/go-errors.v1" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" ) // ErrGroupBy is returned when the aggregation is not supported. @@ -27,7 +28,6 @@ func NewGroupBy( grouping []sql.Expression, child sql.Node, ) *GroupBy { - return &GroupBy{ UnaryNode: UnaryNode{Child: child}, Aggregate: aggregate, @@ -53,10 +53,16 @@ func (p *GroupBy) Schema() sql.Schema { name = e.String() } + var table string + if t, ok := e.(sql.Tableable); ok { + table = t.Table() + } + s[i] = &sql.Column{ Name: name, Type: e.Type(), Nullable: e.IsNullable(), + Source: table, } } @@ -75,39 +81,48 @@ func (p *GroupBy) RowIter(ctx *sql.Context) (sql.RowIter, error) { span.Finish() return nil, err } - return sql.NewSpanIter(span, newGroupByIter(ctx, p, i)), nil + + var iter sql.RowIter + if len(p.Grouping) == 0 { + iter = newGroupByIter(ctx, p.Aggregate, i) + } else { + iter = newGroupByGroupingIter(ctx, p.Aggregate, p.Grouping, i) + } + + return sql.NewSpanIter(span, iter), nil } -// TransformUp implements the Transformable interface. -func (p *GroupBy) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := p.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (p *GroupBy) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 1) } - return f(NewGroupBy(p.Aggregate, p.Grouping, child)) + + return NewGroupBy(p.Aggregate, p.Grouping, children[0]), nil } -// TransformExpressionsUp implements the Transformable interface. -func (p *GroupBy) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - aggregate, err := transformExpressionsUp(f, p.Aggregate) - if err != nil { - return nil, err +// WithExpressions implements the Node interface. +func (p *GroupBy) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + expected := len(p.Aggregate) + len(p.Grouping) + if len(exprs) != expected { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(exprs), expected) } - grouping, err := transformExpressionsUp(f, p.Grouping) - if err != nil { - return nil, err + var agg = make([]sql.Expression, len(p.Aggregate)) + for i := 0; i < len(p.Aggregate); i++ { + agg[i] = exprs[i] } - child, err := p.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err + var grouping = make([]sql.Expression, len(p.Grouping)) + offset := len(p.Aggregate) + for i := 0; i < len(p.Grouping); i++ { + grouping[i] = exprs[i+offset] } - return NewGroupBy(aggregate, grouping, child), nil + return NewGroupBy(agg, grouping, p.Child), nil } -func (p GroupBy) String() string { +func (p *GroupBy) String() string { pr := sql.NewTreePrinter() _ = pr.WriteNode("GroupBy") @@ -130,7 +145,7 @@ func (p GroupBy) String() string { } // Expressions implements the Expressioner interface. -func (p GroupBy) Expressions() []sql.Expression { +func (p *GroupBy) Expressions() []sql.Expression { var exprs []sql.Expression exprs = append(exprs, p.Aggregate...) exprs = append(exprs, p.Grouping...) @@ -138,144 +153,164 @@ func (p GroupBy) Expressions() []sql.Expression { } type groupByIter struct { - p *GroupBy - childIter sql.RowIter - rows []sql.Row - idx int + aggregate []sql.Expression + child sql.RowIter ctx *sql.Context + buf []sql.Row + done bool } -func newGroupByIter(s *sql.Context, p *GroupBy, child sql.RowIter) *groupByIter { +func newGroupByIter(ctx *sql.Context, aggregate []sql.Expression, child sql.RowIter) *groupByIter { return &groupByIter{ - p: p, - childIter: child, - rows: nil, - idx: -1, - ctx: s, + aggregate: aggregate, + child: child, + ctx: ctx, + buf: make([]sql.Row, len(aggregate)), } } func (i *groupByIter) Next() (sql.Row, error) { - if i.idx == -1 { - err := i.computeRows() + if i.done { + return nil, io.EOF + } + + i.done = true + + for j, a := range i.aggregate { + i.buf[j] = fillBuffer(a) + } + + for { + row, err := i.child.Next() if err != nil { + if err == io.EOF { + break + } + return nil, err + } + + if err := updateBuffers(i.ctx, i.buf, i.aggregate, row); err != nil { return nil, err } - i.idx = 0 - } - if i.idx >= len(i.rows) { - return nil, io.EOF } - row := i.rows[i.idx] - i.idx++ - return row, nil + + return evalBuffers(i.ctx, i.buf, i.aggregate) } func (i *groupByIter) Close() error { - i.rows = nil - return i.childIter.Close() + i.buf = nil + return i.child.Close() +} + +type groupByGroupingIter struct { + aggregate []sql.Expression + grouping []sql.Expression + aggregation sql.KeyValueCache + keys []uint64 + pos int + child sql.RowIter + ctx *sql.Context + dispose sql.DisposeFunc } -func (i *groupByIter) computeRows() error { - rows := []sql.Row{} +func newGroupByGroupingIter( + ctx *sql.Context, + aggregate, grouping []sql.Expression, + child sql.RowIter, +) *groupByGroupingIter { + return &groupByGroupingIter{ + aggregate: aggregate, + grouping: grouping, + child: child, + ctx: ctx, + } +} + +func (i *groupByGroupingIter) Next() (sql.Row, error) { + if i.aggregation == nil { + i.aggregation, i.dispose = i.ctx.Memory.NewHistoryCache() + if err := i.compute(); err != nil { + return nil, err + } + } + + if i.pos >= len(i.keys) { + return nil, io.EOF + } + + buffers, err := i.aggregation.Get(i.keys[i.pos]) + if err != nil { + return nil, err + } + i.pos++ + return evalBuffers(i.ctx, buffers.([]sql.Row), i.aggregate) +} + +func (i *groupByGroupingIter) compute() error { for { - childRow, err := i.childIter.Next() - if err == io.EOF { - break + row, err := i.child.Next() + if err != nil { + if err == io.EOF { + break + } + return err } + + key, err := groupingKey(i.ctx, i.grouping, row) if err != nil { return err } - rows = append(rows, childRow) - } - rows, err := groupBy(i.ctx, rows, i.p.Aggregate, i.p.Grouping) - if err != nil { - return err - } + if _, err := i.aggregation.Get(key); err != nil { + var buf = make([]sql.Row, len(i.aggregate)) + for j, a := range i.aggregate { + buf[j] = fillBuffer(a) + } - i.rows = rows - return nil -} + if err := i.aggregation.Put(key, buf); err != nil { + return err + } -func groupBy( - ctx *sql.Context, - rows []sql.Row, - aggExpr []sql.Expression, - groupExpr []sql.Expression, -) ([]sql.Row, error) { - //TODO: currently, we first group all rows, and then - // compute aggregations in a separate stage. We should - // compute aggregations incrementally instead. - - hrows := map[interface{}][]sql.Row{} - for _, row := range rows { - key, err := groupingKey(ctx, groupExpr, row) + i.keys = append(i.keys, key) + } + + b, err := i.aggregation.Get(key) if err != nil { - return nil, err + return err } - hrows[key] = append(hrows[key], row) - } - result := make([]sql.Row, 0, len(hrows)) - for _, rows := range hrows { - row, err := aggregate(ctx, aggExpr, rows) + err = updateBuffers(i.ctx, b.([]sql.Row), i.aggregate, row) if err != nil { - return nil, err + return err } - result = append(result, row) } - return result, nil + return nil +} + +func (i *groupByGroupingIter) Close() error { + i.aggregation = nil + return i.child.Close() } +var table = crc64.MakeTable(crc64.ISO) + func groupingKey( ctx *sql.Context, exprs []sql.Expression, row sql.Row, -) (interface{}, error) { - //TODO: use a more robust/efficient way of calculating grouping keys. +) (uint64, error) { vals := make([]string, 0, len(exprs)) + for _, expr := range exprs { v, err := expr.Eval(ctx, row) if err != nil { - return nil, err + return 0, err } vals = append(vals, fmt.Sprintf("%#v", v)) } - return strings.Join(vals, ","), nil -} - -func aggregate( - ctx *sql.Context, - exprs []sql.Expression, - rows []sql.Row, -) (sql.Row, error) { - buffers := make([]sql.Row, len(exprs)) - for i, expr := range exprs { - buffers[i] = fillBuffer(expr) - } - - for _, row := range rows { - for i, expr := range exprs { - if err := updateBuffer(ctx, buffers, i, expr, row); err != nil { - return nil, err - } - } - } - - fields := make([]interface{}, 0, len(exprs)) - for i, expr := range exprs { - field, err := expr.Eval(ctx, buffers[i]) - if err != nil { - return nil, err - } - - fields = append(fields, field) - } - - return sql.NewRow(fields...), nil + return crc64.Checksum([]byte(strings.Join(vals, ",")), table), nil } func fillBuffer(expr sql.Expression) sql.Row { @@ -285,10 +320,25 @@ func fillBuffer(expr sql.Expression) sql.Row { case *expression.Alias: return fillBuffer(n.Child) default: - return sql.NewRow(nil) + return nil } } +func updateBuffers( + ctx *sql.Context, + buffers []sql.Row, + aggregate []sql.Expression, + row sql.Row, +) error { + for i, a := range aggregate { + if err := updateBuffer(ctx, buffers, i, a, row); err != nil { + return err + } + } + + return nil +} + func updateBuffer( ctx *sql.Context, buffers []sql.Row, @@ -298,14 +348,51 @@ func updateBuffer( ) error { switch n := expr.(type) { case sql.Aggregation: - n.Update(ctx, buffers[idx], row) - return nil + return n.Update(ctx, buffers[idx], row) case *expression.Alias: return updateBuffer(ctx, buffers, idx, n.Child, row) - case *expression.GetField: - buffers[idx] = row + default: + val, err := expr.Eval(ctx, row) + if err != nil { + return err + } + buffers[idx] = sql.NewRow(val) return nil + } +} + +func evalBuffers( + ctx *sql.Context, + buffers []sql.Row, + aggregate []sql.Expression, +) (sql.Row, error) { + var row = make(sql.Row, len(aggregate)) + + for i, agg := range aggregate { + val, err := evalBuffer(ctx, agg, buffers[i]) + if err != nil { + return nil, err + } + row[i] = val + } + + return row, nil +} + +func evalBuffer( + ctx *sql.Context, + aggregation sql.Expression, + buffer sql.Row, +) (interface{}, error) { + switch n := aggregation.(type) { + case sql.Aggregation: + return n.Eval(ctx, buffer) + case *expression.Alias: + return evalBuffer(ctx, n.Child, buffer) default: - return ErrGroupBy.New(n.String()) + if len(buffer) > 0 { + return buffer[0], nil + } + return nil, nil } } diff --git a/sql/plan/group_by_test.go b/sql/plan/group_by_test.go index 976a171a2..69bbb83a5 100644 --- a/sql/plan/group_by_test.go +++ b/sql/plan/group_by_test.go @@ -3,46 +3,46 @@ package plan import ( "testing" + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/src-d/go-mysql-server/sql/expression/function/aggregation" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/mem" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression/function/aggregation" ) -func TestGroupBy_Schema(t *testing.T) { +func TestGroupBySchema(t *testing.T) { require := require.New(t) - child := mem.NewTable("test", sql.Schema{}) + child := memory.NewTable("test", nil) agg := []sql.Expression{ expression.NewAlias(expression.NewLiteral("s", sql.Text), "c1"), expression.NewAlias(aggregation.NewCount(expression.NewStar()), "c2"), } - gb := NewGroupBy(agg, nil, child) + gb := NewGroupBy(agg, nil, NewResolvedTable(child)) require.Equal(sql.Schema{ {Name: "c1", Type: sql.Text}, - {Name: "c2", Type: sql.Int32}, + {Name: "c2", Type: sql.Int64}, }, gb.Schema()) } -func TestGroupBy_Resolved(t *testing.T) { +func TestGroupByResolved(t *testing.T) { require := require.New(t) - child := mem.NewTable("test", sql.Schema{}) + child := memory.NewTable("test", nil) agg := []sql.Expression{ expression.NewAlias(aggregation.NewCount(expression.NewStar()), "c2"), } - gb := NewGroupBy(agg, nil, child) + gb := NewGroupBy(agg, nil, NewResolvedTable(child)) require.True(gb.Resolved()) agg = []sql.Expression{ expression.NewStar(), } - gb = NewGroupBy(agg, nil, child) + gb = NewGroupBy(agg, nil, NewResolvedTable(child)) require.False(gb.Resolved()) } -func TestGroupBy_RowIter(t *testing.T) { +func TestGroupByRowIter(t *testing.T) { require := require.New(t) ctx := sql.NewEmptyContext() @@ -50,12 +50,19 @@ func TestGroupBy_RowIter(t *testing.T) { {Name: "col1", Type: sql.Text}, {Name: "col2", Type: sql.Int64}, } - child := mem.NewTable("test", childSchema) - child.Insert(sql.NewRow("col1_1", int64(1111))) - child.Insert(sql.NewRow("col1_1", int64(1111))) - child.Insert(sql.NewRow("col1_2", int64(4444))) - child.Insert(sql.NewRow("col1_1", int64(1111))) - child.Insert(sql.NewRow("col1_2", int64(4444))) + child := memory.NewTable("test", childSchema) + + rows := []sql.Row{ + sql.NewRow("col1_1", int64(1111)), + sql.NewRow("col1_1", int64(1111)), + sql.NewRow("col1_2", int64(4444)), + sql.NewRow("col1_1", int64(1111)), + sql.NewRow("col1_2", int64(4444)), + } + + for _, r := range rows { + require.NoError(child.Insert(sql.NewEmptyContext(), r)) + } p := NewSort( []SortField{ @@ -76,7 +83,7 @@ func TestGroupBy_RowIter(t *testing.T) { expression.NewGetField(0, sql.Text, "col1", true), expression.NewGetField(1, sql.Int64, "col2", true), }, - child, + NewResolvedTable(child), )) require.Equal(1, len(p.Children())) @@ -89,7 +96,16 @@ func TestGroupBy_RowIter(t *testing.T) { require.Equal(sql.NewRow("col1_2", int64(4444)), rows[1]) } -func TestGroupBy_Error(t *testing.T) { +func TestGroupByEvalEmptyBuffer(t *testing.T) { + require := require.New(t) + ctx := sql.NewEmptyContext() + + r, err := evalBuffer(ctx, expression.NewGetField(0, sql.Text, "col1", true), sql.Row{}) + require.NoError(err) + require.Nil(r) +} + +func TestGroupByAggregationGrouping(t *testing.T) { require := require.New(t) ctx := sql.NewEmptyContext() @@ -98,12 +114,19 @@ func TestGroupBy_Error(t *testing.T) { {Name: "col2", Type: sql.Int64}, } - child := mem.NewTable("test", childSchema) - child.Insert(sql.NewRow("col1_1", int64(1111))) - child.Insert(sql.NewRow("col1_1", int64(2222))) - child.Insert(sql.NewRow("col1_2", int64(4444))) - child.Insert(sql.NewRow("col1_1", int64(1111))) - child.Insert(sql.NewRow("col1_2", int64(4444))) + child := memory.NewTable("test", childSchema) + + rows := []sql.Row{ + sql.NewRow("col1_1", int64(1111)), + sql.NewRow("col1_1", int64(1111)), + sql.NewRow("col1_2", int64(4444)), + sql.NewRow("col1_1", int64(1111)), + sql.NewRow("col1_2", int64(4444)), + } + + for _, r := range rows { + require.NoError(child.Insert(sql.NewEmptyContext(), r)) + } p := NewGroupBy( []sql.Expression{ @@ -114,9 +137,88 @@ func TestGroupBy_Error(t *testing.T) { aggregation.NewCount(expression.NewGetField(0, sql.Text, "col1", true)), expression.NewGetField(1, sql.Int64, "col2", true), }, - child, + NewResolvedTable(child), + ) + + rows, err := sql.NodeToRows(ctx, p) + require.NoError(err) + + expected := []sql.Row{ + {int64(3), false}, + {int64(2), false}, + } + + require.Equal(expected, rows) +} + +func BenchmarkGroupBy(b *testing.B) { + table := benchmarkTable(b) + + node := NewGroupBy( + []sql.Expression{ + aggregation.NewMax( + expression.NewGetField(1, sql.Int64, "b", false), + ), + }, + nil, + NewResolvedTable(table), ) - _, err := sql.NodeToRows(ctx, p) - require.Error(err) + expected := []sql.Row{{int64(200)}} + + bench := func(node sql.Node, expected []sql.Row) func(*testing.B) { + return func(b *testing.B) { + require := require.New(b) + + for i := 0; i < b.N; i++ { + iter, err := node.RowIter(sql.NewEmptyContext()) + require.NoError(err) + + rows, err := sql.RowIterToRows(iter) + require.NoError(err) + require.ElementsMatch(expected, rows) + } + } + } + + b.Run("no grouping", bench(node, expected)) + + node = NewGroupBy( + []sql.Expression{ + expression.NewGetField(0, sql.Int64, "a", false), + aggregation.NewMax( + expression.NewGetField(1, sql.Int64, "b", false), + ), + }, + []sql.Expression{ + expression.NewGetField(0, sql.Int64, "a", false), + }, + NewResolvedTable(table), + ) + + expected = []sql.Row{} + for i := int64(0); i < 50; i++ { + expected = append(expected, sql.NewRow(i, int64(200))) + } + + b.Run("grouping", bench(node, expected)) +} + +func benchmarkTable(t testing.TB) sql.Table { + t.Helper() + require := require.New(t) + + table := memory.NewTable("test", sql.Schema{ + {Name: "a", Type: sql.Int64}, + {Name: "b", Type: sql.Int64}, + }) + + for i := int64(0); i < 50; i++ { + for j := int64(200); j > 0; j-- { + row := sql.NewRow(i, j) + require.NoError(table.Insert(sql.NewEmptyContext(), row)) + } + } + + return table } diff --git a/sql/plan/having.go b/sql/plan/having.go new file mode 100644 index 000000000..48a8bde25 --- /dev/null +++ b/sql/plan/having.go @@ -0,0 +1,62 @@ +package plan + +import "github.com/src-d/go-mysql-server/sql" + +// Having node is a filter that supports aggregate expressions. A having node +// is identical to a filter node in behaviour. The difference is that some +// analyzer rules work specifically on having clauses and not filters. For +// that reason, Having is a completely new node instead of using just filter. +type Having struct { + UnaryNode + Cond sql.Expression +} + +var _ sql.Expressioner = (*Having)(nil) + +// NewHaving creates a new having node. +func NewHaving(cond sql.Expression, child sql.Node) *Having { + return &Having{UnaryNode{Child: child}, cond} +} + +// Resolved implements the sql.Node interface. +func (h *Having) Resolved() bool { return h.Cond.Resolved() && h.Child.Resolved() } + +// Expressions implements the sql.Expressioner interface. +func (h *Having) Expressions() []sql.Expression { return []sql.Expression{h.Cond} } + +// WithChildren implements the Node interface. +func (h *Having) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(h, len(children), 1) + } + + return NewHaving(h.Cond, children[0]), nil +} + +// WithExpressions implements the Expressioner interface. +func (h *Having) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + if len(exprs) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(h, len(exprs), 1) + } + + return NewHaving(exprs[0], h.Child), nil +} + +// RowIter implements the sql.Node interface. +func (h *Having) RowIter(ctx *sql.Context) (sql.RowIter, error) { + span, ctx := ctx.Span("plan.Having") + iter, err := h.Child.RowIter(ctx) + if err != nil { + span.Finish() + return nil, err + } + + return sql.NewSpanIter(span, NewFilterIter(ctx, h.Cond, iter)), nil +} + +func (h *Having) String() string { + p := sql.NewTreePrinter() + _ = p.WriteNode("Having(%s)", h.Cond) + _ = p.WriteChildren(h.Child.String()) + return p.String() +} diff --git a/sql/plan/having_test.go b/sql/plan/having_test.go new file mode 100644 index 000000000..9651a631a --- /dev/null +++ b/sql/plan/having_test.go @@ -0,0 +1,95 @@ +package plan + +import ( + "testing" + + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" +) + +func TestHaving(t *testing.T) { + require := require.New(t) + ctx := sql.NewEmptyContext() + + childSchema := sql.Schema{ + {Name: "col1", Type: sql.Text, Nullable: true}, + {Name: "col2", Type: sql.Text, Nullable: true}, + {Name: "col3", Type: sql.Int32, Nullable: true}, + {Name: "col4", Type: sql.Int64, Nullable: true}, + } + child := memory.NewTable("test", childSchema) + + rows := []sql.Row{ + sql.NewRow("col1_1", "col2_1", int32(1111), int64(2222)), + sql.NewRow("col1_2", "col2_2", int32(3333), int64(4444)), + sql.NewRow("col1_3", "col2_3", nil, int64(4444)), + } + + for _, r := range rows { + require.NoError(child.Insert(sql.NewEmptyContext(), r)) + } + + f := NewHaving( + expression.NewEquals( + expression.NewGetField(0, sql.Text, "col1", true), + expression.NewLiteral("col1_1", sql.Text)), + NewResolvedTable(child), + ) + + require.Equal(1, len(f.Children())) + + iter, err := f.RowIter(ctx) + require.NoError(err) + require.NotNil(iter) + + row, err := iter.Next() + require.NoError(err) + require.NotNil(row) + + require.Equal("col1_1", row[0]) + require.Equal("col2_1", row[1]) + + row, err = iter.Next() + require.NotNil(err) + require.Nil(row) + + f = NewHaving( + expression.NewEquals( + expression.NewGetField(2, sql.Int32, "col3", true), + expression.NewLiteral(int32(1111), sql.Int32), + ), + NewResolvedTable(child), + ) + + iter, err = f.RowIter(ctx) + require.NoError(err) + require.NotNil(iter) + + row, err = iter.Next() + require.NoError(err) + require.NotNil(row) + + require.Equal(int32(1111), row[2]) + require.Equal(int64(2222), row[3]) + + f = NewHaving( + expression.NewEquals( + expression.NewGetField(3, sql.Int64, "col4", true), + expression.NewLiteral(int64(4444), sql.Int64), + ), + NewResolvedTable(child), + ) + + iter, err = f.RowIter(ctx) + require.NoError(err) + require.NotNil(iter) + + row, err = iter.Next() + require.NoError(err) + require.NotNil(row) + + require.Equal(int32(3333), row[2]) + require.Equal(int64(4444), row[3]) +} diff --git a/sql/plan/innerjoin.go b/sql/plan/innerjoin.go deleted file mode 100644 index 6efd36bc2..000000000 --- a/sql/plan/innerjoin.go +++ /dev/null @@ -1,119 +0,0 @@ -package plan - -import ( - "reflect" - - opentracing "github.com/opentracing/opentracing-go" - "gopkg.in/src-d/go-mysql-server.v0/sql" -) - -// InnerJoin is an inner join between two tables. -type InnerJoin struct { - BinaryNode - Cond sql.Expression -} - -// NewInnerJoin creates a new inner join node from two tables. -func NewInnerJoin(left, right sql.Node, cond sql.Expression) *InnerJoin { - return &InnerJoin{ - BinaryNode: BinaryNode{ - Left: left, - Right: right, - }, - Cond: cond, - } -} - -// Schema implements the Node interface. -func (j *InnerJoin) Schema() sql.Schema { - return append(j.Left.Schema(), j.Right.Schema()...) -} - -// Resolved implements the Resolvable interface. -func (j *InnerJoin) Resolved() bool { - return j.Left.Resolved() && j.Right.Resolved() && j.Cond.Resolved() -} - -// RowIter implements the Node interface. -func (j *InnerJoin) RowIter(ctx *sql.Context) (sql.RowIter, error) { - var left, right string - if leftTable, ok := j.Left.(sql.Nameable); ok { - left = leftTable.Name() - } else { - left = reflect.TypeOf(j.Left).String() - } - - if rightTable, ok := j.Right.(sql.Nameable); ok { - right = rightTable.Name() - } else { - right = reflect.TypeOf(j.Right).String() - } - - span, ctx := ctx.Span("plan.InnerJoin", opentracing.Tags{ - "left": left, - "right": right, - }) - - l, err := j.Left.RowIter(ctx) - if err != nil { - span.Finish() - return nil, err - } - - return sql.NewSpanIter(span, NewFilterIter( - ctx, - j.Cond, - &crossJoinIterator{ - l: l, - rp: j.Right, - s: ctx, - }, - )), nil -} - -// TransformUp implements the Transformable interface. -func (j *InnerJoin) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - left, err := j.Left.TransformUp(f) - if err != nil { - return nil, err - } - - right, err := j.Right.TransformUp(f) - if err != nil { - return nil, err - } - - return f(NewInnerJoin(left, right, j.Cond)) -} - -// TransformExpressionsUp implements the Transformable interface. -func (j *InnerJoin) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - left, err := j.Left.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - - right, err := j.Right.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - - cond, err := j.Cond.TransformUp(f) - if err != nil { - return nil, err - } - - return NewInnerJoin(left, right, cond), nil -} - -func (j InnerJoin) String() string { - pr := sql.NewTreePrinter() - _ = pr.WriteNode("InnerJoin(%s)", j.Cond) - _ = pr.WriteChildren(j.Left.String(), j.Right.String()) - return pr.String() -} - -// Expressions implements the Expressioner interface. -func (j InnerJoin) Expressions() []sql.Expression { - return []sql.Expression{j.Cond} -} diff --git a/sql/plan/innerjoin_test.go b/sql/plan/innerjoin_test.go deleted file mode 100644 index 6075246a1..000000000 --- a/sql/plan/innerjoin_test.go +++ /dev/null @@ -1,53 +0,0 @@ -package plan - -import ( - "testing" - - "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/mem" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" -) - -func TestInnerJoin(t *testing.T) { - require := require.New(t) - finalSchema := append(lSchema, rSchema...) - - ltable := mem.NewTable("left", lSchema) - rtable := mem.NewTable("right", rSchema) - insertData(t, ltable) - insertData(t, rtable) - - j := NewInnerJoin(ltable, rtable, expression.NewEquals( - expression.NewGetField(0, sql.Text, "lcol1", false), - expression.NewGetField(4, sql.Text, "rcol1", false), - )) - - require.Equal(finalSchema, j.Schema()) - - rows := collectRows(t, j) - require.Len(rows, 2) - - require.Equal([]sql.Row{ - {"col1_1", "col2_1", int32(1111), int64(2222), "col1_1", "col2_1", int32(1111), int64(2222)}, - {"col1_2", "col2_2", int32(3333), int64(4444), "col1_2", "col2_2", int32(3333), int64(4444)}, - }, rows) -} - -func TestInnerJoinEmpty(t *testing.T) { - require := require.New(t) - ctx := sql.NewEmptyContext() - - ltable := mem.NewTable("left", lSchema) - rtable := mem.NewTable("right", rSchema) - - j := NewInnerJoin(ltable, rtable, expression.NewEquals( - expression.NewGetField(0, sql.Text, "lcol1", false), - expression.NewGetField(4, sql.Text, "rcol1", false), - )) - - iter, err := j.RowIter(ctx) - require.NoError(err) - - assertRows(t, iter, 0) -} diff --git a/sql/plan/insert.go b/sql/plan/insert.go index 521053417..06e638d26 100644 --- a/sql/plan/insert.go +++ b/sql/plan/insert.go @@ -1,28 +1,36 @@ package plan import ( + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "gopkg.in/src-d/go-errors.v1" "io" "strings" - - "gopkg.in/src-d/go-errors.v1" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" ) // ErrInsertIntoNotSupported is thrown when a table doesn't support inserts var ErrInsertIntoNotSupported = errors.NewKind("table doesn't support INSERT INTO") +var ErrReplaceIntoNotSupported = errors.NewKind("table doesn't support REPLACE INTO") +var ErrInsertIntoMismatchValueCount = errors.NewKind("number of values does not match number of columns provided") +var ErrInsertIntoUnsupportedValues = errors.NewKind("%T is unsupported for inserts") +var ErrInsertIntoDuplicateColumn = errors.NewKind("duplicate column name %v") +var ErrInsertIntoNonexistentColumn = errors.NewKind("invalid column name %v") +var ErrInsertIntoNonNullableDefaultNullColumn = errors.NewKind("column name '%v' is non-nullable but attempted to set default value of null") +var ErrInsertIntoNonNullableProvidedNull = errors.NewKind("column name '%v' is non-nullable but attempted to set a value of null") // InsertInto is a node describing the insertion into some table. type InsertInto struct { BinaryNode - Columns []string + Columns []string + IsReplace bool } // NewInsertInto creates an InsertInto node. -func NewInsertInto(dst, src sql.Node, cols []string) *InsertInto { +func NewInsertInto(dst, src sql.Node, isReplace bool, cols []string) *InsertInto { return &InsertInto{ BinaryNode: BinaryNode{Left: dst, Right: src}, Columns: cols, + IsReplace: isReplace, } } @@ -36,15 +44,65 @@ func (p *InsertInto) Schema() sql.Schema { }} } +func getInsertable(node sql.Node) (sql.Inserter, error) { + switch node := node.(type) { + case sql.Inserter: + return node, nil + case *ResolvedTable: + return getInsertableTable(node.Table) + default: + return nil, ErrInsertIntoNotSupported.New() + } +} + +func getInsertableTable(t sql.Table) (sql.Inserter, error) { + switch t := t.(type) { + case sql.Inserter: + return t, nil + case sql.TableWrapper: + return getInsertableTable(t.Underlying()) + default: + return nil, ErrInsertIntoNotSupported.New() + } +} + // Execute inserts the rows in the database. func (p *InsertInto) Execute(ctx *sql.Context) (int, error) { - insertable, ok := p.Left.(sql.Inserter) - if !ok { - return 0, ErrInsertIntoNotSupported.New() + insertable, err := getInsertable(p.Left) + if err != nil { + return 0, err + } + + var replaceable sql.Replacer + if p.IsReplace { + var ok bool + replaceable, ok = insertable.(sql.Replacer) + if !ok { + return 0, ErrReplaceIntoNotSupported.New() + } } dstSchema := p.Left.Schema() projExprs := make([]sql.Expression, len(dstSchema)) + + // If no columns are given, we assume the full schema in order + if len(p.Columns) == 0 { + p.Columns = make([]string, len(dstSchema)) + for i, f := range dstSchema { + p.Columns[i] = f.Name + } + } else { + err = p.validateColumns(ctx, dstSchema) + if err != nil { + return 0, err + } + } + + err = p.validateValueCount(ctx) + if err != nil { + return 0, err + } + for i, f := range dstSchema { found := false for j, col := range p.Columns { @@ -56,8 +114,10 @@ func (p *InsertInto) Execute(ctx *sql.Context) (int, error) { } if !found { - def, _ := f.Type.Convert(nil) - projExprs[i] = expression.NewLiteral(def, f.Type) + if !f.Nullable && f.Default == nil { + return 0, ErrInsertIntoNonNullableDefaultNullColumn.New(f.Name) + } + projExprs[i] = expression.NewLiteral(f.Default, f.Type) } } @@ -74,17 +134,51 @@ func (p *InsertInto) Execute(ctx *sql.Context) (int, error) { if err == io.EOF { break } - if err != nil { _ = iter.Close() return i, err } - if err := insertable.Insert(row); err != nil { + err = p.validateNullability(ctx, dstSchema, row) + if err != nil { _ = iter.Close() return i, err } + // Convert integer values in row to specified type in schema + for colIdx, oldValue := range row { + dstColType := projExprs[colIdx].Type() + + if sql.IsInteger(dstColType) && oldValue != nil { + newValue, err := dstColType.Convert(oldValue) + if err != nil { + return i, err + } + + row[colIdx] = newValue + } + } + + if replaceable != nil { + if err = replaceable.Delete(ctx, row); err != nil { + if err != sql.ErrDeleteRowNotFound { + _ = iter.Close() + return i, err + } + } else { + i++ + } + + if err = replaceable.Insert(ctx, row); err != nil { + _ = iter.Close() + return i, err + } + } else { + if err := insertable.Insert(ctx, row); err != nil { + _ = iter.Close() + return i, err + } + } i++ } @@ -101,39 +195,64 @@ func (p *InsertInto) RowIter(ctx *sql.Context) (sql.RowIter, error) { return sql.RowsToRowIter(sql.NewRow(int64(n))), nil } -// TransformUp implements the Transformable interface. -func (p *InsertInto) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - left, err := p.Left.TransformUp(f) - if err != nil { - return nil, err - } - - right, err := p.Right.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (p *InsertInto) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 2) } - return f(NewInsertInto(left, right, p.Columns)) + return NewInsertInto(children[0], children[1], p.IsReplace, p.Columns), nil } -// TransformExpressionsUp implements the Transformable interface. -func (p *InsertInto) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - left, err := p.Left.TransformExpressionsUp(f) - if err != nil { - return nil, err +func (p InsertInto) String() string { + pr := sql.NewTreePrinter() + if p.IsReplace { + _ = pr.WriteNode("Replace(%s)", strings.Join(p.Columns, ", ")) + } else { + _ = pr.WriteNode("Insert(%s)", strings.Join(p.Columns, ", ")) } + _ = pr.WriteChildren(p.Left.String(), p.Right.String()) + return pr.String() +} - right, err := p.Right.TransformExpressionsUp(f) - if err != nil { - return nil, err +func (p *InsertInto) validateValueCount(ctx *sql.Context) error { + switch node := p.Right.(type) { + case *Values: + for _, exprTuple := range node.ExpressionTuples { + if len(exprTuple) != len(p.Columns) { + return ErrInsertIntoMismatchValueCount.New() + } + } + default: + return ErrInsertIntoUnsupportedValues.New(node) } + return nil +} - return NewInsertInto(left, right, p.Columns), nil +func (p *InsertInto) validateColumns(ctx *sql.Context, dstSchema sql.Schema) error { + dstColNames := make(map[string]struct{}) + for _, dstCol := range dstSchema { + dstColNames[dstCol.Name] = struct{}{} + } + columnNames := make(map[string]struct{}) + for _, columnName := range p.Columns { + if _, exists := dstColNames[columnName]; !exists { + return ErrInsertIntoNonexistentColumn.New(columnName) + } + if _, exists := columnNames[columnName]; !exists { + columnNames[columnName] = struct{}{} + } else { + return ErrInsertIntoDuplicateColumn.New(columnName) + } + } + return nil } -func (p InsertInto) String() string { - pr := sql.NewTreePrinter() - _ = pr.WriteNode("Insert(%s)", strings.Join(p.Columns, ", ")) - _ = pr.WriteChildren(p.Left.String(), p.Right.String()) - return pr.String() +func (p *InsertInto) validateNullability(ctx *sql.Context, dstSchema sql.Schema, row sql.Row) error { + for i, col := range dstSchema { + if !col.Nullable && row[i] == nil { + return ErrInsertIntoNonNullableProvidedNull.New(col.Name) + } + } + return nil } diff --git a/sql/plan/join.go b/sql/plan/join.go new file mode 100644 index 000000000..97f1abc13 --- /dev/null +++ b/sql/plan/join.go @@ -0,0 +1,549 @@ +package plan + +import ( + "io" + "os" + "reflect" + "strings" + + opentracing "github.com/opentracing/opentracing-go" + "github.com/src-d/go-mysql-server/sql" +) + +const ( + inMemoryJoinKey = "INMEMORY_JOINS" + inMemoryJoinSessionVar = "inmemory_joins" +) + +var useInMemoryJoins = shouldUseMemoryJoinsByEnv() + +func shouldUseMemoryJoinsByEnv() bool { + v := strings.TrimSpace(strings.ToLower(os.Getenv(inMemoryJoinKey))) + return v == "on" || v == "1" +} + +// InnerJoin is an inner join between two tables. +type InnerJoin struct { + BinaryNode + Cond sql.Expression +} + +// NewInnerJoin creates a new inner join node from two tables. +func NewInnerJoin(left, right sql.Node, cond sql.Expression) *InnerJoin { + return &InnerJoin{ + BinaryNode: BinaryNode{ + Left: left, + Right: right, + }, + Cond: cond, + } +} + +// Schema implements the Node interface. +func (j *InnerJoin) Schema() sql.Schema { + return append(j.Left.Schema(), j.Right.Schema()...) +} + +// Resolved implements the Resolvable interface. +func (j *InnerJoin) Resolved() bool { + return j.Left.Resolved() && j.Right.Resolved() && j.Cond.Resolved() +} + +// RowIter implements the Node interface. +func (j *InnerJoin) RowIter(ctx *sql.Context) (sql.RowIter, error) { + return joinRowIter(ctx, innerJoin, j.Left, j.Right, j.Cond) +} + +// WithChildren implements the Node interface. +func (j *InnerJoin) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(j, len(children), 2) + } + + return NewInnerJoin(children[0], children[1], j.Cond), nil +} + +// WithExpressions implements the Expressioner interface. +func (j *InnerJoin) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + if len(exprs) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(j, len(exprs), 1) + } + + return NewInnerJoin(j.Left, j.Right, exprs[0]), nil +} + +func (j *InnerJoin) String() string { + pr := sql.NewTreePrinter() + _ = pr.WriteNode("InnerJoin(%s)", j.Cond) + _ = pr.WriteChildren(j.Left.String(), j.Right.String()) + return pr.String() +} + +// Expressions implements the Expressioner interface. +func (j *InnerJoin) Expressions() []sql.Expression { + return []sql.Expression{j.Cond} +} + +// LeftJoin is a left join between two tables. +type LeftJoin struct { + BinaryNode + Cond sql.Expression +} + +// NewLeftJoin creates a new left join node from two tables. +func NewLeftJoin(left, right sql.Node, cond sql.Expression) *LeftJoin { + return &LeftJoin{ + BinaryNode: BinaryNode{ + Left: left, + Right: right, + }, + Cond: cond, + } +} + +// Schema implements the Node interface. +func (j *LeftJoin) Schema() sql.Schema { + return append(j.Left.Schema(), makeNullable(j.Right.Schema())...) +} + +// Resolved implements the Resolvable interface. +func (j *LeftJoin) Resolved() bool { + return j.Left.Resolved() && j.Right.Resolved() && j.Cond.Resolved() +} + +// RowIter implements the Node interface. +func (j *LeftJoin) RowIter(ctx *sql.Context) (sql.RowIter, error) { + return joinRowIter(ctx, leftJoin, j.Left, j.Right, j.Cond) +} + +// WithChildren implements the Node interface. +func (j *LeftJoin) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(j, len(children), 1) + } + + return NewLeftJoin(children[0], children[1], j.Cond), nil +} + +// WithExpressions implements the Expressioner interface. +func (j *LeftJoin) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + if len(exprs) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(j, len(exprs), 1) + } + + return NewLeftJoin(j.Left, j.Right, exprs[0]), nil +} + +func (j *LeftJoin) String() string { + pr := sql.NewTreePrinter() + _ = pr.WriteNode("LeftJoin(%s)", j.Cond) + _ = pr.WriteChildren(j.Left.String(), j.Right.String()) + return pr.String() +} + +// Expressions implements the Expressioner interface. +func (j *LeftJoin) Expressions() []sql.Expression { + return []sql.Expression{j.Cond} +} + +// RightJoin is a left join between two tables. +type RightJoin struct { + BinaryNode + Cond sql.Expression +} + +// NewRightJoin creates a new right join node from two tables. +func NewRightJoin(left, right sql.Node, cond sql.Expression) *RightJoin { + return &RightJoin{ + BinaryNode: BinaryNode{ + Left: left, + Right: right, + }, + Cond: cond, + } +} + +// Schema implements the Node interface. +func (j *RightJoin) Schema() sql.Schema { + return append(makeNullable(j.Left.Schema()), j.Right.Schema()...) +} + +// Resolved implements the Resolvable interface. +func (j *RightJoin) Resolved() bool { + return j.Left.Resolved() && j.Right.Resolved() && j.Cond.Resolved() +} + +// RowIter implements the Node interface. +func (j *RightJoin) RowIter(ctx *sql.Context) (sql.RowIter, error) { + return joinRowIter(ctx, rightJoin, j.Left, j.Right, j.Cond) +} + +// WithChildren implements the Node interface. +func (j *RightJoin) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(j, len(children), 2) + } + + return NewRightJoin(children[0], children[1], j.Cond), nil +} + +// WithExpressions implements the Expressioner interface. +func (j *RightJoin) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + if len(exprs) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(j, len(exprs), 1) + } + + return NewRightJoin(j.Left, j.Right, exprs[0]), nil +} + +func (j *RightJoin) String() string { + pr := sql.NewTreePrinter() + _ = pr.WriteNode("RightJoin(%s)", j.Cond) + _ = pr.WriteChildren(j.Left.String(), j.Right.String()) + return pr.String() +} + +// Expressions implements the Expressioner interface. +func (j *RightJoin) Expressions() []sql.Expression { + return []sql.Expression{j.Cond} +} + +type joinType byte + +const ( + innerJoin joinType = iota + leftJoin + rightJoin +) + +func (t joinType) String() string { + switch t { + case innerJoin: + return "InnerJoin" + case leftJoin: + return "LeftJoin" + case rightJoin: + return "RightJoin" + default: + return "INVALID" + } +} + +func joinRowIter( + ctx *sql.Context, + typ joinType, + left, right sql.Node, + cond sql.Expression, +) (sql.RowIter, error) { + var leftName, rightName string + if leftTable, ok := left.(sql.Nameable); ok { + leftName = leftTable.Name() + } else { + leftName = reflect.TypeOf(left).String() + } + + if rightTable, ok := right.(sql.Nameable); ok { + rightName = rightTable.Name() + } else { + rightName = reflect.TypeOf(right).String() + } + + span, ctx := ctx.Span("plan."+typ.String(), opentracing.Tags{ + "left": leftName, + "right": rightName, + }) + + var inMemorySession bool + _, val := ctx.Get(inMemoryJoinSessionVar) + if val != nil { + inMemorySession = true + } + + var mode = unknownMode + if useInMemoryJoins || inMemorySession { + mode = memoryMode + } + + cache, dispose := ctx.Memory.NewRowsCache() + if typ == rightJoin { + r, err := right.RowIter(ctx) + if err != nil { + span.Finish() + return nil, err + } + return sql.NewSpanIter(span, &joinIter{ + typ: typ, + primary: r, + secondaryProvider: left, + ctx: ctx, + cond: cond, + mode: mode, + secondaryRows: cache, + dispose: dispose, + }), nil + } + + l, err := left.RowIter(ctx) + if err != nil { + span.Finish() + return nil, err + } + + return sql.NewSpanIter(span, &joinIter{ + typ: typ, + primary: l, + secondaryProvider: right, + ctx: ctx, + cond: cond, + mode: mode, + secondaryRows: cache, + dispose: dispose, + }), nil +} + +// joinMode defines the mode in which a join will be performed. +type joinMode byte + +const ( + // unknownMode is the default mode. It will start iterating without really + // knowing in which mode it will end up computing the join. If it + // iterates the right side fully one time and so far it fits in memory, + // then it will switch to memory mode. Otherwise, if at some point during + // this first iteration it finds that it does not fit in memory, will + // switch to multipass mode. + unknownMode joinMode = iota + // memoryMode computes all the join directly in memory iterating each + // side of the join exactly once. + memoryMode + // multipassMode computes the join by iterating the left side once, + // and the right side one time for each row in the left side. + multipassMode +) + +// joinIter is a generic iterator for all join types. +type joinIter struct { + typ joinType + primary sql.RowIter + secondaryProvider rowIterProvider + secondary sql.RowIter + ctx *sql.Context + cond sql.Expression + + primaryRow sql.Row + foundMatch bool + rowSize int + + // used to compute in-memory + mode joinMode + secondaryRows sql.RowsCache + pos int + dispose sql.DisposeFunc +} + +func (i *joinIter) Dispose() { + if i.dispose != nil { + i.dispose() + i.dispose = nil + } +} + +func (i *joinIter) loadPrimary() error { + if i.primaryRow == nil { + r, err := i.primary.Next() + if err != nil { + if err == io.EOF { + i.Dispose() + } + return err + } + + i.primaryRow = r + i.foundMatch = false + } + + return nil +} + +func (i *joinIter) loadSecondaryInMemory() error { + iter, err := i.secondaryProvider.RowIter(i.ctx) + if err != nil { + return err + } + + for { + row, err := iter.Next() + if err == io.EOF { + break + } + if err != nil { + return err + } + + if err := i.secondaryRows.Add(row); err != nil { + return err + } + } + + if len(i.secondaryRows.Get()) == 0 { + return io.EOF + } + + return nil +} + +func (i *joinIter) loadSecondary() (row sql.Row, err error) { + if i.mode == memoryMode { + if len(i.secondaryRows.Get()) == 0 { + if err = i.loadSecondaryInMemory(); err != nil { + return nil, err + } + } + + if i.pos >= len(i.secondaryRows.Get()) { + i.primaryRow = nil + i.pos = 0 + return nil, io.EOF + } + + row := i.secondaryRows.Get()[i.pos] + i.pos++ + return row, nil + } + + if i.secondary == nil { + var iter sql.RowIter + iter, err = i.secondaryProvider.RowIter(i.ctx) + if err != nil { + return nil, err + } + + i.secondary = iter + } + + rightRow, err := i.secondary.Next() + if err != nil { + if err == io.EOF { + i.secondary = nil + i.primaryRow = nil + + // If we got to this point and the mode is still unknown it means + // the right side fits in memory, so the mode changes to memory + // join. + if i.mode == unknownMode { + i.mode = memoryMode + } + + return nil, io.EOF + } + return nil, err + } + + if i.mode == unknownMode { + var switchToMultipass bool + if !i.ctx.Memory.HasAvailable() { + switchToMultipass = true + } else { + err := i.secondaryRows.Add(rightRow) + if err != nil && !sql.ErrNoMemoryAvailable.Is(err) { + return nil, err + } + } + + if switchToMultipass { + i.Dispose() + i.secondaryRows = nil + i.mode = multipassMode + } + } + + return rightRow, nil +} + +func (i *joinIter) Next() (sql.Row, error) { + for { + if err := i.loadPrimary(); err != nil { + return nil, err + } + + primary := i.primaryRow + secondary, err := i.loadSecondary() + if err != nil { + if err == io.EOF { + if !i.foundMatch && (i.typ == leftJoin || i.typ == rightJoin) { + return i.buildRow(primary, nil), nil + } + continue + } + return nil, err + } + + row := i.buildRow(primary, secondary) + v, err := i.cond.Eval(i.ctx, row) + if err != nil { + return nil, err + } + + if v == false { + continue + } + + i.foundMatch = true + return row, nil + } +} + +// buildRow builds the resulting row using the rows from the primary and +// secondary branches depending on the join type. +func (i *joinIter) buildRow(primary, secondary sql.Row) sql.Row { + var row sql.Row + if i.rowSize > 0 { + row = make(sql.Row, i.rowSize) + } else { + row = make(sql.Row, len(primary)+len(secondary)) + i.rowSize = len(row) + } + + switch i.typ { + case rightJoin: + copy(row, secondary) + copy(row[i.rowSize-len(primary):], primary) + default: + copy(row, primary) + copy(row[len(primary):], secondary) + } + + return row +} + +func (i *joinIter) Close() (err error) { + i.Dispose() + i.secondary = nil + + if i.primary != nil { + if err = i.primary.Close(); err != nil { + if i.secondary != nil { + _ = i.secondary.Close() + } + return err + } + + } + + if i.secondary != nil { + err = i.secondary.Close() + } + + return err +} + +// makeNullable will return a copy of the received columns, but all of them +// will be turned into nullable columns. +func makeNullable(cols []*sql.Column) []*sql.Column { + var result = make([]*sql.Column, len(cols)) + for i := 0; i < len(cols); i++ { + col := *cols[i] + col.Nullable = true + result[i] = &col + } + return result +} diff --git a/sql/plan/join_test.go b/sql/plan/join_test.go new file mode 100644 index 000000000..07797602b --- /dev/null +++ b/sql/plan/join_test.go @@ -0,0 +1,288 @@ +package plan + +import ( + "context" + "fmt" + "testing" + + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" +) + +func TestJoinSchema(t *testing.T) { + t1 := NewResolvedTable(memory.NewTable("foo", sql.Schema{ + {Name: "a", Source: "foo", Type: sql.Int64}, + })) + + t2 := NewResolvedTable(memory.NewTable("bar", sql.Schema{ + {Name: "b", Source: "bar", Type: sql.Int64}, + })) + + t.Run("inner", func(t *testing.T) { + j := NewInnerJoin(t1, t2, nil) + result := j.Schema() + + require.Equal(t, sql.Schema{ + {Name: "a", Source: "foo", Type: sql.Int64}, + {Name: "b", Source: "bar", Type: sql.Int64}, + }, result) + }) + + t.Run("left", func(t *testing.T) { + j := NewLeftJoin(t1, t2, nil) + result := j.Schema() + + require.Equal(t, sql.Schema{ + {Name: "a", Source: "foo", Type: sql.Int64}, + {Name: "b", Source: "bar", Type: sql.Int64, Nullable: true}, + }, result) + }) + + t.Run("right", func(t *testing.T) { + j := NewRightJoin(t1, t2, nil) + result := j.Schema() + + require.Equal(t, sql.Schema{ + {Name: "a", Source: "foo", Type: sql.Int64, Nullable: true}, + {Name: "b", Source: "bar", Type: sql.Int64}, + }, result) + }) +} + +func TestInnerJoin(t *testing.T) { + testInnerJoin(t, sql.NewEmptyContext()) +} + +func TestInMemoryInnerJoin(t *testing.T) { + ctx := sql.NewEmptyContext() + ctx.Set(inMemoryJoinSessionVar, sql.Text, "true") + testInnerJoin(t, ctx) +} + +func TestMultiPassInnerJoin(t *testing.T) { + ctx := sql.NewContext(context.TODO(), sql.WithMemoryManager( + sql.NewMemoryManager(mockReporter{2, 1}), + )) + testInnerJoin(t, ctx) +} + +func testInnerJoin(t *testing.T, ctx *sql.Context) { + t.Helper() + + require := require.New(t) + ltable := memory.NewTable("left", lSchema) + rtable := memory.NewTable("right", rSchema) + insertData(t, ltable) + insertData(t, rtable) + + j := NewInnerJoin( + NewResolvedTable(ltable), + NewResolvedTable(rtable), + expression.NewEquals( + expression.NewGetField(0, sql.Text, "lcol1", false), + expression.NewGetField(4, sql.Text, "rcol1", false), + )) + + rows := collectRows(t, j) + require.Len(rows, 2) + + require.Equal([]sql.Row{ + {"col1_1", "col2_1", int32(1), int64(2), "col1_1", "col2_1", int32(1), int64(2)}, + {"col1_2", "col2_2", int32(3), int64(4), "col1_2", "col2_2", int32(3), int64(4)}, + }, rows) +} +func TestInnerJoinEmpty(t *testing.T) { + require := require.New(t) + ctx := sql.NewEmptyContext() + + ltable := memory.NewTable("left", lSchema) + rtable := memory.NewTable("right", rSchema) + + j := NewInnerJoin( + NewResolvedTable(ltable), + NewResolvedTable(rtable), + expression.NewEquals( + expression.NewGetField(0, sql.Text, "lcol1", false), + expression.NewGetField(4, sql.Text, "rcol1", false), + )) + + iter, err := j.RowIter(ctx) + require.NoError(err) + + assertRows(t, iter, 0) +} + +func BenchmarkInnerJoin(b *testing.B) { + t1 := memory.NewTable("foo", sql.Schema{ + {Name: "a", Source: "foo", Type: sql.Int64}, + {Name: "b", Source: "foo", Type: sql.Text}, + }) + + t2 := memory.NewTable("bar", sql.Schema{ + {Name: "a", Source: "bar", Type: sql.Int64}, + {Name: "b", Source: "bar", Type: sql.Text}, + }) + + for i := 0; i < 5; i++ { + t1.Insert(sql.NewEmptyContext(), sql.NewRow(int64(i), fmt.Sprintf("t1_%d", i))) + t2.Insert(sql.NewEmptyContext(), sql.NewRow(int64(i), fmt.Sprintf("t2_%d", i))) + } + + n1 := NewInnerJoin( + NewResolvedTable(t1), + NewResolvedTable(t2), + expression.NewEquals( + expression.NewGetField(0, sql.Int64, "a", false), + expression.NewGetField(2, sql.Int64, "a", false), + ), + ) + + n2 := NewFilter( + expression.NewEquals( + expression.NewGetField(0, sql.Int64, "a", false), + expression.NewGetField(2, sql.Int64, "a", false), + ), + NewCrossJoin( + NewResolvedTable(t1), + NewResolvedTable(t2), + ), + ) + + expected := []sql.Row{ + {int64(0), "t1_0", int64(0), "t2_0"}, + {int64(1), "t1_1", int64(1), "t2_1"}, + {int64(2), "t1_2", int64(2), "t2_2"}, + {int64(3), "t1_3", int64(3), "t2_3"}, + {int64(4), "t1_4", int64(4), "t2_4"}, + } + + ctx := sql.NewContext(context.TODO(), sql.WithMemoryManager( + sql.NewMemoryManager(mockReporter{1, 5}), + )) + b.Run("inner join", func(b *testing.B) { + require := require.New(b) + + for i := 0; i < b.N; i++ { + iter, err := n1.RowIter(ctx) + require.NoError(err) + + rows, err := sql.RowIterToRows(iter) + require.NoError(err) + + require.Equal(expected, rows) + } + }) + + b.Run("in memory inner join", func(b *testing.B) { + useInMemoryJoins = true + require := require.New(b) + + for i := 0; i < b.N; i++ { + iter, err := n1.RowIter(ctx) + require.NoError(err) + + rows, err := sql.RowIterToRows(iter) + require.NoError(err) + + require.Equal(expected, rows) + } + + useInMemoryJoins = false + }) + + b.Run("within memory threshold", func(b *testing.B) { + require := require.New(b) + + for i := 0; i < b.N; i++ { + iter, err := n1.RowIter(ctx) + require.NoError(err) + + rows, err := sql.RowIterToRows(iter) + require.NoError(err) + + require.Equal(expected, rows) + } + }) + + b.Run("cross join with filter", func(b *testing.B) { + require := require.New(b) + + for i := 0; i < b.N; i++ { + iter, err := n2.RowIter(ctx) + require.NoError(err) + + rows, err := sql.RowIterToRows(iter) + require.NoError(err) + + require.Equal(expected, rows) + } + }) +} + +func TestLeftJoin(t *testing.T) { + require := require.New(t) + + ltable := memory.NewTable("left", lSchema) + rtable := memory.NewTable("right", rSchema) + insertData(t, ltable) + insertData(t, rtable) + + j := NewLeftJoin( + NewResolvedTable(ltable), + NewResolvedTable(rtable), + expression.NewEquals( + expression.NewPlus( + expression.NewGetField(2, sql.Text, "lcol3", false), + expression.NewLiteral(int32(2), sql.Int32), + ), + expression.NewGetField(6, sql.Text, "rcol3", false), + )) + + iter, err := j.RowIter(sql.NewEmptyContext()) + require.NoError(err) + rows, err := sql.RowIterToRows(iter) + require.NoError(err) + require.ElementsMatch([]sql.Row{ + {"col1_1", "col2_1", int32(1), int64(2), "col1_2", "col2_2", int32(3), int64(4)}, + {"col1_2", "col2_2", int32(3), int64(4), nil, nil, nil, nil}, + }, rows) +} + +func TestRightJoin(t *testing.T) { + require := require.New(t) + + ltable := memory.NewTable("left", lSchema) + rtable := memory.NewTable("right", rSchema) + insertData(t, ltable) + insertData(t, rtable) + + j := NewRightJoin( + NewResolvedTable(ltable), + NewResolvedTable(rtable), + expression.NewEquals( + expression.NewPlus( + expression.NewGetField(2, sql.Text, "lcol3", false), + expression.NewLiteral(int32(2), sql.Int32), + ), + expression.NewGetField(6, sql.Text, "rcol3", false), + )) + + iter, err := j.RowIter(sql.NewEmptyContext()) + require.NoError(err) + rows, err := sql.RowIterToRows(iter) + require.NoError(err) + require.ElementsMatch([]sql.Row{ + {nil, nil, nil, nil, "col1_1", "col2_1", int32(1), int64(2)}, + {"col1_1", "col2_1", int32(1), int64(2), "col1_2", "col2_2", int32(3), int64(4)}, + }, rows) +} + +type mockReporter struct { + val uint64 + max uint64 +} + +func (m mockReporter) UsedMemory() uint64 { return m.val } +func (m mockReporter) MaxMemory() uint64 { return m.max } diff --git a/sql/plan/limit.go b/sql/plan/limit.go index e8826bdf6..9ec72805e 100644 --- a/sql/plan/limit.go +++ b/sql/plan/limit.go @@ -4,22 +4,20 @@ import ( "io" opentracing "github.com/opentracing/opentracing-go" - "gopkg.in/src-d/go-mysql-server.v0/sql" + "github.com/src-d/go-mysql-server/sql" ) -var _ sql.Node = &Limit{} - // Limit is a node that only allows up to N rows to be retrieved. type Limit struct { UnaryNode - size int64 + Limit int64 } // NewLimit creates a new Limit node with the given size. func NewLimit(size int64, child sql.Node) *Limit { return &Limit{ UnaryNode: UnaryNode{Child: child}, - size: size, + Limit: size, } } @@ -30,7 +28,7 @@ func (l *Limit) Resolved() bool { // RowIter implements the Node interface. func (l *Limit) RowIter(ctx *sql.Context) (sql.RowIter, error) { - span, ctx := ctx.Span("plan.Limit", opentracing.Tag{Key: "limit", Value: l.size}) + span, ctx := ctx.Span("plan.Limit", opentracing.Tag{Key: "limit", Value: l.Limit}) li, err := l.Child.RowIter(ctx) if err != nil { @@ -40,27 +38,17 @@ func (l *Limit) RowIter(ctx *sql.Context) (sql.RowIter, error) { return sql.NewSpanIter(span, &limitIter{l, 0, li}), nil } -// TransformUp implements the Transformable interface. -func (l *Limit) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := l.Child.TransformUp(f) - if err != nil { - return nil, err - } - return f(NewLimit(l.size, child)) -} - -// TransformExpressionsUp implements the Transformable interface. -func (l *Limit) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - child, err := l.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (l *Limit) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(l, len(children), 1) } - return NewLimit(l.size, child), nil + return NewLimit(l.Limit, children[0]), nil } func (l Limit) String() string { pr := sql.NewTreePrinter() - _ = pr.WriteNode("Limit(%d)", l.size) + _ = pr.WriteNode("Limit(%d)", l.Limit) _ = pr.WriteChildren(l.Child.String()) return pr.String() } @@ -72,18 +60,17 @@ type limitIter struct { } func (li *limitIter) Next() (sql.Row, error) { - for { - if li.currentPos >= li.l.size { - return nil, io.EOF - } - childRow, err := li.childIter.Next() - li.currentPos++ - if err != nil { - return nil, err - } + if li.currentPos >= li.l.Limit { + return nil, io.EOF + } - return childRow, nil + childRow, err := li.childIter.Next() + li.currentPos++ + if err != nil { + return nil, err } + + return childRow, nil } func (li *limitIter) Close() error { diff --git a/sql/plan/limit_test.go b/sql/plan/limit_test.go index 97d6fa30c..25fd94c21 100644 --- a/sql/plan/limit_test.go +++ b/sql/plan/limit_test.go @@ -8,28 +8,28 @@ import ( "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/mem" - "gopkg.in/src-d/go-mysql-server.v0/sql" + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" ) -var testingTable *mem.Table +var testingTable *memory.Table var testingTableSize int func TestLimitPlan(t *testing.T) { require := require.New(t) - table, _ := getTestingTable() - limitPlan := NewLimit(0, table) + table, _ := getTestingTable(t) + limitPlan := NewLimit(0, NewResolvedTable(table)) require.Equal(1, len(limitPlan.Children())) - iterator, err := getLimitedIterator(1) - require.Nil(err) + iterator, err := getLimitedIterator(t, 1) + require.NoError(err) require.NotNil(iterator) } func TestLimitImplementsNode(t *testing.T) { require := require.New(t) - table, _ := getTestingTable() - limitPlan := NewLimit(0, table) + table, _ := getTestingTable(t) + limitPlan := NewLimit(0, NewResolvedTable(table)) childSchema := table.Schema() nodeSchema := limitPlan.Schema() require.True(reflect.DeepEqual(childSchema, nodeSchema)) @@ -38,31 +38,31 @@ func TestLimitImplementsNode(t *testing.T) { } func TestLimit0(t *testing.T) { - _, testingTableSize := getTestingTable() + _, size := getTestingTable(t) testingLimit := 0 - iterator, _ := getLimitedIterator(int64(testingLimit)) - testLimitOverflow(t, iterator, testingLimit, testingTableSize) + iterator, _ := getLimitedIterator(t, int64(testingLimit)) + testLimitOverflow(t, iterator, testingLimit, size) } func TestLimitLessThanTotal(t *testing.T) { - _, testingTableSize := getTestingTable() - testingLimit := testingTableSize - 1 - iterator, _ := getLimitedIterator(int64(testingLimit)) - testLimitOverflow(t, iterator, testingLimit, testingTableSize) + _, size := getTestingTable(t) + testingLimit := size - 1 + iterator, _ := getLimitedIterator(t, int64(testingLimit)) + testLimitOverflow(t, iterator, testingLimit, size) } func TestLimitEqualThanTotal(t *testing.T) { - _, testingTableSize := getTestingTable() - testingLimit := testingTableSize - iterator, _ := getLimitedIterator(int64(testingLimit)) - testLimitOverflow(t, iterator, testingLimit, testingTableSize) + _, size := getTestingTable(t) + testingLimit := size + iterator, _ := getLimitedIterator(t, int64(testingLimit)) + testLimitOverflow(t, iterator, testingLimit, size) } func TestLimitGreaterThanTotal(t *testing.T) { - _, testingTableSize := getTestingTable() - testingLimit := testingTableSize + 1 - iterator, _ := getLimitedIterator(int64(testingLimit)) - testLimitOverflow(t, iterator, testingLimit, testingTableSize) + _, size := getTestingTable(t) + testingLimit := size + 1 + iterator, _ := getLimitedIterator(t, int64(testingLimit)) + testLimitOverflow(t, iterator, testingLimit, size) } func testLimitOverflow(t *testing.T, iter sql.RowIter, limit int, dataSize int) { @@ -80,8 +80,8 @@ func testLimitOverflow(t *testing.T, iter sql.RowIter, limit int, dataSize int) } } -func getTestingTable() (*mem.Table, int) { - +func getTestingTable(t *testing.T) (*memory.Table, int) { + t.Helper() if &testingTable == nil { return testingTable, testingTableSize } @@ -89,19 +89,26 @@ func getTestingTable() (*mem.Table, int) { childSchema := sql.Schema{ {Name: "col1", Type: sql.Text}, } - testingTable = mem.NewTable("test", childSchema) - testingTable.Insert(sql.NewRow("11a")) - testingTable.Insert(sql.NewRow("22a")) - testingTable.Insert(sql.NewRow("33a")) - testingTableSize = 3 + testingTable = memory.NewTable("test", childSchema) + + rows := []sql.Row{ + sql.NewRow("11a"), + sql.NewRow("22a"), + sql.NewRow("33a"), + } + + for _, r := range rows { + require.NoError(t, testingTable.Insert(sql.NewEmptyContext(), r)) + } - return testingTable, testingTableSize + return testingTable, len(rows) } -func getLimitedIterator(limitSize int64) (sql.RowIter, error) { +func getLimitedIterator(t *testing.T, limitSize int64) (sql.RowIter, error) { + t.Helper() ctx := sql.NewEmptyContext() - table, _ := getTestingTable() - limitPlan := NewLimit(limitSize, table) + table, _ := getTestingTable(t) + limitPlan := NewLimit(limitSize, NewResolvedTable(table)) return limitPlan.RowIter(ctx) } diff --git a/sql/plan/lock.go b/sql/plan/lock.go new file mode 100644 index 000000000..8e51edec6 --- /dev/null +++ b/sql/plan/lock.go @@ -0,0 +1,174 @@ +package plan + +import ( + "fmt" + + "github.com/src-d/go-mysql-server/sql" + errors "gopkg.in/src-d/go-errors.v1" +) + +// TableLock is a read or write lock on a table. +type TableLock struct { + Table sql.Node + // Write if it's true, read if it's false. + Write bool +} + +// LockTables will lock tables for the session in which it's executed. +type LockTables struct { + Catalog *sql.Catalog + Locks []*TableLock +} + +// NewLockTables creates a new LockTables node. +func NewLockTables(locks []*TableLock) *LockTables { + return &LockTables{Locks: locks} +} + +// Children implements the sql.Node interface. +func (t *LockTables) Children() []sql.Node { + var children = make([]sql.Node, len(t.Locks)) + for i, l := range t.Locks { + children[i] = l.Table + } + return children +} + +// Resolved implements the sql.Node interface. +func (t *LockTables) Resolved() bool { + for _, l := range t.Locks { + if !l.Table.Resolved() { + return false + } + } + return true +} + +// Schema implements the sql.Node interface. +func (t *LockTables) Schema() sql.Schema { return nil } + +// RowIter implements the sql.Node interface. +func (t *LockTables) RowIter(ctx *sql.Context) (sql.RowIter, error) { + span, ctx := ctx.Span("plan.LockTables") + defer span.Finish() + + id := ctx.ID() + for _, l := range t.Locks { + lockable, err := getLockable(l.Table) + if err != nil { + // If a table is not lockable, just skip it + ctx.Warn(0, err.Error()) + continue + } + + if err := lockable.Lock(ctx, l.Write); err != nil { + ctx.Error(0, "unable to lock table: %s", err) + } else { + t.Catalog.LockTable(id, lockable.Name()) + } + } + + return sql.RowsToRowIter(), nil +} + +func (t *LockTables) String() string { + var children = make([]string, len(t.Locks)) + for i, l := range t.Locks { + if l.Write { + children[i] = fmt.Sprintf("[WRITE] %s", l.Table.String()) + } else { + children[i] = fmt.Sprintf("[READ] %s", l.Table.String()) + } + } + + p := sql.NewTreePrinter() + _ = p.WriteNode("LockTables") + _ = p.WriteChildren(children...) + return p.String() +} + +// WithChildren implements the Node interface. +func (t *LockTables) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != len(t.Locks) { + return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), len(t.Locks)) + } + + var locks = make([]*TableLock, len(t.Locks)) + for i, n := range children { + locks[i] = &TableLock{ + Table: n, + Write: t.Locks[i].Write, + } + } + + return &LockTables{t.Catalog, locks}, nil +} + +// ErrTableNotLockable is returned whenever a lockable table can't be found. +var ErrTableNotLockable = errors.NewKind("table %s is not lockable") + +func getLockable(node sql.Node) (sql.Lockable, error) { + switch node := node.(type) { + case *ResolvedTable: + return getLockableTable(node.Table) + default: + return nil, ErrTableNotLockable.New("unknown") + } +} + +func getLockableTable(table sql.Table) (sql.Lockable, error) { + switch t := table.(type) { + case sql.Lockable: + return t, nil + case sql.TableWrapper: + return getLockableTable(t.Underlying()) + default: + return nil, ErrTableNotLockable.New(t.Name()) + } +} + +// UnlockTables will release all locks for the current session. +type UnlockTables struct { + Catalog *sql.Catalog +} + +// NewUnlockTables returns a new UnlockTables node. +func NewUnlockTables() *UnlockTables { + return new(UnlockTables) +} + +// Children implements the sql.Node interface. +func (t *UnlockTables) Children() []sql.Node { return nil } + +// Resolved implements the sql.Node interface. +func (t *UnlockTables) Resolved() bool { return true } + +// Schema implements the sql.Node interface. +func (t *UnlockTables) Schema() sql.Schema { return nil } + +// RowIter implements the sql.Node interface. +func (t *UnlockTables) RowIter(ctx *sql.Context) (sql.RowIter, error) { + span, ctx := ctx.Span("plan.UnlockTables") + defer span.Finish() + + if err := t.Catalog.UnlockTables(ctx, ctx.ID()); err != nil { + return nil, err + } + + return sql.RowsToRowIter(), nil +} + +func (t *UnlockTables) String() string { + p := sql.NewTreePrinter() + _ = p.WriteNode("UnlockTables") + return p.String() +} + +// WithChildren implements the Node interface. +func (t *UnlockTables) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), 0) + } + + return t, nil +} diff --git a/sql/plan/lock_test.go b/sql/plan/lock_test.go new file mode 100644 index 000000000..63b806420 --- /dev/null +++ b/sql/plan/lock_test.go @@ -0,0 +1,83 @@ +package plan + +import ( + "testing" + + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/stretchr/testify/require" +) + +func TestLockTables(t *testing.T) { + require := require.New(t) + + t1 := newLockableTable(memory.NewTable("foo", nil)) + t2 := newLockableTable(memory.NewTable("bar", nil)) + node := NewLockTables([]*TableLock{ + {NewResolvedTable(t1), true}, + {NewResolvedTable(t2), false}, + }) + node.Catalog = sql.NewCatalog() + + _, err := node.RowIter(sql.NewEmptyContext()) + require.NoError(err) + + require.Equal(1, t1.writeLocks) + require.Equal(0, t1.readLocks) + require.Equal(1, t2.readLocks) + require.Equal(0, t2.writeLocks) +} + +func TestUnlockTables(t *testing.T) { + require := require.New(t) + + db := memory.NewDatabase("db") + t1 := newLockableTable(memory.NewTable("foo", nil)) + t2 := newLockableTable(memory.NewTable("bar", nil)) + t3 := newLockableTable(memory.NewTable("baz", nil)) + db.AddTable("foo", t1) + db.AddTable("bar", t2) + db.AddTable("baz", t3) + + catalog := sql.NewCatalog() + catalog.AddDatabase(db) + catalog.LockTable(0, "foo") + catalog.LockTable(0, "bar") + + node := NewUnlockTables() + node.Catalog = catalog + + _, err := node.RowIter(sql.NewEmptyContext()) + require.NoError(err) + + require.Equal(1, t1.unlocks) + require.Equal(1, t2.unlocks) + require.Equal(0, t3.unlocks) +} + +type lockableTable struct { + sql.Table + readLocks int + writeLocks int + unlocks int +} + +func newLockableTable(t sql.Table) *lockableTable { + return &lockableTable{Table: t} +} + +var _ sql.Lockable = (*lockableTable)(nil) + +func (l *lockableTable) Lock(ctx *sql.Context, write bool) error { + if write { + l.writeLocks++ + } else { + l.readLocks++ + } + return nil +} + +func (l *lockableTable) Unlock(ctx *sql.Context, id uint32) error { + l.unlocks++ + return nil +} diff --git a/sql/plan/naturaljoin.go b/sql/plan/naturaljoin.go new file mode 100644 index 000000000..6ccf0182b --- /dev/null +++ b/sql/plan/naturaljoin.go @@ -0,0 +1,45 @@ +package plan + +import "github.com/src-d/go-mysql-server/sql" + +// NaturalJoin is a join that automatically joins by all the columns with the +// same name. +// NaturalJoin is a placeholder node, it should be transformed into an INNER +// JOIN during analysis. +type NaturalJoin struct { + BinaryNode +} + +// NewNaturalJoin returns a new NaturalJoin node. +func NewNaturalJoin(left, right sql.Node) *NaturalJoin { + return &NaturalJoin{BinaryNode{left, right}} +} + +// RowIter implements the Node interface. +func (NaturalJoin) RowIter(*sql.Context) (sql.RowIter, error) { + panic("NaturalJoin is a placeholder, RowIter called") +} + +// Schema implements the Node interface. +func (NaturalJoin) Schema() sql.Schema { + panic("NaturalJoin is a placeholder, Schema called") +} + +// Resolved implements the Node interface. +func (NaturalJoin) Resolved() bool { return false } + +func (j NaturalJoin) String() string { + pr := sql.NewTreePrinter() + _ = pr.WriteNode("NaturalJoin") + _ = pr.WriteChildren(j.Left.String(), j.Right.String()) + return pr.String() +} + +// WithChildren implements the Node interface. +func (j *NaturalJoin) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 2 { + return nil, sql.ErrInvalidChildrenNumber.New(j, len(children), 2) + } + + return NewNaturalJoin(children[0], children[1]), nil +} diff --git a/sql/plan/nothing.go b/sql/plan/nothing.go new file mode 100644 index 000000000..43792405d --- /dev/null +++ b/sql/plan/nothing.go @@ -0,0 +1,25 @@ +package plan + +import "github.com/src-d/go-mysql-server/sql" + +// Nothing is a node that will return no rows. +var Nothing nothing + +type nothing struct{} + +func (nothing) String() string { return "NOTHING" } +func (nothing) Resolved() bool { return true } +func (nothing) Schema() sql.Schema { return nil } +func (nothing) Children() []sql.Node { return nil } +func (nothing) RowIter(*sql.Context) (sql.RowIter, error) { + return sql.RowsToRowIter(), nil +} + +// WithChildren implements the Node interface. +func (n nothing) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(n, len(children), 0) + } + + return Nothing, nil +} diff --git a/sql/plan/offset.go b/sql/plan/offset.go index a02357202..527b2c7cf 100644 --- a/sql/plan/offset.go +++ b/sql/plan/offset.go @@ -2,20 +2,20 @@ package plan import ( opentracing "github.com/opentracing/opentracing-go" - "gopkg.in/src-d/go-mysql-server.v0/sql" + "github.com/src-d/go-mysql-server/sql" ) // Offset is a node that skips the first N rows. type Offset struct { UnaryNode - n int64 + Offset int64 } // NewOffset creates a new Offset node. func NewOffset(n int64, child sql.Node) *Offset { return &Offset{ UnaryNode: UnaryNode{Child: child}, - n: n, + Offset: n, } } @@ -26,37 +26,27 @@ func (o *Offset) Resolved() bool { // RowIter implements the Node interface. func (o *Offset) RowIter(ctx *sql.Context) (sql.RowIter, error) { - span, ctx := ctx.Span("plan.Offset", opentracing.Tag{Key: "offset", Value: o.n}) + span, ctx := ctx.Span("plan.Offset", opentracing.Tag{Key: "offset", Value: o.Offset}) it, err := o.Child.RowIter(ctx) if err != nil { span.Finish() return nil, err } - return sql.NewSpanIter(span, &offsetIter{o.n, it}), nil + return sql.NewSpanIter(span, &offsetIter{o.Offset, it}), nil } -// TransformUp implements the Transformable interface. -func (o *Offset) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := o.Child.TransformUp(f) - if err != nil { - return nil, err - } - return f(NewOffset(o.n, child)) -} - -// TransformExpressionsUp implements the Transformable interface. -func (o *Offset) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - child, err := o.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (o *Offset) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(o, len(children), 1) } - return NewOffset(o.n, child), nil + return NewOffset(o.Offset, children[0]), nil } func (o Offset) String() string { pr := sql.NewTreePrinter() - _ = pr.WriteNode("Offset(%d)", o.n) + _ = pr.WriteNode("Offset(%d)", o.Offset) _ = pr.WriteChildren(o.Child.String()) return pr.String() } diff --git a/sql/plan/offset_test.go b/sql/plan/offset_test.go index d80998f18..607905536 100644 --- a/sql/plan/offset_test.go +++ b/sql/plan/offset_test.go @@ -3,16 +3,16 @@ package plan import ( "testing" + "github.com/src-d/go-mysql-server/sql" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/sql" ) func TestOffsetPlan(t *testing.T) { require := require.New(t) ctx := sql.NewEmptyContext() - table, _ := getTestingTable() - offset := NewOffset(0, table) + table, _ := getTestingTable(t) + offset := NewOffset(0, NewResolvedTable(table)) require.Equal(1, len(offset.Children())) iter, err := offset.RowIter(ctx) @@ -24,8 +24,8 @@ func TestOffset(t *testing.T) { require := require.New(t) ctx := sql.NewEmptyContext() - table, n := getTestingTable() - offset := NewOffset(1, table) + table, n := getTestingTable(t) + offset := NewOffset(1, NewResolvedTable(table)) iter, err := offset.RowIter(ctx) require.NoError(err) diff --git a/sql/plan/process.go b/sql/plan/process.go new file mode 100644 index 000000000..2e0bec439 --- /dev/null +++ b/sql/plan/process.go @@ -0,0 +1,273 @@ +package plan + +import ( + "io" + + "github.com/src-d/go-mysql-server/sql" +) + +// QueryProcess represents a running query process node. It will use a callback +// to notify when it has finished running. +type QueryProcess struct { + UnaryNode + Notify NotifyFunc +} + +// NotifyFunc is a function to notify about some event. +type NotifyFunc func() + +// NewQueryProcess creates a new QueryProcess node. +func NewQueryProcess(node sql.Node, notify NotifyFunc) *QueryProcess { + return &QueryProcess{UnaryNode{Child: node}, notify} +} + +// WithChildren implements the Node interface. +func (p *QueryProcess) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 1) + } + + return NewQueryProcess(children[0], p.Notify), nil +} + +// RowIter implements the sql.Node interface. +func (p *QueryProcess) RowIter(ctx *sql.Context) (sql.RowIter, error) { + iter, err := p.Child.RowIter(ctx) + if err != nil { + return nil, err + } + + return &trackedRowIter{iter: iter, onDone: p.Notify}, nil +} + +func (p *QueryProcess) String() string { return p.Child.String() } + +// ProcessIndexableTable is a wrapper for sql.Tables inside a query process +// that support indexing. +// It notifies the process manager about the status of a query when a +// partition is processed. +type ProcessIndexableTable struct { + sql.IndexableTable + OnPartitionDone NamedNotifyFunc + OnPartitionStart NamedNotifyFunc + OnRowNext NamedNotifyFunc +} + +// NewProcessIndexableTable returns a new ProcessIndexableTable. +func NewProcessIndexableTable(t sql.IndexableTable, onPartitionDone, onPartitionStart, OnRowNext NamedNotifyFunc) *ProcessIndexableTable { + return &ProcessIndexableTable{t, onPartitionDone, onPartitionStart, OnRowNext} +} + +// Underlying implements sql.TableWrapper interface. +func (t *ProcessIndexableTable) Underlying() sql.Table { + return t.IndexableTable +} + +// IndexKeyValues implements the sql.IndexableTable interface. +func (t *ProcessIndexableTable) IndexKeyValues( + ctx *sql.Context, + columns []string, +) (sql.PartitionIndexKeyValueIter, error) { + iter, err := t.IndexableTable.IndexKeyValues(ctx, columns) + if err != nil { + return nil, err + } + + return &trackedPartitionIndexKeyValueIter{iter, t.OnPartitionDone, t.OnPartitionStart, t.OnRowNext}, nil +} + +// PartitionRows implements the sql.Table interface. +func (t *ProcessIndexableTable) PartitionRows(ctx *sql.Context, p sql.Partition) (sql.RowIter, error) { + iter, err := t.IndexableTable.PartitionRows(ctx, p) + if err != nil { + return nil, err + } + + partitionName := partitionName(p) + if t.OnPartitionStart != nil { + t.OnPartitionStart(partitionName) + } + + var onDone NotifyFunc + if t.OnPartitionDone != nil { + onDone = func() { + t.OnPartitionDone(partitionName) + } + } + + var onNext NotifyFunc + if t.OnRowNext != nil { + onNext = func() { + t.OnRowNext(partitionName) + } + } + + return &trackedRowIter{iter: iter, onNext: onNext, onDone: onDone}, nil +} + +var _ sql.IndexableTable = (*ProcessIndexableTable)(nil) + +// NamedNotifyFunc is a function to notify about some event with a string argument. +type NamedNotifyFunc func(name string) + +// ProcessTable is a wrapper for sql.Tables inside a query process. It +// notifies the process manager about the status of a query when a partition +// is processed. +type ProcessTable struct { + sql.Table + OnPartitionDone NamedNotifyFunc + OnPartitionStart NamedNotifyFunc + OnRowNext NamedNotifyFunc +} + +// NewProcessTable returns a new ProcessTable. +func NewProcessTable(t sql.Table, onPartitionDone, onPartitionStart, OnRowNext NamedNotifyFunc) *ProcessTable { + return &ProcessTable{t, onPartitionDone, onPartitionStart, OnRowNext} +} + +// Underlying implements sql.TableWrapper interface. +func (t *ProcessTable) Underlying() sql.Table { + return t.Table +} + +// PartitionRows implements the sql.Table interface. +func (t *ProcessTable) PartitionRows(ctx *sql.Context, p sql.Partition) (sql.RowIter, error) { + iter, err := t.Table.PartitionRows(ctx, p) + if err != nil { + return nil, err + } + + partitionName := partitionName(p) + if t.OnPartitionStart != nil { + t.OnPartitionStart(partitionName) + } + + var onDone NotifyFunc + if t.OnPartitionDone != nil { + onDone = func() { + t.OnPartitionDone(partitionName) + } + } + + var onNext NotifyFunc + if t.OnRowNext != nil { + onNext = func() { + t.OnRowNext(partitionName) + } + } + + return &trackedRowIter{iter: iter, onNext: onNext, onDone: onDone}, nil +} + +type trackedRowIter struct { + iter sql.RowIter + onDone NotifyFunc + onNext NotifyFunc +} + +func (i *trackedRowIter) done() { + if i.onDone != nil { + i.onDone() + i.onDone = nil + } +} + +func (i *trackedRowIter) Next() (sql.Row, error) { + row, err := i.iter.Next() + if err != nil { + if err == io.EOF { + i.done() + } + return nil, err + } + + if i.onNext != nil { + i.onNext() + } + + return row, nil +} + +func (i *trackedRowIter) Close() error { + i.done() + return i.iter.Close() +} + +type trackedPartitionIndexKeyValueIter struct { + sql.PartitionIndexKeyValueIter + OnPartitionDone NamedNotifyFunc + OnPartitionStart NamedNotifyFunc + OnRowNext NamedNotifyFunc +} + +func (i *trackedPartitionIndexKeyValueIter) Next() (sql.Partition, sql.IndexKeyValueIter, error) { + p, iter, err := i.PartitionIndexKeyValueIter.Next() + if err != nil { + return nil, nil, err + } + + partitionName := partitionName(p) + if i.OnPartitionStart != nil { + i.OnPartitionStart(partitionName) + } + + var onDone NotifyFunc + if i.OnPartitionDone != nil { + onDone = func() { + i.OnPartitionDone(partitionName) + } + } + + var onNext NotifyFunc + if i.OnRowNext != nil { + onNext = func() { + i.OnRowNext(partitionName) + } + } + + return p, &trackedIndexKeyValueIter{iter, onDone, onNext}, nil +} + +type trackedIndexKeyValueIter struct { + iter sql.IndexKeyValueIter + onDone NotifyFunc + onNext NotifyFunc +} + +func (i *trackedIndexKeyValueIter) done() { + if i.onDone != nil { + i.onDone() + i.onDone = nil + } +} + +func (i *trackedIndexKeyValueIter) Close() (err error) { + i.done() + if i.iter != nil { + err = i.iter.Close() + } + return err +} + +func (i *trackedIndexKeyValueIter) Next() ([]interface{}, []byte, error) { + v, k, err := i.iter.Next() + if err != nil { + if err == io.EOF { + i.done() + } + return nil, nil, err + } + + if i.onNext != nil { + i.onNext() + } + + return v, k, nil +} + +func partitionName(p sql.Partition) string { + if n, ok := p.(sql.Nameable); ok { + return n.Name() + } + return string(p.Key()) +} diff --git a/sql/plan/process_test.go b/sql/plan/process_test.go new file mode 100644 index 000000000..de819edb0 --- /dev/null +++ b/sql/plan/process_test.go @@ -0,0 +1,168 @@ +package plan + +import ( + "io" + "testing" + + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" +) + +func TestQueryProcess(t *testing.T) { + require := require.New(t) + + table := memory.NewTable("foo", sql.Schema{ + {Name: "a", Type: sql.Int64}, + }) + + table.Insert(sql.NewEmptyContext(), sql.NewRow(int64(1))) + table.Insert(sql.NewEmptyContext(), sql.NewRow(int64(2))) + + var notifications int + + node := NewQueryProcess( + NewProject( + []sql.Expression{ + expression.NewGetField(0, sql.Int64, "a", false), + }, + NewResolvedTable(table), + ), + func() { + notifications++ + }, + ) + + iter, err := node.RowIter(sql.NewEmptyContext()) + require.NoError(err) + + rows, err := sql.RowIterToRows(iter) + require.NoError(err) + + expected := []sql.Row{ + {int64(1)}, + {int64(2)}, + } + + require.ElementsMatch(expected, rows) + require.Equal(1, notifications) +} + +func TestProcessTable(t *testing.T) { + require := require.New(t) + + table := memory.NewPartitionedTable("foo", sql.Schema{ + {Name: "a", Type: sql.Int64}, + }, 2) + + table.Insert(sql.NewEmptyContext(), sql.NewRow(int64(1))) + table.Insert(sql.NewEmptyContext(), sql.NewRow(int64(2))) + table.Insert(sql.NewEmptyContext(), sql.NewRow(int64(3))) + table.Insert(sql.NewEmptyContext(), sql.NewRow(int64(4))) + + var partitionDoneNotifications int + var partitionStartNotifications int + var rowNextNotifications int + + node := NewProject( + []sql.Expression{ + expression.NewGetField(0, sql.Int64, "a", false), + }, + NewResolvedTable( + NewProcessTable( + table, + func(partitionName string) { + partitionDoneNotifications++ + }, + func(partitionName string) { + partitionStartNotifications++ + }, + func(partitionName string) { + rowNextNotifications++ + }, + ), + ), + ) + + iter, err := node.RowIter(sql.NewEmptyContext()) + require.NoError(err) + + rows, err := sql.RowIterToRows(iter) + require.NoError(err) + + expected := []sql.Row{ + {int64(1)}, + {int64(2)}, + {int64(3)}, + {int64(4)}, + } + + require.ElementsMatch(expected, rows) + require.Equal(2, partitionDoneNotifications) + require.Equal(2, partitionStartNotifications) + require.Equal(4, rowNextNotifications) +} + +func TestProcessIndexableTable(t *testing.T) { + require := require.New(t) + + table := memory.NewPartitionedTable("foo", sql.Schema{ + {Name: "a", Type: sql.Int64, Source: "foo"}, + }, 2) + + table.Insert(sql.NewEmptyContext(), sql.NewRow(int64(1))) + table.Insert(sql.NewEmptyContext(), sql.NewRow(int64(2))) + table.Insert(sql.NewEmptyContext(), sql.NewRow(int64(3))) + table.Insert(sql.NewEmptyContext(), sql.NewRow(int64(4))) + + var partitionDoneNotifications int + var partitionStartNotifications int + var rowNextNotifications int + + pt := NewProcessIndexableTable( + table, + func(partitionName string) { + partitionDoneNotifications++ + }, + func(partitionName string) { + partitionStartNotifications++ + }, + func(partitionName string) { + rowNextNotifications++ + }, + ) + + iter, err := pt.IndexKeyValues(sql.NewEmptyContext(), []string{"a"}) + require.NoError(err) + + var values [][]interface{} + for { + _, kviter, err := iter.Next() + if err == io.EOF { + break + } + require.NoError(err) + + for { + v, _, err := kviter.Next() + if err == io.EOF { + break + } + values = append(values, v) + require.NoError(err) + } + } + + expectedValues := [][]interface{}{ + {int64(1)}, + {int64(2)}, + {int64(3)}, + {int64(4)}, + } + + require.ElementsMatch(expectedValues, values) + require.Equal(2, partitionDoneNotifications) + require.Equal(2, partitionStartNotifications) + require.Equal(4, rowNextNotifications) +} diff --git a/sql/plan/processlist.go b/sql/plan/processlist.go new file mode 100644 index 000000000..a447fb79d --- /dev/null +++ b/sql/plan/processlist.go @@ -0,0 +1,119 @@ +package plan + +import ( + "sort" + "strings" + + "github.com/src-d/go-mysql-server/sql" +) + +type process struct { + id int64 + user string + host string + db string + command string + time int64 + state string + info string +} + +func (p process) toRow() sql.Row { + return sql.NewRow( + p.id, + p.user, + p.host, + p.db, + p.command, + p.time, + p.state, + p.info, + ) +} + +var processListSchema = sql.Schema{ + {Name: "Id", Type: sql.Int64}, + {Name: "User", Type: sql.Text}, + {Name: "Host", Type: sql.Text}, + {Name: "db", Type: sql.Text}, + {Name: "Command", Type: sql.Text}, + {Name: "Time", Type: sql.Int64}, + {Name: "State", Type: sql.Text}, + {Name: "Info", Type: sql.Text}, +} + +// ShowProcessList shows a list of all current running processes. +type ShowProcessList struct { + Database string + *sql.ProcessList +} + +// NewShowProcessList creates a new ProcessList node. +func NewShowProcessList() *ShowProcessList { return new(ShowProcessList) } + +// Children implements the Node interface. +func (p *ShowProcessList) Children() []sql.Node { return nil } + +// Resolved implements the Node interface. +func (p *ShowProcessList) Resolved() bool { return true } + +// WithChildren implements the Node interface. +func (p *ShowProcessList) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 0) + } + + return p, nil +} + +// Schema implements the Node interface. +func (p *ShowProcessList) Schema() sql.Schema { return processListSchema } + +// RowIter implements the Node interface. +func (p *ShowProcessList) RowIter(ctx *sql.Context) (sql.RowIter, error) { + processes := p.Processes() + var rows = make([]sql.Row, len(processes)) + + for i, proc := range processes { + var status []string + var names []string + for name := range proc.Progress { + names = append(names, name) + } + sort.Strings(names) + + for _, name := range names { + progress := proc.Progress[name] + + printer := sql.NewTreePrinter() + _ = printer.WriteNode("\n" + progress.String()) + children := []string{} + for _, partitionProgress := range progress.PartitionsProgress { + children = append(children, partitionProgress.String()) + } + sort.Strings(children) + _ = printer.WriteChildren(children...) + + status = append(status, printer.String()) + } + + if len(status) == 0 { + status = []string{"running"} + } + + rows[i] = process{ + id: int64(proc.Connection), + user: proc.User, + time: int64(proc.Seconds()), + state: strings.Join(status, ""), + command: proc.Type.String(), + host: ctx.Session.Client().Address, + info: proc.Query, + db: p.Database, + }.toRow() + } + + return sql.RowsToRowIter(rows...), nil +} + +func (p *ShowProcessList) String() string { return "ProcessList" } diff --git a/sql/plan/processlist_test.go b/sql/plan/processlist_test.go new file mode 100644 index 000000000..12ee98b9a --- /dev/null +++ b/sql/plan/processlist_test.go @@ -0,0 +1,61 @@ +package plan + +import ( + "context" + "testing" + + "github.com/src-d/go-mysql-server/sql" + "github.com/stretchr/testify/require" +) + +func TestShowProcessList(t *testing.T) { + require := require.New(t) + + addr := "127.0.0.1:34567" + + n := NewShowProcessList() + p := sql.NewProcessList() + sess := sql.NewSession("0.0.0.0:3306", addr, "foo", 1) + ctx := sql.NewContext(context.Background(), sql.WithPid(1), sql.WithSession(sess)) + + ctx, err := p.AddProcess(ctx, sql.QueryProcess, "SELECT foo") + require.NoError(err) + + p.AddTableProgress(ctx.Pid(), "a", 5) + p.AddTableProgress(ctx.Pid(), "b", 6) + + ctx = sql.NewContext(context.Background(), sql.WithPid(2), sql.WithSession(sess)) + ctx, err = p.AddProcess(ctx, sql.CreateIndexProcess, "SELECT bar") + require.NoError(err) + + p.AddTableProgress(ctx.Pid(), "foo", 2) + + p.UpdateTableProgress(1, "a", 3) + p.UpdateTableProgress(1, "a", 1) + p.UpdatePartitionProgress(1, "a", "a-1", 7) + p.UpdatePartitionProgress(1, "a", "a-2", 9) + p.UpdateTableProgress(1, "b", 2) + p.UpdateTableProgress(2, "foo", 1) + + n.ProcessList = p + n.Database = "foo" + + iter, err := n.RowIter(ctx) + require.NoError(err) + rows, err := sql.RowIterToRows(iter) + require.NoError(err) + + expected := []sql.Row{ + {int64(1), "foo", addr, "foo", "query", int64(0), + ` +a (4/5 partitions) + ├─ a-1 (7/? rows) + └─ a-2 (9/? rows) + +b (2/6 partitions) +`, "SELECT foo"}, + {int64(1), "foo", addr, "foo", "create_index", int64(0), "\nfoo (1/2 partitions)\n", "SELECT bar"}, + } + + require.ElementsMatch(expected, rows) +} diff --git a/sql/plan/project.go b/sql/plan/project.go index 2214f8691..8b166d449 100644 --- a/sql/plan/project.go +++ b/sql/plan/project.go @@ -4,7 +4,7 @@ import ( "strings" opentracing "github.com/opentracing/opentracing-go" - "gopkg.in/src-d/go-mysql-server.v0/sql" + "github.com/src-d/go-mysql-server/sql" ) // Project is a projection of certain expression from the children node. @@ -69,31 +69,7 @@ func (p *Project) RowIter(ctx *sql.Context) (sql.RowIter, error) { return sql.NewSpanIter(span, &iter{p, i, ctx}), nil } -// TransformUp implements the Transformable interface. -func (p *Project) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := p.Child.TransformUp(f) - if err != nil { - return nil, err - } - return f(NewProject(p.Projections, child)) -} - -// TransformExpressionsUp implements the Transformable interface. -func (p *Project) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - exprs, err := transformExpressionsUp(f, p.Projections) - if err != nil { - return nil, err - } - - child, err := p.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - - return NewProject(exprs, child), nil -} - -func (p Project) String() string { +func (p *Project) String() string { pr := sql.NewTreePrinter() var exprs = make([]string, len(p.Projections)) for i, expr := range p.Projections { @@ -105,10 +81,28 @@ func (p Project) String() string { } // Expressions implements the Expressioner interface. -func (p Project) Expressions() []sql.Expression { +func (p *Project) Expressions() []sql.Expression { return p.Projections } +// WithChildren implements the Node interface. +func (p *Project) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 1) + } + + return NewProject(p.Projections, children[0]), nil +} + +// WithExpressions implements the Expressioner interface. +func (p *Project) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + if len(exprs) != len(p.Projections) { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(exprs), len(p.Projections)) + } + + return NewProject(exprs, p.Child), nil +} + type iter struct { p *Project childIter sql.RowIter diff --git a/sql/plan/project_test.go b/sql/plan/project_test.go index bef6da08f..84950b924 100644 --- a/sql/plan/project_test.go +++ b/sql/plan/project_test.go @@ -6,9 +6,9 @@ import ( "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/mem" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" ) func TestProject(t *testing.T) { @@ -18,25 +18,28 @@ func TestProject(t *testing.T) { {Name: "col1", Type: sql.Text, Nullable: true}, {Name: "col2", Type: sql.Text, Nullable: true}, } - child := mem.NewTable("test", childSchema) - child.Insert(sql.NewRow("col1_1", "col2_1")) - child.Insert(sql.NewRow("col1_2", "col2_2")) - p := NewProject([]sql.Expression{expression.NewGetField(1, sql.Text, "col2", true)}, child) + child := memory.NewTable("test", childSchema) + child.Insert(sql.NewEmptyContext(), sql.NewRow("col1_1", "col2_1")) + child.Insert(sql.NewEmptyContext(), sql.NewRow("col1_2", "col2_2")) + p := NewProject( + []sql.Expression{expression.NewGetField(1, sql.Text, "col2", true)}, + NewResolvedTable(child), + ) require.Equal(1, len(p.Children())) schema := sql.Schema{ {Name: "col2", Type: sql.Text, Nullable: true}, } require.Equal(schema, p.Schema()) iter, err := p.RowIter(ctx) - require.Nil(err) + require.NoError(err) require.NotNil(iter) row, err := iter.Next() - require.Nil(err) + require.NoError(err) require.NotNil(row) require.Equal(1, len(row)) require.Equal("col2_1", row[0]) row, err = iter.Next() - require.Nil(err) + require.NoError(err) require.NotNil(row) require.Equal(1, len(row)) require.Equal("col2_2", row[0]) @@ -44,7 +47,7 @@ func TestProject(t *testing.T) { require.Equal(io.EOF, err) require.Nil(row) - p = NewProject(nil, child) + p = NewProject(nil, NewResolvedTable(child)) require.Equal(0, len(p.Schema())) p = NewProject([]sql.Expression{ @@ -52,7 +55,7 @@ func TestProject(t *testing.T) { expression.NewGetField(1, sql.Text, "col2", true), "foo", ), - }, child) + }, NewResolvedTable(child)) schema = sql.Schema{ {Name: "foo", Type: sql.Text, Nullable: true}, } @@ -71,10 +74,10 @@ func BenchmarkProject(b *testing.B) { expression.NewGetField(3, sql.Int32, "intfield", false), expression.NewGetField(4, sql.Int64, "bigintfield", false), expression.NewGetField(5, sql.Blob, "blobfield", false), - }, benchtable) + }, NewResolvedTable(benchtable)) iter, err := d.RowIter(ctx) - require.Nil(err) + require.NoError(err) require.NotNil(iter) for { diff --git a/sql/plan/pushdown.go b/sql/plan/pushdown.go deleted file mode 100644 index d37978ab3..000000000 --- a/sql/plan/pushdown.go +++ /dev/null @@ -1,169 +0,0 @@ -package plan - -import ( - "fmt" - "strings" - - opentracing "github.com/opentracing/opentracing-go" - "gopkg.in/src-d/go-mysql-server.v0/sql" -) - -// PushdownProjectionTable is a node wrapping a table implementing the -// sql.PushdownProjectionTable interface so it returns a RowIter with -// custom logic given the set of used columns that need to be projected. -// PushdownProjectionTable nodes don't propagate transformations. -type PushdownProjectionTable struct { - sql.PushdownProjectionTable - Columns []string -} - -// NewPushdownProjectionTable creates a new PushdownProjectionTable node. -func NewPushdownProjectionTable( - columns []string, - table sql.PushdownProjectionTable, -) *PushdownProjectionTable { - return &PushdownProjectionTable{table, columns} -} - -// TransformUp implements the Node interface. -func (t *PushdownProjectionTable) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - node, err := t.PushdownProjectionTable.TransformUp(f) - if err != nil { - return nil, err - } - - table, ok := node.(sql.PushdownProjectionTable) - if !ok { - return node, nil - } - - return f(NewPushdownProjectionTable(t.Columns, table)) -} - -// TransformExpressionsUp implements the Node interface. -func (t *PushdownProjectionTable) TransformExpressionsUp( - f sql.TransformExprFunc, -) (sql.Node, error) { - node, err := t.PushdownProjectionTable.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - - table, ok := node.(sql.PushdownProjectionTable) - if !ok { - return node, nil - } - - return NewPushdownProjectionTable(t.Columns, table), nil -} - -// RowIter implements the Node interface. -func (t *PushdownProjectionTable) RowIter(ctx *sql.Context) (sql.RowIter, error) { - span, ctx := ctx.Span("plan.PushdownProjectionTable", opentracing.Tags{ - "columns": len(t.Columns), - "table": t.Name(), - }) - - iter, err := t.WithProject(ctx, t.Columns) - if err != nil { - span.Finish() - return nil, err - } - - return sql.NewSpanIter(span, iter), nil -} - -func (t PushdownProjectionTable) String() string { - pr := sql.NewTreePrinter() - _ = pr.WriteNode("PushdownProjectionTable(%s)", strings.Join(t.Columns, ", ")) - _ = pr.WriteChildren(t.PushdownProjectionTable.String()) - return pr.String() -} - -// PushdownProjectionAndFiltersTable is a node wrapping a table implementing -// the sql.PushdownProjectionAndFiltersTable interface so it returns a RowIter -// with custom logic given the set of used columns that need to be projected -// and the filters that apply to that table. -// PushdownProjectionAndFiltersTable nodes don't propagate transformations. -type PushdownProjectionAndFiltersTable struct { - sql.PushdownProjectionAndFiltersTable - Columns []sql.Expression - Filters []sql.Expression -} - -// NewPushdownProjectionAndFiltersTable creates a new -// PushdownProjectionAndFiltersTable node. -func NewPushdownProjectionAndFiltersTable( - columns []sql.Expression, - filters []sql.Expression, - table sql.PushdownProjectionAndFiltersTable, -) *PushdownProjectionAndFiltersTable { - return &PushdownProjectionAndFiltersTable{table, columns, filters} -} - -// TransformUp implements the Node interface. -func (t *PushdownProjectionAndFiltersTable) TransformUp( - f sql.TransformNodeFunc, -) (sql.Node, error) { - return f(t) -} - -// TransformExpressionsUp implements the Node interface. -func (t *PushdownProjectionAndFiltersTable) TransformExpressionsUp( - f sql.TransformExprFunc, -) (sql.Node, error) { - filters, err := transformExpressionsUp(f, t.Filters) - if err != nil { - return nil, err - } - - return NewPushdownProjectionAndFiltersTable(t.Columns, filters, t.PushdownProjectionAndFiltersTable), nil -} - -// RowIter implements the Node interface. -func (t *PushdownProjectionAndFiltersTable) RowIter(ctx *sql.Context) (sql.RowIter, error) { - span, ctx := ctx.Span("plan.PushdownProjectionAndFiltersTable", opentracing.Tags{ - "columns": len(t.Columns), - "filters": len(t.Filters), - "table": t.Name(), - }) - - iter, err := t.WithProjectAndFilters(ctx, t.Columns, t.Filters) - if err != nil { - span.Finish() - return nil, err - } - - return sql.NewSpanIter(span, iter), nil -} - -func (t PushdownProjectionAndFiltersTable) String() string { - pr := sql.NewTreePrinter() - _ = pr.WriteNode("PushdownProjectionAndFiltersTable") - - var columns = make([]string, len(t.Columns)) - for i, col := range t.Columns { - columns[i] = col.String() - } - - var filters = make([]string, len(t.Filters)) - for i, f := range t.Filters { - filters[i] = f.String() - } - - _ = pr.WriteChildren( - fmt.Sprintf("Columns(%s)", strings.Join(columns, ", ")), - fmt.Sprintf("Filters(%s)", strings.Join(filters, ", ")), - t.PushdownProjectionAndFiltersTable.String(), - ) - - return pr.String() -} - -// Expressions implements the Expressioner interface. -func (t PushdownProjectionAndFiltersTable) Expressions() []sql.Expression { - var exprs []sql.Expression - exprs = append(exprs, t.Columns...) - exprs = append(exprs, t.Filters...) - return exprs -} diff --git a/sql/plan/pushdown_test.go b/sql/plan/pushdown_test.go deleted file mode 100644 index a23b55050..000000000 --- a/sql/plan/pushdown_test.go +++ /dev/null @@ -1,171 +0,0 @@ -package plan - -import ( - "io" - "testing" - - "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/mem" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" -) - -func TestPushdownProjectionTable(t *testing.T) { - require := require.New(t) - memTable := mem.NewTable("table", sql.Schema{ - {Name: "a", Type: sql.Int64, Nullable: false}, - {Name: "b", Type: sql.Int64, Nullable: false}, - {Name: "c", Type: sql.Int64, Nullable: false}, - }) - - table := NewPushdownProjectionTable( - []string{"a", "c"}, - &pushdownProjectionTable{memTable}, - ) - - rows := collectRows(t, table) - expected := []sql.Row{ - sql.Row{int64(1), nil, int64(1)}, - sql.Row{int64(2), nil, int64(2)}, - sql.Row{int64(3), nil, int64(3)}, - sql.Row{int64(4), nil, int64(4)}, - } - - require.Equal(expected, rows) -} - -func TestPushdownProjectionAndFiltersTable(t *testing.T) { - require := require.New(t) - memTable := mem.NewTable("table", sql.Schema{ - {Name: "a", Type: sql.Int64, Nullable: false}, - {Name: "b", Type: sql.Int64, Nullable: false}, - {Name: "c", Type: sql.Int64, Nullable: false}, - }) - - table := NewPushdownProjectionAndFiltersTable( - []sql.Expression{ - expression.NewGetField(0, sql.Int64, "a", false), - expression.NewGetField(2, sql.Int64, "c", false), - }, - []sql.Expression{ - expression.NewNot(expression.NewEquals( - expression.NewGetField(0, sql.Int64, "a", false), - expression.NewLiteral(int64(1), sql.Int64), - )), - expression.NewNot(expression.NewEquals( - expression.NewGetField(0, sql.Int64, "a", false), - expression.NewLiteral(int64(3), sql.Int64), - )), - }, - &pushdownProjectionAndFiltersTable{memTable}, - ) - - rows := collectRows(t, table) - expected := []sql.Row{ - sql.Row{int64(2), nil, int64(2)}, - sql.Row{int64(4), nil, int64(4)}, - } - - require.Equal(expected, rows) -} - -type pushdownProjectionTable struct { - sql.Table -} - -func (t *pushdownProjectionTable) WithProject(_ *sql.Context, cols []string) (sql.RowIter, error) { - var fields []int -Loop: - for i, col := range t.Schema() { - for _, colName := range cols { - if colName == col.Name { - fields = append(fields, i) - continue Loop - } - } - } - - return &pushdownProjectionIter{len(t.Schema()), fields, 0}, nil -} - -type pushdownProjectionIter struct { - len int - fields []int - iter int64 -} - -func (it *pushdownProjectionIter) Next() (sql.Row, error) { - if it.iter > 3 { - return nil, io.EOF - } - - var row = make(sql.Row, it.len) - it.iter++ - for _, f := range it.fields { - row[f] = it.iter - } - return row, nil -} - -func (it *pushdownProjectionIter) Close() error { - it.iter = 4 - return nil -} - -type pushdownProjectionAndFiltersTable struct { - sql.Table -} - -func (pushdownProjectionAndFiltersTable) HandledFilters([]sql.Expression) []sql.Expression { - panic("not implemented") -} - -func (t *pushdownProjectionAndFiltersTable) WithProjectAndFilters(ctx *sql.Context, cols, filters []sql.Expression) (sql.RowIter, error) { - var fields []int -Loop: - for i, col := range t.Schema() { - for _, c := range cols { - if c, ok := c.(sql.Nameable); ok { - if c.Name() == col.Name { - fields = append(fields, i) - continue Loop - } - } - } - } - - return &pushdownProjectionAndFiltersIter{ - &pushdownProjectionIter{len(t.Schema()), fields, 0}, - ctx, - filters, - }, nil -} - -type pushdownProjectionAndFiltersIter struct { - sql.RowIter - ctx *sql.Context - filters []sql.Expression -} - -func (it *pushdownProjectionAndFiltersIter) Next() (sql.Row, error) { -Loop: - for { - row, err := it.RowIter.Next() - if err != nil { - return nil, err - } - - for _, f := range it.filters { - result, err := f.Eval(it.ctx, row) - if err != nil { - return nil, err - } - - if result != true { - continue Loop - } - } - - return row, nil - } -} diff --git a/sql/plan/resolved_table.go b/sql/plan/resolved_table.go new file mode 100644 index 000000000..dbbd689a1 --- /dev/null +++ b/sql/plan/resolved_table.go @@ -0,0 +1,117 @@ +package plan + +import ( + "context" + "io" + + "github.com/src-d/go-mysql-server/sql" +) + +// ResolvedTable represents a resolved SQL Table. +type ResolvedTable struct { + sql.Table +} + +var _ sql.Node = (*ResolvedTable)(nil) + +// NewResolvedTable creates a new instance of ResolvedTable. +func NewResolvedTable(table sql.Table) *ResolvedTable { + return &ResolvedTable{table} +} + +// Resolved implements the Resolvable interface. +func (*ResolvedTable) Resolved() bool { + return true +} + +// Children implements the Node interface. +func (*ResolvedTable) Children() []sql.Node { return nil } + +// RowIter implements the RowIter interface. +func (t *ResolvedTable) RowIter(ctx *sql.Context) (sql.RowIter, error) { + span, ctx := ctx.Span("plan.ResolvedTable") + + partitions, err := t.Table.Partitions(ctx) + if err != nil { + span.Finish() + return nil, err + } + + return sql.NewSpanIter(span, &tableIter{ + ctx: ctx, + table: t.Table, + partitions: partitions, + }), nil +} + +// WithChildren implements the Node interface. +func (t *ResolvedTable) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), 0) + } + + return t, nil +} + +type tableIter struct { + ctx *sql.Context + table sql.Table + partitions sql.PartitionIter + partition sql.Partition + rows sql.RowIter +} + +func (i *tableIter) Next() (sql.Row, error) { + select { + case <-i.ctx.Done(): + return nil, context.Canceled + default: + } + + if i.partition == nil { + partition, err := i.partitions.Next() + if err != nil { + if err == io.EOF { + if e := i.partitions.Close(); e != nil { + return nil, e + } + } + + return nil, err + } + + i.partition = partition + } + + if i.rows == nil { + rows, err := i.table.PartitionRows(i.ctx, i.partition) + if err != nil { + return nil, err + } + + i.rows = rows + } + + row, err := i.rows.Next() + if err != nil && err == io.EOF { + if err = i.rows.Close(); err != nil { + return nil, err + } + + i.partition = nil + i.rows = nil + return i.Next() + } + + return row, err +} + +func (i *tableIter) Close() error { + if i.rows != nil { + if err := i.rows.Close(); err != nil { + _ = i.partitions.Close() + return err + } + } + return i.partitions.Close() +} diff --git a/sql/plan/resolved_table_test.go b/sql/plan/resolved_table_test.go new file mode 100644 index 000000000..a90670468 --- /dev/null +++ b/sql/plan/resolved_table_test.go @@ -0,0 +1,142 @@ +package plan + +import ( + "context" + "fmt" + "io" + "testing" + + "github.com/src-d/go-mysql-server/sql" + "github.com/stretchr/testify/require" +) + +func TestResolvedTable(t *testing.T) { + var require = require.New(t) + + table := NewResolvedTable(newTableTest("test")) + require.NotNil(table) + + iter, err := table.RowIter(sql.NewEmptyContext()) + require.NoError(err) + + rows, err := sql.RowIterToRows(iter) + require.NoError(err) + require.Len(rows, 9) + + tableTest, ok := table.Table.(*dummyTable) + require.True(ok) + + for i, row := range rows { + expected := tableTest.rows[i] + require.ElementsMatch(expected, row) + } +} + +func TestResolvedTableCancelled(t *testing.T) { + var require = require.New(t) + + table := NewResolvedTable(newTableTest("test")) + require.NotNil(table) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + iter, err := table.RowIter(sql.NewContext(ctx)) + require.NoError(err) + + _, err = iter.Next() + require.Equal(context.Canceled, err) +} + +func newTableTest(source string) sql.Table { + schema := []*sql.Column{ + {Name: "col1", Type: sql.Int32, Source: source, Default: int32(0), Nullable: false}, + {Name: "col2", Type: sql.Int64, Source: source, Default: int64(0), Nullable: false}, + {Name: "col3", Type: sql.Text, Source: source, Default: "", Nullable: false}, + } + + keys := [][]byte{ + []byte("partition1"), + []byte("partition2"), + []byte("partition3"), + } + + rows := []sql.Row{ + sql.NewRow(int32(1), int64(9), "one, nine"), + sql.NewRow(int32(2), int64(8), "two, eight"), + sql.NewRow(int32(3), int64(7), "three, seven"), + sql.NewRow(int32(4), int64(6), "four, six"), + sql.NewRow(int32(5), int64(5), "five, five"), + sql.NewRow(int32(6), int64(4), "six, four"), + sql.NewRow(int32(7), int64(3), "seven, three"), + sql.NewRow(int32(8), int64(2), "eight, two"), + sql.NewRow(int32(9), int64(1), "nine, one"), + } + + partitions := map[string][]sql.Row{ + "partition1": []sql.Row{rows[0], rows[1], rows[2]}, + "partition2": []sql.Row{rows[3], rows[4], rows[5]}, + "partition3": []sql.Row{rows[6], rows[7], rows[8]}, + } + + return &dummyTable{schema, keys, partitions, rows} +} + +type dummyTable struct { + schema sql.Schema + keys [][]byte + partitions map[string][]sql.Row + rows []sql.Row +} + +var _ sql.Table = (*dummyTable)(nil) + +func (t *dummyTable) Name() string { return "dummy" } + +func (t *dummyTable) String() string { + panic("not implemented") +} + +func (*dummyTable) Insert(*sql.Context, sql.Row) error { + panic("not implemented") +} + +func (t *dummyTable) Schema() sql.Schema { return t.schema } + +func (t *dummyTable) Partitions(ctx *sql.Context) (sql.PartitionIter, error) { + return &partitionIter{keys: t.keys}, nil +} + +func (t *dummyTable) PartitionRows(ctx *sql.Context, partition sql.Partition) (sql.RowIter, error) { + rows, ok := t.partitions[string(partition.Key())] + if !ok { + return nil, fmt.Errorf( + "partition not found: %q", partition.Key(), + ) + } + + return sql.RowsToRowIter(rows...), nil +} + +type partition struct { + key []byte +} + +func (p *partition) Key() []byte { return p.key } + +type partitionIter struct { + keys [][]byte + pos int +} + +func (p *partitionIter) Next() (sql.Partition, error) { + if p.pos >= len(p.keys) { + return nil, io.EOF + } + + key := p.keys[p.pos] + p.pos++ + return &partition{key}, nil +} + +func (p *partitionIter) Close() error { return nil } diff --git a/sql/plan/set.go b/sql/plan/set.go new file mode 100644 index 000000000..9b8101675 --- /dev/null +++ b/sql/plan/set.go @@ -0,0 +1,133 @@ +package plan + +import ( + "fmt" + "strings" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "vitess.io/vitess/go/vt/sqlparser" +) + +// Set configuration variables. Right now, only session variables are supported. +type Set struct { + Variables []SetVariable +} + +// SetVariable is a key-value pair to represent the value that will be set on +// a variable. +type SetVariable struct { + Name string + Value sql.Expression +} + +// NewSet creates a new Set node. +func NewSet(vars ...SetVariable) *Set { + return &Set{vars} +} + +// Resolved implements the sql.Node interface. +func (s *Set) Resolved() bool { + for _, v := range s.Variables { + if _, ok := v.Value.(*expression.DefaultColumn); ok { + continue + } + if !v.Value.Resolved() { + return false + } + } + return true +} + +// Children implements the sql.Node interface. +func (s *Set) Children() []sql.Node { return nil } + +// WithChildren implements the Node interface. +func (s *Set) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 0) + } + + return s, nil +} + +// WithExpressions implements the Expressioner interface. +func (s *Set) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + if len(exprs) != len(s.Variables) { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(exprs), len(s.Variables)) + } + + var vars = make([]SetVariable, len(s.Variables)) + for i, v := range s.Variables { + vars[i] = SetVariable{ + Name: v.Name, + Value: exprs[i], + } + } + + return NewSet(vars...), nil +} + +// Expressions implements the sql.Expressioner interface. +func (s *Set) Expressions() []sql.Expression { + var exprs = make([]sql.Expression, len(s.Variables)) + for i, v := range s.Variables { + exprs[i] = v.Value + } + return exprs +} + +// RowIter implements the sql.Node interface. +func (s *Set) RowIter(ctx *sql.Context) (sql.RowIter, error) { + span, ctx := ctx.Span("plan.Set") + defer span.Finish() + + const ( + sessionPrefix = sqlparser.SessionStr + "." + globalPrefix = sqlparser.GlobalStr + "." + ) + for _, v := range s.Variables { + var ( + value interface{} + typ sql.Type + err error + ) + + name := strings.TrimPrefix( + strings.TrimPrefix(strings.TrimLeft(v.Name, "@"), sessionPrefix), + globalPrefix, + ) + + if _, ok := v.Value.(*expression.DefaultColumn); ok { + valtyp, ok := sql.DefaultSessionConfig()[name] + if !ok { + continue + } + value, typ = valtyp.Value, valtyp.Typ + } else { + value, err = v.Value.Eval(ctx, nil) + if err != nil { + return nil, err + } + typ = v.Value.Type() + } + + ctx.Set(name, typ, value) + } + + return sql.RowsToRowIter(), nil +} + +// Schema implements the sql.Node interface. +func (s *Set) Schema() sql.Schema { return nil } + +func (s *Set) String() string { + p := sql.NewTreePrinter() + _ = p.WriteNode("Set") + var children = make([]string, len(s.Variables)) + for i, v := range s.Variables { + children[i] = fmt.Sprintf("%s = %s", v.Name, v.Value) + } + _ = p.WriteChildren(children...) + return p.String() +} diff --git a/sql/plan/set_test.go b/sql/plan/set_test.go new file mode 100644 index 000000000..4af46b00e --- /dev/null +++ b/sql/plan/set_test.go @@ -0,0 +1,73 @@ +package plan + +import ( + "context" + "testing" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" +) + +func TestSet(t *testing.T) { + require := require.New(t) + + ctx := sql.NewContext(context.Background(), sql.WithSession(sql.NewBaseSession())) + + s := NewSet( + SetVariable{"foo", expression.NewLiteral("bar", sql.Text)}, + SetVariable{"@@baz", expression.NewLiteral(int64(1), sql.Int64)}, + ) + + _, err := s.RowIter(ctx) + require.NoError(err) + + typ, v := ctx.Get("foo") + require.Equal(sql.Text, typ) + require.Equal("bar", v) + + typ, v = ctx.Get("baz") + require.Equal(sql.Int64, typ) + require.Equal(int64(1), v) +} + +func TestSetDesfault(t *testing.T) { + require := require.New(t) + + ctx := sql.NewContext(context.Background(), sql.WithSession(sql.NewBaseSession())) + + s := NewSet( + SetVariable{"auto_increment_increment", expression.NewLiteral(int64(123), sql.Int64)}, + SetVariable{"@@sql_select_limit", expression.NewLiteral(int64(1), sql.Int64)}, + ) + + _, err := s.RowIter(ctx) + require.NoError(err) + + typ, v := ctx.Get("auto_increment_increment") + require.Equal(sql.Int64, typ) + require.Equal(int64(123), v) + + typ, v = ctx.Get("sql_select_limit") + require.Equal(sql.Int64, typ) + require.Equal(int64(1), v) + + s = NewSet( + SetVariable{"auto_increment_increment", expression.NewDefaultColumn("")}, + SetVariable{"@@sql_select_limit", expression.NewDefaultColumn("")}, + ) + + _, err = s.RowIter(ctx) + require.NoError(err) + + defaults := sql.DefaultSessionConfig() + + typ, v = ctx.Get("auto_increment_increment") + require.Equal(defaults["auto_increment_increment"].Typ, typ) + require.Equal(defaults["auto_increment_increment"].Value, v) + + typ, v = ctx.Get("sql_select_limit") + require.Equal(defaults["sql_select_limit"].Typ, typ) + require.Equal(defaults["sql_select_limit"].Value, v) + +} diff --git a/sql/plan/show_collation.go b/sql/plan/show_collation.go new file mode 100644 index 000000000..3c833b246 --- /dev/null +++ b/sql/plan/show_collation.go @@ -0,0 +1,52 @@ +package plan + +import "github.com/src-d/go-mysql-server/sql" + +// ShowCollation shows all available collations. +type ShowCollation struct{} + +var collationSchema = sql.Schema{ + {Name: "Collation", Type: sql.Text}, + {Name: "Charset", Type: sql.Text}, + {Name: "Id", Type: sql.Int64}, + {Name: "Default", Type: sql.Text}, + {Name: "Compiled", Type: sql.Text}, + {Name: "Sortlen", Type: sql.Int64}, +} + +// NewShowCollation creates a new ShowCollation node. +func NewShowCollation() ShowCollation { + return ShowCollation{} +} + +// Children implements the sql.Node interface. +func (ShowCollation) Children() []sql.Node { return nil } + +func (ShowCollation) String() string { return "SHOW COLLATION" } + +// Resolved implements the sql.Node interface. +func (ShowCollation) Resolved() bool { return true } + +// RowIter implements the sql.Node interface. +func (ShowCollation) RowIter(ctx *sql.Context) (sql.RowIter, error) { + return sql.RowsToRowIter(sql.Row{ + defaultCollation, + defaultCharacterSet, + int64(1), + "Yes", + "Yes", + int64(1), + }), nil +} + +// Schema implements the sql.Node interface. +func (ShowCollation) Schema() sql.Schema { return collationSchema } + +// WithChildren implements the Node interface. +func (s ShowCollation) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 0) + } + + return s, nil +} diff --git a/sql/plan/show_create_database.go b/sql/plan/show_create_database.go new file mode 100644 index 000000000..96de06175 --- /dev/null +++ b/sql/plan/show_create_database.go @@ -0,0 +1,92 @@ +package plan + +import ( + "bytes" + "fmt" + + "github.com/src-d/go-mysql-server/sql" +) + +// ShowCreateDatabase returns the SQL for creating a database. +type ShowCreateDatabase struct { + db sql.Database + IfNotExists bool +} + +const defaultCharacterSet = "utf8mb4" + +var showCreateDatabaseSchema = sql.Schema{ + {Name: "Database", Type: sql.Text}, + {Name: "Create Database", Type: sql.Text}, +} + +// NewShowCreateDatabase creates a new ShowCreateDatabase node. +func NewShowCreateDatabase(db sql.Database, ifNotExists bool) *ShowCreateDatabase { + return &ShowCreateDatabase{db, ifNotExists} +} + +var _ sql.Databaser = (*ShowCreateDatabase)(nil) + +// Database implements the sql.Databaser interface. +func (s *ShowCreateDatabase) Database() sql.Database { + return s.db +} + +// WithDatabase implements the sql.Databaser interface. +func (s *ShowCreateDatabase) WithDatabase(db sql.Database) (sql.Node, error) { + nc := *s + nc.db = db + return &nc, nil +} + +// RowIter implements the sql.Node interface. +func (s *ShowCreateDatabase) RowIter(ctx *sql.Context) (sql.RowIter, error) { + var name = s.db.Name() + + var buf bytes.Buffer + + buf.WriteString("CREATE DATABASE ") + if s.IfNotExists { + buf.WriteString("/*!32312 IF NOT EXISTS*/ ") + } + + buf.WriteRune('`') + buf.WriteString(name) + buf.WriteRune('`') + buf.WriteString(fmt.Sprintf( + " /*!40100 DEFAULT CHARACTER SET %s COLLATE %s */", + defaultCharacterSet, + defaultCollation, + )) + + return sql.RowsToRowIter( + sql.NewRow(name, buf.String()), + ), nil +} + +// Schema implements the sql.Node interface. +func (s *ShowCreateDatabase) Schema() sql.Schema { + return showCreateDatabaseSchema +} + +func (s *ShowCreateDatabase) String() string { + return fmt.Sprintf("SHOW CREATE DATABASE %s", s.db.Name()) +} + +// Children implements the sql.Node interface. +func (s *ShowCreateDatabase) Children() []sql.Node { return nil } + +// Resolved implements the sql.Node interface. +func (s *ShowCreateDatabase) Resolved() bool { + _, ok := s.db.(sql.UnresolvedDatabase) + return !ok +} + +// WithChildren implements the Node interface. +func (s *ShowCreateDatabase) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 0) + } + + return s, nil +} diff --git a/sql/plan/show_create_database_test.go b/sql/plan/show_create_database_test.go new file mode 100644 index 000000000..c58fa065c --- /dev/null +++ b/sql/plan/show_create_database_test.go @@ -0,0 +1,34 @@ +package plan + +import ( + "testing" + + "github.com/src-d/go-mysql-server/sql" + "github.com/stretchr/testify/require" +) + +func TestShowCreateDatabase(t *testing.T) { + require := require.New(t) + + node := NewShowCreateDatabase(sql.UnresolvedDatabase("foo"), true) + iter, err := node.RowIter(sql.NewEmptyContext()) + require.NoError(err) + + rows, err := sql.RowIterToRows(iter) + require.NoError(err) + + require.Equal([]sql.Row{ + {"foo", "CREATE DATABASE /*!32312 IF NOT EXISTS*/ `foo` /*!40100 DEFAULT CHARACTER SET utf8mb4 COLLATE utf8_bin */"}, + }, rows) + + node = NewShowCreateDatabase(sql.UnresolvedDatabase("foo"), false) + iter, err = node.RowIter(sql.NewEmptyContext()) + require.NoError(err) + + rows, err = sql.RowIterToRows(iter) + require.NoError(err) + + require.Equal([]sql.Row{ + {"foo", "CREATE DATABASE `foo` /*!40100 DEFAULT CHARACTER SET utf8mb4 COLLATE utf8_bin */"}, + }, rows) +} diff --git a/sql/plan/show_create_table.go b/sql/plan/show_create_table.go new file mode 100644 index 000000000..500fb5930 --- /dev/null +++ b/sql/plan/show_create_table.go @@ -0,0 +1,139 @@ +package plan + +import ( + "fmt" + "io" + "strings" + + "github.com/src-d/go-mysql-server/internal/similartext" + + "github.com/src-d/go-mysql-server/sql" +) + +// ShowCreateTable is a node that shows the CREATE TABLE statement for a table. +type ShowCreateTable struct { + Catalog *sql.Catalog + CurrentDatabase string + Table string +} + +// Schema implements the Node interface. +func (n *ShowCreateTable) Schema() sql.Schema { + return sql.Schema{ + &sql.Column{Name: "Table", Type: sql.Text, Nullable: false}, + &sql.Column{Name: "Create Table", Type: sql.Text, Nullable: false}, + } +} + +// WithChildren implements the Node interface. +func (n *ShowCreateTable) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(n, len(children), 0) + } + + return n, nil +} + +// RowIter implements the Node interface +func (n *ShowCreateTable) RowIter(*sql.Context) (sql.RowIter, error) { + db, err := n.Catalog.Database(n.CurrentDatabase) + if err != nil { + return nil, err + } + + return &showCreateTablesIter{ + db: db, + table: n.Table, + }, nil +} + +// String implements the Stringer interface. +func (n *ShowCreateTable) String() string { + return fmt.Sprintf("SHOW CREATE TABLE %s", n.Table) +} + +type showCreateTablesIter struct { + db sql.Database + table string + didIteration bool +} + +func (i *showCreateTablesIter) Next() (sql.Row, error) { + if i.didIteration { + return nil, io.EOF + } + + i.didIteration = true + + tables := i.db.Tables() + if len(tables) == 0 { + return nil, sql.ErrTableNotFound.New(i.table) + } + + table, found := tables[i.table] + + if !found { + similar := similartext.FindFromMap(tables, i.table) + return nil, sql.ErrTableNotFound.New(i.table + similar) + } + + composedCreateTableStatement := produceCreateStatement(table) + + return sql.NewRow( + i.table, // "Table" string + composedCreateTableStatement, // "Create Table" string + ), nil +} + +func produceCreateStatement(table sql.Table) string { + schema := table.Schema() + colStmts := make([]string, len(schema)) + + // Statement creation parts for each column + for i, col := range schema { + stmt := fmt.Sprintf(" `%s` %s", col.Name, strings.ToLower(sql.MySQLTypeName(col.Type))) + + if !col.Nullable { + stmt = fmt.Sprintf("%s NOT NULL", stmt) + } + + switch def := col.Default.(type) { + case string: + if def != "" { + stmt = fmt.Sprintf("%s DEFAULT %q", stmt, def) + } + default: + if def != nil { + stmt = fmt.Sprintf("%s DEFAULT %v", stmt, col.Default) + } + } + + colStmts[i] = stmt + } + + return fmt.Sprintf( + "CREATE TABLE `%s` (\n%s\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4", + table.Name(), + strings.Join(colStmts, ",\n"), + ) +} + +func (i *showCreateTablesIter) Close() error { + return nil +} + +// NewShowCreateTable creates a new ShowCreateTable node. +func NewShowCreateTable(db string, ctl *sql.Catalog, table string) sql.Node { + return &ShowCreateTable{ + CurrentDatabase: db, + Table: table, + Catalog: ctl} +} + +// Resolved implements the Resolvable interface. +func (n *ShowCreateTable) Resolved() bool { + return true +} + +// Children implements the Node interface. +func (n *ShowCreateTable) Children() []sql.Node { return nil } diff --git a/sql/plan/show_create_table_test.go b/sql/plan/show_create_table_test.go new file mode 100644 index 000000000..0da6418bf --- /dev/null +++ b/sql/plan/show_create_table_test.go @@ -0,0 +1,51 @@ +package plan + +import ( + "testing" + + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/stretchr/testify/require" +) + +func TestShowCreateTable(t *testing.T) { + var require = require.New(t) + + db := memory.NewDatabase("testdb") + + table := memory.NewTable( + "test-table", + sql.Schema{ + &sql.Column{Name: "baz", Type: sql.Text, Default: "", Nullable: false}, + &sql.Column{Name: "zab", Type: sql.Int32, Default: int32(0), Nullable: true}, + &sql.Column{Name: "bza", Type: sql.Uint64, Default: uint64(0), Nullable: true}, + &sql.Column{Name: "foo", Type: sql.VarChar(123), Default: "", Nullable: true}, + &sql.Column{Name: "pok", Type: sql.Char(123), Default: "", Nullable: true}, + }) + + db.AddTable(table.Name(), table) + + cat := sql.NewCatalog() + cat.AddDatabase(db) + + showCreateTable := NewShowCreateTable(db.Name(), cat, table.Name()) + + ctx := sql.NewEmptyContext() + rowIter, _ := showCreateTable.RowIter(ctx) + + row, err := rowIter.Next() + + require.Nil(err) + + expected := sql.NewRow( + table.Name(), + "CREATE TABLE `test-table` (\n `baz` text NOT NULL,\n"+ + " `zab` integer DEFAULT 0,\n"+ + " `bza` bigint unsigned DEFAULT 0,\n"+ + " `foo` varchar(123),\n"+ + " `pok` char(123)\n"+ + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4", + ) + + require.Equal(expected, row) +} diff --git a/sql/plan/show_indexes.go b/sql/plan/show_indexes.go new file mode 100644 index 000000000..d427509ea --- /dev/null +++ b/sql/plan/show_indexes.go @@ -0,0 +1,204 @@ +package plan + +import ( + "fmt" + "io" + + "github.com/src-d/go-mysql-server/sql" +) + +// ShowIndexes is a node that shows the indexes on a table. +type ShowIndexes struct { + db sql.Database + Table string + Registry *sql.IndexRegistry +} + +// NewShowIndexes creates a new ShowIndexes node. +func NewShowIndexes(db sql.Database, table string, registry *sql.IndexRegistry) sql.Node { + return &ShowIndexes{db, table, registry} +} + +var _ sql.Databaser = (*ShowIndexes)(nil) + +// Database implements the sql.Databaser interface. +func (n *ShowIndexes) Database() sql.Database { + return n.db +} + +// WithDatabase implements the sql.Databaser interface. +func (n *ShowIndexes) WithDatabase(db sql.Database) (sql.Node, error) { + nc := *n + nc.db = db + return &nc, nil +} + +// Resolved implements the Resolvable interface. +func (n *ShowIndexes) Resolved() bool { + _, ok := n.db.(sql.UnresolvedDatabase) + return !ok +} + +// WithChildren implements the Node interface. +func (n *ShowIndexes) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(n, len(children), 0) + } + + return n, nil +} + +// String implements the Stringer interface. +func (n *ShowIndexes) String() string { + return fmt.Sprintf("ShowIndexes(%s)", n.Table) +} + +// Schema implements the Node interface. +func (n *ShowIndexes) Schema() sql.Schema { + return sql.Schema{ + &sql.Column{Name: "Table", Type: sql.Text}, + &sql.Column{Name: "Non_unique", Type: sql.Int32}, + &sql.Column{Name: "Key_name", Type: sql.Text}, + &sql.Column{Name: "Seq_in_index", Type: sql.Int32}, + &sql.Column{Name: "Column_name", Type: sql.Text, Nullable: true}, + &sql.Column{Name: "Collation", Type: sql.Text, Nullable: true}, + &sql.Column{Name: "Cardinality", Type: sql.Int64}, + &sql.Column{Name: "Sub_part", Type: sql.Int64, Nullable: true}, + &sql.Column{Name: "Packed", Type: sql.Text, Nullable: true}, + &sql.Column{Name: "Null", Type: sql.Text}, + &sql.Column{Name: "Index_type", Type: sql.Text}, + &sql.Column{Name: "Comment", Type: sql.Text}, + &sql.Column{Name: "Index_comment", Type: sql.Text}, + &sql.Column{Name: "Visible", Type: sql.Text}, + &sql.Column{Name: "Expression", Type: sql.Text, Nullable: true}, + } +} + +// Children implements the Node interface. +func (n *ShowIndexes) Children() []sql.Node { return nil } + +// RowIter implements the Node interface. +func (n *ShowIndexes) RowIter(*sql.Context) (sql.RowIter, error) { + return &showIndexesIter{ + db: n.db, + table: n.Table, + registry: n.Registry, + }, nil +} + +type showIndexesIter struct { + db sql.Database + table string + registry *sql.IndexRegistry + + idxs *indexesToShow +} + +func (i *showIndexesIter) Next() (sql.Row, error) { + if i.registry == nil { + return nil, io.EOF + } + + if i.idxs == nil { + i.idxs = &indexesToShow{ + indexes: i.registry.IndexesByTable(i.db.Name(), i.table), + } + } + + show, err := i.idxs.next() + if err != nil { + return nil, err + } + + var ( + nullable string + visible string + ) + columnName, expression := "NULL", show.expression + if ok, null := isColumn(show.expression, i.db.Tables()[i.table]); ok { + columnName, expression = expression, columnName + if null { + nullable = "YES" + } + } + if i.registry.CanUseIndex(show.index) { + visible = "YES" + } else { + visible = "NO" + } + return sql.NewRow( + i.table, // "Table" string + int32(1), // "Non_unique" int32, Values [0, 1] + show.index.ID(), // "Key_name" string + show.exPosition+1, // "Seq_in_index" int32 + columnName, // "Column_name" string + "NULL", // "Collation" string, Values [A, D, NULL] + int64(0), // "Cardinality" int64 (returning 0, it is not being calculated for the moment) + "NULL", // "Sub_part" int64 + "NULL", // "Packed" string + nullable, // "Null" string, Values [YES, ''] + show.index.Driver(), // "Index_type" string + "", // "Comment" string + "", // "Index_comment" string + visible, // "Visible" string, Values [YES, NO] + expression, // "Expression" string + ), nil +} + +func isColumn(ex string, table sql.Table) (bool, bool) { + for _, col := range table.Schema() { + if col.Source+"."+col.Name == ex { + return true, col.Nullable + } + } + + return false, false +} + +func (i *showIndexesIter) Close() error { + for _, idx := range i.idxs.indexes { + i.registry.ReleaseIndex(idx) + } + + return nil +} + +type indexesToShow struct { + indexes []sql.Index + pos int + epos int +} + +type idxToShow struct { + index sql.Index + expression string + exPosition int +} + +func (i *indexesToShow) next() (*idxToShow, error) { + if len(i.indexes) == 0 { + return nil, io.EOF + } + + index := i.indexes[i.pos] + expressions := index.Expressions() + if i.epos >= len(expressions) { + i.pos++ + if i.pos >= len(i.indexes) { + return nil, io.EOF + } + + index = i.indexes[i.pos] + i.epos = 0 + expressions = index.Expressions() + } + + show := &idxToShow{ + index: index, + expression: expressions[i.epos], + exPosition: i.epos, + } + + i.epos++ + return show, nil +} diff --git a/sql/plan/show_indexes_test.go b/sql/plan/show_indexes_test.go new file mode 100644 index 000000000..8ca0649e2 --- /dev/null +++ b/sql/plan/show_indexes_test.go @@ -0,0 +1,141 @@ +package plan + +import ( + "testing" + + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "github.com/stretchr/testify/require" +) + +func TestShowIndexes(t *testing.T) { + var require = require.New(t) + + unresolved := NewShowIndexes(sql.UnresolvedDatabase(""), "table-test", nil) + require.False(unresolved.Resolved()) + require.Nil(unresolved.Children()) + + db := memory.NewDatabase("test") + + tests := []struct { + name string + table sql.Table + isExpression bool + }{ + { + name: "test1", + table: memory.NewTable( + "test1", + sql.Schema{ + &sql.Column{Name: "foo", Type: sql.Int32, Source: "test1", Default: int32(0), Nullable: false}, + }, + ), + }, + { + name: "test2", + table: memory.NewTable( + "test2", + sql.Schema{ + &sql.Column{Name: "bar", Type: sql.Int64, Source: "test2", Default: int64(0), Nullable: true}, + &sql.Column{Name: "rab", Type: sql.Int64, Source: "test2", Default: int32(0), Nullable: false}, + }, + ), + }, + { + name: "test3", + table: memory.NewTable( + "test3", + sql.Schema{ + &sql.Column{Name: "baz", Type: sql.Text, Source: "test3", Default: "", Nullable: false}, + &sql.Column{Name: "zab", Type: sql.Int32, Source: "test3", Default: int32(0), Nullable: true}, + &sql.Column{Name: "bza", Type: sql.Int64, Source: "test3", Default: int64(0), Nullable: true}, + }, + ), + }, + { + name: "test4", + table: memory.NewTable( + "test4", + sql.Schema{ + &sql.Column{Name: "oof", Type: sql.Text, Source: "test4", Default: "", Nullable: false}, + }, + ), + isExpression: true, + }, + } + + r := sql.NewIndexRegistry() + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + db.AddTable(test.name, test.table) + + expressions := make([]sql.Expression, len(test.table.Schema())) + for i, col := range test.table.Schema() { + var ex sql.Expression = expression.NewGetFieldWithTable( + i, col.Type, test.name, col.Name, col.Nullable, + ) + + if test.isExpression { + ex = expression.NewEquals(ex, expression.NewLiteral("a", sql.Text)) + } + + expressions[i] = ex + } + + idx := &mockIndex{ + db: "test", + table: test.name, + id: test.name + "_idx", + exprs: expressions, + } + + created, ready, err := r.AddIndex(idx) + require.NoError(err) + close(created) + <-ready + + showIdxs := NewShowIndexes(db, test.name, r) + + ctx := sql.NewEmptyContext() + rowIter, err := showIdxs.RowIter(ctx) + require.NoError(err) + + rows, err := sql.RowIterToRows(rowIter) + require.NoError(err) + require.Len(rows, len(expressions)) + + for i, row := range rows { + var nullable string + columnName, ex := "NULL", expressions[i].String() + if ok, null := isColumn(ex, test.table); ok { + columnName, ex = ex, columnName + if null { + nullable = "YES" + } + } + + expected := sql.NewRow( + test.name, + int32(1), + idx.ID(), + i+1, + columnName, + "NULL", + int64(0), + "NULL", + "NULL", + nullable, + idx.Driver(), + "", + "", + "YES", + ex, + ) + + require.Equal(expected, row) + } + + }) + } +} diff --git a/sql/plan/show_tables.go b/sql/plan/show_tables.go index 0abbe38eb..930cb20b4 100644 --- a/sql/plan/show_tables.go +++ b/sql/plan/show_tables.go @@ -1,27 +1,51 @@ package plan import ( - "io" "sort" - "gopkg.in/src-d/go-mysql-server.v0/sql" + "github.com/src-d/go-mysql-server/sql" ) // ShowTables is a node that shows the database tables. type ShowTables struct { - Database sql.Database + db sql.Database + Full bool +} + +var showTablesSchema = sql.Schema{ + {Name: "Table", Type: sql.Text}, +} + +var showTablesFullSchema = sql.Schema{ + {Name: "Table", Type: sql.Text}, + {Name: "Table_type", Type: sql.Text}, } // NewShowTables creates a new show tables node given a database. -func NewShowTables(database sql.Database) *ShowTables { +func NewShowTables(database sql.Database, full bool) *ShowTables { return &ShowTables{ - Database: database, + db: database, + Full: full, } } +var _ sql.Databaser = (*ShowTables)(nil) + +// Database implements the sql.Databaser interface. +func (p *ShowTables) Database() sql.Database { + return p.db +} + +// WithDatabase implements the sql.Databaser interface. +func (p *ShowTables) WithDatabase(db sql.Database) (sql.Node, error) { + nc := *p + nc.db = db + return &nc, nil +} + // Resolved implements the Resolvable interface. func (p *ShowTables) Resolved() bool { - _, ok := p.Database.(*sql.UnresolvedDatabase) + _, ok := p.db.(sql.UnresolvedDatabase) return !ok } @@ -31,56 +55,44 @@ func (*ShowTables) Children() []sql.Node { } // Schema implements the Node interface. -func (*ShowTables) Schema() sql.Schema { - return sql.Schema{{ - Name: "table", - Type: sql.Text, - Nullable: false, - }} +func (p *ShowTables) Schema() sql.Schema { + if p.Full { + return showTablesFullSchema + } + + return showTablesSchema } // RowIter implements the Node interface. func (p *ShowTables) RowIter(ctx *sql.Context) (sql.RowIter, error) { tableNames := []string{} - for key := range p.Database.Tables() { + for key := range p.db.Tables() { tableNames = append(tableNames, key) } sort.Strings(tableNames) - return &showTablesIter{tableNames: tableNames}, nil -} + var rows = make([]sql.Row, len(tableNames)) + for i, n := range tableNames { + row := sql.Row{n} + if p.Full { + row = append(row, "BASE TABLE") + } + rows[i] = row + } -// TransformUp implements the Transformable interface. -func (p *ShowTables) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(NewShowTables(p.Database)) + return sql.RowsToRowIter(rows...), nil } -// TransformExpressionsUp implements the Transformable interface. -func (p *ShowTables) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { +// WithChildren implements the Node interface. +func (p *ShowTables) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 0) + } + return p, nil } func (p ShowTables) String() string { return "ShowTables" } - -type showTablesIter struct { - tableNames []string - idx int -} - -func (i *showTablesIter) Next() (sql.Row, error) { - if i.idx >= len(i.tableNames) { - return nil, io.EOF - } - row := sql.NewRow(i.tableNames[i.idx]) - i.idx++ - - return row, nil -} - -func (i *showTablesIter) Close() error { - i.tableNames = nil - return nil -} diff --git a/sql/plan/show_tables_test.go b/sql/plan/show_tables_test.go index 57cd3da0a..153c91c50 100644 --- a/sql/plan/show_tables_test.go +++ b/sql/plan/show_tables_test.go @@ -4,42 +4,42 @@ import ( "io" "testing" + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/mem" - "gopkg.in/src-d/go-mysql-server.v0/sql" ) func TestShowTables(t *testing.T) { require := require.New(t) ctx := sql.NewEmptyContext() - unresolvedShowTables := NewShowTables(&sql.UnresolvedDatabase{}) + unresolvedShowTables := NewShowTables(sql.UnresolvedDatabase(""), false) require.False(unresolvedShowTables.Resolved()) require.Nil(unresolvedShowTables.Children()) - db := mem.NewDatabase("test") - db.AddTable("test1", mem.NewTable("test1", nil)) - db.AddTable("test2", mem.NewTable("test2", nil)) - db.AddTable("test3", mem.NewTable("test3", nil)) + db := memory.NewDatabase("test") + db.AddTable("test1", memory.NewTable("test1", nil)) + db.AddTable("test2", memory.NewTable("test2", nil)) + db.AddTable("test3", memory.NewTable("test3", nil)) - resolvedShowTables := NewShowTables(db) + resolvedShowTables := NewShowTables(db, false) require.True(resolvedShowTables.Resolved()) require.Nil(resolvedShowTables.Children()) iter, err := resolvedShowTables.RowIter(ctx) - require.Nil(err) + require.NoError(err) res, err := iter.Next() - require.Nil(err) + require.NoError(err) require.Equal("test1", res[0]) res, err = iter.Next() - require.Nil(err) + require.NoError(err) require.Equal("test2", res[0]) res, err = iter.Next() - require.Nil(err) + require.NoError(err) require.Equal("test3", res[0]) _, err = iter.Next() diff --git a/sql/plan/showcolumns.go b/sql/plan/showcolumns.go new file mode 100644 index 000000000..d2094fe3f --- /dev/null +++ b/sql/plan/showcolumns.go @@ -0,0 +1,125 @@ +package plan + +import ( + "fmt" + + "github.com/src-d/go-mysql-server/sql" +) + +// ShowColumns shows the columns details of a table. +type ShowColumns struct { + UnaryNode + Full bool +} + +const defaultCollation = "utf8_bin" + +var ( + showColumnsSchema = sql.Schema{ + {Name: "Field", Type: sql.Text}, + {Name: "Type", Type: sql.Text}, + {Name: "Null", Type: sql.Text}, + {Name: "Key", Type: sql.Text}, + {Name: "Default", Type: sql.Text, Nullable: true}, + {Name: "Extra", Type: sql.Text}, + } + + showColumnsFullSchema = sql.Schema{ + {Name: "Field", Type: sql.Text}, + {Name: "Type", Type: sql.Text}, + {Name: "Collation", Type: sql.Text, Nullable: true}, + {Name: "Null", Type: sql.Text}, + {Name: "Key", Type: sql.Text}, + {Name: "Default", Type: sql.Text, Nullable: true}, + {Name: "Extra", Type: sql.Text}, + {Name: "Privileges", Type: sql.Text}, + {Name: "Comment", Type: sql.Text}, + } +) + +// NewShowColumns creates a new ShowColumns node. +func NewShowColumns(full bool, child sql.Node) *ShowColumns { + return &ShowColumns{UnaryNode{Child: child}, full} +} + +var _ sql.Node = (*ShowColumns)(nil) + +// Schema implements the sql.Node interface. +func (s *ShowColumns) Schema() sql.Schema { + if s.Full { + return showColumnsFullSchema + } + return showColumnsSchema +} + +// RowIter creates a new ShowColumns node. +func (s *ShowColumns) RowIter(ctx *sql.Context) (sql.RowIter, error) { + span, _ := ctx.Span("plan.ShowColumns") + + schema := s.Child.Schema() + var rows = make([]sql.Row, len(schema)) + for i, col := range schema { + var row sql.Row + var collation interface{} + if col.Type == sql.Text { + collation = defaultCollation + } + + var null = "NO" + if col.Nullable { + null = "YES" + } + + var defaultVal string + if col.Default != nil { + defaultVal = fmt.Sprint(col.Default) + } + + if s.Full { + row = sql.Row{ + col.Name, + col.Type.String(), + collation, + null, + "", // Key + defaultVal, + "", // Extra + "", // Privileges + "", // Comment + } + } else { + row = sql.Row{ + col.Name, + col.Type.String(), + null, + "", // Key + defaultVal, + "", // Extra + } + } + + rows[i] = row + } + + return sql.NewSpanIter(span, sql.RowsToRowIter(rows...)), nil +} + +// WithChildren implements the Node interface. +func (s *ShowColumns) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 1) + } + + return NewShowColumns(s.Full, children[0]), nil +} + +func (s *ShowColumns) String() string { + tp := sql.NewTreePrinter() + if s.Full { + _ = tp.WriteNode("ShowColumns(full)") + } else { + _ = tp.WriteNode("ShowColumns") + } + _ = tp.WriteChildren(s.Child.String()) + return tp.String() +} diff --git a/sql/plan/showcolumns_test.go b/sql/plan/showcolumns_test.go new file mode 100644 index 000000000..696627247 --- /dev/null +++ b/sql/plan/showcolumns_test.go @@ -0,0 +1,56 @@ +package plan + +import ( + "testing" + + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/stretchr/testify/require" +) + +func TestShowColumns(t *testing.T) { + require := require.New(t) + + table := NewResolvedTable(memory.NewTable("foo", sql.Schema{ + {Name: "a", Type: sql.Text}, + {Name: "b", Type: sql.Int64, Nullable: true}, + {Name: "c", Type: sql.Int64, Default: int64(1)}, + })) + + iter, err := NewShowColumns(false, table).RowIter(sql.NewEmptyContext()) + require.NoError(err) + + rows, err := sql.RowIterToRows(iter) + require.NoError(err) + + expected := []sql.Row{ + sql.Row{"a", "TEXT", "NO", "", "", ""}, + sql.Row{"b", "INT64", "YES", "", "", ""}, + sql.Row{"c", "INT64", "NO", "", "1", ""}, + } + + require.Equal(expected, rows) +} +func TestShowColumnsFull(t *testing.T) { + require := require.New(t) + + table := NewResolvedTable(memory.NewTable("foo", sql.Schema{ + {Name: "a", Type: sql.Text}, + {Name: "b", Type: sql.Int64, Nullable: true}, + {Name: "c", Type: sql.Int64, Default: int64(1)}, + })) + + iter, err := NewShowColumns(true, table).RowIter(sql.NewEmptyContext()) + require.NoError(err) + + rows, err := sql.RowIterToRows(iter) + require.NoError(err) + + expected := []sql.Row{ + sql.Row{"a", "TEXT", "utf8_bin", "NO", "", "", "", "", ""}, + sql.Row{"b", "INT64", nil, "YES", "", "", "", "", ""}, + sql.Row{"c", "INT64", nil, "NO", "", "1", "", "", ""}, + } + + require.Equal(expected, rows) +} diff --git a/sql/plan/showdatabases.go b/sql/plan/showdatabases.go new file mode 100644 index 000000000..41873c7be --- /dev/null +++ b/sql/plan/showdatabases.go @@ -0,0 +1,67 @@ +package plan + +import ( + "sort" + "strings" + + "github.com/src-d/go-mysql-server/sql" +) + +// ShowDatabases is a node that shows the databases. +type ShowDatabases struct { + Catalog *sql.Catalog +} + +// NewShowDatabases creates a new show databases node. +func NewShowDatabases() *ShowDatabases { + return new(ShowDatabases) +} + +// Resolved implements the Resolvable interface. +func (p *ShowDatabases) Resolved() bool { + return true +} + +// Children implements the Node interface. +func (*ShowDatabases) Children() []sql.Node { + return nil +} + +// Schema implements the Node interface. +func (*ShowDatabases) Schema() sql.Schema { + return sql.Schema{{ + Name: "Database", + Type: sql.Text, + Nullable: false, + }} +} + +// RowIter implements the Node interface. +func (p *ShowDatabases) RowIter(ctx *sql.Context) (sql.RowIter, error) { + dbs := p.Catalog.AllDatabases() + var rows = make([]sql.Row, 0, len(dbs)) + for _, db := range dbs { + if sql.InformationSchemaDatabaseName != db.Name() { + rows = append(rows, sql.Row{db.Name()}) + } + } + + sort.Slice(rows, func(i, j int) bool { + return strings.Compare(rows[i][0].(string), rows[j][0].(string)) < 0 + }) + + return sql.RowsToRowIter(rows...), nil +} + +// WithChildren implements the Node interface. +func (p *ShowDatabases) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 0) + } + + return p, nil +} + +func (p ShowDatabases) String() string { + return "ShowDatabases" +} diff --git a/sql/plan/showtablestatus.go b/sql/plan/showtablestatus.go new file mode 100644 index 000000000..158a1d65c --- /dev/null +++ b/sql/plan/showtablestatus.go @@ -0,0 +1,130 @@ +package plan + +import ( + "fmt" + "sort" + "strings" + + "github.com/src-d/go-mysql-server/sql" +) + +// ShowTableStatus returns the status of the tables in the databases. +type ShowTableStatus struct { + Databases []string + Catalog *sql.Catalog +} + +// NewShowTableStatus creates a new ShowTableStatus node. +func NewShowTableStatus(dbs ...string) *ShowTableStatus { + return &ShowTableStatus{Databases: dbs} +} + +var showTableStatusSchema = sql.Schema{ + {Name: "Name", Type: sql.Text}, + {Name: "Engine", Type: sql.Text}, + {Name: "Version", Type: sql.Text}, + {Name: "Row_format", Type: sql.Text}, + {Name: "Rows", Type: sql.Int64}, + {Name: "Avg_row_length", Type: sql.Int64}, + {Name: "Data_length", Type: sql.Int64}, + {Name: "Max_data_length", Type: sql.Int64}, + {Name: "Index_length", Type: sql.Int64}, + {Name: "Data_free", Type: sql.Int64}, + {Name: "Auto_increment", Type: sql.Int64}, + {Name: "Create_time", Type: sql.Timestamp, Nullable: true}, + {Name: "Update_time", Type: sql.Timestamp, Nullable: true}, + {Name: "Check_time", Type: sql.Timestamp, Nullable: true}, + {Name: "Collation", Type: sql.Text}, + {Name: "Checksum", Type: sql.Text, Nullable: true}, + {Name: "Create_options", Type: sql.Text, Nullable: true}, + {Name: "Comments", Type: sql.Text, Nullable: true}, +} + +// Children implements the sql.Node interface. +func (s *ShowTableStatus) Children() []sql.Node { return nil } + +// Resolved implements the sql.Node interface. +func (s *ShowTableStatus) Resolved() bool { return true } + +// Schema implements the sql.Node interface. +func (s *ShowTableStatus) Schema() sql.Schema { return showTableStatusSchema } + +// RowIter implements the sql.Node interface. +func (s *ShowTableStatus) RowIter(ctx *sql.Context) (sql.RowIter, error) { + var rows []sql.Row + var tables []string + if len(s.Databases) > 0 { + for _, db := range s.Catalog.AllDatabases() { + if !stringContains(s.Databases, db.Name()) { + continue + } + + for t := range db.Tables() { + tables = append(tables, t) + } + } + } else { + db, err := s.Catalog.Database(s.Catalog.CurrentDatabase()) + if err != nil { + return nil, err + } + + for t := range db.Tables() { + tables = append(tables, t) + } + } + + sort.Strings(tables) + for _, t := range tables { + rows = append(rows, tableToStatusRow(t)) + } + + return sql.RowsToRowIter(rows...), nil +} + +func (s *ShowTableStatus) String() string { + return fmt.Sprintf("ShowTableStatus(%s)", strings.Join(s.Databases, ", ")) +} + +// WithChildren implements the Node interface. +func (s *ShowTableStatus) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 0) + } + + return s, nil +} + +func stringContains(slice []string, str string) bool { + for _, s := range slice { + if s == str { + return true + } + } + return false +} + +func tableToStatusRow(table string) sql.Row { + return sql.NewRow( + table, // Name + "InnoDB", // Engine + // This column is unused. With the removal of .frm files in MySQL 8.0, this + // column now reports a hardcoded value of 10, which is the last .frm file + // version used in MySQL 5.7. + "10", // Version + "Fixed", // Row_format + int64(0), // Rows + int64(0), // Avg_row_length + int64(0), // Data_length + int64(0), // Max_data_length + int64(0), // Index_length + int64(0), // Data_free + int64(0), // Auto_increment + nil, // Create_time + nil, // Update_time + nil, // Check_time + "utf8_bin", // Collation + nil, // Create_options + nil, // Comments + ) +} diff --git a/sql/plan/showtablestatus_test.go b/sql/plan/showtablestatus_test.go new file mode 100644 index 000000000..7852409a9 --- /dev/null +++ b/sql/plan/showtablestatus_test.go @@ -0,0 +1,57 @@ +package plan + +import ( + "testing" + + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/stretchr/testify/require" +) + +func TestShowTableStatus(t *testing.T) { + require := require.New(t) + + catalog := sql.NewCatalog() + + db1 := memory.NewDatabase("a") + db1.AddTable("t1", memory.NewTable("t1", nil)) + db1.AddTable("t2", memory.NewTable("t2", nil)) + catalog.AddDatabase(db1) + + db2 := memory.NewDatabase("b") + db2.AddTable("t3", memory.NewTable("t3", nil)) + db2.AddTable("t4", memory.NewTable("t4", nil)) + catalog.AddDatabase(db2) + + node := NewShowTableStatus() + node.Catalog = catalog + + iter, err := node.RowIter(sql.NewEmptyContext()) + require.NoError(err) + + rows, err := sql.RowIterToRows(iter) + require.NoError(err) + + expected := []sql.Row{ + {"t1", "InnoDB", "10", "Fixed", int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), nil, nil, nil, "utf8_bin", nil, nil}, + {"t2", "InnoDB", "10", "Fixed", int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), nil, nil, nil, "utf8_bin", nil, nil}, + } + + require.Equal(expected, rows) + + node = NewShowTableStatus("a") + node.Catalog = catalog + + iter, err = node.RowIter(sql.NewEmptyContext()) + require.NoError(err) + + rows, err = sql.RowIterToRows(iter) + require.NoError(err) + + expected = []sql.Row{ + {"t1", "InnoDB", "10", "Fixed", int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), nil, nil, nil, "utf8_bin", nil, nil}, + {"t2", "InnoDB", "10", "Fixed", int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), nil, nil, nil, "utf8_bin", nil, nil}, + } + + require.Equal(expected, rows) +} diff --git a/sql/plan/showvariables.go b/sql/plan/showvariables.go new file mode 100644 index 000000000..8cb7fa0c1 --- /dev/null +++ b/sql/plan/showvariables.go @@ -0,0 +1,89 @@ +package plan + +import ( + "fmt" + + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" +) + +// ShowVariables is a node that shows the global and session variables +type ShowVariables struct { + config map[string]sql.TypedValue + pattern string +} + +// NewShowVariables returns a new ShowVariables reference. +// config is a variables lookup table +// like is a "like pattern". If like is an empty string it will return all variables. +func NewShowVariables(config map[string]sql.TypedValue, like string) *ShowVariables { + return &ShowVariables{ + config: config, + pattern: like, + } +} + +// Resolved implements sql.Node interface. The function always returns true. +func (sv *ShowVariables) Resolved() bool { + return true +} + +// WithChildren implements the Node interface. +func (sv *ShowVariables) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(sv, len(children), 0) + } + + return sv, nil +} + +// String implements the Stringer interface. +func (sv *ShowVariables) String() string { + var like string + if sv.pattern != "" { + like = fmt.Sprintf(" LIKE '%s'", sv.pattern) + } + return fmt.Sprintf("SHOW VARIABLES%s", like) +} + +// Schema returns a new Schema reference for "SHOW VARIABLES" query. +func (*ShowVariables) Schema() sql.Schema { + return sql.Schema{ + &sql.Column{Name: "Variable_name", Type: sql.Text, Nullable: false}, + &sql.Column{Name: "Value", Type: sql.Text, Nullable: true}, + } +} + +// Children implements sql.Node interface. The function always returns nil. +func (*ShowVariables) Children() []sql.Node { return nil } + +// RowIter implements the sql.Node interface. +// The function returns an iterator for filtered variables (based on like pattern) +func (sv *ShowVariables) RowIter(ctx *sql.Context) (sql.RowIter, error) { + var ( + rows []sql.Row + like sql.Expression + ) + if sv.pattern != "" { + like = expression.NewLike( + expression.NewGetField(0, sql.Text, "", false), + expression.NewGetField(1, sql.Text, sv.pattern, false), + ) + } + + for k, v := range sv.config { + if like != nil { + b, err := like.Eval(ctx, sql.NewRow(k, sv.pattern)) + if err != nil { + return nil, err + } + if !b.(bool) { + continue + } + } + + rows = append(rows, sql.NewRow(k, v.Value)) + } + + return sql.RowsToRowIter(rows...), nil +} diff --git a/sql/plan/showvariables_test.go b/sql/plan/showvariables_test.go new file mode 100644 index 000000000..1abbc7b96 --- /dev/null +++ b/sql/plan/showvariables_test.go @@ -0,0 +1,68 @@ +package plan + +import ( + "io" + "testing" + + "github.com/src-d/go-mysql-server/sql" + "github.com/stretchr/testify/require" +) + +func TestShowVariables(t *testing.T) { + require := require.New(t) + + ctx := sql.NewEmptyContext() + config := ctx.Session.GetAll() + sv := NewShowVariables(config, "") + require.True(sv.Resolved()) + + it, err := sv.RowIter(ctx) + require.NoError(err) + + for row, err := it.Next(); err == nil; row, err = it.Next() { + key := row[0].(string) + val := row[1] + + t.Logf("key: %s\tval: %v\n", key, val) + + require.Equal(config[key].Value, val) + delete(config, key) + } + if err != io.EOF { + require.NoError(err) + } + require.NoError(it.Close()) + require.Equal(0, len(config)) +} + +func TestShowVariablesWithLike(t *testing.T) { + require := require.New(t) + + vars := map[string]sql.TypedValue{ + "int1": {Typ: sql.Int32, Value: 1}, + "int2": {Typ: sql.Int32, Value: 2}, + "txt": {Typ: sql.Text, Value: "abcdefghijklmnoprstuwxyz"}, + } + + sv := NewShowVariables(vars, "int%") + require.True(sv.Resolved()) + + it, err := sv.RowIter(sql.NewEmptyContext()) + require.NoError(err) + + for row, err := it.Next(); err == nil; row, err = it.Next() { + key := row[0].(string) + val := row[1] + require.Equal(vars[key].Value, val) + require.Equal(sql.Int32, vars[key].Typ) + delete(vars, key) + } + if err != io.EOF { + require.NoError(err) + } + require.NoError(it.Close()) + require.Equal(1, len(vars)) + + _, ok := vars["txt"] + require.True(ok) +} diff --git a/sql/plan/showwarnings.go b/sql/plan/showwarnings.go new file mode 100644 index 000000000..c990bfc81 --- /dev/null +++ b/sql/plan/showwarnings.go @@ -0,0 +1,50 @@ +package plan + +import ( + "github.com/src-d/go-mysql-server/sql" +) + +// ShowWarnings is a node that shows the session warnings +type ShowWarnings []*sql.Warning + +// Resolved implements sql.Node interface. The function always returns true. +func (ShowWarnings) Resolved() bool { + return true +} + +// WithChildren implements the Node interface. +func (sw ShowWarnings) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(sw, len(children), 0) + } + + return sw, nil +} + +// String implements the Stringer interface. +func (ShowWarnings) String() string { + return "SHOW WARNINGS" +} + +// Schema returns a new Schema reference for "SHOW VARIABLES" query. +func (ShowWarnings) Schema() sql.Schema { + return sql.Schema{ + &sql.Column{Name: "Level", Type: sql.Text, Nullable: false}, + &sql.Column{Name: "Code", Type: sql.Int32, Nullable: true}, + &sql.Column{Name: "Message", Type: sql.Text, Nullable: false}, + } +} + +// Children implements sql.Node interface. The function always returns nil. +func (ShowWarnings) Children() []sql.Node { return nil } + +// RowIter implements the sql.Node interface. +// The function returns an iterator for warnings (considering offset and counter) +func (sw ShowWarnings) RowIter(ctx *sql.Context) (sql.RowIter, error) { + var rows []sql.Row + for _, w := range sw { + rows = append(rows, sql.NewRow(w.Level, w.Code, w.Message)) + } + + return sql.RowsToRowIter(rows...), nil +} diff --git a/sql/plan/showwarnings_test.go b/sql/plan/showwarnings_test.go new file mode 100644 index 000000000..480cc7726 --- /dev/null +++ b/sql/plan/showwarnings_test.go @@ -0,0 +1,40 @@ +package plan + +import ( + "io" + "testing" + + "github.com/src-d/go-mysql-server/sql" + "github.com/stretchr/testify/require" +) + +func TestShowWarnings(t *testing.T) { + require := require.New(t) + + ctx := sql.NewEmptyContext() + ctx.Session.Warn(&sql.Warning{Level: "l1", Message: "w1", Code: 1}) + ctx.Session.Warn(&sql.Warning{Level: "l2", Message: "w2", Code: 2}) + ctx.Session.Warn(&sql.Warning{Level: "l4", Message: "w3", Code: 3}) + + sw := ShowWarnings(ctx.Session.Warnings()) + require.True(sw.Resolved()) + + it, err := sw.RowIter(ctx) + require.NoError(err) + + n := 3 + for row, err := it.Next(); err == nil; row, err = it.Next() { + level := row[0].(string) + code := row[1].(int) + message := row[2].(string) + + t.Logf("level: %s\tcode: %v\tmessage: %s\n", level, code, message) + + require.Equal(n, code) + n-- + } + if err != io.EOF { + require.NoError(err) + } + require.NoError(it.Close()) +} diff --git a/sql/plan/sort.go b/sql/plan/sort.go index 79b074802..fed1f7da3 100644 --- a/sql/plan/sort.go +++ b/sql/plan/sort.go @@ -6,8 +6,8 @@ import ( "sort" "strings" + "github.com/src-d/go-mysql-server/sql" "gopkg.in/src-d/go-errors.v1" - "gopkg.in/src-d/go-mysql-server.v0/sql" ) // ErrUnableSort is thrown when something happens on sorting @@ -68,18 +68,16 @@ func NewSort(sortFields []SortField, child sql.Node) *Sort { } } +var _ sql.Expressioner = (*Sort)(nil) + // Resolved implements the Resolvable interface. func (s *Sort) Resolved() bool { - return s.UnaryNode.Child.Resolved() && s.expressionsResolved() -} - -func (s *Sort) expressionsResolved() bool { for _, f := range s.SortFields { if !f.Column.Resolved() { return false } } - return true + return s.Child.Resolved() } // RowIter implements the Node interface. @@ -90,38 +88,10 @@ func (s *Sort) RowIter(ctx *sql.Context) (sql.RowIter, error) { span.Finish() return nil, err } - return sql.NewSpanIter(span, newSortIter(s, i)), nil -} - -// TransformUp implements the Transformable interface. -func (s *Sort) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := s.Child.TransformUp(f) - if err != nil { - return nil, err - } - return f(NewSort(s.SortFields, child)) -} - -// TransformExpressionsUp implements the Transformable interface. -func (s *Sort) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - var sfs = make([]SortField, len(s.SortFields)) - for i, sf := range s.SortFields { - col, err := sf.Column.TransformUp(f) - if err != nil { - return nil, err - } - sfs[i] = SortField{col, sf.Order, sf.NullOrdering} - } - - child, err := s.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - - return NewSort(sfs, child), nil + return sql.NewSpanIter(span, newSortIter(ctx, s, i)), nil } -func (s Sort) String() string { +func (s *Sort) String() string { pr := sql.NewTreePrinter() var fields = make([]string, len(s.SortFields)) for i, f := range s.SortFields { @@ -133,7 +103,7 @@ func (s Sort) String() string { } // Expressions implements the Expressioner interface. -func (s Sort) Expressions() []sql.Expression { +func (s *Sort) Expressions() []sql.Expression { var exprs = make([]sql.Expression, len(s.SortFields)) for i, f := range s.SortFields { exprs[i] = f.Column @@ -141,19 +111,47 @@ func (s Sort) Expressions() []sql.Expression { return exprs } +// WithChildren implements the Node interface. +func (s *Sort) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 1) + } + + return NewSort(s.SortFields, children[0]), nil +} + +// WithExpressions implements the Expressioner interface. +func (s *Sort) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + if len(exprs) != len(s.SortFields) { + return nil, sql.ErrInvalidChildrenNumber.New(s, len(exprs), len(s.SortFields)) + } + + var fields = make([]SortField, len(s.SortFields)) + for i, expr := range exprs { + fields[i] = SortField{ + Column: expr, + NullOrdering: s.SortFields[i].NullOrdering, + Order: s.SortFields[i].Order, + } + } + + return NewSort(fields, s.Child), nil +} + type sortIter struct { + ctx *sql.Context s *Sort childIter sql.RowIter sortedRows []sql.Row idx int } -func newSortIter(s *Sort, child sql.RowIter) *sortIter { +func newSortIter(ctx *sql.Context, s *Sort, child sql.RowIter) *sortIter { return &sortIter{ - s: s, - childIter: child, - sortedRows: nil, - idx: -1, + ctx: ctx, + s: s, + childIter: child, + idx: -1, } } @@ -165,6 +163,7 @@ func (i *sortIter) Next() (sql.Row, error) { } i.idx = 0 } + if i.idx >= len(i.sortedRows) { return nil, io.EOF } @@ -179,24 +178,30 @@ func (i *sortIter) Close() error { } func (i *sortIter) computeSortedRows() error { - var rows []sql.Row + cache, dispose := i.ctx.Memory.NewRowsCache() + defer dispose() + for { - childRow, err := i.childIter.Next() + row, err := i.childIter.Next() if err == io.EOF { break } if err != nil { return err } - rows = append(rows, childRow) + + if err := cache.Add(row); err != nil { + return err + } } + rows := cache.Get() sorter := &sorter{ sortFields: i.s.SortFields, rows: rows, lastError: nil, } - sort.Sort(sorter) + sort.Stable(sorter) if sorter.lastError != nil { return sorter.lastError } diff --git a/sql/plan/sort_test.go b/sql/plan/sort_test.go index 1a05f8408..d5cb51ad8 100644 --- a/sql/plan/sort_test.go +++ b/sql/plan/sort_test.go @@ -3,9 +3,9 @@ package plan import ( "testing" - "gopkg.in/src-d/go-mysql-server.v0/mem" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" "github.com/stretchr/testify/require" ) @@ -27,16 +27,16 @@ func TestSort(t *testing.T) { {Name: "col2", Type: sql.Int32, Nullable: true}, } - child := mem.NewTable("test", schema) + child := memory.NewTable("test", schema) for _, row := range data { - require.NoError(child.Insert(row)) + require.NoError(child.Insert(sql.NewEmptyContext(), row)) } sf := []SortField{ {Column: expression.NewGetField(1, sql.Int32, "col2", true), Order: Ascending, NullOrdering: NullsFirst}, {Column: expression.NewGetField(0, sql.Text, "col1", true), Order: Descending, NullOrdering: NullsLast}, } - s := NewSort(sf, child) + s := NewSort(sf, NewResolvedTable(child)) require.Equal(schema, s.Schema()) expected := []sql.Row{ @@ -68,15 +68,15 @@ func TestSortAscending(t *testing.T) { {Name: "col1", Type: sql.Text, Nullable: true}, } - child := mem.NewTable("test", schema) + child := memory.NewTable("test", schema) for _, row := range data { - require.NoError(child.Insert(row)) + require.NoError(child.Insert(sql.NewEmptyContext(), row)) } sf := []SortField{ {Column: expression.NewGetField(0, sql.Text, "col1", true), Order: Ascending, NullOrdering: NullsFirst}, } - s := NewSort(sf, child) + s := NewSort(sf, NewResolvedTable(child)) require.Equal(schema, s.Schema()) expected := []sql.Row{ @@ -108,15 +108,15 @@ func TestSortDescending(t *testing.T) { {Name: "col1", Type: sql.Text, Nullable: true}, } - child := mem.NewTable("test", schema) + child := memory.NewTable("test", schema) for _, row := range data { - require.NoError(child.Insert(row)) + require.NoError(child.Insert(sql.NewEmptyContext(), row)) } sf := []SortField{ {Column: expression.NewGetField(0, sql.Text, "col1", true), Order: Descending, NullOrdering: NullsFirst}, } - s := NewSort(sf, child) + s := NewSort(sf, NewResolvedTable(child)) require.Equal(schema, s.Schema()) expected := []sql.Row{ diff --git a/sql/plan/subqueryalias.go b/sql/plan/subqueryalias.go index 47c42c19c..da1264c88 100644 --- a/sql/plan/subqueryalias.go +++ b/sql/plan/subqueryalias.go @@ -1,7 +1,7 @@ package plan import ( - "gopkg.in/src-d/go-mysql-server.v0/sql" + "github.com/src-d/go-mysql-server/sql" ) // SubqueryAlias is a node that gives a subquery a name. @@ -45,16 +45,22 @@ func (n *SubqueryAlias) RowIter(ctx *sql.Context) (sql.RowIter, error) { return sql.NewSpanIter(span, iter), nil } -// TransformUp implements the Node interface. -func (n *SubqueryAlias) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(n) -} +// WithChildren implements the Node interface. +func (n *SubqueryAlias) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(n, len(children), 1) + } -// TransformExpressionsUp implements the Node interface. -func (n *SubqueryAlias) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { + nn := *n + nn.Child = children[0] return n, nil } +// Opaque implements the OpaqueNode interface. +func (n *SubqueryAlias) Opaque() bool { + return true +} + func (n SubqueryAlias) String() string { pr := sql.NewTreePrinter() _ = pr.WriteNode("SubqueryAlias(%s)", n.name) diff --git a/sql/plan/subqueryalias_test.go b/sql/plan/subqueryalias_test.go index 80032f840..90c538e7f 100644 --- a/sql/plan/subqueryalias_test.go +++ b/sql/plan/subqueryalias_test.go @@ -3,10 +3,10 @@ package plan import ( "testing" + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/mem" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" ) func TestSubqueryAliasSchema(t *testing.T) { @@ -22,14 +22,14 @@ func TestSubqueryAliasSchema(t *testing.T) { {Name: "baz", Type: sql.Text, Nullable: false, Source: "alias"}, } - table := mem.NewTable("bar", tableSchema) + table := memory.NewTable("bar", tableSchema) subquery := NewProject( []sql.Expression{ expression.NewGetField(0, sql.Text, "foo", false), expression.NewGetField(1, sql.Text, "baz", false), }, - table, + NewResolvedTable(table), ) require.Equal( diff --git a/sql/plan/tablealias.go b/sql/plan/tablealias.go index aa68d2d2a..be37fb109 100644 --- a/sql/plan/tablealias.go +++ b/sql/plan/tablealias.go @@ -4,7 +4,7 @@ import ( "reflect" opentracing "github.com/opentracing/opentracing-go" - "gopkg.in/src-d/go-mysql-server.v0/sql" + "github.com/src-d/go-mysql-server/sql" ) // TableAlias is a node that acts as a table with a given name. @@ -23,22 +23,13 @@ func (t *TableAlias) Name() string { return t.name } -// TransformUp implements the Transformable interface. -func (t *TableAlias) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - child, err := t.Child.TransformUp(f) - if err != nil { - return nil, err +// WithChildren implements the Node interface. +func (t *TableAlias) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), 1) } - return f(NewTableAlias(t.name, child)) -} -// TransformExpressionsUp implements the Transformable interface. -func (t *TableAlias) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - child, err := t.Child.TransformExpressionsUp(f) - if err != nil { - return nil, err - } - return NewTableAlias(t.name, child), nil + return NewTableAlias(t.name, children[0]), nil } // RowIter implements the Node interface. diff --git a/sql/plan/tablealias_test.go b/sql/plan/tablealias_test.go index 33d565b3e..aee4d83c0 100644 --- a/sql/plan/tablealias_test.go +++ b/sql/plan/tablealias_test.go @@ -4,20 +4,20 @@ import ( "io" "testing" + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/mem" - "gopkg.in/src-d/go-mysql-server.v0/sql" ) func TestTableAlias(t *testing.T) { require := require.New(t) ctx := sql.NewEmptyContext() - table := mem.NewTable("bar", sql.Schema{ + table := memory.NewTable("bar", sql.Schema{ {Name: "a", Type: sql.Text, Nullable: true}, {Name: "b", Type: sql.Text, Nullable: true}, }) - alias := NewTableAlias("foo", table) + alias := NewTableAlias("foo", NewResolvedTable(table)) var rows = []sql.Row{ sql.NewRow("1", "2"), @@ -26,7 +26,7 @@ func TestTableAlias(t *testing.T) { } for _, r := range rows { - require.NoError(table.Insert(r)) + require.NoError(table.Insert(sql.NewEmptyContext(), r)) } require.Equal(table.Schema(), alias.Schema()) diff --git a/sql/plan/transaction.go b/sql/plan/transaction.go new file mode 100644 index 000000000..3d09366ef --- /dev/null +++ b/sql/plan/transaction.go @@ -0,0 +1,34 @@ +package plan + +import "github.com/src-d/go-mysql-server/sql" + +// Rollback undoes the changes performed in a transaction. +type Rollback struct{} + +// NewRollback creates a new Rollback node. +func NewRollback() *Rollback { return new(Rollback) } + +// RowIter implements the sql.Node interface. +func (*Rollback) RowIter(*sql.Context) (sql.RowIter, error) { + return sql.RowsToRowIter(), nil +} + +func (*Rollback) String() string { return "ROLLBACK" } + +// WithChildren implements the Node interface. +func (r *Rollback) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), 0) + } + + return r, nil +} + +// Resolved implements the sql.Node interface. +func (*Rollback) Resolved() bool { return true } + +// Children implements the sql.Node interface. +func (*Rollback) Children() []sql.Node { return nil } + +// Schema implements the sql.Node interface. +func (*Rollback) Schema() sql.Schema { return nil } diff --git a/sql/plan/transform.go b/sql/plan/transform.go new file mode 100644 index 000000000..e437327ed --- /dev/null +++ b/sql/plan/transform.go @@ -0,0 +1,89 @@ +package plan + +import ( + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" +) + +// TransformUp applies a transformation function to the given tree from the +// bottom up. +func TransformUp(node sql.Node, f sql.TransformNodeFunc) (sql.Node, error) { + if o, ok := node.(sql.OpaqueNode); ok && o.Opaque() { + return f(node) + } + + children := node.Children() + if len(children) == 0 { + return f(node) + } + + newChildren := make([]sql.Node, len(children)) + for i, c := range children { + c, err := TransformUp(c, f) + if err != nil { + return nil, err + } + newChildren[i] = c + } + + node, err := node.WithChildren(newChildren...) + if err != nil { + return nil, err + } + + return f(node) +} + +// TransformExpressionsUp applies a transformation function to all expressions +// on the given tree from the bottom up. +func TransformExpressionsUp(node sql.Node, f sql.TransformExprFunc) (sql.Node, error) { + if o, ok := node.(sql.OpaqueNode); ok && o.Opaque() { + return TransformExpressions(node, f) + } + + children := node.Children() + if len(children) == 0 { + return TransformExpressions(node, f) + } + + newChildren := make([]sql.Node, len(children)) + for i, c := range children { + c, err := TransformExpressionsUp(c, f) + if err != nil { + return nil, err + } + newChildren[i] = c + } + + node, err := node.WithChildren(newChildren...) + if err != nil { + return nil, err + } + + return TransformExpressions(node, f) +} + +// TransformExpressions applies a transformation function to all expressions +// on the given node. +func TransformExpressions(node sql.Node, f sql.TransformExprFunc) (sql.Node, error) { + e, ok := node.(sql.Expressioner) + if !ok { + return node, nil + } + + exprs := e.Expressions() + if len(exprs) == 0 { + return node, nil + } + + newExprs := make([]sql.Expression, len(exprs)) + for i, e := range exprs { + e, err := expression.TransformUp(e, f) + if err != nil { + return nil, err + } + newExprs[i] = e + } + + return e.WithExpressions(newExprs...) +} diff --git a/sql/plan/transform_test.go b/sql/plan/transform_test.go index 521cb3f07..38d1bb0fb 100644 --- a/sql/plan/transform_test.go +++ b/sql/plan/transform_test.go @@ -3,9 +3,9 @@ package plan import ( "testing" - "gopkg.in/src-d/go-mysql-server.v0/mem" - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" "github.com/stretchr/testify/require" ) @@ -15,25 +15,30 @@ func TestTransformUp(t *testing.T) { aCol := expression.NewUnresolvedColumn("a") bCol := expression.NewUnresolvedColumn("a") - ur := &UnresolvedTable{"unresolved"} + ur := NewUnresolvedTable("unresolved", "") p := NewProject([]sql.Expression{aCol, bCol}, NewFilter(expression.NewEquals(aCol, bCol), ur)) schema := sql.Schema{ {Name: "a", Type: sql.Text}, {Name: "b", Type: sql.Text}, } - table := mem.NewTable("resolved", schema) + table := memory.NewTable("resolved", schema) - pt, err := p.TransformUp(func(n sql.Node) (sql.Node, error) { + pt, err := TransformUp(p, func(n sql.Node) (sql.Node, error) { switch n.(type) { case *UnresolvedTable: - return table, nil + return NewResolvedTable(table), nil default: return n, nil } }) require.NoError(err) - ep := NewProject([]sql.Expression{aCol, bCol}, NewFilter(expression.NewEquals(aCol, bCol), table)) + ep := NewProject( + []sql.Expression{aCol, bCol}, + NewFilter(expression.NewEquals(aCol, bCol), + NewResolvedTable(table), + ), + ) require.Equal(ep, pt) } diff --git a/sql/plan/unresolved.go b/sql/plan/unresolved.go index 62dced1ab..9bf70a639 100644 --- a/sql/plan/unresolved.go +++ b/sql/plan/unresolved.go @@ -3,8 +3,8 @@ package plan import ( "fmt" + "github.com/src-d/go-mysql-server/sql" errors "gopkg.in/src-d/go-errors.v1" - "gopkg.in/src-d/go-mysql-server.v0/sql" ) // ErrUnresolvedTable is thrown when a table cannot be resolved @@ -12,13 +12,18 @@ var ErrUnresolvedTable = errors.NewKind("unresolved table") // UnresolvedTable is a table that has not been resolved yet but whose name is known. type UnresolvedTable struct { - // Name of the table. - Name string + name string + Database string } // NewUnresolvedTable creates a new Unresolved table. -func NewUnresolvedTable(name string) *UnresolvedTable { - return &UnresolvedTable{name} +func NewUnresolvedTable(name, db string) *UnresolvedTable { + return &UnresolvedTable{name, db} +} + +// Name implements the Nameable interface. +func (t *UnresolvedTable) Name() string { + return t.name } // Resolved implements the Resolvable interface. @@ -27,30 +32,25 @@ func (*UnresolvedTable) Resolved() bool { } // Children implements the Node interface. -func (*UnresolvedTable) Children() []sql.Node { - return []sql.Node{} -} +func (*UnresolvedTable) Children() []sql.Node { return nil } // Schema implements the Node interface. -func (*UnresolvedTable) Schema() sql.Schema { - return sql.Schema{} -} +func (*UnresolvedTable) Schema() sql.Schema { return nil } // RowIter implements the RowIter interface. func (*UnresolvedTable) RowIter(ctx *sql.Context) (sql.RowIter, error) { return nil, ErrUnresolvedTable.New() } -// TransformUp implements the Transformable interface. -func (t *UnresolvedTable) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(NewUnresolvedTable(t.Name)) -} +// WithChildren implements the Node interface. +func (t *UnresolvedTable) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), 0) + } -// TransformExpressionsUp implements the Transformable interface. -func (t *UnresolvedTable) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { return t, nil } func (t UnresolvedTable) String() string { - return fmt.Sprintf("UnresolvedTable(%s)", t.Name) + return fmt.Sprintf("UnresolvedTable(%s)", t.name) } diff --git a/sql/plan/unresolved_test.go b/sql/plan/unresolved_test.go index b7426fc50..3865fe40e 100644 --- a/sql/plan/unresolved_test.go +++ b/sql/plan/unresolved_test.go @@ -3,12 +3,12 @@ package plan import ( "testing" + "github.com/src-d/go-mysql-server/sql" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/sql" ) func TestUnresolvedTable(t *testing.T) { require := require.New(t) - var n sql.Node = NewUnresolvedTable("test_table") + var n sql.Node = NewUnresolvedTable("test_table", "") require.NotNil(n) } diff --git a/sql/plan/update.go b/sql/plan/update.go new file mode 100644 index 000000000..a29fb7392 --- /dev/null +++ b/sql/plan/update.go @@ -0,0 +1,189 @@ +package plan + +import ( + "github.com/src-d/go-mysql-server/sql" + "gopkg.in/src-d/go-errors.v1" + "io" +) + +var ErrUpdateNotSupported = errors.NewKind("table doesn't support UPDATE") +var ErrUpdateUnexpectedSetResult = errors.NewKind("attempted to set field but expression returned %T") + +// Update is a node for updating rows on tables. +type Update struct { + sql.Node + UpdateExprs []sql.Expression +} + +// NewUpdate creates an Update node. +func NewUpdate(n sql.Node, updateExprs []sql.Expression) *Update { + return &Update{n, updateExprs} +} + +// Expressions implements the Expressioner interface. +func (p *Update) Expressions() []sql.Expression { + return p.UpdateExprs +} + +// Schema implements the Node interface. +func (p *Update) Schema() sql.Schema { + return sql.Schema{ + { + Name: "matched", + Type: sql.Int64, + Default: int64(0), + Nullable: false, + }, + { + Name: "updated", + Type: sql.Int64, + Default: int64(0), + Nullable: false, + }, + } +} + +// Resolved implements the Resolvable interface. +func (p *Update) Resolved() bool { + if !p.Node.Resolved() { + return false + } + for _, updateExpr := range p.UpdateExprs { + if !updateExpr.Resolved() { + return false + } + } + return true +} + +func (p *Update) Children() []sql.Node { + return []sql.Node{p.Node} +} + +func getUpdatable(node sql.Node) (sql.Updater, error) { + switch node := node.(type) { + case sql.Updater: + return node, nil + case *ResolvedTable: + return getUpdatableTable(node.Table) + } + for _, child := range node.Children() { + updater, _ := getUpdatable(child) + if updater != nil { + return updater, nil + } + } + return nil, ErrUpdateNotSupported.New() +} + +func getUpdatableTable(t sql.Table) (sql.Updater, error) { + switch t := t.(type) { + case sql.Updater: + return t, nil + case sql.TableWrapper: + return getUpdatableTable(t.Underlying()) + default: + return nil, ErrUpdateNotSupported.New() + } +} + +// Execute inserts the rows in the database. +func (p *Update) Execute(ctx *sql.Context) (int, int, error) { + updatable, err := getUpdatable(p.Node) + if err != nil { + return 0, 0, err + } + schema := p.Node.Schema() + + iter, err := p.Node.RowIter(ctx) + if err != nil { + return 0, 0, err + } + + rowsMatched := 0 + rowsUpdated := 0 + for { + oldRow, err := iter.Next() + if err == io.EOF { + break + } + if err != nil { + _ = iter.Close() + return rowsMatched, rowsUpdated, err + } + rowsMatched++ + + newRow, err := p.applyUpdates(ctx, oldRow) + if err != nil { + _ = iter.Close() + return rowsMatched, rowsUpdated, err + } + if equals, err := oldRow.Equals(newRow, schema); err == nil { + if !equals { + err = updatable.Update(ctx, oldRow, newRow) + if err != nil { + _ = iter.Close() + return rowsMatched, rowsUpdated, err + } + rowsUpdated++ + } + } else { + _ = iter.Close() + return rowsMatched, rowsUpdated, err + } + } + + return rowsMatched, rowsUpdated, nil +} + +// RowIter implements the Node interface. +func (p *Update) RowIter(ctx *sql.Context) (sql.RowIter, error) { + matched, updated, err := p.Execute(ctx) + if err != nil { + return nil, err + } + + return sql.RowsToRowIter(sql.NewRow(int64(matched), int64(updated))), nil +} + +// WithChildren implements the Node interface. +func (p *Update) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 1) + } + return NewUpdate(children[0], p.UpdateExprs), nil +} + +// WithExpressions implements the Expressioner interface. +func (p *Update) WithExpressions(newExprs ...sql.Expression) (sql.Node, error) { + if len(newExprs) != len(p.UpdateExprs) { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(p.UpdateExprs), 1) + } + return NewUpdate(p.Node, newExprs), nil +} + +func (p Update) String() string { + pr := sql.NewTreePrinter() + _ = pr.WriteNode("Update") + _ = pr.WriteChildren(p.Node.String()) + for _, updateExpr := range p.UpdateExprs { + _ = pr.WriteChildren(updateExpr.String()) + } + return pr.String() +} + +func (p *Update) applyUpdates(ctx *sql.Context, row sql.Row) (sql.Row, error) { + var ok bool + prev := row + for _, updateExpr := range p.UpdateExprs { + val, err := updateExpr.Eval(ctx, prev) + if err != nil { + return nil, err + } + prev, ok = val.(sql.Row) + if !ok { + return nil, ErrUpdateUnexpectedSetResult.New(val) + } + } + return prev, nil +} diff --git a/sql/plan/use.go b/sql/plan/use.go new file mode 100644 index 000000000..b7eb517b7 --- /dev/null +++ b/sql/plan/use.go @@ -0,0 +1,65 @@ +package plan + +import ( + "fmt" + + "github.com/src-d/go-mysql-server/sql" +) + +// Use changes the current database. +type Use struct { + db sql.Database + Catalog *sql.Catalog +} + +// NewUse creates a new Use node. +func NewUse(db sql.Database) *Use { + return &Use{db: db} +} + +var _ sql.Node = (*Use)(nil) +var _ sql.Databaser = (*Use)(nil) + +// Database implements the sql.Databaser interface. +func (u *Use) Database() sql.Database { + return u.db +} + +// WithDatabase implements the sql.Databaser interface. +func (u *Use) WithDatabase(db sql.Database) (sql.Node, error) { + nc := *u + nc.db = db + return &nc, nil +} + +// Children implements the sql.Node interface. +func (Use) Children() []sql.Node { return nil } + +// Resolved implements the sql.Node interface. +func (u *Use) Resolved() bool { + _, ok := u.db.(sql.UnresolvedDatabase) + return !ok +} + +// Schema implements the sql.Node interface. +func (Use) Schema() sql.Schema { return nil } + +// RowIter implements the sql.Node interface. +func (u *Use) RowIter(ctx *sql.Context) (sql.RowIter, error) { + u.Catalog.SetCurrentDatabase(u.db.Name()) + return sql.RowsToRowIter(), nil +} + +// WithChildren implements the Node interface. +func (u *Use) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(u, len(children), 1) + } + + return u, nil +} + +// String implements the sql.Node interface. +func (u *Use) String() string { + return fmt.Sprintf("USE(%s)", u.db.Name()) +} diff --git a/sql/plan/values.go b/sql/plan/values.go index 0537d9cfd..ea1d5be39 100644 --- a/sql/plan/values.go +++ b/sql/plan/values.go @@ -3,7 +3,7 @@ package plan import ( "fmt" - "gopkg.in/src-d/go-mysql-server.v0/sql" + "github.com/src-d/go-mysql-server/sql" ) // Values represents a set of tuples of expressions. @@ -76,34 +76,47 @@ func (p *Values) RowIter(ctx *sql.Context) (sql.RowIter, error) { return sql.RowsToRowIter(rows...), nil } -// TransformUp implements the Transformable interface. -func (p *Values) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(p) -} - -// TransformExpressionsUp implements the Transformable interface. -func (p *Values) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { - ets := make([][]sql.Expression, len(p.ExpressionTuples)) - var err error - for i, et := range p.ExpressionTuples { - ets[i], err = transformExpressionsUp(f, et) - if err != nil { - return nil, err - } - } - - return NewValues(ets), nil -} - -func (p Values) String() string { +func (p *Values) String() string { return fmt.Sprintf("Values(%d tuples)", len(p.ExpressionTuples)) } // Expressions implements the Expressioner interface. -func (p Values) Expressions() []sql.Expression { +func (p *Values) Expressions() []sql.Expression { var exprs []sql.Expression for _, tuple := range p.ExpressionTuples { exprs = append(exprs, tuple...) } return exprs } + +// WithChildren implements the Node interface. +func (p *Values) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 0) + } + + return p, nil +} + +// WithExpressions implements the Expressioner interface. +func (p *Values) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + var expected int + for _, t := range p.ExpressionTuples { + expected += len(t) + } + + if len(exprs) != expected { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(exprs), expected) + } + + var offset int + var tuples = make([][]sql.Expression, len(p.ExpressionTuples)) + for i, t := range p.ExpressionTuples { + for range t { + tuples[i] = append(tuples[i], exprs[offset]) + offset++ + } + } + + return NewValues(tuples), nil +} diff --git a/sql/plan/walk.go b/sql/plan/walk.go index 3a940a0b4..6e43524e0 100644 --- a/sql/plan/walk.go +++ b/sql/plan/walk.go @@ -1,8 +1,8 @@ package plan import ( - "gopkg.in/src-d/go-mysql-server.v0/sql" - "gopkg.in/src-d/go-mysql-server.v0/sql/expression" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" ) // Visitor visits nodes in the plan. @@ -52,9 +52,9 @@ func Inspect(node sql.Node, f func(sql.Node) bool) { // expression it finds. func WalkExpressions(v expression.Visitor, node sql.Node) { Inspect(node, func(node sql.Node) bool { - if node, ok := node.(sql.Expressioner); ok { - for _, e := range node.Expressions() { - expression.Walk(v, e) + if n, ok := node.(sql.Expressioner); ok { + for _, err := range n.Expressions() { + expression.Walk(v, err) } } return true diff --git a/sql/plan/walk_test.go b/sql/plan/walk_test.go index 303351703..a07230d96 100644 --- a/sql/plan/walk_test.go +++ b/sql/plan/walk_test.go @@ -3,13 +3,13 @@ package plan import ( "testing" + "github.com/src-d/go-mysql-server/sql" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/sql" ) func TestWalk(t *testing.T) { - t1 := NewUnresolvedTable("foo") - t2 := NewUnresolvedTable("bar") + t1 := NewUnresolvedTable("foo", "") + t2 := NewUnresolvedTable("bar", "") join := NewCrossJoin(t1, t2) filter := NewFilter(nil, join) project := NewProject(nil, filter) @@ -52,8 +52,8 @@ func (f visitor) Visit(n sql.Node) Visitor { } func TestInspect(t *testing.T) { - t1 := NewUnresolvedTable("foo") - t2 := NewUnresolvedTable("bar") + t1 := NewUnresolvedTable("foo", "") + t2 := NewUnresolvedTable("bar", "") join := NewCrossJoin(t1, t2) filter := NewFilter(nil, join) project := NewProject(nil, filter) diff --git a/sql/processlist.go b/sql/processlist.go new file mode 100644 index 000000000..5bc1fecf6 --- /dev/null +++ b/sql/processlist.go @@ -0,0 +1,329 @@ +package sql + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/sirupsen/logrus" + "gopkg.in/src-d/go-errors.v1" +) + +// Progress between done items and total items. +type Progress struct { + Name string + Done int64 + Total int64 +} + +func (p Progress) totalString() string { + var total = "?" + if p.Total > 0 { + total = fmt.Sprint(p.Total) + } + return total +} + +// TableProgress keeps track of a table progress, and for each of its partitions +type TableProgress struct { + Progress + PartitionsProgress map[string]PartitionProgress +} + +func NewTableProgress(name string, total int64) TableProgress { + return TableProgress{ + Progress: Progress{ + Name: name, + Total: total, + }, + PartitionsProgress: make(map[string]PartitionProgress), + } +} + +func (p TableProgress) String() string { + return fmt.Sprintf("%s (%d/%s partitions)", p.Name, p.Done, p.totalString()) +} + +// PartitionProgress keeps track of a partition progress +type PartitionProgress struct { + Progress +} + +func (p PartitionProgress) String() string { + return fmt.Sprintf("%s (%d/%s rows)", p.Name, p.Done, p.totalString()) +} + +// ProcessType is the type of process. +type ProcessType byte + +const ( + // QueryProcess is a query process. + QueryProcess ProcessType = iota + // CreateIndexProcess is a process to create an index. + CreateIndexProcess +) + +func (p ProcessType) String() string { + switch p { + case QueryProcess: + return "query" + case CreateIndexProcess: + return "create_index" + default: + return "invalid" + } +} + +// Process represents a process in the SQL server. +type Process struct { + Pid uint64 + Connection uint32 + User string + Type ProcessType + Query string + Progress map[string]TableProgress + StartedAt time.Time + Kill context.CancelFunc +} + +// Done needs to be called when this process has finished. +func (p *Process) Done() { p.Kill() } + +// Seconds returns the number of seconds this process has been running. +func (p *Process) Seconds() uint64 { + return uint64(time.Since(p.StartedAt) / time.Second) +} + +// ProcessList is a structure that keeps track of all the processes and their +// status. +type ProcessList struct { + mu sync.RWMutex + procs map[uint64]*Process +} + +// NewProcessList creates a new process list. +func NewProcessList() *ProcessList { + return &ProcessList{ + procs: make(map[uint64]*Process), + } +} + +// ErrPidAlreadyUsed is returned when the pid is already registered. +var ErrPidAlreadyUsed = errors.NewKind("pid %d is already in use") + +// AddProcess adds a new process to the list given a process type and a query. +// Steps is a map between the name of the items that need to be completed and +// the total amount in these items. -1 means unknown. +// It returns a new context that should be passed around from now on. That +// context will be cancelled if the process is killed. +func (pl *ProcessList) AddProcess( + ctx *Context, + typ ProcessType, + query string, +) (*Context, error) { + pl.mu.Lock() + defer pl.mu.Unlock() + + if _, ok := pl.procs[ctx.Pid()]; ok { + return nil, ErrPidAlreadyUsed.New(ctx.Pid()) + } + + newCtx, cancel := context.WithCancel(ctx) + ctx = ctx.WithContext(newCtx) + + pl.procs[ctx.Pid()] = &Process{ + Pid: ctx.Pid(), + Connection: ctx.ID(), + Type: typ, + Query: query, + Progress: make(map[string]TableProgress), + User: ctx.Session.Client().User, + StartedAt: time.Now(), + Kill: cancel, + } + + return ctx, nil +} + +// UpdateTableProgress updates the progress of the table with the given name for the +// process with the given pid. +func (pl *ProcessList) UpdateTableProgress(pid uint64, name string, delta int64) { + pl.mu.Lock() + defer pl.mu.Unlock() + + p, ok := pl.procs[pid] + if !ok { + return + } + + progress, ok := p.Progress[name] + if !ok { + progress = NewTableProgress(name, -1) + } + + progress.Done += delta + p.Progress[name] = progress +} + +// UpdatePartitionProgress updates the progress of the table partition with the +// given name for the process with the given pid. +func (pl *ProcessList) UpdatePartitionProgress(pid uint64, tableName, partitionName string, delta int64) { + pl.mu.Lock() + defer pl.mu.Unlock() + + p, ok := pl.procs[pid] + if !ok { + return + } + + tablePg, ok := p.Progress[tableName] + if !ok { + return + } + + partitionPg, ok := tablePg.PartitionsProgress[partitionName] + if !ok { + partitionPg = PartitionProgress{Progress: Progress{Name: partitionName, Total: -1}} + } + + partitionPg.Done += delta + tablePg.PartitionsProgress[partitionName] = partitionPg +} + +// AddTableProgress adds a new item to track progress from to the process with +// the given pid. If the pid does not exist, it will do nothing. +func (pl *ProcessList) AddTableProgress(pid uint64, name string, total int64) { + pl.mu.Lock() + defer pl.mu.Unlock() + + p, ok := pl.procs[pid] + if !ok { + return + } + + if pg, ok := p.Progress[name]; ok { + pg.Total = total + p.Progress[name] = pg + } else { + p.Progress[name] = NewTableProgress(name, total) + } +} + +// AddPartitionProgress adds a new item to track progress from to the process with +// the given pid. If the pid or the table does not exist, it will do nothing. +func (pl *ProcessList) AddPartitionProgress(pid uint64, tableName, partitionName string, total int64) { + pl.mu.Lock() + defer pl.mu.Unlock() + + p, ok := pl.procs[pid] + if !ok { + return + } + + tablePg, ok := p.Progress[tableName] + if !ok { + return + } + + if pg, ok := tablePg.PartitionsProgress[partitionName]; ok { + pg.Total = total + tablePg.PartitionsProgress[partitionName] = pg + } else { + tablePg.PartitionsProgress[partitionName] = + PartitionProgress{Progress: Progress{Name: partitionName, Total: total}} + } +} + +// RemoveTableProgress removes an existing item tracking progress from the +// process with the given pid, if it exists. +func (pl *ProcessList) RemoveTableProgress(pid uint64, name string) { + pl.mu.Lock() + defer pl.mu.Unlock() + + p, ok := pl.procs[pid] + if !ok { + return + } + + delete(p.Progress, name) +} + +// RemovePartitionProgress removes an existing item tracking progress from the +// process with the given pid, if it exists. +func (pl *ProcessList) RemovePartitionProgress(pid uint64, tableName, partitionName string) { + pl.mu.Lock() + defer pl.mu.Unlock() + + p, ok := pl.procs[pid] + if !ok { + return + } + + tablePg, ok := p.Progress[tableName] + if !ok { + return + } + + delete(tablePg.PartitionsProgress, partitionName) +} + +// Kill terminates all queries for a given connection id. +func (pl *ProcessList) Kill(connID uint32) { + pl.mu.Lock() + defer pl.mu.Unlock() + + for pid, proc := range pl.procs { + if proc.Connection == connID { + logrus.Infof("kill query: pid %d", pid) + proc.Done() + delete(pl.procs, pid) + } + } +} + +// KillOnlyQueries kills all queries, but not index creation queries, for a +// given connection id. +func (pl *ProcessList) KillOnlyQueries(connID uint32) { + pl.mu.Lock() + defer pl.mu.Unlock() + + for pid, proc := range pl.procs { + if proc.Connection == connID && proc.Type == QueryProcess { + logrus.Infof("kill query: pid %d", pid) + proc.Done() + delete(pl.procs, pid) + } + } +} + +// Done removes the finished process with the given pid from the process list. +// If the process does not exist, it will do nothing. +func (pl *ProcessList) Done(pid uint64) { + pl.mu.Lock() + defer pl.mu.Unlock() + + if proc, ok := pl.procs[pid]; ok { + proc.Done() + } + + delete(pl.procs, pid) +} + +// Processes returns the list of current running processes. +func (pl *ProcessList) Processes() []Process { + pl.mu.RLock() + defer pl.mu.RUnlock() + var result = make([]Process, 0, len(pl.procs)) + + for _, proc := range pl.procs { + p := *proc + var progress = make(map[string]TableProgress, len(p.Progress)) + for n, p := range p.Progress { + progress[n] = p + } + result = append(result, p) + } + + return result +} diff --git a/sql/processlist_test.go b/sql/processlist_test.go new file mode 100644 index 000000000..198f40b12 --- /dev/null +++ b/sql/processlist_test.go @@ -0,0 +1,140 @@ +package sql + +import ( + "context" + "sort" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestProcessList(t *testing.T) { + require := require.New(t) + + p := NewProcessList() + sess := NewSession("0.0.0.0:3306", "127.0.0.1:34567", "foo", 1) + ctx := NewContext(context.Background(), WithPid(1), WithSession(sess)) + ctx, err := p.AddProcess(ctx, QueryProcess, "SELECT foo") + require.NoError(err) + + require.Equal(uint64(1), ctx.Pid()) + require.Len(p.procs, 1) + + p.AddTableProgress(ctx.Pid(), "a", 5) + p.AddTableProgress(ctx.Pid(), "b", 6) + + expectedProcess := &Process{ + Pid: 1, + Connection: 1, + Type: QueryProcess, + Progress: map[string]TableProgress{ + "a": {Progress{Name: "a", Done: 0, Total: 5}, map[string]PartitionProgress{}}, + "b": {Progress{Name: "b", Done: 0, Total: 6}, map[string]PartitionProgress{}}, + }, + User: "foo", + Query: "SELECT foo", + StartedAt: p.procs[ctx.Pid()].StartedAt, + } + require.NotNil(p.procs[ctx.Pid()].Kill) + p.procs[ctx.Pid()].Kill = nil + require.Equal(expectedProcess, p.procs[ctx.Pid()]) + + p.AddPartitionProgress(ctx.Pid(), "b", "b-1", -1) + p.AddPartitionProgress(ctx.Pid(), "b", "b-2", -1) + p.AddPartitionProgress(ctx.Pid(), "b", "b-3", -1) + + p.UpdatePartitionProgress(ctx.Pid(), "b", "b-2", 1) + + p.RemovePartitionProgress(ctx.Pid(), "b", "b-3") + + expectedProgress := map[string]TableProgress{ + "a": {Progress{Name: "a", Total: 5}, map[string]PartitionProgress{}}, + "b": {Progress{Name: "b", Total: 6}, map[string]PartitionProgress{ + "b-1": {Progress{Name: "b-1", Done: 0, Total: -1}}, + "b-2": {Progress{Name: "b-2", Done: 1, Total: -1}}, + }}, + } + require.Equal(expectedProgress, p.procs[ctx.Pid()].Progress) + + ctx = NewContext(context.Background(), WithPid(2), WithSession(sess)) + ctx, err = p.AddProcess(ctx, CreateIndexProcess, "SELECT bar") + require.NoError(err) + + p.AddTableProgress(ctx.Pid(), "foo", 2) + + require.Equal(uint64(2), ctx.Pid()) + require.Len(p.procs, 2) + + p.UpdateTableProgress(1, "a", 3) + p.UpdateTableProgress(1, "a", 1) + p.UpdateTableProgress(1, "b", 2) + p.UpdateTableProgress(2, "foo", 1) + + require.Equal(int64(4), p.procs[1].Progress["a"].Done) + require.Equal(int64(2), p.procs[1].Progress["b"].Done) + require.Equal(int64(1), p.procs[2].Progress["foo"].Done) + + var expected []Process + for _, p := range p.procs { + np := *p + np.Kill = nil + expected = append(expected, np) + } + + result := p.Processes() + for i := range result { + result[i].Kill = nil + } + + sortByPid(expected) + sortByPid(result) + require.Equal(expected, result) + + p.Done(2) + + require.Len(p.procs, 1) + _, ok := p.procs[1] + require.True(ok) +} + +func sortByPid(slice []Process) { + sort.Slice(slice, func(i, j int) bool { + return slice[i].Pid < slice[j].Pid + }) +} + +func TestKillConnection(t *testing.T) { + pl := NewProcessList() + + s1 := NewSession("", "", "", 1) + s2 := NewSession("", "", "", 2) + + var killed = make(map[uint64]bool) + for i := uint64(1); i <= 3; i++ { + // Odds get s1, evens get s2 + s := s1 + if i%2 == 0 { + s = s2 + } + + _, err := pl.AddProcess( + NewContext(context.Background(), WithPid(i), WithSession(s)), + QueryProcess, + "foo", + ) + require.NoError(t, err) + + i := i + pl.procs[i].Kill = func() { + killed[i] = true + } + } + + pl.Kill(1) + require.Len(t, pl.procs, 1) + + // Odds should have been killed + require.True(t, killed[1]) + require.False(t, killed[2]) + require.True(t, killed[3]) +} diff --git a/sql/session.go b/sql/session.go index 942ae985a..251467ba0 100644 --- a/sql/session.go +++ b/sql/session.go @@ -2,33 +2,220 @@ package sql import ( "context" + "fmt" "io" + "math" + "sync" "time" opentracing "github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go/log" ) +type key uint + +const ( + // QueryKey to access query in the context. + QueryKey key = iota +) + +// Client holds session user information. +type Client struct { + // User of the session. + User string + // Address of the client. + Address string +} + // Session holds the session data. type Session interface { - // TODO: add config + // Address of the server. + Address() string + // User of the session. + Client() Client + // Set session configuration. + Set(key string, typ Type, value interface{}) + // Get session configuration. + Get(key string) (Type, interface{}) + // GetAll returns a copy of session configuration + GetAll() map[string]TypedValue + // ID returns the unique ID of the connection. + ID() uint32 + // Warn stores the warning in the session. + Warn(warn *Warning) + // Warnings returns a copy of session warnings (from the most recent). + Warnings() []*Warning + // ClearWarnings cleans up session warnings. + ClearWarnings() + // WarningCount returns a number of session warnings + WarningCount() uint16 } // BaseSession is the basic session type. type BaseSession struct { - // TODO: add config + id uint32 + addr string + client Client + mu sync.RWMutex + config map[string]TypedValue + warnings []*Warning + warncnt uint16 +} + +// Address returns the server address. +func (s *BaseSession) Address() string { return s.addr } + +// Client returns session's client information. +func (s *BaseSession) Client() Client { return s.client } + +// Set implements the Session interface. +func (s *BaseSession) Set(key string, typ Type, value interface{}) { + s.mu.Lock() + defer s.mu.Unlock() + s.config[key] = TypedValue{typ, value} +} + +// Get implements the Session interface. +func (s *BaseSession) Get(key string) (Type, interface{}) { + s.mu.RLock() + defer s.mu.RUnlock() + v, ok := s.config[key] + if !ok { + return Null, nil + } + + return v.Typ, v.Value +} + +// GetAll returns a copy of session configuration +func (s *BaseSession) GetAll() map[string]TypedValue { + m := make(map[string]TypedValue) + s.mu.RLock() + defer s.mu.RUnlock() + + for k, v := range s.config { + m[k] = v + } + return m +} + +// ID implements the Session interface. +func (s *BaseSession) ID() uint32 { return s.id } + +// Warn stores the warning in the session. +func (s *BaseSession) Warn(warn *Warning) { + s.mu.Lock() + defer s.mu.Unlock() + s.warnings = append(s.warnings, warn) +} + +// Warnings returns a copy of session warnings (from the most recent - the last one) +// The function implements sql.Session interface +func (s *BaseSession) Warnings() []*Warning { + s.mu.RLock() + defer s.mu.RUnlock() + + n := len(s.warnings) + warns := make([]*Warning, n) + for i := 0; i < n; i++ { + warns[i] = s.warnings[n-i-1] + } + + return warns +} + +// ClearWarnings cleans up session warnings +func (s *BaseSession) ClearWarnings() { + s.mu.Lock() + defer s.mu.Unlock() + + cnt := uint16(len(s.warnings)) + if s.warncnt == cnt { + if s.warnings != nil { + s.warnings = s.warnings[:0] + } + s.warncnt = 0 + } else { + s.warncnt = cnt + } +} + +// WarningCount returns a number of session warnings +func (s *BaseSession) WarningCount() uint16 { + s.mu.RLock() + defer s.mu.RUnlock() + return uint16(len(s.warnings)) +} + +type ( + // TypedValue is a value along with its type. + TypedValue struct { + Typ Type + Value interface{} + } + + // Warning stands for mySQL warning record. + Warning struct { + Level string + Message string + Code int + } +) + +// DefaultSessionConfig returns default values for session variables +func DefaultSessionConfig() map[string]TypedValue { + return map[string]TypedValue{ + "auto_increment_increment": TypedValue{Int64, int64(1)}, + "time_zone": TypedValue{Text, time.Local.String()}, + "system_time_zone": TypedValue{Text, time.Local.String()}, + "max_allowed_packet": TypedValue{Int32, math.MaxInt32}, + "sql_mode": TypedValue{Text, ""}, + "gtid_mode": TypedValue{Int32, int32(0)}, + "collation_database": TypedValue{Text, "utf8_bin"}, + "ndbinfo_version": TypedValue{Text, ""}, + "sql_select_limit": TypedValue{Int32, math.MaxInt32}, + "transaction_isolation": TypedValue{Text, "READ UNCOMMITTED"}, + "version": TypedValue{Text, ""}, + "version_comment": TypedValue{Text, ""}, + } +} + +// HasDefaultValue checks if session variable value is the default one. +func HasDefaultValue(s Session, key string) (bool, interface{}) { + typ, val := s.Get(key) + if cfg, ok := DefaultSessionConfig()[key]; ok { + return (cfg.Typ == typ && cfg.Value == val), val + } + return false, val +} + +// NewSession creates a new session with data. +func NewSession(server, client, user string, id uint32) Session { + return &BaseSession{ + id: id, + addr: server, + client: Client{ + Address: client, + User: user, + }, + config: DefaultSessionConfig(), + } } -// NewBaseSession creates a new basic session. +// NewBaseSession creates a new empty session. func NewBaseSession() Session { - return &BaseSession{} + return &BaseSession{config: DefaultSessionConfig()} } // Context of the query execution. type Context struct { context.Context Session - tracer opentracing.Tracer + Memory *MemoryManager + pid uint64 + query string + tracer opentracing.Tracer + rootSpan opentracing.Span } // ContextOption is a function to configure the context. @@ -48,24 +235,63 @@ func WithTracer(t opentracing.Tracer) ContextOption { } } +// WithPid adds the given pid to the context. +func WithPid(pid uint64) ContextOption { + return func(ctx *Context) { + ctx.pid = pid + } +} + +// WithQuery adds the given query to the context. +func WithQuery(q string) ContextOption { + return func(ctx *Context) { + ctx.query = q + } +} + +// WithMemoryManager adds the given memory manager to the context. +func WithMemoryManager(m *MemoryManager) ContextOption { + return func(ctx *Context) { + ctx.Memory = m + } +} + +// WithRootSpan sets the root span of the context. +func WithRootSpan(s opentracing.Span) ContextOption { + return func(ctx *Context) { + ctx.rootSpan = s + } +} + // NewContext creates a new query context. Options can be passed to configure // the context. If some aspect of the context is not configure, the default // value will be used. -// By default, the context will have an empty base session and a noop tracer. +// By default, the context will have an empty base session, a noop tracer and +// a memory manager using the process reporter. func NewContext( ctx context.Context, opts ...ContextOption, ) *Context { - c := &Context{ctx, NewBaseSession(), opentracing.NoopTracer{}} + c := &Context{ctx, NewBaseSession(), nil, 0, "", opentracing.NoopTracer{}, nil} for _, opt := range opts { opt(c) } + + if c.Memory == nil { + c.Memory = NewMemoryManager(ProcessMemory) + } return c } // NewEmptyContext returns a default context with default values. func NewEmptyContext() *Context { return NewContext(context.TODO()) } +// Pid returns the process id associated with this context. +func (c *Context) Pid() uint64 { return c.pid } + +// Query returns the query string associated with this context. +func (c *Context) Query() string { return c.query } + // Span creates a new tracing span with the given context. // It will return the span and a new context that should be passed to all // childrens of this span. @@ -80,26 +306,71 @@ func (c *Context) Span( span := c.tracer.StartSpan(opName, opts...) ctx := opentracing.ContextWithSpan(c.Context, span) - return span, &Context{ctx, c.Session, c.tracer} + return span, &Context{ctx, c.Session, c.Memory, c.Pid(), c.Query(), c.tracer, c.rootSpan} +} + +// WithContext returns a new context with the given underlying context. +func (c *Context) WithContext(ctx context.Context) *Context { + return &Context{ctx, c.Session, c.Memory, c.Pid(), c.Query(), c.tracer, c.rootSpan} +} + +// RootSpan returns the root span, if any. +func (c *Context) RootSpan() opentracing.Span { + return c.rootSpan +} + +// Error adds an error as warning to the session. +func (c *Context) Error(code int, msg string, args ...interface{}) { + c.Session.Warn(&Warning{ + Level: "Error", + Code: code, + Message: fmt.Sprintf(msg, args...), + }) +} + +// Warn adds a warning to the session. +func (c *Context) Warn(code int, msg string, args ...interface{}) { + c.Session.Warn(&Warning{ + Level: "Warning", + Code: code, + Message: fmt.Sprintf(msg, args...), + }) } // NewSpanIter creates a RowIter executed in the given span. func NewSpanIter(span opentracing.Span, iter RowIter) RowIter { - return &spanIter{span, iter, 0, false} + return &spanIter{ + span: span, + iter: iter, + } } type spanIter struct { span opentracing.Span iter RowIter count int + max time.Duration + min time.Duration + total time.Duration done bool } -func (i *spanIter) Next() (Row, error) { - if i.done { - return nil, io.EOF +func (i *spanIter) updateTimings(start time.Time) { + elapsed := time.Since(start) + if i.max < elapsed { + i.max = elapsed } + if i.min > elapsed || i.min == 0 { + i.min = elapsed + } + + i.total += elapsed +} + +func (i *spanIter) Next() (Row, error) { + start := time.Now() + row, err := i.iter.Next() if err == io.EOF { i.finish() @@ -112,15 +383,27 @@ func (i *spanIter) Next() (Row, error) { } i.count++ + i.updateTimings(start) return row, nil } func (i *spanIter) finish() { + var avg time.Duration + if i.count > 0 { + avg = i.total / time.Duration(i.count) + } + i.span.FinishWithOptions(opentracing.FinishOptions{ LogRecords: []opentracing.LogRecord{ { Timestamp: time.Now(), - Fields: []log.Field{log.Int("rows", i.count)}, + Fields: []log.Field{ + log.Int("rows", i.count), + log.String("total_time", i.total.String()), + log.String("max_time", i.max.String()), + log.String("min_time", i.min.String()), + log.String("avg_time", avg.String()), + }, }, }, }) diff --git a/sql/session_test.go b/sql/session_test.go index 062accb9c..315015ec5 100644 --- a/sql/session_test.go +++ b/sql/session_test.go @@ -8,29 +8,65 @@ import ( "github.com/stretchr/testify/require" ) -type testNode struct{} +func TestSessionConfig(t *testing.T) { + require := require.New(t) -func (t *testNode) Resolved() bool { - panic("not implemented") + sess := NewSession("foo", "baz", "bar", 1) + typ, v := sess.Get("foo") + require.Equal(Null, typ) + require.Equal(nil, v) + + sess.Set("foo", Int64, 1) + + typ, v = sess.Get("foo") + require.Equal(Int64, typ) + require.Equal(1, v) + + require.Equal(uint16(0), sess.WarningCount()) + + sess.Warn(&Warning{Code: 1}) + sess.Warn(&Warning{Code: 2}) + sess.Warn(&Warning{Code: 3}) + + require.Equal(uint16(3), sess.WarningCount()) + + require.Equal(3, sess.Warnings()[0].Code) + require.Equal(2, sess.Warnings()[1].Code) + require.Equal(1, sess.Warnings()[2].Code) } -func (t *testNode) TransformUp(func(Node) Node) Node { - panic("not implemented") +func TestHasDefaultValue(t *testing.T) { + require := require.New(t) + sess := NewSession("foo", "baz", "bar", 1) + + for key := range DefaultSessionConfig() { + require.True(HasDefaultValue(sess, key)) + } + + sess.Set("auto_increment_increment", Int64, 123) + require.False(HasDefaultValue(sess, "auto_increment_increment")) + + require.False(HasDefaultValue(sess, "non_existing_key")) } -func (t *testNode) TransformExpressionsUp(func(Expression) Expression) Node { +type testNode struct{} + +func (*testNode) Resolved() bool { + panic("not implemented") +} +func (*testNode) WithChildren(...Node) (Node, error) { panic("not implemented") } -func (t *testNode) Schema() Schema { +func (*testNode) Schema() Schema { panic("not implemented") } -func (t *testNode) Children() []Node { +func (*testNode) Children() []Node { panic("not implemented") } -func (t *testNode) RowIter(ctx *Context) (RowIter, error) { +func (*testNode) RowIter(ctx *Context) (RowIter, error) { return newTestNodeIterator(ctx), nil } diff --git a/sql/type.go b/sql/type.go index 3168e336e..a93e7d83e 100644 --- a/sql/type.go +++ b/sql/type.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/json" "fmt" + "io" + "math" "reflect" "strconv" "strings" @@ -11,8 +13,8 @@ import ( "github.com/spf13/cast" "gopkg.in/src-d/go-errors.v1" - "gopkg.in/src-d/go-vitess.v0/sqltypes" - "gopkg.in/src-d/go-vitess.v0/vt/proto/query" + "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/proto/query" ) var ( @@ -25,10 +27,16 @@ var ( // ErrConvertingToTime is thrown when a value cannot be converted to a Time ErrConvertingToTime = errors.NewKind("value %q can't be converted to time.Time") + // ErrCharTruncation is thrown when a Char value is textually longer than the destination capacity + ErrCharTruncation = errors.NewKind("string value of %q is longer than destination capacity %d") + + // ErrVarCharTruncation is thrown when a VarChar value is textually longer than the destination capacity + ErrVarCharTruncation = errors.NewKind("string value of %q is longer than destination capacity %d") + // ErrValueNotNil is thrown when a value that was expected to be nil, is not ErrValueNotNil = errors.NewKind("value not nil: %#v") - // ErrNotTuple is retuned when the value is not a tuple. + // ErrNotTuple is returned when the value is not a tuple. ErrNotTuple = errors.NewKind("value of type %T is not a tuple") // ErrInvalidColumnNumber is returned when a tuple has an invalid number of @@ -37,6 +45,10 @@ var ( // ErrNotArray is returned when the value is not an array. ErrNotArray = errors.NewKind("value of type %T is not an array") + + // ErrConvertToSQL is returned when Convert failed. + // It makes an error less verbose comparingto what spf13/cast returns. + ErrConvertToSQL = errors.NewKind("incompatible conversion to SQL type: %s") ) // Schema is the definition of a table. @@ -64,13 +76,21 @@ func (s Schema) CheckRow(row Row) error { } // Contains returns whether the schema contains a column with the given name. -func (s Schema) Contains(column string) bool { - for _, col := range s { - if col.Name == column { - return true +func (s Schema) Contains(column string, source string) bool { + return s.IndexOf(column, source) >= 0 +} + +// IndexOf returns the index of the given column in the schema or -1 if it's +// not present. +func (s Schema) IndexOf(column, source string) int { + column = strings.ToLower(column) + source = strings.ToLower(source) + for i, col := range s { + if strings.ToLower(col.Name) == column && strings.ToLower(col.Source) == source { + return i } } - return false + return -1 } // Equals checks whether the given schema is equal to this one. @@ -104,6 +124,8 @@ type Column struct { Nullable bool // Source is the name of the table this column came from. Source string + // PrimaryKey is true if the column is part of the primary key for its table. + PrimaryKey bool } // Check ensures the value is correct for this column. @@ -135,7 +157,19 @@ type Type interface { // The result will be 0 if a==b, -1 if a < b, and +1 if a > b. Compare(interface{}, interface{}) (int, error) // SQL returns the sqltypes.Value for the given value. - SQL(interface{}) sqltypes.Value + SQL(interface{}) (sqltypes.Value, error) + fmt.Stringer +} + +var maxTime = time.Date(9999, time.December, 31, 23, 59, 59, 0, time.UTC) + +// ValidateTime receives a time and returns either that time or nil if it's +// not a valid time. +func ValidateTime(t time.Time) interface{} { + if t.After(maxTime) { + return nil + } + return t } var ( @@ -144,23 +178,37 @@ var ( // Numeric types + // Int8 is an integer of 8 bits + Int8 = numberT{t: sqltypes.Int8} + // Uint8 is an unsigned integer of 8 bits + Uint8 = numberT{t: sqltypes.Uint8} + // Int16 is an integer of 16 bits + Int16 = numberT{t: sqltypes.Int16} + // Uint16 is an unsigned integer of 16 bits + Uint16 = numberT{t: sqltypes.Uint16} + // Int24 is an integer of 24 bits. + Int24 = numberT{t: sqltypes.Int24} + // Uint24 is an unsigned integer of 24 bits. + Uint24 = numberT{t: sqltypes.Uint24} // Int32 is an integer of 32 bits. Int32 = numberT{t: sqltypes.Int32} + // Uint32 is an unsigned integer of 32 bits. + Uint32 = numberT{t: sqltypes.Uint32} // Int64 is an integer of 64 bytes. Int64 = numberT{t: sqltypes.Int64} - // Uint32 is an unsigned integer of 32 bytes. - Uint32 = numberT{t: sqltypes.Uint32} - // Uint64 is an unsigned integer of 64 bytes. + // Uint64 is an unsigned integer of 64 bits. Uint64 = numberT{t: sqltypes.Uint64} - // Float32 is a floating point number of 32 bytes. + // Float32 is a floating point number of 32 bits. Float32 = numberT{t: sqltypes.Float32} - // Float64 is a floating point number of 64 bytes. + // Float64 is a floating point number of 64 bits. Float64 = numberT{t: sqltypes.Float64} // Timestamp is an UNIX timestamp. Timestamp timestampT // Date is a date with day, month and year. Date dateT + // Datetime is a date and a time + Datetime datetimeT // Text is a string type. Text textT // Boolean is a boolean type. @@ -181,17 +229,39 @@ func Array(underlying Type) Type { return arrayT{underlying} } +// Char returns a new Char type of the given length. +func Char(length int) Type { + return charT{length: length} +} + +// VarChar returns a new VarChar type of the given length. +func VarChar(length int) Type { + return varCharT{length: length} +} + // MysqlTypeToType gets the column type using the mysql type func MysqlTypeToType(sql query.Type) (Type, error) { switch sql { case sqltypes.Null: return Null, nil + case sqltypes.Int8: + return Int8, nil + case sqltypes.Uint8: + return Uint8, nil + case sqltypes.Int16: + return Int16, nil + case sqltypes.Uint16: + return Uint16, nil + case sqltypes.Int24: + return Int24, nil + case sqltypes.Uint24: + return Uint24, nil case sqltypes.Int32: return Int32, nil - case sqltypes.Int64: - return Int64, nil case sqltypes.Uint32: return Uint32, nil + case sqltypes.Int64: + return Int64, nil case sqltypes.Uint64: return Uint64, nil case sqltypes.Float32: @@ -202,8 +272,18 @@ func MysqlTypeToType(sql query.Type) (Type, error) { return Timestamp, nil case sqltypes.Date: return Date, nil - case sqltypes.Text, sqltypes.VarChar: + case sqltypes.Text: + return Text, nil + case sqltypes.Char: + // Since we can't get the size of the sqltypes.Char to instantiate a + // specific Char(length) type we return a Text here return Text, nil + case sqltypes.VarChar: + // Since we can't get the size of the sqltypes.VarChar to instantiate a + // specific VarChar(length) type we return a Text here + return Text, nil + case sqltypes.Datetime: + return Datetime, nil case sqltypes.Bit: return Boolean, nil case sqltypes.TypeJSON: @@ -217,14 +297,16 @@ func MysqlTypeToType(sql query.Type) (Type, error) { type nullT struct{} +func (t nullT) String() string { return "NULL" } + // Type implements Type interface. func (t nullT) Type() query.Type { return sqltypes.Null } // SQL implements Type interface. -func (t nullT) SQL(interface{}) sqltypes.Value { - return sqltypes.NULL +func (t nullT) SQL(interface{}) (sqltypes.Value, error) { + return sqltypes.NULL, nil } // Convert implements Type interface. @@ -242,6 +324,11 @@ func (t nullT) Compare(a interface{}, b interface{}) (int, error) { return 0, nil } +// IsNull returns true if expression is nil or is Null Type, otherwise false. +func IsNull(ex Expression) bool { + return ex == nil || ex.Type() == Null +} + type numberT struct { t query.Type } @@ -252,17 +339,56 @@ func (t numberT) Type() query.Type { } // SQL implements Type interface. -func (t numberT) SQL(v interface{}) sqltypes.Value { - return sqltypes.MakeTrusted(t.t, strconv.AppendInt(nil, cast.ToInt64(v), 10)) +func (t numberT) SQL(v interface{}) (sqltypes.Value, error) { + if v == nil { + return sqltypes.NULL, nil + } + + switch t.t { + case sqltypes.Int8: + return sqltypes.MakeTrusted(t.t, strconv.AppendInt(nil, cast.ToInt64(v), 10)), nil + case sqltypes.Int16: + return sqltypes.MakeTrusted(t.t, strconv.AppendInt(nil, cast.ToInt64(v), 10)), nil + case sqltypes.Int32: + return sqltypes.MakeTrusted(t.t, strconv.AppendInt(nil, cast.ToInt64(v), 10)), nil + case sqltypes.Int64: + return sqltypes.MakeTrusted(t.t, strconv.AppendInt(nil, cast.ToInt64(v), 10)), nil + case sqltypes.Uint8: + return sqltypes.MakeTrusted(t.t, strconv.AppendUint(nil, cast.ToUint64(v), 10)), nil + case sqltypes.Uint16: + return sqltypes.MakeTrusted(t.t, strconv.AppendUint(nil, cast.ToUint64(v), 10)), nil + case sqltypes.Uint32: + return sqltypes.MakeTrusted(t.t, strconv.AppendUint(nil, cast.ToUint64(v), 10)), nil + case sqltypes.Uint64: + return sqltypes.MakeTrusted(t.t, strconv.AppendUint(nil, cast.ToUint64(v), 10)), nil + case sqltypes.Float32: + return sqltypes.MakeTrusted(t.t, strconv.AppendFloat(nil, cast.ToFloat64(v), 'f', -1, 64)), nil + case sqltypes.Float64: + return sqltypes.MakeTrusted(t.t, strconv.AppendFloat(nil, cast.ToFloat64(v), 'f', -1, 64)), nil + default: + return sqltypes.MakeTrusted(t.t, []byte{}), nil + } } // Convert implements Type interface. func (t numberT) Convert(v interface{}) (interface{}, error) { + if ti, ok := v.(time.Time); ok { + v = ti.Unix() + } + switch t.t { + case sqltypes.Int8: + return cast.ToInt8E(v) + case sqltypes.Int16: + return cast.ToInt16E(v) case sqltypes.Int32: return cast.ToInt32E(v) case sqltypes.Int64: return cast.ToInt64E(v) + case sqltypes.Uint8: + return cast.ToUint8E(v) + case sqltypes.Uint16: + return cast.ToUint16E(v) case sqltypes.Uint32: return cast.ToUint32E(v) case sqltypes.Uint64: @@ -274,19 +400,55 @@ func (t numberT) Convert(v interface{}) (interface{}, error) { default: return nil, ErrInvalidType.New(t.t) } - } // Compare implements Type interface. func (t numberT) Compare(a interface{}, b interface{}) (int, error) { if IsUnsigned(t) { - return compareUnsigned(a, b) + // only int types are unsigned + return compareUnsignedInts(a, b) + } + + switch t.t { + case sqltypes.Float64, sqltypes.Float32: + return compareFloats(a, b) + default: + return compareSignedInts(a, b) + } +} + +func (t numberT) String() string { return t.t.String() } + +func compareFloats(a interface{}, b interface{}) (int, error) { + if hasNulls, res := compareNulls(a, b); hasNulls { + return res, nil + } + + ca, err := cast.ToFloat64E(a) + if err != nil { + return 0, err + } + cb, err := cast.ToFloat64E(b) + if err != nil { + return 0, err + } + + if ca == cb { + return 0, nil + } + + if ca < cb { + return -1, nil } - return compareSigned(a, b) + return +1, nil } -func compareSigned(a interface{}, b interface{}) (int, error) { +func compareSignedInts(a interface{}, b interface{}) (int, error) { + if hasNulls, res := compareNulls(a, b); hasNulls { + return res, nil + } + ca, err := cast.ToInt64E(a) if err != nil { return 0, err @@ -307,7 +469,11 @@ func compareSigned(a interface{}, b interface{}) (int, error) { return +1, nil } -func compareUnsigned(a interface{}, b interface{}) (int, error) { +func compareUnsignedInts(a interface{}, b interface{}) (int, error) { + if hasNulls, res := compareNulls(a, b); hasNulls { + return res, nil + } + ca, err := cast.ToUint64E(a) if err != nil { return 0, err @@ -330,6 +496,8 @@ func compareUnsigned(a interface{}, b interface{}) (int, error) { type timestampT struct{} +func (t timestampT) String() string { return "TIMESTAMP" } + // Type implements Type interface. func (t timestampT) Type() query.Type { return sqltypes.Timestamp @@ -339,13 +507,33 @@ func (t timestampT) Type() query.Type { // using the format of Go "time" package. const TimestampLayout = "2006-01-02 15:04:05" +// TimestampLayouts hold extra timestamps allowed for parsing. It does +// not have all the layouts supported by mysql. Missing are two digit year +// versions of common cases and dates that use non common separators. +// +// https://github.com/MariaDB/server/blob/mysql-5.5.36/sql-common/my_time.c#L124 +var TimestampLayouts = []string{ + "2006-01-02", + time.RFC3339, + "20060102150405", + "20060102", +} + // SQL implements Type interface. -func (t timestampT) SQL(v interface{}) sqltypes.Value { - time := MustConvert(t, v).(time.Time) +func (t timestampT) SQL(v interface{}) (sqltypes.Value, error) { + if v == nil { + return sqltypes.NULL, nil + } + + v, err := t.Convert(v) + if err != nil { + return sqltypes.Value{}, err + } + return sqltypes.MakeTrusted( sqltypes.Timestamp, - []byte(time.Format(TimestampLayout)), - ) + []byte(v.(time.Time).Format(TimestampLayout)), + ), nil } // Convert implements Type interface. @@ -356,7 +544,18 @@ func (t timestampT) Convert(v interface{}) (interface{}, error) { case string: t, err := time.Parse(TimestampLayout, value) if err != nil { - return nil, ErrConvertingToTime.Wrap(err, v) + failed := true + for _, fmt := range TimestampLayouts { + if t2, err2 := time.Parse(fmt, value); err2 == nil { + t = t2 + failed = false + break + } + } + + if failed { + return nil, ErrConvertingToTime.Wrap(err, v) + } } return t.UTC(), nil default: @@ -371,6 +570,10 @@ func (t timestampT) Convert(v interface{}) (interface{}, error) { // Compare implements Type interface. func (t timestampT) Compare(a interface{}, b interface{}) (int, error) { + if hasNulls, res := compareNulls(a, b); hasNulls { + return res, nil + } + av := a.(time.Time) bv := b.(time.Time) if av.Before(bv) { @@ -391,16 +594,26 @@ func truncateDate(t time.Time) time.Time { return t.Truncate(24 * time.Hour) } +func (t dateT) String() string { return "DATE" } + func (t dateT) Type() query.Type { return sqltypes.Date } -func (t dateT) SQL(v interface{}) sqltypes.Value { - time := MustConvert(t, v).(time.Time) +func (t dateT) SQL(v interface{}) (sqltypes.Value, error) { + if v == nil { + return sqltypes.NULL, nil + } + + v, err := t.Convert(v) + if err != nil { + return sqltypes.Value{}, err + } + return sqltypes.MakeTrusted( sqltypes.Timestamp, - []byte(time.Format(DateLayout)), - ) + []byte(v.(time.Time).Format(DateLayout)), + ), nil } func (t dateT) Convert(v interface{}) (interface{}, error) { @@ -424,6 +637,10 @@ func (t dateT) Convert(v interface{}) (interface{}, error) { } func (t dateT) Compare(a, b interface{}) (int, error) { + if hasNulls, res := compareNulls(a, b); hasNulls { + return res, nil + } + av := truncateDate(a.(time.Time)) bv := truncateDate(b.(time.Time)) if av.Before(bv) { @@ -434,78 +651,288 @@ func (t dateT) Compare(a, b interface{}) (int, error) { return 0, nil } +type datetimeT struct{} + +// DatetimeLayout is the layout of the MySQL date format in the representation +// Go understands. +const DatetimeLayout = "2006-01-02 15:04:05" + +func (t datetimeT) String() string { return "DATETIME" } + +func (t datetimeT) Type() query.Type { + return sqltypes.Datetime +} + +func (t datetimeT) SQL(v interface{}) (sqltypes.Value, error) { + if v == nil { + return sqltypes.NULL, nil + } + + v, err := t.Convert(v) + if err != nil { + return sqltypes.Value{}, err + } + + return sqltypes.MakeTrusted( + sqltypes.Datetime, + []byte(v.(time.Time).Format(DatetimeLayout)), + ), nil +} + +func (t datetimeT) Convert(v interface{}) (interface{}, error) { + switch value := v.(type) { + case time.Time: + return value.UTC(), nil + case string: + t, err := time.Parse(DatetimeLayout, value) + if err != nil { + return nil, ErrConvertingToTime.Wrap(err, v) + } + return t.UTC(), nil + default: + ts, err := Int64.Convert(v) + if err != nil { + return nil, ErrInvalidType.New(reflect.TypeOf(v)) + } + + return time.Unix(ts.(int64), 0).UTC(), nil + } +} + +func (t datetimeT) Compare(a, b interface{}) (int, error) { + av := a.(time.Time) + bv := b.(time.Time) + if av.Before(bv) { + return -1, nil + } else if av.After(bv) { + return 1, nil + } + return 0, nil +} + +type charT struct { + length int +} + +func (t charT) Capacity() int { return t.length } + +func (t charT) String() string { return fmt.Sprintf("CHAR(%d)", t.length) } + +func (t charT) Type() query.Type { + return sqltypes.Char +} + +func (t charT) SQL(v interface{}) (sqltypes.Value, error) { + if v == nil { + return sqltypes.MakeTrusted(sqltypes.Char, nil), nil + } + + v, err := t.Convert(v) + if err != nil { + return sqltypes.Value{}, err + } + + return sqltypes.MakeTrusted(sqltypes.Char, []byte(v.(string))), nil +} + +// Converts any value that can be casted to a string +func (t charT) Convert(v interface{}) (interface{}, error) { + val, err := cast.ToStringE(v) + if err != nil { + return nil, ErrConvertToSQL.New(t) + } + + if len(val) > t.length { + return nil, ErrCharTruncation.New(val, t.length) + } + return val, nil +} + +// Compares two strings lexicographically +func (t charT) Compare(a interface{}, b interface{}) (int, error) { + return strings.Compare(a.(string), b.(string)), nil +} + +type varCharT struct { + length int +} + +func (t varCharT) Capacity() int { return t.length } + +func (t varCharT) String() string { return fmt.Sprintf("VARCHAR(%d)", t.length) } + +// Type implements Type interface +func (t varCharT) Type() query.Type { + return sqltypes.VarChar +} + +// SQL implements Type interface +func (t varCharT) SQL(v interface{}) (sqltypes.Value, error) { + if v == nil { + return sqltypes.MakeTrusted(sqltypes.VarChar, nil), nil + } + + v, err := t.Convert(v) + if err != nil { + return sqltypes.Value{}, err + } + + return sqltypes.MakeTrusted(sqltypes.VarChar, []byte(v.(string))), nil +} + +// Convert implements Type interface +func (t varCharT) Convert(v interface{}) (interface{}, error) { + val, err := cast.ToStringE(v) + if err != nil { + return nil, ErrConvertToSQL.New(t) + } + + if len(val) > t.length { + return nil, ErrVarCharTruncation.New(val, t.length) + } + return val, nil +} + +// Compare implements Type interface. +func (t varCharT) Compare(a interface{}, b interface{}) (int, error) { + if hasNulls, res := compareNulls(a, b); hasNulls { + return res, nil + } + return strings.Compare(a.(string), b.(string)), nil +} + type textT struct{} +func (t textT) String() string { return "TEXT" } + // Type implements Type interface. func (t textT) Type() query.Type { return sqltypes.Text } // SQL implements Type interface. -func (t textT) SQL(v interface{}) sqltypes.Value { - return sqltypes.MakeTrusted(sqltypes.Text, []byte(MustConvert(t, v).(string))) +func (t textT) SQL(v interface{}) (sqltypes.Value, error) { + if v == nil { + return sqltypes.NULL, nil + } + + v, err := t.Convert(v) + if err != nil { + return sqltypes.Value{}, err + } + + return sqltypes.MakeTrusted(sqltypes.Text, []byte(v.(string))), nil } // Convert implements Type interface. func (t textT) Convert(v interface{}) (interface{}, error) { - return cast.ToStringE(v) + val, err := cast.ToStringE(v) + if err != nil { + return nil, ErrConvertToSQL.New(t) + } + return val, nil } // Compare implements Type interface. func (t textT) Compare(a interface{}, b interface{}) (int, error) { + if hasNulls, res := compareNulls(a, b); hasNulls { + return res, nil + } return strings.Compare(a.(string), b.(string)), nil } type booleanT struct{} +func (t booleanT) String() string { return "BOOLEAN" } + // Type implements Type interface. func (t booleanT) Type() query.Type { return sqltypes.Bit } // SQL implements Type interface. -func (t booleanT) SQL(v interface{}) sqltypes.Value { +func (t booleanT) SQL(v interface{}) (sqltypes.Value, error) { + if v == nil { + return sqltypes.NULL, nil + } + b := []byte{'0'} if cast.ToBool(v) { b[0] = '1' } - return sqltypes.MakeTrusted(sqltypes.Bit, b) + return sqltypes.MakeTrusted(sqltypes.Bit, b), nil } // Convert implements Type interface. func (t booleanT) Convert(v interface{}) (interface{}, error) { - return cast.ToBoolE(v) + switch b := v.(type) { + case bool: + return b, nil + case int, int64, int32, int16, int8, uint, uint64, uint32, uint16, uint8: + return b != 0, nil + case time.Duration: + return int64(b) != 0, nil + case time.Time: + return b.UnixNano() != 0, nil + case float32, float64: + return int(math.Round(v.(float64))) != 0, nil + case string: + return false, nil + + case nil: + return nil, fmt.Errorf("unable to cast nil to bool") + + default: + return nil, fmt.Errorf("unable to cast %#v of type %T to bool", v, v) + } } // Compare implements Type interface. func (t booleanT) Compare(a interface{}, b interface{}) (int, error) { + if hasNulls, res := compareNulls(a, b); hasNulls { + return res, nil + } + if a == b { return 0, nil } - if a.(bool) == false { + if a == false { return -1, nil } - return +1, nil + return 1, nil } type blobT struct{} +func (t blobT) String() string { return "BLOB" } + // Type implements Type interface. func (t blobT) Type() query.Type { return sqltypes.Blob } // SQL implements Type interface. -func (t blobT) SQL(v interface{}) sqltypes.Value { - return sqltypes.MakeTrusted(sqltypes.Blob, MustConvert(t, v).([]byte)) +func (t blobT) SQL(v interface{}) (sqltypes.Value, error) { + if v == nil { + return sqltypes.NULL, nil + } + + v, err := t.Convert(v) + if err != nil { + return sqltypes.Value{}, err + } + + return sqltypes.MakeTrusted(sqltypes.Blob, v.([]byte)), nil } // Convert implements Type interface. func (t blobT) Convert(v interface{}) (interface{}, error) { switch value := v.(type) { + case nil: + return []byte(nil), nil case []byte: return value, nil case string: @@ -519,39 +946,73 @@ func (t blobT) Convert(v interface{}) (interface{}, error) { // Compare implements Type interface. func (t blobT) Compare(a interface{}, b interface{}) (int, error) { + if hasNulls, res := compareNulls(a, b); hasNulls { + return res, nil + } return bytes.Compare(a.([]byte), b.([]byte)), nil } type jsonT struct{} +func (t jsonT) String() string { return "JSON" } + // Type implements Type interface. func (t jsonT) Type() query.Type { return sqltypes.TypeJSON } // SQL implements Type interface. -func (t jsonT) SQL(v interface{}) sqltypes.Value { - return sqltypes.MakeTrusted(sqltypes.TypeJSON, MustConvert(t, v).([]byte)) +func (t jsonT) SQL(v interface{}) (sqltypes.Value, error) { + if v == nil { + return sqltypes.NULL, nil + } + + v, err := t.Convert(v) + if err != nil { + return sqltypes.Value{}, err + } + + return sqltypes.MakeTrusted(sqltypes.TypeJSON, v.([]byte)), nil } // Convert implements Type interface. func (t jsonT) Convert(v interface{}) (interface{}, error) { - return json.Marshal(v) + switch v := v.(type) { + case string: + var doc interface{} + if err := json.Unmarshal([]byte(v), &doc); err != nil { + return json.Marshal(v) + } + return json.Marshal(doc) + default: + return json.Marshal(v) + } } // Compare implements Type interface. func (t jsonT) Compare(a interface{}, b interface{}) (int, error) { + if hasNulls, res := compareNulls(a, b); hasNulls { + return res, nil + } return bytes.Compare(a.([]byte), b.([]byte)), nil } type tupleT []Type +func (t tupleT) String() string { + var elems = make([]string, len(t)) + for i, el := range t { + elems[i] = el.String() + } + return fmt.Sprintf("TUPLE(%s)", strings.Join(elems, ", ")) +} + func (t tupleT) Type() query.Type { return sqltypes.Expression } -func (t tupleT) SQL(v interface{}) sqltypes.Value { - panic("unable to convert tuple type to SQL") +func (t tupleT) SQL(v interface{}) (sqltypes.Value, error) { + return sqltypes.Value{}, fmt.Errorf("unable to convert tuple type to SQL") } func (t tupleT) Convert(v interface{}) (interface{}, error) { @@ -605,28 +1066,69 @@ type arrayT struct { underlying Type } +func (t arrayT) String() string { return fmt.Sprintf("ARRAY(%s)", t.underlying) } + func (t arrayT) Type() query.Type { return sqltypes.TypeJSON } -func (t arrayT) SQL(v interface{}) sqltypes.Value { - return JSON.SQL(v) +func (t arrayT) SQL(v interface{}) (sqltypes.Value, error) { + if v == nil { + return sqltypes.NULL, nil + } + + v, err := convertForJSON(t, v) + if err != nil { + return sqltypes.Value{}, err + } + + val, err := json.Marshal(v) + if err != nil { + return sqltypes.Value{}, err + } + + return sqltypes.MakeTrusted(sqltypes.TypeJSON, val), nil } func (t arrayT) Convert(v interface{}) (interface{}, error) { - if vals, ok := v.([]interface{}); ok { - var result = make([]interface{}, len(vals)) - for i, v := range vals { + switch v := v.(type) { + case []interface{}: + var result = make([]interface{}, len(v)) + for i, v := range v { var err error result[i], err = t.underlying.Convert(v) if err != nil { return nil, err } } - return result, nil + case Generator: + var values []interface{} + for { + val, err := v.Next() + if err != nil { + if err == io.EOF { + break + } + return nil, err + } + + val, err = t.underlying.Convert(val) + if err != nil { + return nil, err + } + + values = append(values, val) + } + + if err := v.Close(); err != nil { + return nil, err + } + + return values, nil + default: + return nil, ErrNotArray.New(v) } - return nil, ErrNotArray.New(v) } func (t arrayT) Compare(a, b interface{}) (int, error) { @@ -663,16 +1165,6 @@ func (t arrayT) Compare(a, b interface{}) (int, error) { return 0, nil } -// MustConvert calls the Convert function from a given Type, it err panics. -func MustConvert(t Type, v interface{}) interface{} { - c, err := t.Convert(v) - if err != nil { - panic(err) - } - - return c -} - // IsNumber checks if t is a number type func IsNumber(t Type) bool { return IsInteger(t) || IsDecimal(t) @@ -680,19 +1172,24 @@ func IsNumber(t Type) bool { // IsSigned checks if t is a signed type. func IsSigned(t Type) bool { - return t == Int32 || t == Int64 + return t == Int8 || t == Int16 || t == Int32 || t == Int64 } // IsUnsigned checks if t is an unsigned type. func IsUnsigned(t Type) bool { - return t == Uint32 || t == Uint64 + return t == Uint8 || t == Uint16 || t == Uint32 || t == Uint64 } -// IsInteger check if t is a (U)Int32/64 type +// IsInteger checks if t is a (U)Int32/64 type. func IsInteger(t Type) bool { return IsSigned(t) || IsUnsigned(t) } +// IsTime checks if t is a timestamp, date or datetime +func IsTime(t Type) bool { + return t == Timestamp || t == Date || t == Datetime +} + // IsDecimal checks if t is decimal type. func IsDecimal(t Type) bool { return t == Float32 || t == Float64 @@ -700,7 +1197,19 @@ func IsDecimal(t Type) bool { // IsText checks if t is a text type. func IsText(t Type) bool { - return t == Text || t == Blob || t == JSON + return t == Text || t == Blob || t == JSON || IsVarChar(t) || IsChar(t) +} + +// IsChar checks if t is a Char type. +func IsChar(t Type) bool { + _, ok := t.(charT) + return ok +} + +// IsVarChar checks if t is a varchar type. +func IsVarChar(t Type) bool { + _, ok := t.(varCharT) + return ok } // IsTuple checks if t is a tuple type. @@ -726,3 +1235,139 @@ func NumColumns(t Type) int { } return len(v) } + +// MySQLTypeName returns the MySQL display name for the given type. +func MySQLTypeName(t Type) string { + switch t.Type() { + case sqltypes.Int8: + return "TINYINT" + case sqltypes.Uint8: + return "TINYINT UNSIGNED" + case sqltypes.Int16: + return "SMALLINT" + case sqltypes.Uint16: + return "SMALLINT UNSIGNED" + case sqltypes.Int32: + return "INTEGER" + case sqltypes.Int64: + return "BIGINT" + case sqltypes.Uint32: + return "INTEGER UNSIGNED" + case sqltypes.Uint64: + return "BIGINT UNSIGNED" + case sqltypes.Float32: + return "FLOAT" + case sqltypes.Float64: + return "DOUBLE" + case sqltypes.Timestamp: + return "TIMESTAMP" + case sqltypes.Datetime: + return "DATETIME" + case sqltypes.Date: + return "DATE" + case sqltypes.Char: + return fmt.Sprintf("CHAR(%v)", t.(charT).Capacity()) + case sqltypes.VarChar: + return fmt.Sprintf("VARCHAR(%v)", t.(varCharT).Capacity()) + case sqltypes.Text: + return "TEXT" + case sqltypes.Bit: + return "BIT" + case sqltypes.TypeJSON: + return "JSON" + case sqltypes.Blob: + return "BLOB" + default: + return "UNKNOWN" + } +} + +// UnderlyingType returns the underlying type of an array if the type is an +// array, or the type itself in any other case. +func UnderlyingType(t Type) Type { + a, ok := t.(arrayT) + if !ok { + return t + } + + return a.underlying +} + +func convertForJSON(t Type, v interface{}) (interface{}, error) { + switch t := t.(type) { + case jsonT: + val, err := t.Convert(v) + if err != nil { + return nil, err + } + + var doc interface{} + err = json.Unmarshal(val.([]byte), &doc) + if err != nil { + return nil, err + } + + return doc, nil + case arrayT: + return convertArrayForJSON(t, v) + default: + return t.Convert(v) + } +} + +func convertArrayForJSON(t arrayT, v interface{}) (interface{}, error) { + switch v := v.(type) { + case []interface{}: + var result = make([]interface{}, len(v)) + for i, v := range v { + var err error + result[i], err = convertForJSON(t.underlying, v) + if err != nil { + return nil, err + } + } + return result, nil + case Generator: + var values []interface{} + for { + val, err := v.Next() + if err != nil { + if err == io.EOF { + break + } + return nil, err + } + + val, err = convertForJSON(t.underlying, val) + if err != nil { + return nil, err + } + + values = append(values, val) + } + + if err := v.Close(); err != nil { + return nil, err + } + + return values, nil + default: + return nil, ErrNotArray.New(v) + } +} + +// compareNulls compares two values, and returns true if either is null. +// The returned integer represents the ordering, with a rule that states nulls +// as being ordered before non-nulls. +func compareNulls(a interface{}, b interface{}) (bool, int) { + aIsNull := a == nil + bIsNull := b == nil + if aIsNull && bIsNull { + return true, 0 + } else if aIsNull && !bIsNull { + return true, -1 + } else if !aIsNull && bIsNull { + return true, 1 + } + return false, 0 +} diff --git a/sql/type_test.go b/sql/type_test.go index 797dda477..0dab60260 100644 --- a/sql/type_test.go +++ b/sql/type_test.go @@ -5,9 +5,18 @@ import ( "time" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-vitess.v0/sqltypes" + "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/proto/query" ) +func TestIsNull(t *testing.T) { + require.True(t, IsNull(nil)) + + n := numberT{sqltypes.Uint64} + require.Equal(t, sqltypes.NULL, mustSQL(n.SQL(nil))) + require.Equal(t, sqltypes.NewUint64(0), mustSQL(n.SQL(uint64(0)))) +} + func TestText(t *testing.T) { convert(t, Text, "", "") convert(t, Text, 1, "1") @@ -15,18 +24,140 @@ func TestText(t *testing.T) { lt(t, Text, "a", "b") eq(t, Text, "a", "a") gt(t, Text, "b", "a") + + var3, err := VarChar(3).Convert("abc") + require.NoError(t, err) + convert(t, Text, var3, "abc") +} + +func TestBoolean(t *testing.T) { + convert(t, Boolean, "", false) + convert(t, Boolean, "true", false) + convert(t, Boolean, 0, false) + convert(t, Boolean, 1, true) + convert(t, Boolean, -1, true) + convert(t, Boolean, 0.0, false) + convert(t, Boolean, 0.4, false) + convert(t, Boolean, 0.5, true) + convert(t, Boolean, 1.0, true) + convert(t, Boolean, -1.0, true) + + eq(t, Boolean, true, true) + eq(t, Boolean, false, false) +} + +// Test conversion of all types of numbers to the specified signed integer type +// in typ, where minusOne, zero and one are the expected values with the +// same type as typ +func testSignedInt(t *testing.T, typ Type, minusOne, zero, one interface{}) { + t.Helper() + + convert(t, typ, -1, minusOne) + convert(t, typ, int8(-1), minusOne) + convert(t, typ, int16(-1), minusOne) + convert(t, typ, int32(-1), minusOne) + convert(t, typ, int64(-1), minusOne) + convert(t, typ, 0, zero) + convert(t, typ, int8(0), zero) + convert(t, typ, int16(0), zero) + convert(t, typ, int32(0), zero) + convert(t, typ, int64(0), zero) + convert(t, typ, uint8(0), zero) + convert(t, typ, uint16(0), zero) + convert(t, typ, uint32(0), zero) + convert(t, typ, uint64(0), zero) + convert(t, typ, 1, one) + convert(t, typ, int8(1), one) + convert(t, typ, int16(1), one) + convert(t, typ, int32(1), one) + convert(t, typ, int64(1), one) + convert(t, typ, uint8(1), one) + convert(t, typ, uint16(1), one) + convert(t, typ, uint32(1), one) + convert(t, typ, uint64(1), one) + convert(t, typ, "-1", minusOne) + convert(t, typ, "0", zero) + convert(t, typ, "1", one) + convertErr(t, typ, "") + + lt(t, Int8, minusOne, one) + eq(t, Int8, zero, zero) + eq(t, Int8, minusOne, minusOne) + eq(t, Int8, one, one) + gt(t, Int8, one, minusOne) +} + +// Test conversion of all types of numbers to the specified unsigned integer +// type in typ, where zero and one are the expected values with the same type +// as typ. The expected errors when converting from negative numbers are also +// tested +func testUnsignedInt(t *testing.T, typ Type, zero, one interface{}) { + t.Helper() + + convertErr(t, typ, -1) + convertErr(t, typ, int8(-1)) + convertErr(t, typ, int16(-1)) + convertErr(t, typ, int32(-1)) + convertErr(t, typ, int64(-1)) + convert(t, typ, 0, zero) + convert(t, typ, int8(0), zero) + convert(t, typ, int16(0), zero) + convert(t, typ, int32(0), zero) + convert(t, typ, int64(0), zero) + convert(t, typ, uint8(0), zero) + convert(t, typ, uint16(0), zero) + convert(t, typ, uint32(0), zero) + convert(t, typ, uint64(0), zero) + convert(t, typ, 1, one) + convert(t, typ, int8(1), one) + convert(t, typ, int16(1), one) + convert(t, typ, int32(1), one) + convert(t, typ, int64(1), one) + convert(t, typ, uint8(1), one) + convert(t, typ, uint16(1), one) + convert(t, typ, uint32(1), one) + convert(t, typ, uint64(1), one) + convertErr(t, typ, "-1") + convert(t, typ, "0", zero) + convert(t, typ, "1", one) + convertErr(t, typ, "") + + lt(t, Int8, zero, one) + eq(t, Int8, zero, zero) + eq(t, Int8, one, one) + gt(t, Int8, one, zero) +} + +func TestInt8(t *testing.T) { + testSignedInt(t, Int8, int8(-1), int8(0), int8(1)) +} + +func TestInt16(t *testing.T) { + testSignedInt(t, Int16, int16(-1), int16(0), int16(1)) } func TestInt32(t *testing.T) { - convert(t, Int32, int32(1), int32(1)) - convert(t, Int32, 1, int32(1)) - convert(t, Int32, int64(1), int32(1)) - convert(t, Int32, "5", int32(5)) - convertErr(t, Int32, "") + testSignedInt(t, Int32, int32(-1), int32(0), int32(1)) +} + +func TestInt64(t *testing.T) { + testSignedInt(t, Int64, int64(-1), int64(0), int64(1)) +} + +func TestUint8(t *testing.T) { + testUnsignedInt(t, Uint8, uint8(0), uint8(1)) +} + +func TestUint16(t *testing.T) { + testUnsignedInt(t, Uint16, uint16(0), uint16(1)) +} - lt(t, Int32, int32(1), int32(2)) - eq(t, Int32, int32(1), int32(1)) - gt(t, Int32, int32(3), int32(2)) +func TestUint32(t *testing.T) { + testUnsignedInt(t, Uint32, uint32(0), uint32(1)) +} + +func TestUint64(t *testing.T) { + testUnsignedInt(t, Uint64, uint64(0), uint64(1)) } func TestNumberComparison(t *testing.T) { @@ -41,18 +172,88 @@ func TestNumberComparison(t *testing.T) { gt(t, Uint32, int64(5), uint32(1)) gt(t, Uint32, uint32(5), int64(1)) lt(t, Uint32, uint64(1), int32(5)) + + eq(t, Uint8, uint8(255), uint8(255)) + eq(t, Uint8, uint8(255), int32(255)) + eq(t, Uint8, uint8(255), int64(255)) + eq(t, Uint8, uint8(255), int64(255)) + gt(t, Uint8, uint8(255), int32(1)) + gt(t, Uint8, uint8(255), int64(1)) + lt(t, Uint8, uint8(255), int16(256)) + + // Exhaustive numeric type equality test + type typeAndValue struct { + t numberT + v interface{} + } + + allTypes := []typeAndValue{ + {Int8, int8(42)}, + {Uint8, uint8(42)}, + {Int16, int16(42)}, + {Uint16, uint16(42)}, + {Int24, int32(42)}, + {Uint24, uint32(42)}, + {Int32, int32(42)}, + {Uint32, uint32(42)}, + {Int64, int64(42)}, + {Uint64, uint64(42)}, + {Float32, float32(42)}, + {Float64, float64(42)}, + } + for _, a := range allTypes { + for _, b := range allTypes { + eq(t, a.t, a.v, b.v) + } + } + + // Float comparisons against other floats + greaterFloat := 7.5 + lesserFloat := 7.4 + gt(t, Float64, float64(greaterFloat), float64(lesserFloat)) + lt(t, Float64, float64(lesserFloat), float64(greaterFloat)) + eq(t, Float64, float64(greaterFloat), float64(greaterFloat)) + gt(t, Float64, float64(greaterFloat), float32(lesserFloat)) + lt(t, Float64, float64(lesserFloat), float32(greaterFloat)) + eq(t, Float64, float64(greaterFloat), float32(greaterFloat)) + gt(t, Float32, float32(greaterFloat), float32(lesserFloat)) + lt(t, Float32, float32(lesserFloat), float32(greaterFloat)) + eq(t, Float32, float32(greaterFloat), float32(greaterFloat)) + gt(t, Float32, float32(greaterFloat), float64(lesserFloat)) + lt(t, Float32, float32(lesserFloat), float64(greaterFloat)) + eq(t, Float32, float32(greaterFloat), float64(greaterFloat)) + + // Float comparisons against other types, testing comparison and truncation (when an int type is the left side of a + // comparison with a float type) + lessInt := 7 + floatComps := []typeAndValue{ + {Int8, int8(lessInt)}, + {Uint8, uint8(lessInt)}, + {Int16, int16(lessInt)}, + {Uint16, uint16(lessInt)}, + {Int32, int32(lessInt)}, + {Uint32, uint32(lessInt)}, + {Int64, int64(lessInt)}, + {Uint64, uint64(lessInt)}, + } + for _, a := range floatComps { + gt(t, Float64, float64(greaterFloat), a.v) + eq(t, a.t, float64(greaterFloat), a.v) + gt(t, Float32, float32(greaterFloat), a.v) + eq(t, a.t, float32(greaterFloat), a.v) + } } -func TestInt64(t *testing.T) { - convert(t, Int64, int32(1), int64(1)) - convert(t, Int64, 1, int64(1)) - convert(t, Int64, int64(1), int64(1)) - convertErr(t, Int64, "") - convert(t, Int64, "5", int64(5)) +func TestFloat64(t *testing.T) { + require := require.New(t) - lt(t, Int64, int64(1), int64(2)) - eq(t, Int64, int64(1), int64(1)) - gt(t, Int64, int64(3), int64(2)) + var f = numberT{ + t: query.Type_FLOAT64, + } + val, err := f.SQL(23.222) + require.NoError(err) + require.True(val.IsFloat()) + require.Equal(sqltypes.NewFloat64(23.222), val) } func TestTimestamp(t *testing.T) { @@ -60,24 +261,25 @@ func TestTimestamp(t *testing.T) { now := time.Now().UTC() v, err := Timestamp.Convert(now) - require.Nil(err) + require.NoError(err) require.Equal(now, v) v, err = Timestamp.Convert(now.Format(TimestampLayout)) - require.Nil(err) + require.NoError(err) require.Equal( now.Format(TimestampLayout), v.(time.Time).Format(TimestampLayout), ) v, err = Timestamp.Convert(now.Unix()) - require.Nil(err) + require.NoError(err) require.Equal( now.Format(TimestampLayout), v.(time.Time).Format(TimestampLayout), ) - sql := Timestamp.SQL(now) + sql, err := Timestamp.SQL(now) + require.NoError(err) require.Equal([]byte(now.Format(TimestampLayout)), sql.Raw()) after := now.Add(time.Second) @@ -86,46 +288,104 @@ func TestTimestamp(t *testing.T) { gt(t, Timestamp, after, now) } -func TestDate(t *testing.T) { - require := require.New(t) +func TestExtraTimestamps(t *testing.T) { + tests := []struct { + date string + expected string + }{ + { + date: "2018-10-18T05:22:25Z", + expected: "2018-10-18 05:22:25", + }, + { + date: "2018-10-18T05:22:25+07:00", + expected: "2018-10-17 22:22:25", + }, + { + date: "20181018052225", + expected: "2018-10-18 05:22:25", + }, + { + date: "20181018", + expected: "2018-10-18 00:00:00", + }, + { + date: "2018-10-18", + expected: "2018-10-18 00:00:00", + }, + } + + for _, c := range tests { + t.Run(c.date, func(t *testing.T) { + require := require.New(t) + + p, err := Timestamp.Convert(c.date) + require.NoError(err) + + str := string([]byte(p.(time.Time).Format(TimestampLayout))) + require.Equal(c.expected, str) + }) + } +} - now := time.Now() - v, err := Date.Convert(now) - require.Nil(err) - require.Equal(now.Format(DateLayout), v.(time.Time).Format(DateLayout)) +// Generic tests for Date and Datetime. +// typ should be Date or Datetime +func commonTestsDatesTypes(typ Type, layout string, t *testing.T) { + require := require.New(t) + now := time.Now().UTC() + v, err := typ.Convert(now) + require.NoError(err) + require.Equal(now.Format(layout), v.(time.Time).Format(layout)) - v, err = Date.Convert(now.Format(DateLayout)) - require.Nil(err) + v, err = typ.Convert(now.Format(layout)) + require.NoError(err) require.Equal( - now.Format(DateLayout), - v.(time.Time).Format(DateLayout), + now.Format(layout), + v.(time.Time).Format(layout), ) - v, err = Date.Convert(now.Unix()) - require.Nil(err) + v, err = typ.Convert(now.Unix()) + require.NoError(err) require.Equal( - now.Format(DateLayout), - v.(time.Time).Format(DateLayout), + now.Format(layout), + v.(time.Time).Format(layout), ) - sql := Date.SQL(now) - require.Equal([]byte(now.Format(DateLayout)), sql.Raw()) + sql, err := typ.SQL(now) + require.NoError(err) + require.Equal([]byte(now.Format(layout)), sql.Raw()) + + after := now.Add(26 * time.Hour) + lt(t, typ, now, after) + eq(t, typ, now, now) + gt(t, typ, after, now) +} + +func TestDate(t *testing.T) { + commonTestsDatesTypes(Date, DateLayout, t) + now := time.Now().UTC() after := now.Add(time.Second) eq(t, Date, now, after) eq(t, Date, now, now) eq(t, Date, after, now) +} - after = now.Add(26 * time.Hour) - lt(t, Date, now, after) - eq(t, Date, now, now) - gt(t, Date, after, now) +func TestDatetime(t *testing.T) { + commonTestsDatesTypes(Datetime, DatetimeLayout, t) + + now := time.Now().UTC() + after := now.Add(time.Millisecond) + lt(t, Datetime, now, after) + eq(t, Datetime, now, now) + gt(t, Datetime, after, now) } func TestBlob(t *testing.T) { require := require.New(t) convert(t, Blob, "", []byte{}) + convert(t, Blob, nil, []byte(nil)) _, err := Blob.Convert(1) require.NotNil(err) @@ -139,6 +399,7 @@ func TestBlob(t *testing.T) { func TestJSON(t *testing.T) { convert(t, JSON, "", []byte(`""`)) convert(t, JSON, []int{1, 2}, []byte("[1,2]")) + convert(t, JSON, `{"a": true, "b": 3}`, []byte(`{"a":true,"b":3}`)) lt(t, JSON, []byte("A"), []byte("B")) eq(t, JSON, []byte("A"), []byte("A")) @@ -159,9 +420,8 @@ func TestTuple(t *testing.T) { convert(t, typ, []interface{}{1, 2, 3}, []interface{}{int32(1), "2", int64(3)}) - require.Panics(func() { - typ.SQL(nil) - }) + _, err = typ.SQL(nil) + require.Error(err) require.Equal(sqltypes.Expression, typ.Type()) @@ -174,6 +434,48 @@ func TestTuple(t *testing.T) { gt(t, typ, []interface{}{1, 2, 4}, []interface{}{1, 2, 3}) } +// Generic test for Char and VarChar types. +// genType should be sql.Char or sql.VarChar +func testCharTypes(genType func(int) Type, checkType func(Type) bool, t *testing.T) { + typ := genType(3) + require.True(t, checkType(typ)) + require.True(t, IsText(typ)) + convert(t, typ, "foo", "foo") + fooByte := []byte{'f', 'o', 'o'} + convert(t, typ, fooByte, "foo") + + typ = genType(1) + convertErr(t, typ, "foo") + convertErr(t, typ, fooByte) + convertErr(t, typ, 123) + + typ = genType(10) + convert(t, typ, 123, "123") + convertErr(t, typ, 1234567890123) + + convert(t, typ, "", "") + convert(t, typ, 1, "1") + + lt(t, typ, "a", "b") + eq(t, typ, "a", "a") + gt(t, typ, "b", "a") + + text, err := Text.Convert("abc") + require.NoError(t, err) + + convert(t, typ, text, "abc") + typ1 := genType(1) + convertErr(t, typ1, text) +} + +func TestChar(t *testing.T) { + testCharTypes(Char, IsChar, t) +} + +func TestVarChar(t *testing.T) { + testCharTypes(VarChar, IsVarChar, t) +} + func TestArray(t *testing.T) { require := require.New(t) @@ -183,6 +485,12 @@ func TestArray(t *testing.T) { require.True(ErrNotArray.Is(err)) convert(t, typ, []interface{}{1, 2, 3}, []interface{}{int64(1), int64(2), int64(3)}) + convert( + t, + typ, + NewArrayGenerator([]interface{}{1, 2, 3}), + []interface{}{int64(1), int64(2), int64(3)}, + ) require.Equal(sqltypes.TypeJSON, typ.Type()) @@ -195,6 +503,77 @@ func TestArray(t *testing.T) { gt(t, typ, []interface{}{1, 3, 3}, []interface{}{1, 2, 3}) gt(t, typ, []interface{}{1, 2, 4}, []interface{}{1, 2, 3}) gt(t, typ, []interface{}{1, 2, 4}, []interface{}{5, 6}) + + expected := []byte("[1,2,3]") + + v, err := Array(Int64).SQL([]interface{}{1, 2, 3}) + require.NoError(err) + require.Equal(expected, v.Raw()) + + v, err = Array(Int64).SQL(NewArrayGenerator([]interface{}{1, 2, 3})) + require.NoError(err) + require.Equal(expected, v.Raw()) +} + +func TestUnderlyingType(t *testing.T) { + require.Equal(t, Text, UnderlyingType(Array(Text))) + require.Equal(t, Text, UnderlyingType(Text)) +} + +type testJSONStruct struct { + A int + B string +} + +func TestJSONArraySQL(t *testing.T) { + require := require.New(t) + val, err := Array(JSON).SQL([]interface{}{ + testJSONStruct{1, "foo"}, + testJSONStruct{2, "bar"}, + }) + require.NoError(err) + expected := `[{"A":1,"B":"foo"},{"A":2,"B":"bar"}]` + require.Equal(expected, string(val.Raw())) +} + +func TestComparesWithNulls(t *testing.T) { + timeParse := func(layout string, value string) time.Time { + t, err := time.Parse(layout, value) + if err != nil { + panic(err) + } + return t + } + + var typeVals = []struct { + typ Type + val interface{} + }{ + {Int8, int8(0)}, + {Uint8, uint8(0)}, + {Int16, int16(0)}, + {Uint16, uint16(0)}, + {Int32, int32(0)}, + {Uint32, uint32(0)}, + {Int64, int64(0)}, + {Uint64, uint64(0)}, + {Float32, float32(0)}, + {Float64, float64(0)}, + {Timestamp, timeParse(TimestampLayout, "2132-04-05 12:51:36")}, + {Date, timeParse(DateLayout, "2231-11-07")}, + {Text, ""}, + {Boolean, false}, + {JSON, `{}`}, + {Blob, ""}, + } + + for _, typeVal := range typeVals { + t.Run(typeVal.typ.String(), func(t *testing.T) { + lt(t, typeVal.typ, nil, typeVal.val) + gt(t, typeVal.typ, typeVal.val, nil) + eq(t, typeVal.typ, nil, nil) + }) + } } func eq(t *testing.T, typ Type, a, b interface{}) { @@ -230,3 +609,10 @@ func convertErr(t *testing.T, typ Type, val interface{}) { _, err := typ.Convert(val) require.Error(t, err) } + +func mustSQL(v sqltypes.Value, err error) sqltypes.Value { + if err != nil { + panic(err) + } + return v +} diff --git a/sql/unresolved_database.go b/sql/unresolved_database.go index 62556bedd..86c59b617 100644 --- a/sql/unresolved_database.go +++ b/sql/unresolved_database.go @@ -1,14 +1,14 @@ package sql // UnresolvedDatabase is a database which has not been resolved yet. -type UnresolvedDatabase struct{} +type UnresolvedDatabase string -// Name returns the database name, which is always "unresolved_database". -func (d *UnresolvedDatabase) Name() string { - return "unresolved_database" +// Name returns the database name. +func (d UnresolvedDatabase) Name() string { + return string(d) } // Tables returns the tables in the database. -func (d *UnresolvedDatabase) Tables() map[string]Table { +func (UnresolvedDatabase) Tables() map[string]Table { return make(map[string]Table) } diff --git a/test/mem_tracer.go b/test/mem_tracer.go new file mode 100644 index 000000000..362e4321c --- /dev/null +++ b/test/mem_tracer.go @@ -0,0 +1,53 @@ +package test + +import ( + "sync" + + opentracing "github.com/opentracing/opentracing-go" + "github.com/opentracing/opentracing-go/log" +) + +// MemTracer implements a simple tracer in memory for testing. +type MemTracer struct { + Spans []string + sync.Mutex +} + +type memSpan struct { + opName string +} + +// StartSpan implements opentracing.Tracer interface. +func (t *MemTracer) StartSpan(operationName string, opts ...opentracing.StartSpanOption) opentracing.Span { + t.Lock() + t.Spans = append(t.Spans, operationName) + t.Unlock() + return &memSpan{operationName} +} + +// Inject implements opentracing.Tracer interface. +func (t *MemTracer) Inject(sm opentracing.SpanContext, format interface{}, carrier interface{}) error { + panic("not implemented") +} + +// Extract implements opentracing.Tracer interface. +func (t *MemTracer) Extract(format interface{}, carrier interface{}) (opentracing.SpanContext, error) { + panic("not implemented") +} + +func (m memSpan) Context() opentracing.SpanContext { return m } +func (m memSpan) SetBaggageItem(key, val string) opentracing.Span { return m } +func (m memSpan) BaggageItem(key string) string { return "" } +func (m memSpan) SetTag(key string, value interface{}) opentracing.Span { return m } +func (m memSpan) LogFields(fields ...log.Field) {} +func (m memSpan) LogKV(keyVals ...interface{}) {} +func (m memSpan) Finish() {} +func (m memSpan) FinishWithOptions(opts opentracing.FinishOptions) {} +func (m memSpan) SetOperationName(operationName string) opentracing.Span { + return &memSpan{operationName} +} +func (m memSpan) Tracer() opentracing.Tracer { return &MemTracer{} } +func (m memSpan) LogEvent(event string) {} +func (m memSpan) LogEventWithPayload(event string, payload interface{}) {} +func (m memSpan) Log(data opentracing.LogData) {} +func (m memSpan) ForeachBaggageItem(handler func(k, v string) bool) {}