2016-06-21 02:50:40 +02:00
package mock
import (
"fmt"
"reflect"
"regexp"
"runtime"
"strings"
"sync"
"time"
"github.com/stretchr/objx"
"github.com/stretchr/testify/assert"
)
// TestingT is an interface wrapper around *testing.T
type TestingT interface {
Logf ( format string , args ... interface { } )
Errorf ( format string , args ... interface { } )
FailNow ( )
}
/ *
Call
* /
// Call represents a method call and is used for setting expectations,
// as well as recording activity.
type Call struct {
Parent * Mock
// The name of the method that was or will be called.
Method string
// Holds the arguments of the method.
Arguments Arguments
// Holds the arguments that should be returned when
// this method is called.
ReturnArguments Arguments
// The number of times to return the return arguments when setting
// expectations. 0 means to always return the value.
Repeatability int
// Holds a channel that will be used to block the Return until it either
// recieves a message or is closed. nil means it returns immediately.
WaitFor <- chan time . Time
// Holds a handler used to manipulate arguments content that are passed by
// reference. It's useful when mocking methods such as unmarshalers or
// decoders.
RunFn func ( Arguments )
}
func newCall ( parent * Mock , methodName string , methodArguments ... interface { } ) * Call {
return & Call {
Parent : parent ,
Method : methodName ,
Arguments : methodArguments ,
ReturnArguments : make ( [ ] interface { } , 0 ) ,
Repeatability : 0 ,
WaitFor : nil ,
RunFn : nil ,
}
}
func ( c * Call ) lock ( ) {
c . Parent . mutex . Lock ( )
}
func ( c * Call ) unlock ( ) {
c . Parent . mutex . Unlock ( )
}
// Return specifies the return arguments for the expectation.
//
// Mock.On("DoSomething").Return(errors.New("failed"))
func ( c * Call ) Return ( returnArguments ... interface { } ) * Call {
c . lock ( )
defer c . unlock ( )
c . ReturnArguments = returnArguments
return c
}
// Once indicates that that the mock should only return the value once.
//
// Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Once()
func ( c * Call ) Once ( ) * Call {
return c . Times ( 1 )
}
// Twice indicates that that the mock should only return the value twice.
//
// Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Twice()
func ( c * Call ) Twice ( ) * Call {
return c . Times ( 2 )
}
// Times indicates that that the mock should only return the indicated number
// of times.
//
// Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Times(5)
func ( c * Call ) Times ( i int ) * Call {
c . lock ( )
defer c . unlock ( )
c . Repeatability = i
return c
}
// WaitUntil sets the channel that will block the mock's return until its closed
// or a message is received.
//
// Mock.On("MyMethod", arg1, arg2).WaitUntil(time.After(time.Second))
func ( c * Call ) WaitUntil ( w <- chan time . Time ) * Call {
c . lock ( )
defer c . unlock ( )
c . WaitFor = w
return c
}
// After sets how long to block until the call returns
//
// Mock.On("MyMethod", arg1, arg2).After(time.Second)
func ( c * Call ) After ( d time . Duration ) * Call {
return c . WaitUntil ( time . After ( d ) )
}
// Run sets a handler to be called before returning. It can be used when
// mocking a method such as unmarshalers that takes a pointer to a struct and
// sets properties in such struct
//
2016-06-21 18:05:20 +02:00
// Mock.On("Unmarshal", AnythingOfType("*map[string]interface{}").Return().Run(function(args Arguments) {
2016-06-21 02:50:40 +02:00
// arg := args.Get(0).(*map[string]interface{})
// arg["foo"] = "bar"
// })
func ( c * Call ) Run ( fn func ( Arguments ) ) * Call {
c . lock ( )
defer c . unlock ( )
c . RunFn = fn
return c
}
// On chains a new expectation description onto the mocked interface. This
// allows syntax like.
//
// Mock.
// On("MyMethod", 1).Return(nil).
// On("MyOtherMethod", 'a', 'b', 'c').Return(errors.New("Some Error"))
func ( c * Call ) On ( methodName string , arguments ... interface { } ) * Call {
return c . Parent . On ( methodName , arguments ... )
}
// Mock is the workhorse used to track activity on another object.
// For an example of its usage, refer to the "Example Usage" section at the top
// of this document.
type Mock struct {
// Represents the calls that are expected of
// an object.
ExpectedCalls [ ] * Call
// Holds the calls that were made to this mocked object.
Calls [ ] Call
// TestData holds any data that might be useful for testing. Testify ignores
// this data completely allowing you to do whatever you like with it.
testData objx . Map
mutex sync . Mutex
}
// TestData holds any data that might be useful for testing. Testify ignores
// this data completely allowing you to do whatever you like with it.
func ( m * Mock ) TestData ( ) objx . Map {
if m . testData == nil {
m . testData = make ( objx . Map )
}
return m . testData
}
/ *
Setting expectations
* /
// On starts a description of an expectation of the specified method
// being called.
//
// Mock.On("MyMethod", arg1, arg2)
func ( m * Mock ) On ( methodName string , arguments ... interface { } ) * Call {
for _ , arg := range arguments {
if v := reflect . ValueOf ( arg ) ; v . Kind ( ) == reflect . Func {
panic ( fmt . Sprintf ( "cannot use Func in expectations. Use mock.AnythingOfType(\"%T\")" , arg ) )
}
}
m . mutex . Lock ( )
defer m . mutex . Unlock ( )
c := newCall ( m , methodName , arguments ... )
m . ExpectedCalls = append ( m . ExpectedCalls , c )
return c
}
// /*
// Recording and responding to activity
// */
func ( m * Mock ) findExpectedCall ( method string , arguments ... interface { } ) ( int , * Call ) {
m . mutex . Lock ( )
defer m . mutex . Unlock ( )
for i , call := range m . ExpectedCalls {
if call . Method == method && call . Repeatability > - 1 {
_ , diffCount := call . Arguments . Diff ( arguments )
if diffCount == 0 {
return i , call
}
}
}
return - 1 , nil
}
func ( m * Mock ) findClosestCall ( method string , arguments ... interface { } ) ( bool , * Call ) {
diffCount := 0
var closestCall * Call
for _ , call := range m . expectedCalls ( ) {
if call . Method == method {
_ , tempDiffCount := call . Arguments . Diff ( arguments )
if tempDiffCount < diffCount || diffCount == 0 {
diffCount = tempDiffCount
closestCall = call
}
}
}
if closestCall == nil {
return false , nil
}
return true , closestCall
}
func callString ( method string , arguments Arguments , includeArgumentValues bool ) string {
var argValsString string
if includeArgumentValues {
var argVals [ ] string
for argIndex , arg := range arguments {
argVals = append ( argVals , fmt . Sprintf ( "%d: %#v" , argIndex , arg ) )
}
argValsString = fmt . Sprintf ( "\n\t\t%s" , strings . Join ( argVals , "\n\t\t" ) )
}
return fmt . Sprintf ( "%s(%s)%s" , method , arguments . String ( ) , argValsString )
}
// Called tells the mock object that a method has been called, and gets an array
2016-06-21 18:05:20 +02:00
// of arguments to return. Panics if the call is unexpected (i.e. not preceeded by
2016-06-21 02:50:40 +02:00
// appropriate .On .Return() calls)
// If Call.WaitFor is set, blocks until the channel is closed or receives a message.
func ( m * Mock ) Called ( arguments ... interface { } ) Arguments {
// get the calling function's name
pc , _ , _ , ok := runtime . Caller ( 1 )
if ! ok {
panic ( "Couldn't get the caller information" )
}
functionPath := runtime . FuncForPC ( pc ) . Name ( )
//Next four lines are required to use GCCGO function naming conventions.
//For Ex: github_com_docker_libkv_store_mock.WatchTree.pN39_github_com_docker_libkv_store_mock.Mock
//uses inteface information unlike golang github.com/docker/libkv/store/mock.(*Mock).WatchTree
//With GCCGO we need to remove interface information starting from pN<dd>.
re := regexp . MustCompile ( "\\.pN\\d+_" )
if re . MatchString ( functionPath ) {
functionPath = re . Split ( functionPath , - 1 ) [ 0 ]
}
parts := strings . Split ( functionPath , "." )
functionName := parts [ len ( parts ) - 1 ]
found , call := m . findExpectedCall ( functionName , arguments ... )
if found < 0 {
// we have to fail here - because we don't know what to do
// as the return arguments. This is because:
//
// a) this is a totally unexpected call to this method,
// b) the arguments are not what was expected, or
// c) the developer has forgotten to add an accompanying On...Return pair.
closestFound , closestCall := m . findClosestCall ( functionName , arguments ... )
if closestFound {
panic ( fmt . Sprintf ( "\n\nmock: Unexpected Method Call\n-----------------------------\n\n%s\n\nThe closest call I have is: \n\n%s\n" , callString ( functionName , arguments , true ) , callString ( functionName , closestCall . Arguments , true ) ) )
} else {
panic ( fmt . Sprintf ( "\nassert: mock: I don't know what to return because the method call was unexpected.\n\tEither do Mock.On(\"%s\").Return(...) first, or remove the %s() call.\n\tThis method was unexpected:\n\t\t%s\n\tat: %s" , functionName , functionName , callString ( functionName , arguments , true ) , assert . CallerInfo ( ) ) )
}
} else {
m . mutex . Lock ( )
switch {
case call . Repeatability == 1 :
call . Repeatability = - 1
case call . Repeatability > 1 :
call . Repeatability --
}
m . mutex . Unlock ( )
}
// add the call
m . mutex . Lock ( )
m . Calls = append ( m . Calls , * newCall ( m , functionName , arguments ... ) )
m . mutex . Unlock ( )
// block if specified
if call . WaitFor != nil {
<- call . WaitFor
}
if call . RunFn != nil {
call . RunFn ( arguments )
}
return call . ReturnArguments
}
/ *
Assertions
* /
// AssertExpectationsForObjects asserts that everything specified with On and Return
// of the specified objects was in fact called as expected.
//
// Calls may have occurred in any order.
func AssertExpectationsForObjects ( t TestingT , testObjects ... interface { } ) bool {
var success = true
for _ , obj := range testObjects {
mockObj := obj . ( Mock )
success = success && mockObj . AssertExpectations ( t )
}
return success
}
// AssertExpectations asserts that everything specified with On and Return was
// in fact called as expected. Calls may have occurred in any order.
func ( m * Mock ) AssertExpectations ( t TestingT ) bool {
var somethingMissing bool
var failedExpectations int
// iterate through each expectation
expectedCalls := m . expectedCalls ( )
for _ , expectedCall := range expectedCalls {
2016-06-21 18:05:20 +02:00
if ! m . methodWasCalled ( expectedCall . Method , expectedCall . Arguments ) {
2016-06-21 02:50:40 +02:00
somethingMissing = true
failedExpectations ++
t . Logf ( "\u274C\t%s(%s)" , expectedCall . Method , expectedCall . Arguments . String ( ) )
} else {
m . mutex . Lock ( )
if expectedCall . Repeatability > 0 {
somethingMissing = true
failedExpectations ++
} else {
t . Logf ( "\u2705\t%s(%s)" , expectedCall . Method , expectedCall . Arguments . String ( ) )
}
m . mutex . Unlock ( )
}
}
if somethingMissing {
t . Errorf ( "FAIL: %d out of %d expectation(s) were met.\n\tThe code you are testing needs to make %d more call(s).\n\tat: %s" , len ( expectedCalls ) - failedExpectations , len ( expectedCalls ) , failedExpectations , assert . CallerInfo ( ) )
}
return ! somethingMissing
}
// AssertNumberOfCalls asserts that the method was called expectedCalls times.
func ( m * Mock ) AssertNumberOfCalls ( t TestingT , methodName string , expectedCalls int ) bool {
var actualCalls int
for _ , call := range m . calls ( ) {
if call . Method == methodName {
actualCalls ++
}
}
return assert . Equal ( t , expectedCalls , actualCalls , fmt . Sprintf ( "Expected number of calls (%d) does not match the actual number of calls (%d)." , expectedCalls , actualCalls ) )
}
// AssertCalled asserts that the method was called.
func ( m * Mock ) AssertCalled ( t TestingT , methodName string , arguments ... interface { } ) bool {
if ! assert . True ( t , m . methodWasCalled ( methodName , arguments ) , fmt . Sprintf ( "The \"%s\" method should have been called with %d argument(s), but was not." , methodName , len ( arguments ) ) ) {
t . Logf ( "%v" , m . expectedCalls ( ) )
return false
}
return true
}
// AssertNotCalled asserts that the method was not called.
func ( m * Mock ) AssertNotCalled ( t TestingT , methodName string , arguments ... interface { } ) bool {
if ! assert . False ( t , m . methodWasCalled ( methodName , arguments ) , fmt . Sprintf ( "The \"%s\" method was called with %d argument(s), but should NOT have been." , methodName , len ( arguments ) ) ) {
t . Logf ( "%v" , m . expectedCalls ( ) )
return false
}
return true
}
func ( m * Mock ) methodWasCalled ( methodName string , expected [ ] interface { } ) bool {
for _ , call := range m . calls ( ) {
if call . Method == methodName {
_ , differences := Arguments ( expected ) . Diff ( call . Arguments )
if differences == 0 {
// found the expected call
return true
}
}
}
// we didn't find the expected call
return false
}
func ( m * Mock ) expectedCalls ( ) [ ] * Call {
m . mutex . Lock ( )
defer m . mutex . Unlock ( )
return append ( [ ] * Call { } , m . ExpectedCalls ... )
}
func ( m * Mock ) calls ( ) [ ] Call {
m . mutex . Lock ( )
defer m . mutex . Unlock ( )
return append ( [ ] Call { } , m . Calls ... )
}
/ *
Arguments
* /
// Arguments holds an array of method arguments or return values.
type Arguments [ ] interface { }
const (
// Anything is used in Diff and Assert when the argument being tested
// shouldn't be taken into consideration.
Anything string = "mock.Anything"
)
// AnythingOfTypeArgument is a string that contains the type of an argument
// for use when type checking. Used in Diff and Assert.
type AnythingOfTypeArgument string
// AnythingOfType returns an AnythingOfTypeArgument object containing the
// name of the type to check for. Used in Diff and Assert.
//
// For example:
// Assert(t, AnythingOfType("string"), AnythingOfType("int"))
func AnythingOfType ( t string ) AnythingOfTypeArgument {
return AnythingOfTypeArgument ( t )
}
// argumentMatcher performs custom argument matching, returning whether or
// not the argument is matched by the expectation fixture function.
type argumentMatcher struct {
// fn is a function which accepts one argument, and returns a bool.
fn reflect . Value
}
func ( f argumentMatcher ) Matches ( argument interface { } ) bool {
expectType := f . fn . Type ( ) . In ( 0 )
if reflect . TypeOf ( argument ) . AssignableTo ( expectType ) {
result := f . fn . Call ( [ ] reflect . Value { reflect . ValueOf ( argument ) } )
return result [ 0 ] . Bool ( )
}
return false
}
func ( f argumentMatcher ) String ( ) string {
return fmt . Sprintf ( "func(%s) bool" , f . fn . Type ( ) . In ( 0 ) . Name ( ) )
}
// MatchedBy can be used to match a mock call based on only certain properties
// from a complex struct or some calculation. It takes a function that will be
// evaluated with the called argument and will return true when there's a match
// and false otherwise.
//
// Example:
2016-06-21 18:05:20 +02:00
// m.On("Do", func(req *http.Request) bool { return req.Host == "example.com" })
2016-06-21 02:50:40 +02:00
//
// |fn|, must be a function accepting a single argument (of the expected type)
// which returns a bool. If |fn| doesn't match the required signature,
// MathedBy() panics.
func MatchedBy ( fn interface { } ) argumentMatcher {
fnType := reflect . TypeOf ( fn )
if fnType . Kind ( ) != reflect . Func {
panic ( fmt . Sprintf ( "assert: arguments: %s is not a func" , fn ) )
}
if fnType . NumIn ( ) != 1 {
panic ( fmt . Sprintf ( "assert: arguments: %s does not take exactly one argument" , fn ) )
}
if fnType . NumOut ( ) != 1 || fnType . Out ( 0 ) . Kind ( ) != reflect . Bool {
panic ( fmt . Sprintf ( "assert: arguments: %s does not return a bool" , fn ) )
}
return argumentMatcher { fn : reflect . ValueOf ( fn ) }
}
// Get Returns the argument at the specified index.
func ( args Arguments ) Get ( index int ) interface { } {
if index + 1 > len ( args ) {
panic ( fmt . Sprintf ( "assert: arguments: Cannot call Get(%d) because there are %d argument(s)." , index , len ( args ) ) )
}
return args [ index ]
}
// Is gets whether the objects match the arguments specified.
func ( args Arguments ) Is ( objects ... interface { } ) bool {
for i , obj := range args {
if obj != objects [ i ] {
return false
}
}
return true
}
// Diff gets a string describing the differences between the arguments
// and the specified objects.
//
// Returns the diff string and number of differences found.
func ( args Arguments ) Diff ( objects [ ] interface { } ) ( string , int ) {
var output = "\n"
var differences int
var maxArgCount = len ( args )
if len ( objects ) > maxArgCount {
maxArgCount = len ( objects )
}
for i := 0 ; i < maxArgCount ; i ++ {
var actual , expected interface { }
if len ( objects ) <= i {
actual = "(Missing)"
} else {
actual = objects [ i ]
}
if len ( args ) <= i {
expected = "(Missing)"
} else {
expected = args [ i ]
}
if matcher , ok := expected . ( argumentMatcher ) ; ok {
if matcher . Matches ( actual ) {
output = fmt . Sprintf ( "%s\t%d: \u2705 %s matched by %s\n" , output , i , actual , matcher )
} else {
differences ++
output = fmt . Sprintf ( "%s\t%d: \u2705 %s not matched by %s\n" , output , i , actual , matcher )
}
} else if reflect . TypeOf ( expected ) == reflect . TypeOf ( ( * AnythingOfTypeArgument ) ( nil ) ) . Elem ( ) {
// type checking
if reflect . TypeOf ( actual ) . Name ( ) != string ( expected . ( AnythingOfTypeArgument ) ) && reflect . TypeOf ( actual ) . String ( ) != string ( expected . ( AnythingOfTypeArgument ) ) {
// not match
differences ++
output = fmt . Sprintf ( "%s\t%d: \u274C type %s != type %s - %s\n" , output , i , expected , reflect . TypeOf ( actual ) . Name ( ) , actual )
}
} else {
// normal checking
if assert . ObjectsAreEqual ( expected , Anything ) || assert . ObjectsAreEqual ( actual , Anything ) || assert . ObjectsAreEqual ( actual , expected ) {
// match
output = fmt . Sprintf ( "%s\t%d: \u2705 %s == %s\n" , output , i , actual , expected )
} else {
// not match
differences ++
output = fmt . Sprintf ( "%s\t%d: \u274C %s != %s\n" , output , i , actual , expected )
}
}
}
if differences == 0 {
return "No differences." , differences
}
return output , differences
}
// Assert compares the arguments with the specified objects and fails if
// they do not exactly match.
func ( args Arguments ) Assert ( t TestingT , objects ... interface { } ) bool {
// get the differences
diff , diffCount := args . Diff ( objects )
if diffCount == 0 {
return true
}
// there are differences... report them...
t . Logf ( diff )
t . Errorf ( "%sArguments do not match." , assert . CallerInfo ( ) )
return false
}
// String gets the argument at the specified index. Panics if there is no argument, or
// if the argument is of the wrong type.
//
// If no index is provided, String() returns a complete string representation
// of the arguments.
func ( args Arguments ) String ( indexOrNil ... int ) string {
if len ( indexOrNil ) == 0 {
// normal String() method - return a string representation of the args
var argsStr [ ] string
for _ , arg := range args {
argsStr = append ( argsStr , fmt . Sprintf ( "%s" , reflect . TypeOf ( arg ) ) )
}
return strings . Join ( argsStr , "," )
} else if len ( indexOrNil ) == 1 {
// Index has been specified - get the argument at that index
var index = indexOrNil [ 0 ]
var s string
var ok bool
if s , ok = args . Get ( index ) . ( string ) ; ! ok {
panic ( fmt . Sprintf ( "assert: arguments: String(%d) failed because object wasn't correct type: %s" , index , args . Get ( index ) ) )
}
return s
}
panic ( fmt . Sprintf ( "assert: arguments: Wrong number of arguments passed to String. Must be 0 or 1, not %d" , len ( indexOrNil ) ) )
}
// Int gets the argument at the specified index. Panics if there is no argument, or
// if the argument is of the wrong type.
func ( args Arguments ) Int ( index int ) int {
var s int
var ok bool
if s , ok = args . Get ( index ) . ( int ) ; ! ok {
panic ( fmt . Sprintf ( "assert: arguments: Int(%d) failed because object wasn't correct type: %v" , index , args . Get ( index ) ) )
}
return s
}
// Error gets the argument at the specified index. Panics if there is no argument, or
// if the argument is of the wrong type.
func ( args Arguments ) Error ( index int ) error {
obj := args . Get ( index )
var s error
var ok bool
if obj == nil {
return nil
}
if s , ok = obj . ( error ) ; ! ok {
panic ( fmt . Sprintf ( "assert: arguments: Error(%d) failed because object wasn't correct type: %v" , index , args . Get ( index ) ) )
}
return s
}
// Bool gets the argument at the specified index. Panics if there is no argument, or
// if the argument is of the wrong type.
func ( args Arguments ) Bool ( index int ) bool {
var s bool
var ok bool
if s , ok = args . Get ( index ) . ( bool ) ; ! ok {
panic ( fmt . Sprintf ( "assert: arguments: Bool(%d) failed because object wasn't correct type: %v" , index , args . Get ( index ) ) )
}
return s
}